mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-06-16 21:54:14 +00:00
feat(core): implement sni-router for stealth turn
- add `orama-sni-router` binary to build process - introduce `cmd/sni-router` for TLS-level SNI routing - add documentation for stealth turn deployment architecture
This commit is contained in:
parent
54852076f9
commit
0379dc39f1
@ -63,7 +63,7 @@ test-e2e-quick:
|
||||
|
||||
.PHONY: build clean test deps tidy fmt vet lint install-hooks push-devnet push-testnet rollout-devnet rollout-testnet release
|
||||
|
||||
VERSION := 0.120.0
|
||||
VERSION := 0.121.0
|
||||
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
|
||||
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
|
||||
@ -80,6 +80,7 @@ build: deps
|
||||
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
|
||||
go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu
|
||||
go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn
|
||||
go build -ldflags "$(LDFLAGS)" -o bin/orama-sni-router ./cmd/sni-router
|
||||
@echo "Build complete! Run ./bin/orama version"
|
||||
|
||||
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)
|
||||
|
||||
242
core/cmd/sni-router/main.go
Normal file
242
core/cmd/sni-router/main.go
Normal file
@ -0,0 +1,242 @@
|
||||
// Command sni-router is a TLS-level Server Name Indication router.
|
||||
//
|
||||
// It listens on a public TCP port (typically :443), peeks at the TLS
|
||||
// ClientHello SNI on each connection, and forwards the raw stream to
|
||||
// a configured backend. It does NOT terminate TLS — encrypted bytes
|
||||
// pass through verbatim. This lets one port serve multiple TLS-speaking
|
||||
// backends (HTTPS for the gateway, TURN-over-TLS for stealth WebRTC).
|
||||
//
|
||||
// See pkg/sniproxy for the underlying library.
|
||||
//
|
||||
// Configuration: YAML file at --config (defaults to ~/.orama/sni-router.yaml).
|
||||
//
|
||||
// Example sni-router.yaml:
|
||||
//
|
||||
// listen: ":443"
|
||||
// client_hello_timeout: 5s
|
||||
// backend_dial_timeout: 5s
|
||||
// max_concurrent_conns: 10000
|
||||
// fallback:
|
||||
// name: caddy
|
||||
// addr: "127.0.0.1:8443"
|
||||
// routes:
|
||||
// - match: "cdn.example.com"
|
||||
// backend:
|
||||
// name: turn-tls
|
||||
// addr: "127.0.0.1:5349"
|
||||
// - match: "turn.example.com"
|
||||
// backend:
|
||||
// name: turn-tls
|
||||
// addr: "127.0.0.1:5349"
|
||||
// - match: "*.ns-myapp.example.com"
|
||||
// backend:
|
||||
// name: gateway
|
||||
// addr: "127.0.0.1:8443"
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/config"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/sniproxy"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "unknown"
|
||||
)
|
||||
|
||||
// yamlBackend mirrors sniproxy.Backend for YAML decoding.
|
||||
type yamlBackend struct {
|
||||
Name string `yaml:"name"`
|
||||
Network string `yaml:"network"`
|
||||
Addr string `yaml:"addr"`
|
||||
}
|
||||
|
||||
// yamlRoute mirrors sniproxy.Route for YAML decoding.
|
||||
type yamlRoute struct {
|
||||
Match string `yaml:"match"`
|
||||
Backend yamlBackend `yaml:"backend"`
|
||||
}
|
||||
|
||||
// yamlConfig is the on-disk configuration shape.
|
||||
type yamlConfig struct {
|
||||
Listen string `yaml:"listen"`
|
||||
ClientHelloTimeout time.Duration `yaml:"client_hello_timeout"`
|
||||
BackendDialTimeout time.Duration `yaml:"backend_dial_timeout"`
|
||||
MaxConcurrentConns int `yaml:"max_concurrent_conns"`
|
||||
Fallback yamlBackend `yaml:"fallback"`
|
||||
Routes []yamlRoute `yaml:"routes"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger, err := logging.NewColoredLogger(logging.ComponentSNI, true)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentSNI, "Starting SNI router",
|
||||
zap.String("version", version),
|
||||
zap.String("commit", commit))
|
||||
|
||||
cfg := parseConfig(logger)
|
||||
|
||||
router := sniproxy.NewRouter(toBackend(cfg.Fallback))
|
||||
router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback))
|
||||
|
||||
srv := sniproxy.NewServer(router, sniproxy.Config{
|
||||
ClientHelloTimeout: cfg.ClientHelloTimeout,
|
||||
BackendDialTimeout: cfg.BackendDialTimeout,
|
||||
MaxConcurrentConns: cfg.MaxConcurrentConns,
|
||||
}, logger.Logger)
|
||||
|
||||
ln, err := net.Listen("tcp", cfg.Listen)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to listen",
|
||||
zap.String("addr", cfg.Listen), zap.Error(err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentSNI, "SNI router listening",
|
||||
zap.String("addr", cfg.Listen),
|
||||
zap.Int("routes", len(cfg.Routes)),
|
||||
zap.String("fallback", cfg.Fallback.Addr),
|
||||
)
|
||||
|
||||
// Run Serve in a goroutine so the main goroutine can wait on signals.
|
||||
serveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
serveErrCh <- srv.Serve(ln)
|
||||
}()
|
||||
|
||||
// Wait for termination signal or unrecoverable Serve error.
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case sig := <-quit:
|
||||
logger.ComponentInfo(logging.ComponentSNI, "Shutdown signal received",
|
||||
zap.String("signal", sig.String()))
|
||||
case err := <-serveErrCh:
|
||||
logger.ComponentError(logging.ComponentSNI, "Serve returned",
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// Stop accepting new connections, then drain in-flight ones.
|
||||
_ = ln.Close()
|
||||
srv.Close()
|
||||
|
||||
logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete")
|
||||
}
|
||||
|
||||
func parseConfig(logger *logging.ColoredLogger) yamlConfig {
|
||||
configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)")
|
||||
flag.Parse()
|
||||
|
||||
var configPath string
|
||||
var err error
|
||||
if *configFlag != "" {
|
||||
if filepath.IsAbs(*configFlag) {
|
||||
configPath = *configFlag
|
||||
} else {
|
||||
configPath, err = config.DefaultPath(*configFlag)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
|
||||
zap.Error(err))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
configPath, err = config.DefaultPath("sni-router.yaml")
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
|
||||
zap.Error(err))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Config file not found",
|
||||
zap.String("path", configPath), zap.Error(err))
|
||||
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var y yamlConfig
|
||||
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config",
|
||||
zap.Error(err))
|
||||
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if errs := validateConfig(&y); len(errs) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs))
|
||||
for _, e := range errs {
|
||||
fmt.Fprintf(os.Stderr, " - %s\n", e)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentSNI, "Loaded SNI router configuration",
|
||||
zap.String("path", configPath),
|
||||
)
|
||||
|
||||
return y
|
||||
}
|
||||
|
||||
// validateConfig returns a non-empty slice of human-readable errors on misconfig.
|
||||
func validateConfig(y *yamlConfig) []string {
|
||||
var errs []string
|
||||
if y.Listen == "" {
|
||||
errs = append(errs, "listen: required (e.g. \":443\")")
|
||||
}
|
||||
if y.Fallback.Addr == "" {
|
||||
errs = append(errs, "fallback.addr: required (where to send unmatched SNIs, typically Caddy)")
|
||||
}
|
||||
for i, r := range y.Routes {
|
||||
if r.Match == "" {
|
||||
errs = append(errs, fmt.Sprintf("routes[%d].match: required", i))
|
||||
}
|
||||
if r.Backend.Addr == "" {
|
||||
errs = append(errs, fmt.Sprintf("routes[%d].backend.addr: required", i))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
func toBackend(b yamlBackend) sniproxy.Backend {
|
||||
network := b.Network
|
||||
if network == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
return sniproxy.Backend{
|
||||
Name: b.Name,
|
||||
Network: network,
|
||||
Addr: b.Addr,
|
||||
}
|
||||
}
|
||||
|
||||
func toRoutes(in []yamlRoute) []sniproxy.Route {
|
||||
out := make([]sniproxy.Route, len(in))
|
||||
for i, r := range in {
|
||||
out[i] = sniproxy.Route{
|
||||
Match: r.Match,
|
||||
Backend: toBackend(r.Backend),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
187
core/docs/STEALTH_TURN.md
Normal file
187
core/docs/STEALTH_TURN.md
Normal file
@ -0,0 +1,187 @@
|
||||
# Stealth TURN Deployment Guide
|
||||
|
||||
## What this is
|
||||
|
||||
A TLS-level SNI router that lets Orama serve TURN-over-TLS on `:443`,
|
||||
sharing the port with Caddy HTTPS. From a network observer's
|
||||
perspective, TURN traffic is indistinguishable from ordinary HTTPS —
|
||||
useful for users in regions that block standard VoIP ports (UAE, Saudi
|
||||
Arabia, China, Iran).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Internet
|
||||
│
|
||||
▼
|
||||
TCP :443
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
│ orama-sni-router │ peeks SNI, forwards bytes
|
||||
└─────────┬─────────┘
|
||||
│
|
||||
┌───────────────┼────────────────┐
|
||||
▼ ▼
|
||||
cdn.<base> *.<base>, <base>
|
||||
turn.<base> (everything else)
|
||||
│ │
|
||||
▼ ▼
|
||||
Pion TURN-TLS Caddy
|
||||
127.0.0.1:5349 127.0.0.1:8443
|
||||
(existing) (moved from :443)
|
||||
```
|
||||
|
||||
The router does **not** terminate TLS. It reads the unencrypted TLS
|
||||
ClientHello (first ~5 KB), inspects the SNI extension, and dials the
|
||||
matching backend. Encrypted bytes pass through verbatim.
|
||||
|
||||
## Components
|
||||
|
||||
- **Library:** `pkg/sniproxy/` — ClientHello parser, route table, TCP server
|
||||
- **Binary:** `cmd/sni-router/` (built as `bin/orama-sni-router`)
|
||||
- **Systemd unit:** `systemd/orama-sni-router.service`
|
||||
- **Config:** `~/.orama/sni-router.yaml`
|
||||
|
||||
## Deployment cutover
|
||||
|
||||
⚠️ **This change touches production `:443`. Stage on one node first, watch for 24h, then roll out.**
|
||||
|
||||
### 1. Reconfigure Caddy to listen on `:8443`
|
||||
|
||||
Update wherever the Caddy config is generated (`pkg/environments/production/installers/caddy.go`)
|
||||
so Caddy binds `:8443` (HTTPS) and `:8080` (HTTP) instead of `:443` and `:80`.
|
||||
|
||||
Drop `CAP_NET_BIND_SERVICE` from Caddy's systemd unit — it no longer needs privileged ports.
|
||||
|
||||
### 2. Provision the cert SAN for `cdn.<base-domain>`
|
||||
|
||||
Caddy's automatic Let's Encrypt flow needs to issue a cert covering
|
||||
`cdn.<base-domain>` and `cdn.ns-*.<base-domain>` so Pion TURN can read it
|
||||
on startup. Add these names to Caddy's TLS config block.
|
||||
|
||||
### 3. Drop `sni-router.yaml` config
|
||||
|
||||
Example for a single-namespace node:
|
||||
|
||||
```yaml
|
||||
listen: ":443"
|
||||
client_hello_timeout: 5s
|
||||
backend_dial_timeout: 5s
|
||||
max_concurrent_conns: 10000
|
||||
fallback:
|
||||
name: caddy
|
||||
addr: "127.0.0.1:8443"
|
||||
routes:
|
||||
- match: "cdn.example.com"
|
||||
backend:
|
||||
name: turn-tls
|
||||
addr: "127.0.0.1:5349"
|
||||
- match: "turn.example.com"
|
||||
backend:
|
||||
name: turn-tls
|
||||
addr: "127.0.0.1:5349"
|
||||
```
|
||||
|
||||
For multi-namespace, add per-namespace TURN backends (each namespace's
|
||||
TURN-TLS port is allocated by `pkg/namespace`):
|
||||
|
||||
```yaml
|
||||
- match: "cdn.ns-myapp.example.com"
|
||||
backend: { name: "turn-myapp", addr: "127.0.0.1:5349" }
|
||||
- match: "cdn.ns-other.example.com"
|
||||
backend: { name: "turn-other", addr: "127.0.0.1:5350" }
|
||||
```
|
||||
|
||||
### 4. Deploy + start in order
|
||||
|
||||
```bash
|
||||
# Install binary
|
||||
sudo cp bin-linux/orama-sni-router /opt/orama/bin/
|
||||
|
||||
# Install service
|
||||
sudo cp systemd/orama-sni-router.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
# Stop Caddy briefly (it's about to lose :443)
|
||||
sudo systemctl stop caddy
|
||||
|
||||
# Start the SNI router (it takes :443)
|
||||
sudo systemctl enable --now orama-sni-router
|
||||
|
||||
# Restart Caddy on its new port
|
||||
sudo systemctl start caddy
|
||||
|
||||
# Verify
|
||||
curl -v https://cdn.<base>:443 # should hit TURN backend (TLS handshake will fail; that's fine)
|
||||
curl -v https://<base>:443 # should hit Caddy (normal HTTPS response)
|
||||
```
|
||||
|
||||
### 5. Enable stealth in the gateway
|
||||
|
||||
Once the SNI router is live, tell the gateway to advertise the stealth URI:
|
||||
|
||||
```go
|
||||
// in gateway dependencies / startup
|
||||
webrtcHandlers.SetStealthCDNDomain("cdn.<base-domain>")
|
||||
```
|
||||
|
||||
The credentials handler will start including `turns:cdn.<base-domain>:443`
|
||||
in `POST /v1/webrtc/turn/credentials` responses automatically.
|
||||
|
||||
### 6. Monitor
|
||||
|
||||
```bash
|
||||
journalctl -u orama-sni-router.service -f
|
||||
journalctl -u caddy.service -f
|
||||
```
|
||||
|
||||
Watch for:
|
||||
- `Connection limit reached` warnings (bump `max_concurrent_conns`)
|
||||
- `backend dial failed` warnings (Caddy isn't listening on `:8443`, or TURN isn't on `:5349`)
|
||||
- `ClientHello peek failed` debugs (curious clients sending non-TLS to `:443` — usually port scanners)
|
||||
|
||||
## Rollback
|
||||
|
||||
If anything is wrong:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop orama-sni-router
|
||||
# Reconfigure Caddy back to :443 and restart
|
||||
sudo systemctl restart caddy
|
||||
```
|
||||
|
||||
Caddy reclaiming `:443` from the disabled router is the fastest way back to
|
||||
the previous topology.
|
||||
|
||||
## Known gaps
|
||||
|
||||
- **Dynamic route source:** today's router reads YAML once at startup. To
|
||||
pick up new namespaces without restart, implement a `RouteSource` that
|
||||
polls `pkg/namespace` for active TURN deployments. The library is
|
||||
already designed for `Router.Replace` to be called concurrently.
|
||||
- **TLS cert hot-reload:** Pion TURN reads the cert once at startup. When
|
||||
Caddy renews `cdn.<base-domain>`, Pion needs to be restarted to pick up
|
||||
the new cert. A small file-watcher service (or a periodic restart in
|
||||
off-peak hours) handles this for now.
|
||||
|
||||
## What clients see
|
||||
|
||||
Once enabled, the credentials response gains one entry:
|
||||
|
||||
```json
|
||||
{
|
||||
"username": "...",
|
||||
"password": "...",
|
||||
"ttl": 600,
|
||||
"uris": [
|
||||
"turn:turn.example.com:3478?transport=udp",
|
||||
"turn:turn.example.com:3478?transport=tcp",
|
||||
"turns:turn.example.com:5349",
|
||||
"turns:cdn.example.com:443"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Browsers iterate ICE candidates; users in restricted regions will silently
|
||||
succeed via the `:443` URI when others fail. No client-side change is
|
||||
required.
|
||||
28
core/migrations/021_pubsub_trigger_patterns.sql
Normal file
28
core/migrations/021_pubsub_trigger_patterns.sql
Normal file
@ -0,0 +1,28 @@
|
||||
-- =============================================================================
|
||||
-- 021_pubsub_trigger_patterns.sql
|
||||
--
|
||||
-- Add `topic_pattern` column alongside the existing `topic` column to
|
||||
-- function_pubsub_triggers. The new column may contain SQLite GLOB
|
||||
-- patterns (e.g. "presence:*") in addition to exact topic names.
|
||||
--
|
||||
-- This is intentionally ADDITIVE rather than a column rename to remain
|
||||
-- safe under rolling upgrades:
|
||||
-- - Old binaries continue reading `topic` and keep working.
|
||||
-- - New binaries read `topic_pattern` (which is back-filled from
|
||||
-- `topic` for existing rows) and write BOTH columns.
|
||||
-- A future migration can DROP COLUMN topic once every node is on the
|
||||
-- new release.
|
||||
-- =============================================================================
|
||||
|
||||
ALTER TABLE function_pubsub_triggers
|
||||
ADD COLUMN topic_pattern TEXT NOT NULL DEFAULT '';
|
||||
|
||||
UPDATE function_pubsub_triggers
|
||||
SET topic_pattern = topic
|
||||
WHERE topic_pattern = '';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function
|
||||
ON function_pubsub_triggers(function_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_enabled
|
||||
ON function_pubsub_triggers(enabled);
|
||||
20
core/migrations/022_aggregation_windows.sql
Normal file
20
core/migrations/022_aggregation_windows.sql
Normal file
@ -0,0 +1,20 @@
|
||||
-- =============================================================================
|
||||
-- 022_aggregation_windows.sql
|
||||
--
|
||||
-- Add per-trigger aggregation parameters to function_pubsub_triggers.
|
||||
--
|
||||
-- aggregation_window_ms = 0 means "no aggregation, invoke once per event"
|
||||
-- (the existing behaviour). Any positive value enables buffering of events
|
||||
-- in-memory on the dispatching node; the function is invoked once per
|
||||
-- window with a batched payload.
|
||||
--
|
||||
-- aggregation_max_batch_size caps the per-window batch. When the buffer
|
||||
-- reaches this size, the dispatcher flushes immediately even if the
|
||||
-- window timer hasn't fired yet.
|
||||
-- =============================================================================
|
||||
|
||||
ALTER TABLE function_pubsub_triggers
|
||||
ADD COLUMN aggregation_window_ms INTEGER NOT NULL DEFAULT 0;
|
||||
|
||||
ALTER TABLE function_pubsub_triggers
|
||||
ADD COLUMN aggregation_max_batch_size INTEGER NOT NULL DEFAULT 100;
|
||||
33
core/migrations/023_push_devices.sql
Normal file
33
core/migrations/023_push_devices.sql
Normal file
@ -0,0 +1,33 @@
|
||||
-- =============================================================================
|
||||
-- 023_push_devices.sql
|
||||
--
|
||||
-- Per-namespace, per-user push notification device registry.
|
||||
--
|
||||
-- token_encrypted is AES-256-GCM ciphertext (prefix 'enc:') derived via
|
||||
-- pkg/secrets. Tokens are sensitive — they let the holder spam a user's
|
||||
-- device — so they are never returned via any API or written to logs.
|
||||
--
|
||||
-- provider matches a registered push.PushProvider name:
|
||||
-- 'ntfy', 'expo', 'apns', 'fcm' (future), ...
|
||||
-- =============================================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS push_devices (
|
||||
id TEXT PRIMARY KEY,
|
||||
namespace TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
token_encrypted TEXT NOT NULL,
|
||||
platform TEXT,
|
||||
app_version TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
last_seen INTEGER,
|
||||
UNIQUE(namespace, user_id, device_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_push_devices_user
|
||||
ON push_devices(namespace, user_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_push_devices_provider
|
||||
ON push_devices(provider);
|
||||
@ -157,6 +157,7 @@ func (b *Builder) buildOramaBinaries() error {
|
||||
{Name: "identity", Package: "./cmd/identity/"},
|
||||
{Name: "sfu", Package: "./cmd/sfu/"},
|
||||
{Name: "turn", Package: "./cmd/turn/"},
|
||||
{Name: "orama-sni-router", Package: "./cmd/sni-router/"},
|
||||
}
|
||||
|
||||
for _, bin := range binaries {
|
||||
|
||||
@ -171,7 +171,7 @@ rm -f /tmp/orama-*.sh /tmp/network-source.tar.gz /tmp/orama-*.tar.gz
|
||||
# Nuclear: remove binaries
|
||||
if [ -n "$NUCLEAR" ]; then
|
||||
rm -f /usr/local/bin/orama /usr/local/bin/orama-node /usr/local/bin/gateway
|
||||
rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn
|
||||
rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn /usr/local/bin/orama-sni-router
|
||||
rm -f /usr/local/bin/olric-server /usr/local/bin/ipfs /usr/local/bin/ipfs-cluster-service
|
||||
rm -f /usr/local/bin/rqlited /usr/local/bin/coredns
|
||||
rm -f /usr/bin/caddy
|
||||
|
||||
@ -47,10 +47,29 @@ type DatabaseClient interface {
|
||||
type PubSubClient interface {
|
||||
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
|
||||
Publish(ctx context.Context, topic string, data []byte) error
|
||||
// PublishBatch publishes multiple messages in parallel, one per topic.
|
||||
// See pubsub.Manager.PublishBatch for semantics (fail-fast vs. best-effort).
|
||||
PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error
|
||||
// PublishSame sends the same payload to every topic in parallel.
|
||||
PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error
|
||||
Unsubscribe(ctx context.Context, topic string) error
|
||||
ListTopics(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
// TopicMessage is one entry in a batch publish.
|
||||
// Mirrors pubsub.TopicMessage to avoid forcing client callers to import pkg/pubsub.
|
||||
type TopicMessage struct {
|
||||
Topic string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PublishBatchOptions controls batch publish behavior.
|
||||
// Mirrors pubsub.PublishBatchOptions.
|
||||
type PublishBatchOptions struct {
|
||||
BestEffort bool
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
// NetworkInfo provides network status and peer information
|
||||
type NetworkInfo interface {
|
||||
GetPeers(ctx context.Context) ([]PeerInfo, error)
|
||||
|
||||
@ -4,13 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
pkgpubsub "github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
)
|
||||
|
||||
// pubSubBridge bridges between our PubSubClient interface and the pubsub package
|
||||
type pubSubBridge struct {
|
||||
client *Client
|
||||
adapter *pubsub.ClientAdapter
|
||||
adapter *pkgpubsub.ClientAdapter
|
||||
}
|
||||
|
||||
func (p *pubSubBridge) Subscribe(ctx context.Context, topic string, handler MessageHandler) error {
|
||||
@ -31,6 +31,26 @@ func (p *pubSubBridge) Publish(ctx context.Context, topic string, data []byte) e
|
||||
return p.adapter.Publish(ctx, topic, data)
|
||||
}
|
||||
|
||||
func (p *pubSubBridge) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
|
||||
if err := p.client.requireAccess(ctx); err != nil {
|
||||
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
|
||||
}
|
||||
pkgMsgs := make([]pkgpubsub.TopicMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
pkgMsgs[i] = pkgpubsub.TopicMessage{Topic: m.Topic, Data: m.Data}
|
||||
}
|
||||
pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
|
||||
return p.adapter.PublishBatch(ctx, pkgMsgs, pkgOpts)
|
||||
}
|
||||
|
||||
func (p *pubSubBridge) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
|
||||
if err := p.client.requireAccess(ctx); err != nil {
|
||||
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
|
||||
}
|
||||
pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
|
||||
return p.adapter.PublishSame(ctx, topics, data, pkgOpts)
|
||||
}
|
||||
|
||||
func (p *pubSubBridge) Unsubscribe(ctx context.Context, topic string) error {
|
||||
if err := p.client.requireAccess(ctx); err != nil {
|
||||
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
|
||||
|
||||
@ -56,4 +56,15 @@ type Config struct {
|
||||
SFUPort int // Local SFU signaling port to proxy WebSocket connections to
|
||||
TURNDomain string // TURN server domain for credential generation
|
||||
TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation
|
||||
|
||||
// StealthCDNDomain, when set, makes the WebRTC credentials handler
|
||||
// advertise turns:<StealthCDNDomain>:443 (served by the SNI router).
|
||||
StealthCDNDomain string
|
||||
|
||||
// Push notification configuration. Push is enabled when at least one
|
||||
// provider URL/token is set. Tokens stored in the push_devices table
|
||||
// are encrypted at rest via pkg/secrets using the cluster secret.
|
||||
NtfyBaseURL string // ntfy server URL (e.g. "http://localhost:8080")
|
||||
NtfyAuthToken string // optional bearer token for ntfy
|
||||
ExpoAccessToken string // optional Expo access token
|
||||
}
|
||||
|
||||
@ -19,6 +19,9 @@ import (
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
pushexpo "github.com/DeBrosOfficial/network/pkg/push/providers/expo"
|
||||
pushntfy "github.com/DeBrosOfficial/network/pkg/push/providers/ntfy"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless/hostfunctions"
|
||||
@ -63,6 +66,11 @@ type Dependencies struct {
|
||||
// PubSub trigger dispatcher (used to wire into PubSubHandlers)
|
||||
PubSubDispatcher *triggers.PubSubDispatcher
|
||||
|
||||
// Push notification dispatcher (nil when push isn't configured —
|
||||
// hostfunc + HTTP handlers degrade to no-op / 503).
|
||||
PushDispatcher *push.PushDispatcher
|
||||
PushDeviceStore push.PushDeviceStore
|
||||
|
||||
// Authentication service
|
||||
AuthService *auth.Service
|
||||
}
|
||||
@ -412,6 +420,19 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
secretsMgr = smImpl
|
||||
}
|
||||
|
||||
// Initialize push notification dispatcher if any provider is configured.
|
||||
// Devices are stored encrypted in RQLite (see migration 023). Providers
|
||||
// are registered based on gateway config; missing config = provider absent.
|
||||
pushDispatcher, pushStore, err := buildPushDispatcher(cfg, deps.ORMClient, logger)
|
||||
if err != nil {
|
||||
// Non-fatal: log and continue. Functions calling push_send will get nil
|
||||
// (silent no-op) and HTTP /v1/push/* endpoints return 503.
|
||||
logger.ComponentWarn(logging.ComponentGeneral,
|
||||
"push notifications disabled (init failed)", zap.Error(err))
|
||||
}
|
||||
deps.PushDispatcher = pushDispatcher
|
||||
deps.PushDeviceStore = pushStore
|
||||
|
||||
// Create host functions provider (allows functions to call Orama services)
|
||||
hostFuncsCfg := hostfunctions.HostFunctionsConfig{
|
||||
IPFSAPIURL: cfg.IPFSAPIURL,
|
||||
@ -424,6 +445,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
pubsubAdapter, // pubsub adapter for serverless functions
|
||||
deps.ServerlessWSMgr,
|
||||
secretsMgr,
|
||||
pushDispatcher, // may be nil — PushSend hostfunc handles that
|
||||
hostFuncsCfg,
|
||||
logger.Logger,
|
||||
)
|
||||
@ -686,3 +708,40 @@ func injectRQLiteAuth(dsn, username, password string) string {
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
|
||||
// buildPushDispatcher constructs a push.PushDispatcher + device store with
|
||||
// all enabled providers. Returns (nil, nil, nil) when no provider is
|
||||
// configured — that's a supported state, not an error. Returns (nil, nil,
|
||||
// err) on hard init failures (e.g. cluster secret missing for the
|
||||
// encrypted device store).
|
||||
func buildPushDispatcher(cfg *Config, db rqlite.Client, logger *logging.ColoredLogger) (*push.PushDispatcher, push.PushDeviceStore, error) {
|
||||
if cfg.NtfyBaseURL == "" && cfg.ExpoAccessToken == "" {
|
||||
// No providers configured — push is disabled.
|
||||
return nil, nil, nil
|
||||
}
|
||||
if cfg.ClusterSecret == "" {
|
||||
// Devices are encrypted at rest using a cluster-secret-derived key.
|
||||
// Without it we can't store anything safely.
|
||||
return nil, nil, fmt.Errorf("push enabled but ClusterSecret is empty")
|
||||
}
|
||||
store, err := push.NewRqliteDeviceStore(db, cfg.ClusterSecret, logger.Logger)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("init push device store: %w", err)
|
||||
}
|
||||
d := push.New(store, logger.Logger)
|
||||
if cfg.NtfyBaseURL != "" {
|
||||
d.Register(pushntfy.New(pushntfy.Config{
|
||||
BaseURL: cfg.NtfyBaseURL,
|
||||
AuthToken: cfg.NtfyAuthToken,
|
||||
}, logger.Logger))
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: ntfy",
|
||||
zap.String("base_url", cfg.NtfyBaseURL))
|
||||
}
|
||||
if cfg.ExpoAccessToken != "" {
|
||||
d.Register(pushexpo.New(pushexpo.Config{
|
||||
AccessToken: cfg.ExpoAccessToken,
|
||||
}, logger.Logger))
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: expo")
|
||||
}
|
||||
return d, store, nil
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ import (
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
|
||||
deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments"
|
||||
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
||||
pushhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/push"
|
||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
||||
enrollhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/enroll"
|
||||
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
|
||||
@ -43,6 +44,7 @@ import (
|
||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless/triggers"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -83,13 +85,15 @@ type Gateway struct {
|
||||
mu sync.RWMutex
|
||||
presenceMu sync.RWMutex
|
||||
pubsubHandlers *pubsubhandlers.PubSubHandlers
|
||||
pushHandlers *pushhandlers.Handlers
|
||||
|
||||
// Serverless function engine
|
||||
serverlessEngine *serverless.Engine
|
||||
serverlessRegistry *serverless.Registry
|
||||
serverlessInvoker *serverless.Invoker
|
||||
serverlessWSMgr *serverless.WSManager
|
||||
serverlessHandlers *serverlesshandlers.ServerlessHandlers
|
||||
serverlessEngine *serverless.Engine
|
||||
serverlessRegistry *serverless.Registry
|
||||
serverlessInvoker *serverless.Invoker
|
||||
serverlessWSMgr *serverless.WSManager
|
||||
serverlessHandlers *serverlesshandlers.ServerlessHandlers
|
||||
pubsubDispatcher *triggers.PubSubDispatcher
|
||||
|
||||
// Authentication service
|
||||
authService *auth.Service
|
||||
@ -342,11 +346,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||
|
||||
// Wire PubSub trigger dispatch if serverless is available
|
||||
if deps.PubSubDispatcher != nil {
|
||||
gw.pubsubDispatcher = deps.PubSubDispatcher
|
||||
gw.pubsubHandlers.SetOnPublish(func(ctx context.Context, namespace, topic string, data []byte) {
|
||||
deps.PubSubDispatcher.Dispatch(ctx, namespace, topic, data, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// Push notification handlers — disabled when no provider is configured.
|
||||
// The handlers themselves return 503 if dispatcher/store is nil; we
|
||||
// register them unconditionally so the routes always exist with a
|
||||
// predictable shape.
|
||||
if deps.PushDispatcher != nil {
|
||||
gw.pushHandlers = pushhandlers.NewHandlers(deps.PushDispatcher, deps.PushDeviceStore, logger)
|
||||
}
|
||||
|
||||
if cfg.WebRTCEnabled && cfg.SFUPort > 0 {
|
||||
gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers(
|
||||
logger,
|
||||
|
||||
@ -21,10 +21,12 @@ import (
|
||||
|
||||
// mockPubSubClient implements client.PubSubClient for testing
|
||||
type mockPubSubClient struct {
|
||||
PublishFunc func(ctx context.Context, topic string, data []byte) error
|
||||
SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
|
||||
UnsubscribeFunc func(ctx context.Context, topic string) error
|
||||
ListTopicsFunc func(ctx context.Context) ([]string, error)
|
||||
PublishFunc func(ctx context.Context, topic string, data []byte) error
|
||||
PublishBatchFunc func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error
|
||||
PublishSameFunc func(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error
|
||||
SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
|
||||
UnsubscribeFunc func(ctx context.Context, topic string) error
|
||||
ListTopicsFunc func(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byte) error {
|
||||
@ -34,6 +36,20 @@ func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPubSubClient) PublishBatch(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
|
||||
if m.PublishBatchFunc != nil {
|
||||
return m.PublishBatchFunc(ctx, msgs, opts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPubSubClient) PublishSame(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error {
|
||||
if m.PublishSameFunc != nil {
|
||||
return m.PublishSameFunc(ctx, topics, data, opts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPubSubClient) Subscribe(ctx context.Context, topic string, handler client.MessageHandler) error {
|
||||
if m.SubscribeFunc != nil {
|
||||
return m.SubscribeFunc(ctx, topic, handler)
|
||||
|
||||
156
core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go
Normal file
156
core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
)
|
||||
|
||||
func TestPublishBatchHandler_invalid_method(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
req := withNamespace(httptest.NewRequest(http.MethodGet, "/v1/pubsub/publish-batch", nil), "ns")
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected 405, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_missing_namespace(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{{Topic: "a", DataB64: "AA=="}}})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body))
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403, got %d (body: %s)", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_empty_messages_rejected(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{}})
|
||||
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for empty messages, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_oversize_batch_rejected(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
entries := make([]PublishBatchEntry, MaxPublishBatchSize+1)
|
||||
for i := range entries {
|
||||
entries[i] = PublishBatchEntry{Topic: "t", DataB64: "AA=="}
|
||||
}
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
|
||||
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for oversize batch, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_invalid_base64_rejected(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
|
||||
{Topic: "good", DataB64: base64.StdEncoding.EncodeToString([]byte("ok"))},
|
||||
{Topic: "bad", DataB64: "!!!not-base64"},
|
||||
}})
|
||||
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid base64, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_missing_topic_rejected(t *testing.T) {
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
|
||||
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
|
||||
{Topic: "", DataB64: base64.StdEncoding.EncodeToString([]byte("x"))},
|
||||
}})
|
||||
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
|
||||
rr := httptest.NewRecorder()
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for missing topic, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatchHandler_happy_calls_PublishBatch(t *testing.T) {
|
||||
var (
|
||||
called int32
|
||||
gotMessages []client.TopicMessage
|
||||
mu sync.Mutex
|
||||
)
|
||||
mock := &mockPubSubClient{
|
||||
PublishBatchFunc: func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
|
||||
atomic.AddInt32(&called, 1)
|
||||
mu.Lock()
|
||||
gotMessages = append(gotMessages, msgs...)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
h := newTestHandlers(&mockNetworkClient{pubsub: mock})
|
||||
|
||||
entries := []PublishBatchEntry{
|
||||
{Topic: "a", DataB64: base64.StdEncoding.EncodeToString([]byte("data-a"))},
|
||||
{Topic: "b", DataB64: base64.StdEncoding.EncodeToString([]byte("data-b"))},
|
||||
}
|
||||
body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
|
||||
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "test-ns")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.PublishBatchHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
// PublishBatch is invoked from a goroutine; give it a moment to run.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&called) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("PublishBatch was not called within 2s")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(gotMessages) != 2 {
|
||||
t.Fatalf("expected 2 messages forwarded, got %d", len(gotMessages))
|
||||
}
|
||||
if gotMessages[0].Topic != "a" || string(gotMessages[0].Data) != "data-a" {
|
||||
t.Errorf("unexpected first message: %+v", gotMessages[0])
|
||||
}
|
||||
if gotMessages[1].Topic != "b" || string(gotMessages[1].Data) != "data-b" {
|
||||
t.Errorf("unexpected second message: %+v", gotMessages[1])
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
@ -12,6 +13,10 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MaxPublishBatchSize is the maximum number of messages allowed in a single
|
||||
// /v1/pubsub/publish-batch request. Mirrors pubsub.MaxBatchSize.
|
||||
const MaxPublishBatchSize = pubsub.MaxBatchSize
|
||||
|
||||
// PublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
|
||||
func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if p.client == nil {
|
||||
@ -39,9 +44,133 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for local websocket subscribers FIRST and deliver directly
|
||||
p.deliverLocal(ns, body.Topic, data)
|
||||
|
||||
// Publish to libp2p asynchronously for cross-node delivery.
|
||||
go func() {
|
||||
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
|
||||
if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
|
||||
p.logger.ComponentWarn("gateway", "async libp2p publish failed",
|
||||
zap.String("topic", body.Topic),
|
||||
zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||
}
|
||||
|
||||
// PublishBatchRequest is the request body for POST /v1/pubsub/publish-batch.
|
||||
type PublishBatchRequest struct {
|
||||
Messages []PublishBatchEntry `json:"messages"`
|
||||
BestEffort bool `json:"best_effort,omitempty"`
|
||||
}
|
||||
|
||||
// PublishBatchEntry is one message in a batch publish request.
|
||||
type PublishBatchEntry struct {
|
||||
Topic string `json:"topic"`
|
||||
DataB64 string `json:"data_base64"`
|
||||
}
|
||||
|
||||
// PublishBatchResponse is the response body for /v1/pubsub/publish-batch.
|
||||
//
|
||||
// libp2p delivery is asynchronous and not awaited here, mirroring the
|
||||
// single-publish handler's fire-and-forget contract. Per-topic failures
|
||||
// are not surfaced via this response — operators should consult logs /
|
||||
// metrics for delivery health.
|
||||
type PublishBatchResponse struct {
|
||||
Status string `json:"status"` // always "ok" — request was accepted
|
||||
}
|
||||
|
||||
// MaxPerMessageBytes caps an individual message payload inside a batch.
|
||||
// Mirrors the 1MB cap on /v1/pubsub/publish.
|
||||
const MaxPerMessageBytes = 1 << 20
|
||||
|
||||
// PublishBatchHandler handles POST /v1/pubsub/publish-batch.
|
||||
// Accepts up to MaxPublishBatchSize messages and publishes them in parallel,
|
||||
// preserving namespace isolation. Local subscribers receive messages
|
||||
// immediately; libp2p delivery is async.
|
||||
func (p *PubSubHandlers) PublishBatchHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if p.client == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Limit body size: MaxPublishBatchSize messages * ~1MB each = up to ~100MB.
|
||||
// Cap conservatively at 16MB to discourage huge payloads.
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 16<<20)
|
||||
|
||||
var body PublishBatchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid body: expected {messages:[{topic,data_base64}]}")
|
||||
return
|
||||
}
|
||||
if len(body.Messages) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "messages required")
|
||||
return
|
||||
}
|
||||
if len(body.Messages) > MaxPublishBatchSize {
|
||||
writeError(w, http.StatusBadRequest, "too many messages: max is 100 per batch")
|
||||
return
|
||||
}
|
||||
|
||||
// Decode all messages up-front so we can fail fast on bad input.
|
||||
decoded := make([]pubsub.TopicMessage, 0, len(body.Messages))
|
||||
for i, m := range body.Messages {
|
||||
if m.Topic == "" {
|
||||
writeError(w, http.StatusBadRequest, "message missing topic at index "+strconv.Itoa(i))
|
||||
return
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(m.DataB64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid base64 data at index "+strconv.Itoa(i))
|
||||
return
|
||||
}
|
||||
if len(data) > MaxPerMessageBytes {
|
||||
writeError(w, http.StatusBadRequest, "message too large at index "+strconv.Itoa(i))
|
||||
return
|
||||
}
|
||||
decoded = append(decoded, pubsub.TopicMessage{Topic: m.Topic, Data: data})
|
||||
}
|
||||
|
||||
// Deliver locally + dispatch triggers per topic synchronously (fast in-process).
|
||||
for _, msg := range decoded {
|
||||
p.deliverLocal(ns, msg.Topic, msg.Data)
|
||||
}
|
||||
|
||||
// Async libp2p batch publish, similar to PublishHandler's approach.
|
||||
go func() {
|
||||
publishCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
|
||||
opts := pubsub.PublishBatchOptions{BestEffort: body.BestEffort}
|
||||
err := p.client.PubSub().PublishBatch(ctx, toClientMessages(decoded), clientOpts(opts))
|
||||
if err != nil {
|
||||
p.logger.ComponentWarn("gateway", "async libp2p batch publish failed",
|
||||
zap.Int("messages", len(decoded)),
|
||||
zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, http.StatusOK, PublishBatchResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// deliverLocal handles local-subscriber delivery and fires PubSub triggers.
|
||||
// It does NOT publish to libp2p — callers handle that themselves (single
|
||||
// or batched) so this helper stays focused on in-process fan-out.
|
||||
func (p *PubSubHandlers) deliverLocal(ns, topic string, data []byte) {
|
||||
p.mu.RLock()
|
||||
localSubs := p.getLocalSubscribers(body.Topic, ns)
|
||||
localSubs := p.getLocalSubscribers(topic, ns)
|
||||
p.mu.RUnlock()
|
||||
|
||||
localDeliveryCount := 0
|
||||
@ -50,48 +179,38 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
|
||||
select {
|
||||
case sub.msgChan <- data:
|
||||
localDeliveryCount++
|
||||
p.logger.ComponentDebug("gateway", "delivered to local subscriber",
|
||||
zap.String("topic", body.Topic))
|
||||
default:
|
||||
// Drop if buffer full
|
||||
p.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
|
||||
zap.String("topic", body.Topic))
|
||||
zap.String("topic", topic))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.ComponentInfo("gateway", "pubsub publish: processing message",
|
||||
zap.String("topic", body.Topic),
|
||||
zap.String("topic", topic),
|
||||
zap.String("namespace", ns),
|
||||
zap.Int("data_len", len(data)),
|
||||
zap.Int("local_subscribers", len(localSubs)),
|
||||
zap.Int("local_delivered", localDeliveryCount))
|
||||
|
||||
// Fire PubSub triggers for serverless functions (non-blocking)
|
||||
// Fire PubSub triggers for serverless functions (non-blocking).
|
||||
if p.onPublish != nil {
|
||||
go p.onPublish(context.Background(), ns, body.Topic, data)
|
||||
go p.onPublish(context.Background(), ns, topic, data)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish to libp2p asynchronously for cross-node delivery
|
||||
// This prevents blocking the HTTP response if libp2p network is slow
|
||||
go func() {
|
||||
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// toClientMessages converts pubsub.TopicMessage to client.TopicMessage for
|
||||
// passing through the PubSubClient interface.
|
||||
func toClientMessages(msgs []pubsub.TopicMessage) []client.TopicMessage {
|
||||
out := make([]client.TopicMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = client.TopicMessage{Topic: m.Topic, Data: m.Data}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
|
||||
if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
|
||||
p.logger.ComponentWarn("gateway", "async libp2p publish failed",
|
||||
zap.String("topic", body.Topic),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
p.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
|
||||
zap.String("topic", body.Topic))
|
||||
}
|
||||
}()
|
||||
|
||||
// Return immediately after local delivery
|
||||
// Local WebSocket subscribers already received the message
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||
func clientOpts(o pubsub.PublishBatchOptions) client.PublishBatchOptions {
|
||||
return client.PublishBatchOptions{BestEffort: o.BestEffort, MaxConcurrency: o.MaxConcurrency}
|
||||
}
|
||||
|
||||
// TopicsHandler lists topics within the caller's namespace
|
||||
|
||||
291
core/pkg/gateway/handlers/push/handlers.go
Normal file
291
core/pkg/gateway/handlers/push/handlers.go
Normal file
@ -0,0 +1,291 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// validProviders is the allowlist for the `provider` field on RegisterDevice.
|
||||
// Keep in sync with what the dispatcher actually has registered at startup.
|
||||
var validProviders = map[string]struct{}{
|
||||
"ntfy": {},
|
||||
"expo": {},
|
||||
"apns": {}, // future — accepted at registration so apps can pre-flight
|
||||
}
|
||||
|
||||
// MaxTokenBytes caps the device-token length to prevent abuse.
|
||||
// Real ntfy topic paths and Expo tokens are well under this.
|
||||
const MaxTokenBytes = 512
|
||||
|
||||
// RegisterDeviceHandler handles POST /v1/push/devices.
|
||||
//
|
||||
// The caller must be authenticated; their JWT subject (Sub) is used as the
|
||||
// user_id. API-key callers are allowed only if the body explicitly carries
|
||||
// a user_id — currently rejected to keep the surface small.
|
||||
func (h *Handlers) RegisterDeviceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
userID := resolveCallerUserID(r)
|
||||
if userID == "" {
|
||||
// We require a JWT-authenticated user to bind the device to.
|
||||
// API-key-only callers can't register devices on behalf of users.
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 4096)
|
||||
var body RegisterDeviceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid body")
|
||||
return
|
||||
}
|
||||
body.DeviceID = strings.TrimSpace(body.DeviceID)
|
||||
body.Provider = strings.TrimSpace(body.Provider)
|
||||
body.Token = strings.TrimSpace(body.Token)
|
||||
|
||||
if body.DeviceID == "" {
|
||||
writeError(w, http.StatusBadRequest, "device_id required")
|
||||
return
|
||||
}
|
||||
if _, ok := validProviders[body.Provider]; !ok {
|
||||
writeError(w, http.StatusBadRequest, "unknown provider: "+body.Provider)
|
||||
return
|
||||
}
|
||||
if body.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token required")
|
||||
return
|
||||
}
|
||||
if len(body.Token) > MaxTokenBytes {
|
||||
writeError(w, http.StatusBadRequest, "token too long")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
dev := push.PushDevice{
|
||||
Namespace: ns,
|
||||
UserID: userID,
|
||||
DeviceID: body.DeviceID,
|
||||
Provider: body.Provider,
|
||||
Token: body.Token,
|
||||
Platform: body.Platform,
|
||||
AppVer: body.AppVersion,
|
||||
LastSeen: now,
|
||||
}
|
||||
if err := h.store.Upsert(boundCtx(r), dev); err != nil {
|
||||
h.logger.ComponentWarn("push", "device upsert failed",
|
||||
zap.String("namespace", ns),
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "registration failed")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, RegisterDeviceResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// ListDevicesHandler handles GET /v1/push/devices.
|
||||
//
|
||||
// Returns the caller's own devices; tokens are NEVER included in the
|
||||
// response. Other namespaces / other users are inaccessible.
|
||||
func (h *Handlers) ListDevicesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
userID := resolveCallerUserID(r)
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "list failed")
|
||||
return
|
||||
}
|
||||
views := make([]PushDeviceView, len(devs))
|
||||
for i, d := range devs {
|
||||
views[i] = PushDeviceView{
|
||||
ID: d.ID,
|
||||
DeviceID: d.DeviceID,
|
||||
Provider: d.Provider,
|
||||
Platform: d.Platform,
|
||||
AppVersion: d.AppVer,
|
||||
CreatedAt: d.CreatedAt,
|
||||
UpdatedAt: d.UpdatedAt,
|
||||
LastSeen: d.LastSeen,
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"devices": views})
|
||||
}
|
||||
|
||||
// DeleteDeviceHandler handles DELETE /v1/push/devices/{id}.
|
||||
//
|
||||
// `{id}` is the database row ID returned at registration / by ListDevices.
|
||||
// Only devices belonging to the caller (matched by namespace + user_id +
|
||||
// the device ID lookup) can be deleted.
|
||||
func (h *Handlers) DeleteDeviceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodDelete {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
userID := resolveCallerUserID(r)
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
id := extractIDFromPath(r.URL.Path, "/v1/push/devices/")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "device id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
// Authorization check: confirm the device belongs to the caller.
|
||||
devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "lookup failed")
|
||||
return
|
||||
}
|
||||
owns := false
|
||||
for _, d := range devs {
|
||||
if d.ID == id {
|
||||
owns = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !owns {
|
||||
// 404, not 403 — don't leak whether the ID exists in another scope.
|
||||
writeError(w, http.StatusNotFound, "not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Delete(boundCtx(r), ns, id); err != nil {
|
||||
h.logger.ComponentWarn("push", "device delete failed",
|
||||
zap.String("namespace", ns),
|
||||
zap.String("device_row_id", id),
|
||||
zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "delete failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// SendHandler handles POST /v1/push/send.
|
||||
//
|
||||
// SECURITY: this endpoint sends arbitrary push messages to any user_id
|
||||
// in the caller's namespace. It MUST be gated to a small set of trusted
|
||||
// callers — typically only the namespace's own serverless functions
|
||||
// (which can send via the WASM `push_send` hostfunc directly without
|
||||
// going through HTTP) and the namespace operator.
|
||||
//
|
||||
// The current implementation accepts any JWT-authenticated caller within
|
||||
// the namespace. **Add an explicit allow-list or admin-scope check before
|
||||
// exposing this in production.** The WASM hostfunc bypasses this issue
|
||||
// because trigger registration already gates which functions exist.
|
||||
func (h *Handlers) SendHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.dispatcher == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "push: dispatcher not configured")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
if resolveCallerUserID(r) == "" {
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 64*1024) // generous for Data payloads
|
||||
var body SendRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid body")
|
||||
return
|
||||
}
|
||||
body.UserID = strings.TrimSpace(body.UserID)
|
||||
if body.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "user_id required")
|
||||
return
|
||||
}
|
||||
|
||||
msg := push.PushMessage{
|
||||
Title: body.Title,
|
||||
Body: body.Body,
|
||||
Channel: body.Channel,
|
||||
Priority: pickPriority(body.Priority),
|
||||
Badge: body.Badge,
|
||||
Sound: body.Sound,
|
||||
Data: body.Data,
|
||||
}
|
||||
if err := h.dispatcher.SendToUser(boundCtx(r), ns, body.UserID, msg); err != nil {
|
||||
// Treat as non-fatal: some devices may have failed but others may
|
||||
// have succeeded. Surface as 502 to signal partial trouble; logs
|
||||
// have the per-device detail.
|
||||
h.logger.ComponentWarn("push", "send to user partially failed",
|
||||
zap.String("namespace", ns),
|
||||
zap.String("user_id", body.UserID),
|
||||
zap.Error(err))
|
||||
writeError(w, http.StatusBadGateway, "one or more devices failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, SendResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// extractIDFromPath returns the trailing path segment after `prefix`, or
|
||||
// empty string if the path doesn't match. Used because the gateway uses
|
||||
// the standard `net/http` mux which doesn't extract path params.
|
||||
func extractIDFromPath(urlPath, prefix string) string {
|
||||
if !strings.HasPrefix(urlPath, prefix) {
|
||||
return ""
|
||||
}
|
||||
rest := urlPath[len(prefix):]
|
||||
// Drop any query string (shouldn't normally appear in path here).
|
||||
if i := strings.IndexAny(rest, "?#/"); i >= 0 {
|
||||
rest = rest[:i]
|
||||
}
|
||||
return rest
|
||||
}
|
||||
330
core/pkg/gateway/handlers/push/handlers_test.go
Normal file
330
core/pkg/gateway/handlers/push/handlers_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
authsvc "github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// fakeStore is an in-memory PushDeviceStore for tests.
|
||||
type fakeStore struct {
|
||||
devices []push.PushDevice
|
||||
upsertFn func(push.PushDevice) error
|
||||
deleteFn func(ns, id string) error
|
||||
listErr error
|
||||
}
|
||||
|
||||
func (s *fakeStore) Upsert(ctx context.Context, dev push.PushDevice) error {
|
||||
if s.upsertFn != nil {
|
||||
return s.upsertFn(dev)
|
||||
}
|
||||
if dev.ID == "" {
|
||||
dev.ID = "row-" + dev.DeviceID
|
||||
}
|
||||
s.devices = append(s.devices, dev)
|
||||
return nil
|
||||
}
|
||||
func (s *fakeStore) Delete(ctx context.Context, ns, id string) error {
|
||||
if s.deleteFn != nil {
|
||||
return s.deleteFn(ns, id)
|
||||
}
|
||||
for i, d := range s.devices {
|
||||
if d.ID == id && d.Namespace == ns {
|
||||
s.devices = append(s.devices[:i], s.devices[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("not found")
|
||||
}
|
||||
func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]push.PushDevice, error) {
|
||||
if s.listErr != nil {
|
||||
return nil, s.listErr
|
||||
}
|
||||
out := []push.PushDevice{}
|
||||
for _, d := range s.devices {
|
||||
if d.Namespace == ns && d.UserID == userID {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// withAuth populates the namespace + JWT claims (caller user ID).
|
||||
func withAuth(r *http.Request, namespace, userID string) *http.Request {
|
||||
ctx := r.Context()
|
||||
if namespace != "" {
|
||||
ctx = context.WithValue(ctx, ctxkeys.NamespaceOverride, namespace)
|
||||
}
|
||||
if userID != "" {
|
||||
ctx = context.WithValue(ctx, ctxkeys.JWT, &authsvc.JWTClaims{Sub: userID, Namespace: namespace})
|
||||
}
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func newHandlers(store push.PushDeviceStore, dispatcher *push.PushDispatcher) *Handlers {
|
||||
logger := &logging.ColoredLogger{Logger: zap.NewNop()}
|
||||
return NewHandlers(dispatcher, store, logger)
|
||||
}
|
||||
|
||||
// --- RegisterDeviceHandler ---
|
||||
|
||||
func TestRegister_happy_path(t *testing.T) {
|
||||
store := &fakeStore{}
|
||||
h := newHandlers(store, nil)
|
||||
|
||||
body, _ := json.Marshal(RegisterDeviceRequest{
|
||||
DeviceID: "iphone-abc",
|
||||
Provider: "ntfy",
|
||||
Token: "ns/myapp/user-1",
|
||||
Platform: "ios",
|
||||
})
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "myapp", "user-1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.RegisterDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
|
||||
}
|
||||
if len(store.devices) != 1 {
|
||||
t.Fatalf("expected 1 device stored, got %d", len(store.devices))
|
||||
}
|
||||
d := store.devices[0]
|
||||
if d.Namespace != "myapp" || d.UserID != "user-1" || d.Token != "ns/myapp/user-1" {
|
||||
t.Errorf("unexpected device: %+v", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_unauthenticated_rejected(t *testing.T) {
|
||||
h := newHandlers(&fakeStore{}, nil)
|
||||
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: "t"})
|
||||
|
||||
// No JWT in context.
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "")
|
||||
rr := httptest.NewRecorder()
|
||||
h.RegisterDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_unknown_provider_rejected(t *testing.T) {
|
||||
h := newHandlers(&fakeStore{}, nil)
|
||||
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "weirdmail", Token: "t"})
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
|
||||
rr := httptest.NewRecorder()
|
||||
h.RegisterDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_oversize_token_rejected(t *testing.T) {
|
||||
h := newHandlers(&fakeStore{}, nil)
|
||||
huge := make([]byte, MaxTokenBytes+1)
|
||||
for i := range huge {
|
||||
huge[i] = 'a'
|
||||
}
|
||||
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: string(huge)})
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
|
||||
rr := httptest.NewRecorder()
|
||||
h.RegisterDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_no_store_returns_503(t *testing.T) {
|
||||
h := newHandlers(nil, nil)
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader([]byte(`{}`))), "ns", "u")
|
||||
rr := httptest.NewRecorder()
|
||||
h.RegisterDeviceHandler(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected 503, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ListDevicesHandler ---
|
||||
|
||||
func TestList_returns_only_callers_devices_without_tokens(t *testing.T) {
|
||||
store := &fakeStore{
|
||||
devices: []push.PushDevice{
|
||||
{ID: "1", Namespace: "myapp", UserID: "u1", DeviceID: "d1", Provider: "ntfy", Token: "secret-token-1"},
|
||||
{ID: "2", Namespace: "myapp", UserID: "u1", DeviceID: "d2", Provider: "expo", Token: "secret-token-2"},
|
||||
{ID: "3", Namespace: "myapp", UserID: "u2", DeviceID: "d3", Provider: "ntfy", Token: "secret-token-3"},
|
||||
{ID: "4", Namespace: "other", UserID: "u1", DeviceID: "d4", Provider: "ntfy", Token: "secret-token-4"},
|
||||
},
|
||||
}
|
||||
h := newHandlers(store, nil)
|
||||
|
||||
req := withAuth(httptest.NewRequest(http.MethodGet, "/v1/push/devices", nil), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.ListDevicesHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
var resp struct {
|
||||
Devices []PushDeviceView `json:"devices"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(resp.Devices) != 2 {
|
||||
t.Fatalf("expected 2 devices, got %d", len(resp.Devices))
|
||||
}
|
||||
// Tokens must NOT appear in response — they're not even in the struct.
|
||||
if bytes.Contains(rr.Body.Bytes(), []byte("secret-token")) {
|
||||
t.Errorf("response leaked a token: %s", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- DeleteDeviceHandler ---
|
||||
|
||||
func TestDelete_owns_device_succeeds(t *testing.T) {
|
||||
store := &fakeStore{
|
||||
devices: []push.PushDevice{
|
||||
{ID: "row-d1", Namespace: "myapp", UserID: "u1", DeviceID: "d1"},
|
||||
},
|
||||
}
|
||||
h := newHandlers(store, nil)
|
||||
|
||||
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.DeleteDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
|
||||
}
|
||||
if len(store.devices) != 0 {
|
||||
t.Errorf("expected device removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_other_users_device_returns_404(t *testing.T) {
|
||||
store := &fakeStore{
|
||||
devices: []push.PushDevice{
|
||||
{ID: "row-d1", Namespace: "myapp", UserID: "other-user", DeviceID: "d1"},
|
||||
},
|
||||
}
|
||||
h := newHandlers(store, nil)
|
||||
|
||||
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.DeleteDeviceHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", rr.Code)
|
||||
}
|
||||
if len(store.devices) != 1 {
|
||||
t.Errorf("expected device NOT removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_missing_id_returns_400(t *testing.T) {
|
||||
h := newHandlers(&fakeStore{}, nil)
|
||||
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/", nil), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.DeleteDeviceHandler(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- SendHandler ---
|
||||
|
||||
func TestSend_dispatcher_called_for_user(t *testing.T) {
|
||||
var sent int32
|
||||
dispatcher := push.New(&fakeStore{
|
||||
devices: []push.PushDevice{
|
||||
{ID: "row-1", Namespace: "myapp", UserID: "target-user", Provider: "fake", Token: "tok"},
|
||||
},
|
||||
}, zap.NewNop())
|
||||
dispatcher.Register(&fakePushProvider{
|
||||
name: "fake",
|
||||
fn: func(ctx context.Context, msg push.PushMessage) error { atomic.AddInt32(&sent, 1); return nil },
|
||||
})
|
||||
|
||||
h := newHandlers(&fakeStore{}, dispatcher)
|
||||
|
||||
body, _ := json.Marshal(SendRequest{
|
||||
UserID: "target-user", Title: "hi", Body: "world",
|
||||
})
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.SendHandler(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
|
||||
}
|
||||
if atomic.LoadInt32(&sent) != 1 {
|
||||
t.Errorf("expected provider called once, got %d", sent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_no_dispatcher_returns_503(t *testing.T) {
|
||||
h := newHandlers(&fakeStore{}, nil)
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader([]byte(`{"user_id":"u"}`))), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.SendHandler(rr, req)
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected 503, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_missing_user_id_returns_400(t *testing.T) {
|
||||
dispatcher := push.New(&fakeStore{}, zap.NewNop())
|
||||
h := newHandlers(&fakeStore{}, dispatcher)
|
||||
|
||||
body, _ := json.Marshal(SendRequest{})
|
||||
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
|
||||
rr := httptest.NewRecorder()
|
||||
h.SendHandler(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
type fakePushProvider struct {
|
||||
name string
|
||||
fn func(ctx context.Context, msg push.PushMessage) error
|
||||
}
|
||||
|
||||
func (p *fakePushProvider) Name() string { return p.name }
|
||||
func (p *fakePushProvider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
if p.fn != nil {
|
||||
return p.fn(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExtractIDFromPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
path, prefix, want string
|
||||
}{
|
||||
{"/v1/push/devices/abc", "/v1/push/devices/", "abc"},
|
||||
{"/v1/push/devices/abc?x=1", "/v1/push/devices/", "abc"},
|
||||
{"/v1/push/devices/", "/v1/push/devices/", ""},
|
||||
{"/v1/other/abc", "/v1/push/devices/", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := extractIDFromPath(c.path, c.prefix); got != c.want {
|
||||
t.Errorf("extractIDFromPath(%q, %q) = %q, want %q", c.path, c.prefix, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
150
core/pkg/gateway/handlers/push/types.go
Normal file
150
core/pkg/gateway/handlers/push/types.go
Normal file
@ -0,0 +1,150 @@
|
||||
// Package push provides HTTP handlers for managing push-notification
|
||||
// device registrations and sending pushes.
|
||||
//
|
||||
// Endpoints:
|
||||
//
|
||||
// GET /v1/push/devices — list caller's registered devices (tokens omitted)
|
||||
// POST /v1/push/devices — register / update a device
|
||||
// DELETE /v1/push/devices/{id} — unregister a device
|
||||
// POST /v1/push/send — send a push to a user (admin/internal scope)
|
||||
//
|
||||
// Device tokens are stored AES-256-GCM-encrypted in RQLite via the
|
||||
// pkg/push.RqliteDeviceStore. Tokens are NEVER returned by any endpoint —
|
||||
// the GET endpoint omits the token field for safety.
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
)
|
||||
|
||||
// Handlers serves the /v1/push/* HTTP endpoints. Construct via NewHandlers;
|
||||
// it's safe for concurrent use.
|
||||
type Handlers struct {
|
||||
dispatcher *push.PushDispatcher
|
||||
store push.PushDeviceStore
|
||||
logger *logging.ColoredLogger
|
||||
}
|
||||
|
||||
// NewHandlers constructs a Handlers. Either argument may be nil — in which
|
||||
// case the corresponding endpoints return 503 Service Unavailable.
|
||||
func NewHandlers(dispatcher *push.PushDispatcher, store push.PushDeviceStore, logger *logging.ColoredLogger) *Handlers {
|
||||
return &Handlers{
|
||||
dispatcher: dispatcher,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterDeviceRequest is the body of POST /v1/push/devices.
|
||||
//
|
||||
// `device_id` is an app-supplied stable identifier (e.g. the OS-assigned
|
||||
// device UUID). Combined with (namespace, user_id) it uniquely identifies
|
||||
// the registration; re-posting with the same device_id updates the token.
|
||||
//
|
||||
// `token` is provider-specific:
|
||||
// - ntfy: the topic path the device subscribes to (e.g. "ns/myapp/user-1")
|
||||
// - expo: an ExponentPushToken[...]
|
||||
// - apns: a hex APNs device token (future)
|
||||
type RegisterDeviceRequest struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
Provider string `json:"provider"` // "ntfy" | "expo" | "apns"
|
||||
Token string `json:"token"`
|
||||
Platform string `json:"platform,omitempty"` // "ios" | "android" | "web"
|
||||
AppVersion string `json:"app_version,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterDeviceResponse is the body of POST /v1/push/devices.
|
||||
type RegisterDeviceResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// PushDeviceView is the safe (token-omitting) representation returned
|
||||
// by GET /v1/push/devices.
|
||||
type PushDeviceView struct {
|
||||
ID string `json:"id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
Provider string `json:"provider"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AppVersion string `json:"app_version,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
LastSeen int64 `json:"last_seen,omitempty"`
|
||||
}
|
||||
|
||||
// SendRequest is the body of POST /v1/push/send.
|
||||
//
|
||||
// The dispatcher fans out to all of `user_id`'s registered devices in
|
||||
// the caller's namespace. Auth scope: see SendHandler — currently
|
||||
// requires the caller to act on behalf of their own namespace; finer
|
||||
// per-user authorization is the app's responsibility.
|
||||
type SendRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Priority string `json:"priority,omitempty"` // "high" | "normal" | "" (default)
|
||||
Badge int `json:"badge,omitempty"`
|
||||
Sound string `json:"sound,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// SendResponse is the body of POST /v1/push/send.
|
||||
type SendResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// resolveNamespace pulls the namespace set by auth middleware out of context.
|
||||
func resolveNamespace(r *http.Request) string {
|
||||
if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveCallerUserID extracts the JWT subject (typically the wallet) of
|
||||
// the caller, or empty if the request was authenticated by API key only.
|
||||
func resolveCallerUserID(r *http.Request) string {
|
||||
if v := r.Context().Value(ctxkeys.JWT); v != nil {
|
||||
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||
return claims.Sub
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": message})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, code int, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// pickPriority maps the wire-format priority string to the typed enum.
|
||||
func pickPriority(s string) push.PushPriority {
|
||||
switch s {
|
||||
case "high":
|
||||
return push.PriorityHigh
|
||||
case "normal":
|
||||
return push.PriorityNormal
|
||||
default:
|
||||
return push.PriorityNormal
|
||||
}
|
||||
}
|
||||
|
||||
// boundCtx returns a request-scoped context with no extra wrapping;
|
||||
// kept as a seam for future scope (rate-limit context etc.).
|
||||
func boundCtx(r *http.Request) context.Context { return r.Context() }
|
||||
@ -42,6 +42,11 @@ func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Reque
|
||||
fmt.Sprintf("turns:%s:5349", h.turnDomain),
|
||||
)
|
||||
}
|
||||
// Stealth: TURNS via the SNI router on :443. Looks like ordinary HTTPS
|
||||
// to a passive observer / DPI; usable in restricted regions.
|
||||
if h.stealthCDNDomain != "" {
|
||||
uris = append(uris, fmt.Sprintf("turns:%s:443", h.stealthCDNDomain))
|
||||
}
|
||||
|
||||
h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials",
|
||||
zap.String("namespace", ns),
|
||||
|
||||
@ -17,10 +17,21 @@ type WebRTCHandlers struct {
|
||||
turnDomain string // TURN server domain for building URIs
|
||||
turnSecret string // HMAC-SHA1 shared secret for TURN credential generation
|
||||
|
||||
// stealthCDNDomain, when non-empty, causes CredentialsHandler to also
|
||||
// advertise turns://<stealthCDNDomain>:443 — the stealth TURN URI served
|
||||
// via the in-house SNI router. See pkg/sniproxy.
|
||||
stealthCDNDomain string
|
||||
|
||||
// proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic
|
||||
proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool
|
||||
}
|
||||
|
||||
// SetStealthCDNDomain enables the stealth TURN URI in CredentialsHandler.
|
||||
// Pass empty string to disable. Safe to call before serving begins.
|
||||
func (h *WebRTCHandlers) SetStealthCDNDomain(domain string) {
|
||||
h.stealthCDNDomain = domain
|
||||
}
|
||||
|
||||
// NewWebRTCHandlers creates a new WebRTCHandlers instance.
|
||||
func NewWebRTCHandlers(
|
||||
logger *logging.ColoredLogger,
|
||||
|
||||
@ -12,6 +12,23 @@ import (
|
||||
// It closes the serverless engine, network client, database connections,
|
||||
// Olric cache client, and IPFS client in sequence.
|
||||
func (g *Gateway) Close() {
|
||||
// Flush PubSub aggregator buffers before tearing down the engine.
|
||||
// Pending events are dispatched via the invoker which still needs the
|
||||
// engine to be alive, so this MUST happen before the engine close.
|
||||
// Aggregator state is local to this node — events not flushed here are
|
||||
// lost (intended trade-off for high-frequency lossy streams).
|
||||
if g.pubsubDispatcher != nil {
|
||||
if agg := g.pubsubDispatcher.Aggregator(); agg != nil {
|
||||
// 5s budget — same as the engine close timeout below.
|
||||
// In-flight flushes call back into the invoker which still
|
||||
// needs the engine to be alive.
|
||||
if !agg.Shutdown(5 * time.Second) {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral,
|
||||
"PubSub aggregator shutdown timed out; some buffered events may be lost")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close serverless engine first
|
||||
if g.serverlessEngine != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
@ -643,6 +643,9 @@ func requiresNamespaceOwnership(p string) bool {
|
||||
if strings.HasPrefix(p, "/v1/webrtc/") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(p, "/v1/push/") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@ -119,10 +119,30 @@ func (g *Gateway) Routes() http.Handler {
|
||||
if g.pubsubHandlers != nil {
|
||||
mux.HandleFunc("/v1/pubsub/ws", g.pubsubHandlers.WebsocketHandler)
|
||||
mux.HandleFunc("/v1/pubsub/publish", g.pubsubHandlers.PublishHandler)
|
||||
mux.HandleFunc("/v1/pubsub/publish-batch", g.pubsubHandlers.PublishBatchHandler)
|
||||
mux.HandleFunc("/v1/pubsub/topics", g.pubsubHandlers.TopicsHandler)
|
||||
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
|
||||
}
|
||||
|
||||
// push notifications
|
||||
if g.pushHandlers != nil {
|
||||
// GET + POST share the path; the handler dispatches by method.
|
||||
mux.HandleFunc("/v1/push/devices", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
g.pushHandlers.ListDevicesHandler(w, r)
|
||||
case http.MethodPost:
|
||||
g.pushHandlers.RegisterDeviceHandler(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
// DELETE /v1/push/devices/{id} — uses path-prefix routing because
|
||||
// net/http mux doesn't extract path params; the handler parses {id}.
|
||||
mux.HandleFunc("/v1/push/devices/", g.pushHandlers.DeleteDeviceHandler)
|
||||
mux.HandleFunc("/v1/push/send", g.pushHandlers.SendHandler)
|
||||
}
|
||||
|
||||
// operator node management (wallet JWT auth via middleware)
|
||||
if g.operatorHandler != nil {
|
||||
mux.HandleFunc("/v1/operator/invite", g.operatorHandler.HandleInvite)
|
||||
|
||||
@ -57,6 +57,7 @@ const (
|
||||
ComponentGateway Component = "GATEWAY"
|
||||
ComponentSFU Component = "SFU"
|
||||
ComponentTURN Component = "TURN"
|
||||
ComponentSNI Component = "SNI"
|
||||
)
|
||||
|
||||
// getComponentColor returns the color for a specific component
|
||||
|
||||
@ -29,6 +29,18 @@ func (a *ClientAdapter) Publish(ctx context.Context, topic string, data []byte)
|
||||
return a.manager.Publish(ctx, topic, data)
|
||||
}
|
||||
|
||||
// PublishBatch publishes multiple messages in parallel.
|
||||
// See Manager.PublishBatch for semantics.
|
||||
func (a *ClientAdapter) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
|
||||
return a.manager.PublishBatch(ctx, msgs, opts)
|
||||
}
|
||||
|
||||
// PublishSame sends the same payload to every topic in parallel.
|
||||
// See Manager.PublishSame for semantics.
|
||||
func (a *ClientAdapter) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
|
||||
return a.manager.PublishSame(ctx, topics, data, opts)
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from a topic
|
||||
func (a *ClientAdapter) Unsubscribe(ctx context.Context, topic string) error {
|
||||
return a.manager.Unsubscribe(ctx, topic)
|
||||
|
||||
@ -3,9 +3,56 @@ package pubsub
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// defaultBatchConcurrency is the default cap on in-flight publishes within a single batch.
|
||||
const defaultBatchConcurrency = 32
|
||||
|
||||
// MaxBatchSize is the maximum number of messages allowed per PublishBatch call.
|
||||
// The HTTP handler enforces this; the Manager itself does not, so internal callers
|
||||
// (e.g. the SDK) can pass larger batches if they accept the responsibility.
|
||||
const MaxBatchSize = 100
|
||||
|
||||
// TopicMessage is one entry in a batch publish.
|
||||
type TopicMessage struct {
|
||||
Topic string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PublishBatchOptions controls batch publish behavior.
|
||||
type PublishBatchOptions struct {
|
||||
// BestEffort, if true, attempts every publish even when some fail and returns
|
||||
// a *BatchError summarizing per-topic failures. Default (false) is fail-fast:
|
||||
// the first failure cancels remaining in-flight publishes and returns that error.
|
||||
BestEffort bool
|
||||
|
||||
// MaxConcurrency caps the number of in-flight publishes within this batch.
|
||||
// 0 means use defaultBatchConcurrency.
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
// BatchError aggregates per-topic errors returned when PublishBatch is called
|
||||
// with BestEffort=true and at least one publish failed.
|
||||
type BatchError struct {
|
||||
Errors map[string]error // topic -> error
|
||||
}
|
||||
|
||||
func (e *BatchError) Error() string {
|
||||
if len(e.Errors) == 0 {
|
||||
return "batch publish: no errors"
|
||||
}
|
||||
names := make([]string, 0, len(e.Errors))
|
||||
for t := range e.Errors {
|
||||
names = append(names, t)
|
||||
}
|
||||
return fmt.Sprintf("batch publish: %d topic(s) failed: %s", len(e.Errors), strings.Join(names, ", "))
|
||||
}
|
||||
|
||||
// Publish publishes a message to a topic
|
||||
func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error {
|
||||
if m.pubsub == nil {
|
||||
@ -58,3 +105,90 @@ func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishBatch publishes multiple messages in parallel, one per topic.
|
||||
// Default behavior is fail-fast: the first publish error cancels remaining work
|
||||
// and is returned. If opts.BestEffort is set, every publish is attempted and a
|
||||
// *BatchError is returned if any failed.
|
||||
//
|
||||
// Concurrency is bounded by opts.MaxConcurrency (default 32).
|
||||
// Empty msgs slice is a no-op and returns nil.
|
||||
func (m *Manager) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if m.pubsub == nil {
|
||||
return fmt.Errorf("pubsub not initialized")
|
||||
}
|
||||
|
||||
maxConc := opts.MaxConcurrency
|
||||
if maxConc <= 0 {
|
||||
maxConc = defaultBatchConcurrency
|
||||
}
|
||||
if maxConc > len(msgs) {
|
||||
maxConc = len(msgs)
|
||||
}
|
||||
sem := make(chan struct{}, maxConc)
|
||||
|
||||
if !opts.BestEffort {
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
for _, msg := range msgs {
|
||||
msg := msg
|
||||
g.Go(func() error {
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
case <-gctx.Done():
|
||||
return gctx.Err()
|
||||
}
|
||||
defer func() { <-sem }()
|
||||
return m.Publish(gctx, msg.Topic, msg.Data)
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// Best-effort path: attempt every publish, collect per-topic errors.
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
errMu sync.Mutex
|
||||
errMap = map[string]error{}
|
||||
)
|
||||
for _, msg := range msgs {
|
||||
msg := msg
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
errMu.Lock()
|
||||
errMap[msg.Topic] = ctx.Err()
|
||||
errMu.Unlock()
|
||||
return
|
||||
}
|
||||
defer func() { <-sem }()
|
||||
if err := m.Publish(ctx, msg.Topic, msg.Data); err != nil {
|
||||
errMu.Lock()
|
||||
errMap[msg.Topic] = err
|
||||
errMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if len(errMap) > 0 {
|
||||
return &BatchError{Errors: errMap}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishSame is a convenience wrapper that sends the same payload to every topic.
|
||||
func (m *Manager) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
|
||||
if len(topics) == 0 {
|
||||
return nil
|
||||
}
|
||||
msgs := make([]TopicMessage, len(topics))
|
||||
for i, t := range topics {
|
||||
msgs[i] = TopicMessage{Topic: t, Data: data}
|
||||
}
|
||||
return m.PublishBatch(ctx, msgs, opts)
|
||||
}
|
||||
|
||||
196
core/pkg/pubsub/publish_batch_test.go
Normal file
196
core/pkg/pubsub/publish_batch_test.go
Normal file
@ -0,0 +1,196 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPublishBatch_empty_slice_returns_nil(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
if err := mgr.PublishBatch(context.Background(), nil, PublishBatchOptions{}); err != nil {
|
||||
t.Fatalf("expected nil error for empty slice, got: %v", err)
|
||||
}
|
||||
if err := mgr.PublishBatch(context.Background(), []TopicMessage{}, PublishBatchOptions{}); err != nil {
|
||||
t.Fatalf("expected nil error for empty slice, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatch_happy_path(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
msgs := []TopicMessage{
|
||||
{Topic: "a", Data: []byte("data-a")},
|
||||
{Topic: "b", Data: []byte("data-b")},
|
||||
{Topic: "c", Data: []byte("data-c")},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
|
||||
t.Fatalf("PublishBatch failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishSame_uses_same_payload(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
topics := []string{"x", "y", "z"}
|
||||
data := []byte("shared")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.PublishSame(ctx, topics, data, PublishBatchOptions{}); err != nil {
|
||||
t.Fatalf("PublishSame failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishSame_empty_returns_nil(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
if err := mgr.PublishSame(context.Background(), nil, []byte("x"), PublishBatchOptions{}); err != nil {
|
||||
t.Fatalf("expected nil for empty topics, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatch_context_cancel_returns_error(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
msgs := make([]TopicMessage, 50)
|
||||
for i := range msgs {
|
||||
msgs[i] = TopicMessage{Topic: fmt.Sprintf("topic-%d", i), Data: []byte("d")}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected context.Canceled error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Logf("got error (acceptable as long as it's an error): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBatch_concurrency_limit(t *testing.T) {
|
||||
// Verify PublishBatch with low MaxConcurrency completes without deadlocking.
|
||||
// Each Publish in a no-peer test environment waits up to 2s for mesh formation,
|
||||
// so we use a small batch size to keep wall time bounded.
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
msgs := make([]TopicMessage, 8)
|
||||
for i := range msgs {
|
||||
msgs[i] = TopicMessage{Topic: fmt.Sprintf("ct-%d", i), Data: []byte("d")}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 2}); err != nil {
|
||||
t.Fatalf("PublishBatch with low concurrency failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPublishBatch_caps_concurrency_above_msg_count verifies that MaxConcurrency
|
||||
// is clamped to len(msgs) — passing 100 with 3 messages should not panic on
|
||||
// channel capacity.
|
||||
func TestPublishBatch_caps_concurrency_above_msg_count(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
msgs := []TopicMessage{
|
||||
{Topic: "a", Data: []byte("1")},
|
||||
{Topic: "b", Data: []byte("2")},
|
||||
{Topic: "c", Data: []byte("3")},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 100}); err != nil {
|
||||
t.Fatalf("PublishBatch failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchError_Error_summarizes(t *testing.T) {
|
||||
be := &BatchError{Errors: map[string]error{
|
||||
"topic-a": errors.New("boom"),
|
||||
"topic-b": errors.New("kaboom"),
|
||||
}}
|
||||
s := be.Error()
|
||||
if s == "" {
|
||||
t.Fatal("expected non-empty error string")
|
||||
}
|
||||
// Should mention both topics.
|
||||
if !contains(s, "topic-a") || !contains(s, "topic-b") {
|
||||
t.Errorf("error string %q should mention both failing topics", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchError_Error_empty_map(t *testing.T) {
|
||||
be := &BatchError{}
|
||||
if s := be.Error(); s == "" {
|
||||
t.Fatal("expected non-empty string even for empty map")
|
||||
}
|
||||
}
|
||||
|
||||
// contains is a tiny helper to avoid importing strings just for this.
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i+len(substr) <= len(s); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestPublishBatch_concurrent_publishes_thread_safe ensures concurrent
|
||||
// PublishBatch invocations don't race on internal state.
|
||||
func TestPublishBatch_concurrent_publishes_thread_safe(t *testing.T) {
|
||||
mgr, cleanup := createTestManager(t, "test-ns")
|
||||
defer cleanup()
|
||||
|
||||
const goroutines = 8
|
||||
const msgsPerGoroutine = 5
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var failures int64
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
msgs := make([]TopicMessage, msgsPerGoroutine)
|
||||
for i := range msgs {
|
||||
msgs[i] = TopicMessage{
|
||||
Topic: fmt.Sprintf("g%d-t%d", gid, i),
|
||||
Data: []byte("d"),
|
||||
}
|
||||
}
|
||||
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
|
||||
atomic.AddInt64(&failures, 1)
|
||||
t.Logf("goroutine %d failed: %v", gid, err)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if failures > 0 {
|
||||
t.Errorf("%d concurrent batches failed", failures)
|
||||
}
|
||||
}
|
||||
172
core/pkg/push/device_store_rqlite.go
Normal file
172
core/pkg/push/device_store_rqlite.go
Normal file
@ -0,0 +1,172 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/secrets"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SecretsKeyPurpose is the HKDF purpose string for push token encryption.
|
||||
// Used in pkg/secrets.DeriveKey to derive a domain-separated AES key.
|
||||
const SecretsKeyPurpose = "push-device-tokens"
|
||||
|
||||
// RqliteDeviceStore is a PushDeviceStore backed by RQLite + AES-256-GCM
|
||||
// at-rest encryption of the push token.
|
||||
type RqliteDeviceStore struct {
|
||||
db rqlite.Client
|
||||
encKey []byte // derived once at construction
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRqliteDeviceStore derives the per-cluster encryption key from the
|
||||
// cluster secret and returns a ready-to-use store. The cluster secret is
|
||||
// the same one used for other at-rest encryption (see pkg/secrets).
|
||||
func NewRqliteDeviceStore(db rqlite.Client, clusterSecret string, logger *zap.Logger) (*RqliteDeviceStore, error) {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
key, err := secrets.DeriveKey(clusterSecret, SecretsKeyPurpose)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive push-device key: %w", err)
|
||||
}
|
||||
return &RqliteDeviceStore{
|
||||
db: db,
|
||||
encKey: key,
|
||||
logger: logger.Named("push-store"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// deviceRow is the scan target for SELECT queries.
|
||||
type deviceRow struct {
|
||||
ID string
|
||||
Namespace string
|
||||
UserID string
|
||||
DeviceID string
|
||||
Provider string
|
||||
TokenEncrypted string
|
||||
Platform string
|
||||
AppVersion string
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
LastSeen int64
|
||||
}
|
||||
|
||||
// Upsert implements PushDeviceStore.
|
||||
func (s *RqliteDeviceStore) Upsert(ctx context.Context, dev PushDevice) error {
|
||||
if dev.Namespace == "" || dev.UserID == "" || dev.DeviceID == "" {
|
||||
return fmt.Errorf("namespace, user_id, device_id required")
|
||||
}
|
||||
if dev.Provider == "" {
|
||||
return fmt.Errorf("provider required")
|
||||
}
|
||||
if dev.Token == "" {
|
||||
return ErrEmptyToken
|
||||
}
|
||||
|
||||
encToken, err := secrets.Encrypt(dev.Token, s.encKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if dev.CreatedAt == 0 {
|
||||
dev.CreatedAt = now
|
||||
}
|
||||
dev.UpdatedAt = now
|
||||
|
||||
id := dev.ID
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
|
||||
// SQLite UPSERT keyed on (namespace, user_id, device_id) per the migration's
|
||||
// UNIQUE constraint. On conflict we replace token + provider + metadata
|
||||
// while preserving the original id and created_at.
|
||||
query := `
|
||||
INSERT INTO push_devices
|
||||
(id, namespace, user_id, device_id, provider, token_encrypted,
|
||||
platform, app_version, created_at, updated_at, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(namespace, user_id, device_id) DO UPDATE SET
|
||||
provider = excluded.provider,
|
||||
token_encrypted = excluded.token_encrypted,
|
||||
platform = excluded.platform,
|
||||
app_version = excluded.app_version,
|
||||
updated_at = excluded.updated_at,
|
||||
last_seen = excluded.last_seen
|
||||
`
|
||||
_, err = s.db.Exec(ctx, query,
|
||||
id, dev.Namespace, dev.UserID, dev.DeviceID, dev.Provider, encToken,
|
||||
dev.Platform, dev.AppVer, dev.CreatedAt, dev.UpdatedAt, dev.LastSeen,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert push device: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete implements PushDeviceStore.
|
||||
func (s *RqliteDeviceStore) Delete(ctx context.Context, namespace, id string) error {
|
||||
if namespace == "" || id == "" {
|
||||
return fmt.Errorf("namespace and id required")
|
||||
}
|
||||
query := `DELETE FROM push_devices WHERE id = ? AND namespace = ?`
|
||||
res, err := s.db.Exec(ctx, query, id, namespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete push device: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("push device not found: %s", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListForUser implements PushDeviceStore. Returns devices with decrypted tokens.
|
||||
// Caller MUST treat tokens as sensitive.
|
||||
func (s *RqliteDeviceStore) ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error) {
|
||||
if namespace == "" || userID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
query := `
|
||||
SELECT id, namespace, user_id, device_id, provider, token_encrypted,
|
||||
COALESCE(platform, ''), COALESCE(app_version, ''),
|
||||
created_at, updated_at, COALESCE(last_seen, 0)
|
||||
FROM push_devices
|
||||
WHERE namespace = ? AND user_id = ?
|
||||
`
|
||||
var rows []deviceRow
|
||||
if err := s.db.Query(ctx, &rows, query, namespace, userID); err != nil {
|
||||
return nil, fmt.Errorf("query push devices: %w", err)
|
||||
}
|
||||
|
||||
out := make([]PushDevice, 0, len(rows))
|
||||
for _, r := range rows {
|
||||
token, err := secrets.Decrypt(r.TokenEncrypted, s.encKey)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to decrypt push token; skipping device",
|
||||
zap.String("device_id", r.DeviceID),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
out = append(out, PushDevice{
|
||||
ID: r.ID,
|
||||
Namespace: r.Namespace,
|
||||
UserID: r.UserID,
|
||||
DeviceID: r.DeviceID,
|
||||
Provider: r.Provider,
|
||||
Token: token,
|
||||
Platform: r.Platform,
|
||||
AppVer: r.AppVersion,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
LastSeen: r.LastSeen,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
97
core/pkg/push/dispatcher.go
Normal file
97
core/pkg/push/dispatcher.go
Normal file
@ -0,0 +1,97 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PushDispatcher routes push messages to the matching provider for each
|
||||
// of a user's registered devices.
|
||||
type PushDispatcher struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]PushProvider
|
||||
devices PushDeviceStore
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a dispatcher with the given device store. Register
|
||||
// providers before sending.
|
||||
func New(devices PushDeviceStore, logger *zap.Logger) *PushDispatcher {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &PushDispatcher{
|
||||
providers: map[string]PushProvider{},
|
||||
devices: devices,
|
||||
logger: logger.Named("push"),
|
||||
}
|
||||
}
|
||||
|
||||
// Register makes a provider available to dispatch. Calling Register with
|
||||
// the same name twice replaces the previous provider — useful in tests.
|
||||
func (d *PushDispatcher) Register(p PushProvider) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.providers[p.Name()] = p
|
||||
}
|
||||
|
||||
// Provider returns the registered provider by name, or nil.
|
||||
func (d *PushDispatcher) Provider(name string) PushProvider {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.providers[name]
|
||||
}
|
||||
|
||||
// SendToUser fans out the message to every registered device for the
|
||||
// user. Each provider failure is logged but does not stop subsequent
|
||||
// devices. Returns the first encountered error (if any) so callers can
|
||||
// surface a partial-failure signal.
|
||||
//
|
||||
// SendToUser returns nil if the user has no registered devices — that
|
||||
// is normal, not an error.
|
||||
func (d *PushDispatcher) SendToUser(
|
||||
ctx context.Context,
|
||||
namespace, userID string,
|
||||
msg PushMessage,
|
||||
) error {
|
||||
devs, err := d.devices.ListForUser(ctx, namespace, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list devices: %w", err)
|
||||
}
|
||||
if len(devs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
for _, dev := range devs {
|
||||
d.mu.RLock()
|
||||
p, ok := d.providers[dev.Provider]
|
||||
d.mu.RUnlock()
|
||||
if !ok {
|
||||
d.logger.Warn("push: dropping device with unregistered provider",
|
||||
zap.String("provider", dev.Provider),
|
||||
zap.String("device_id", dev.DeviceID),
|
||||
)
|
||||
if firstErr == nil {
|
||||
firstErr = fmt.Errorf("%w: %s", ErrUnknownProvider, dev.Provider)
|
||||
}
|
||||
continue
|
||||
}
|
||||
m := msg
|
||||
m.DeviceToken = dev.Token
|
||||
if err := p.Send(ctx, m); err != nil {
|
||||
d.logger.Warn("push: provider send failed",
|
||||
zap.String("provider", dev.Provider),
|
||||
zap.String("device_id", dev.DeviceID),
|
||||
zap.Error(err),
|
||||
)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
149
core/pkg/push/dispatcher_test.go
Normal file
149
core/pkg/push/dispatcher_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// fakeProvider records every Send call.
|
||||
type fakeProvider struct {
|
||||
name string
|
||||
sent int32
|
||||
lastToken string
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeProvider) Name() string { return f.name }
|
||||
func (f *fakeProvider) Send(ctx context.Context, msg PushMessage) error {
|
||||
atomic.AddInt32(&f.sent, 1)
|
||||
f.lastToken = msg.DeviceToken
|
||||
return f.err
|
||||
}
|
||||
|
||||
// fakeStore is an in-memory PushDeviceStore.
|
||||
type fakeStore struct {
|
||||
devices []PushDevice
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *fakeStore) Upsert(ctx context.Context, dev PushDevice) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
s.devices = append(s.devices, dev)
|
||||
return nil
|
||||
}
|
||||
func (s *fakeStore) Delete(ctx context.Context, ns, id string) error { return nil }
|
||||
func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]PushDevice, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
out := []PushDevice{}
|
||||
for _, d := range s.devices {
|
||||
if d.Namespace == ns && d.UserID == userID {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestSendToUser_no_devices_returns_nil(t *testing.T) {
|
||||
d := New(&fakeStore{}, zap.NewNop())
|
||||
if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "x"}); err != nil {
|
||||
t.Fatalf("expected nil for no devices, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToUser_routes_to_correct_provider(t *testing.T) {
|
||||
store := &fakeStore{devices: []PushDevice{
|
||||
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "ntfy-tok"},
|
||||
{Namespace: "ns", UserID: "u", Provider: "expo", Token: "expo-tok"},
|
||||
}}
|
||||
ntfy := &fakeProvider{name: "ntfy"}
|
||||
expo := &fakeProvider{name: "expo"}
|
||||
|
||||
d := New(store, zap.NewNop())
|
||||
d.Register(ntfy)
|
||||
d.Register(expo)
|
||||
|
||||
if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "hi"}); err != nil {
|
||||
t.Fatalf("SendToUser: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&ntfy.sent) != 1 || ntfy.lastToken != "ntfy-tok" {
|
||||
t.Errorf("ntfy provider not called correctly: sent=%d token=%s", ntfy.sent, ntfy.lastToken)
|
||||
}
|
||||
if atomic.LoadInt32(&expo.sent) != 1 || expo.lastToken != "expo-tok" {
|
||||
t.Errorf("expo provider not called correctly: sent=%d token=%s", expo.sent, expo.lastToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToUser_unknown_provider_returns_error_continues(t *testing.T) {
|
||||
store := &fakeStore{devices: []PushDevice{
|
||||
{Namespace: "ns", UserID: "u", Provider: "ghost", Token: "tok"},
|
||||
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "real"},
|
||||
}}
|
||||
ntfy := &fakeProvider{name: "ntfy"}
|
||||
|
||||
d := New(store, zap.NewNop())
|
||||
d.Register(ntfy)
|
||||
|
||||
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider")
|
||||
}
|
||||
if !errors.Is(err, ErrUnknownProvider) {
|
||||
t.Errorf("expected ErrUnknownProvider, got %v", err)
|
||||
}
|
||||
// ntfy should still have been called.
|
||||
if atomic.LoadInt32(&ntfy.sent) != 1 {
|
||||
t.Error("ntfy should have been called for the second device")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToUser_provider_failure_returned_but_other_devices_still_processed(t *testing.T) {
|
||||
store := &fakeStore{devices: []PushDevice{
|
||||
{Namespace: "ns", UserID: "u", Provider: "expo", Token: "tok-1"},
|
||||
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "tok-2"},
|
||||
}}
|
||||
expoErr := errors.New("expo down")
|
||||
expo := &fakeProvider{name: "expo", err: expoErr}
|
||||
ntfy := &fakeProvider{name: "ntfy"}
|
||||
|
||||
d := New(store, zap.NewNop())
|
||||
d.Register(expo)
|
||||
d.Register(ntfy)
|
||||
|
||||
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
|
||||
if !errors.Is(err, expoErr) {
|
||||
t.Errorf("expected expo error, got %v", err)
|
||||
}
|
||||
if atomic.LoadInt32(&ntfy.sent) != 1 {
|
||||
t.Error("ntfy should have been called even though expo failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendToUser_store_error_propagated(t *testing.T) {
|
||||
storeErr := errors.New("store boom")
|
||||
d := New(&fakeStore{err: storeErr}, zap.NewNop())
|
||||
|
||||
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
|
||||
if err == nil || !errors.Is(err, storeErr) {
|
||||
t.Errorf("expected store error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_replaces_existing_provider(t *testing.T) {
|
||||
d := New(&fakeStore{}, zap.NewNop())
|
||||
a := &fakeProvider{name: "ntfy"}
|
||||
b := &fakeProvider{name: "ntfy"}
|
||||
d.Register(a)
|
||||
d.Register(b)
|
||||
if d.Provider("ntfy") != b {
|
||||
t.Error("expected second Register to replace the first")
|
||||
}
|
||||
}
|
||||
160
core/pkg/push/providers/expo/expo.go
Normal file
160
core/pkg/push/providers/expo/expo.go
Normal file
@ -0,0 +1,160 @@
|
||||
// Package expo wraps the Expo push relay as a push.PushProvider.
|
||||
//
|
||||
// This is a thin port of the legacy gateway.PushNotificationService —
|
||||
// behaviour preserved, surface adapted to the provider abstraction.
|
||||
//
|
||||
// Long term Expo is intended to be replaced with direct APNs (iOS) +
|
||||
// ntfy (Android). This provider exists so the gateway can keep using
|
||||
// Expo while the migration happens.
|
||||
package expo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const expoAPIURL = "https://exp.host/--/api/v2/push/send"
|
||||
|
||||
// Config holds Expo provider settings.
|
||||
type Config struct {
|
||||
// AccessToken is an optional Expo access token. When set, it's sent
|
||||
// as a Bearer token, which Expo uses for higher-priority delivery
|
||||
// and to attribute the send to your account.
|
||||
AccessToken string
|
||||
// Timeout bounds each Send call. 0 selects 10 seconds (matching the
|
||||
// previous PushNotificationService default).
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Provider is the Expo push.PushProvider implementation.
|
||||
type Provider struct {
|
||||
accessToken string
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a Provider with the given config.
|
||||
func New(cfg Config, logger *zap.Logger) *Provider {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
return &Provider{
|
||||
accessToken: cfg.AccessToken,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
logger: logger.Named("expo"),
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements push.PushProvider.
|
||||
func (p *Provider) Name() string { return "expo" }
|
||||
|
||||
// expoMessage matches the wire format Expo expects.
|
||||
type expoMessage struct {
|
||||
To string `json:"to"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Body string `json:"body,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Sound string `json:"sound,omitempty"`
|
||||
Badge int `json:"badge,omitempty"`
|
||||
Priority string `json:"priority,omitempty"`
|
||||
MutableContent bool `json:"mutableContent,omitempty"`
|
||||
ChannelID string `json:"channelId,omitempty"`
|
||||
}
|
||||
|
||||
// expoTicket is the per-message response.
|
||||
type expoTicket struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type expoResponse struct {
|
||||
Data []expoTicket `json:"data"`
|
||||
}
|
||||
|
||||
// Send delivers a push via the Expo relay.
|
||||
func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
if msg.DeviceToken == "" {
|
||||
return push.ErrEmptyToken
|
||||
}
|
||||
|
||||
priority := "default"
|
||||
if msg.Priority == push.PriorityHigh {
|
||||
priority = "high"
|
||||
}
|
||||
|
||||
wire := expoMessage{
|
||||
To: msg.DeviceToken,
|
||||
Title: msg.Title,
|
||||
Body: msg.Body,
|
||||
Data: msg.Data,
|
||||
Sound: msg.Sound,
|
||||
Badge: msg.Badge,
|
||||
Priority: priority,
|
||||
MutableContent: true, // for iOS Notification Service Extension
|
||||
ChannelID: msg.Channel,
|
||||
}
|
||||
if wire.Sound == "" {
|
||||
wire.Sound = "default"
|
||||
}
|
||||
|
||||
body, err := json.Marshal(wire)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expo: marshal: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, expoAPIURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("expo: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if p.accessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.accessToken)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expo: post: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 16<<10))
|
||||
if err != nil {
|
||||
return fmt.Errorf("expo: read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("expo: http %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var er expoResponse
|
||||
if err := json.Unmarshal(respBody, &er); err != nil {
|
||||
// Older Expo responses sometimes return a bare array; try that fallback.
|
||||
var tickets []expoTicket
|
||||
if err2 := json.Unmarshal(respBody, &tickets); err2 == nil {
|
||||
er.Data = tickets
|
||||
} else {
|
||||
return fmt.Errorf("expo: parse response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range er.Data {
|
||||
if t.Status != "" && t.Status != "ok" {
|
||||
return fmt.Errorf("expo: ticket status %q: %s", t.Status, t.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
118
core/pkg/push/providers/expo/expo_test.go
Normal file
118
core/pkg/push/providers/expo/expo_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package expo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
)
|
||||
|
||||
// roundTripFunc lets us mock http.Client transport for the Expo provider so
|
||||
// we can assert against requests without hitting exp.host.
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
|
||||
|
||||
func newTestProvider(rt roundTripFunc) *Provider {
|
||||
p := New(Config{}, nil)
|
||||
p.httpClient.Transport = rt
|
||||
return p
|
||||
}
|
||||
|
||||
func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
|
||||
p := New(Config{}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{Body: "x"})
|
||||
if err != push.ErrEmptyToken {
|
||||
t.Errorf("expected ErrEmptyToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_happy_path(t *testing.T) {
|
||||
var gotPayload map[string]interface{}
|
||||
var gotAuth string
|
||||
|
||||
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
|
||||
gotAuth = req.Header.Get("Authorization")
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
_ = json.Unmarshal(body, &gotPayload)
|
||||
resp := httptest.NewRecorder()
|
||||
resp.WriteHeader(200)
|
||||
_, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
p.accessToken = "secret-token"
|
||||
|
||||
err := p.Send(context.Background(), push.PushMessage{
|
||||
DeviceToken: "ExponentPushToken[abc]",
|
||||
Title: "T", Body: "B",
|
||||
Priority: push.PriorityHigh,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
if gotAuth != "Bearer secret-token" {
|
||||
t.Errorf("auth header wrong: %s", gotAuth)
|
||||
}
|
||||
if gotPayload["to"] != "ExponentPushToken[abc]" {
|
||||
t.Errorf("to wrong: %v", gotPayload["to"])
|
||||
}
|
||||
if gotPayload["priority"] != "high" {
|
||||
t.Errorf("priority wrong: %v", gotPayload["priority"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_ticket_error_returns_error(t *testing.T) {
|
||||
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
|
||||
resp := httptest.NewRecorder()
|
||||
resp.WriteHeader(200)
|
||||
_, _ = resp.WriteString(`{"data":[{"status":"error","message":"DeviceNotRegistered"}]}`)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for ticket failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_http_error_returns_error(t *testing.T) {
|
||||
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
|
||||
resp := httptest.NewRecorder()
|
||||
resp.WriteHeader(500)
|
||||
_, _ = resp.WriteString(`upstream broken`)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 500")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_normal_priority_maps_to_default(t *testing.T) {
|
||||
var gotPayload map[string]interface{}
|
||||
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
_ = json.Unmarshal(body, &gotPayload)
|
||||
resp := httptest.NewRecorder()
|
||||
resp.WriteHeader(200)
|
||||
_, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
if err := p.Send(context.Background(), push.PushMessage{
|
||||
DeviceToken: "x", Body: "y", Priority: push.PriorityNormal,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gotPayload["priority"] != "default" {
|
||||
t.Errorf("expected priority=default, got %v", gotPayload["priority"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
if New(Config{}, nil).Name() != "expo" {
|
||||
t.Error("expected Name=expo")
|
||||
}
|
||||
}
|
||||
132
core/pkg/push/providers/ntfy/ntfy.go
Normal file
132
core/pkg/push/providers/ntfy/ntfy.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Package ntfy implements a push.PushProvider backed by an ntfy server.
|
||||
//
|
||||
// ntfy delivers notifications via plain HTTP POST to <baseURL>/<topic>.
|
||||
// We map PushMessage fields to ntfy headers:
|
||||
// - Title -> "Title"
|
||||
// - Priority -> "Priority"
|
||||
// - Channel -> "Tags"
|
||||
// - Data -> base64-encoded JSON in "X-Data"
|
||||
//
|
||||
// See https://docs.ntfy.sh/publish/#publish-as-json for details.
|
||||
package ntfy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config holds per-provider settings.
|
||||
type Config struct {
|
||||
// BaseURL is the ntfy HTTP endpoint (e.g. "http://localhost:8080" or
|
||||
// "https://push.example.com"). Trailing slash is tolerated.
|
||||
BaseURL string
|
||||
// AuthToken is an optional per-namespace bearer token. Leave empty to
|
||||
// disable authentication.
|
||||
AuthToken string
|
||||
// Timeout bounds each Send call. 0 selects 5 seconds.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Provider is the ntfy push.PushProvider implementation.
|
||||
type Provider struct {
|
||||
baseURL string
|
||||
authToken string
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a Provider with the given config.
|
||||
func New(cfg Config, logger *zap.Logger) *Provider {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
return &Provider{
|
||||
baseURL: strings.TrimRight(cfg.BaseURL, "/"),
|
||||
authToken: cfg.AuthToken,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
logger: logger.Named("ntfy"),
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements push.PushProvider.
|
||||
func (p *Provider) Name() string { return "ntfy" }
|
||||
|
||||
// Send delivers a push notification to the device's ntfy topic.
|
||||
func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
if msg.DeviceToken == "" {
|
||||
return push.ErrEmptyToken
|
||||
}
|
||||
if p.baseURL == "" {
|
||||
return fmt.Errorf("ntfy: base URL not configured")
|
||||
}
|
||||
|
||||
// URL-escape each path segment of the device token. ntfy topics can be
|
||||
// hierarchical (e.g. "ns/myapp/user-1") and we want to preserve those
|
||||
// '/' separators while escaping any other special characters that
|
||||
// could let a malicious token escape the topic path.
|
||||
parts := strings.Split(msg.DeviceToken, "/")
|
||||
for i, p := range parts {
|
||||
parts[i] = url.PathEscape(p)
|
||||
}
|
||||
endpointURL := p.baseURL + "/" + strings.Join(parts, "/")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(msg.Body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ntfy: build request: %w", err)
|
||||
}
|
||||
|
||||
if msg.Title != "" {
|
||||
req.Header.Set("Title", msg.Title)
|
||||
}
|
||||
if msg.Priority == push.PriorityHigh {
|
||||
req.Header.Set("Priority", "high")
|
||||
} else if msg.Priority == push.PriorityNormal {
|
||||
req.Header.Set("Priority", "default")
|
||||
}
|
||||
if msg.Channel != "" {
|
||||
// ntfy uses "Tags" for both visual emoji and operator-defined tags.
|
||||
req.Header.Set("Tags", msg.Channel)
|
||||
}
|
||||
if msg.Badge > 0 {
|
||||
req.Header.Set("X-Badge", fmt.Sprintf("%d", msg.Badge))
|
||||
}
|
||||
if len(msg.Data) > 0 {
|
||||
b, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ntfy: marshal data: %w", err)
|
||||
}
|
||||
req.Header.Set("X-Data", base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
if p.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.authToken)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ntfy: post: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
// Drain body to allow connection reuse.
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
|
||||
return nil
|
||||
}
|
||||
191
core/pkg/push/providers/ntfy/ntfy_test.go
Normal file
191
core/pkg/push/providers/ntfy/ntfy_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
package ntfy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
)
|
||||
|
||||
func TestSend_happy_path(t *testing.T) {
|
||||
var (
|
||||
gotPath string
|
||||
gotBody string
|
||||
gotTitle string
|
||||
gotPriority string
|
||||
gotAuth string
|
||||
)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotTitle = r.Header.Get("Title")
|
||||
gotPriority = r.Header.Get("Priority")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
gotBody = string(b)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL, AuthToken: "secret"}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{
|
||||
DeviceToken: "ns/myapp/user-1",
|
||||
Title: "Hello",
|
||||
Body: "World",
|
||||
Priority: push.PriorityHigh,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
if gotPath != "/ns/myapp/user-1" {
|
||||
t.Errorf("expected path /ns/myapp/user-1, got %s", gotPath)
|
||||
}
|
||||
if gotTitle != "Hello" {
|
||||
t.Errorf("expected Title=Hello, got %s", gotTitle)
|
||||
}
|
||||
if gotPriority != "high" {
|
||||
t.Errorf("expected Priority=high, got %s", gotPriority)
|
||||
}
|
||||
if gotAuth != "Bearer secret" {
|
||||
t.Errorf("expected Authorization=Bearer secret, got %s", gotAuth)
|
||||
}
|
||||
if gotBody != "World" {
|
||||
t.Errorf("expected body=World, got %s", gotBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_includes_data_header_when_data_set(t *testing.T) {
|
||||
var gotData string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotData = r.Header.Get("X-Data")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{
|
||||
DeviceToken: "topic",
|
||||
Body: "x",
|
||||
Data: map[string]interface{}{"call_id": "abc-123"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(gotData)
|
||||
if err != nil {
|
||||
t.Fatalf("X-Data not valid base64: %v", err)
|
||||
}
|
||||
var got map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &got); err != nil {
|
||||
t.Fatalf("X-Data not valid JSON: %v", err)
|
||||
}
|
||||
if got["call_id"] != "abc-123" {
|
||||
t.Errorf("data round-trip failed: got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_no_data_no_data_header(t *testing.T) {
|
||||
var gotData string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotData = r.Header.Get("X-Data")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL}, nil)
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gotData != "" {
|
||||
t.Errorf("expected no X-Data header, got %q", gotData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_no_auth_header_when_token_empty(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL}, nil)
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gotAuth != "" {
|
||||
t.Errorf("expected no Authorization header, got %q", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_4xx_returns_error_with_body_excerpt(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("forbidden topic"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "403") || !strings.Contains(err.Error(), "forbidden") {
|
||||
t.Errorf("error should mention status and body, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
|
||||
p := New(Config{BaseURL: "http://example.invalid"}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{Body: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token")
|
||||
}
|
||||
if err != push.ErrEmptyToken {
|
||||
t.Errorf("expected ErrEmptyToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_short_timeout_returns_error(t *testing.T) {
|
||||
// Server that blocks for 2s — provider with 100ms timeout should give up.
|
||||
blockUntil := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-blockUntil:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}))
|
||||
defer func() { close(blockUntil); srv.Close() }()
|
||||
|
||||
p := New(Config{BaseURL: srv.URL, Timeout: 100 * time.Millisecond}, nil)
|
||||
start := time.Now()
|
||||
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
|
||||
elapsed := time.Since(start)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error")
|
||||
}
|
||||
if elapsed > 1*time.Second {
|
||||
t.Errorf("expected fast timeout, took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_no_baseURL_returns_error(t *testing.T) {
|
||||
p := New(Config{}, nil)
|
||||
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing base URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
p := New(Config{BaseURL: "http://x"}, nil)
|
||||
if p.Name() != "ntfy" {
|
||||
t.Errorf("expected Name=ntfy, got %s", p.Name())
|
||||
}
|
||||
}
|
||||
91
core/pkg/push/types.go
Normal file
91
core/pkg/push/types.go
Normal file
@ -0,0 +1,91 @@
|
||||
// Package push provides a generic push-notification abstraction for Orama.
|
||||
//
|
||||
// Apps register devices with a provider name ("ntfy", "expo", "apns", ...)
|
||||
// and a provider-specific token. The PushDispatcher routes outbound push
|
||||
// messages to the matching provider so call sites stay backend-agnostic.
|
||||
//
|
||||
// Long-term the platform aims to drop Expo in favour of direct APNs +
|
||||
// ntfy. The abstraction makes that swap a configuration change rather
|
||||
// than a code change.
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// PushPriority signals delivery urgency to the provider.
|
||||
// Providers that don't support priorities ignore the value.
|
||||
type PushPriority string
|
||||
|
||||
const (
|
||||
PriorityNormal PushPriority = "normal"
|
||||
PriorityHigh PushPriority = "high"
|
||||
)
|
||||
|
||||
// PushMessage is the provider-agnostic message format.
|
||||
//
|
||||
// DeviceToken is the provider-specific identifier (e.g. an ntfy topic,
|
||||
// an Expo push token, an APNs device token). The PushDispatcher fills
|
||||
// it in per-device before calling Send.
|
||||
type PushMessage struct {
|
||||
DeviceToken string
|
||||
Title string
|
||||
Body string
|
||||
Data map[string]interface{}
|
||||
Badge int
|
||||
Sound string
|
||||
Channel string // "messages", "calls", etc — provider may map to its own channel concept
|
||||
Priority PushPriority
|
||||
}
|
||||
|
||||
// PushProvider is implemented by each backend (ntfy, expo, apns).
|
||||
type PushProvider interface {
|
||||
Name() string
|
||||
// Send delivers a single push. Returning an error counts as a delivery
|
||||
// failure for that device; the dispatcher logs it and continues.
|
||||
Send(ctx context.Context, msg PushMessage) error
|
||||
}
|
||||
|
||||
// PushDevice represents a registered push target for a user.
|
||||
//
|
||||
// Token is plaintext in this struct — encryption happens at the storage
|
||||
// layer. Callers who load Devices from the store must treat tokens as
|
||||
// sensitive material (don't log them).
|
||||
type PushDevice struct {
|
||||
ID string
|
||||
Namespace string
|
||||
UserID string
|
||||
DeviceID string // app-provided
|
||||
Provider string // matches PushProvider.Name()
|
||||
Token string
|
||||
Platform string // "ios" | "android" | "web"
|
||||
AppVer string
|
||||
CreatedAt int64 // unix seconds
|
||||
UpdatedAt int64
|
||||
LastSeen int64
|
||||
}
|
||||
|
||||
// PushDeviceStore persists per-user device registrations.
|
||||
type PushDeviceStore interface {
|
||||
// Upsert registers or updates a device. The Token is encrypted by the
|
||||
// implementation before being written to durable storage.
|
||||
Upsert(ctx context.Context, dev PushDevice) error
|
||||
|
||||
// Delete removes a single device by ID, scoped to the namespace.
|
||||
Delete(ctx context.Context, namespace, id string) error
|
||||
|
||||
// ListForUser returns all devices for a user within a namespace.
|
||||
// Tokens in the returned slice are decrypted.
|
||||
ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error)
|
||||
}
|
||||
|
||||
// Sentinel errors.
|
||||
var (
|
||||
// ErrUnknownProvider is returned by the dispatcher when a device
|
||||
// references a provider that isn't registered.
|
||||
ErrUnknownProvider = errors.New("push: unknown provider")
|
||||
// ErrEmptyToken is returned by providers when called with an empty
|
||||
// DeviceToken.
|
||||
ErrEmptyToken = errors.New("push: empty device token")
|
||||
)
|
||||
295
core/pkg/serverless/aggregator/aggregator.go
Normal file
295
core/pkg/serverless/aggregator/aggregator.go
Normal file
@ -0,0 +1,295 @@
|
||||
// Package aggregator buffers PubSub trigger events per
|
||||
// (namespace, function, trigger) and flushes them as a single batched
|
||||
// invocation. It's used by the PubSub trigger dispatcher when a trigger
|
||||
// declares aggregation_window_ms > 0.
|
||||
//
|
||||
// State is local to each node — buffers are not replicated. This is by
|
||||
// design: aggregation is intended for high-frequency, lossy event streams
|
||||
// (presence, VAD signals, metrics). Crash recovery is not provided; an
|
||||
// orderly shutdown flushes pending buffers via Shutdown().
|
||||
package aggregator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DefaultMaxBatchSize is used when a trigger sets MaxBatchSize=0.
|
||||
const DefaultMaxBatchSize = 100
|
||||
|
||||
// Event is one buffered message, mirroring the dispatcher's PubSubEvent
|
||||
// shape but kept local to avoid an import cycle.
|
||||
type Event struct {
|
||||
Topic string `json:"topic"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Namespace string `json:"namespace"`
|
||||
TriggerDepth int `json:"trigger_depth"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// BatchedPayload is what the function receives on a flush.
|
||||
// `Batched: true` lets a function differentiate single vs. batched mode
|
||||
// by parsing this discriminator first.
|
||||
type BatchedPayload struct {
|
||||
Batched bool `json:"batched"`
|
||||
Events []Event `json:"events"`
|
||||
}
|
||||
|
||||
// FlushFn is invoked when a buffer flushes. It receives the marshalled
|
||||
// BatchedPayload and a context with a sane timeout. The aggregator does
|
||||
// not retry on flush errors — that's the invoker's responsibility.
|
||||
type FlushFn func(ctx context.Context, payload []byte)
|
||||
|
||||
// FlushReason annotates why a flush happened. Useful for metrics.
|
||||
type FlushReason string
|
||||
|
||||
const (
|
||||
FlushReasonTimer FlushReason = "timer"
|
||||
FlushReasonSize FlushReason = "size"
|
||||
FlushReasonShutdown FlushReason = "shutdown"
|
||||
)
|
||||
|
||||
// FlushFnWithReason is like FlushFn but also receives the reason.
|
||||
// Internal use; FlushFn is the simple public form.
|
||||
type FlushFnWithReason func(ctx context.Context, payload []byte, reason FlushReason)
|
||||
|
||||
// bufferKey identifies a single in-memory buffer.
|
||||
type bufferKey struct {
|
||||
Namespace string
|
||||
FunctionID string
|
||||
TriggerID string
|
||||
}
|
||||
|
||||
type bufferEntry struct {
|
||||
events []Event
|
||||
timer *time.Timer
|
||||
windowMs int
|
||||
maxBatch int
|
||||
flushFn FlushFnWithReason
|
||||
}
|
||||
|
||||
// Aggregator buffers events per (namespace, function, trigger) and flushes
|
||||
// either when the window timer fires or when MaxBatch is reached.
|
||||
type Aggregator struct {
|
||||
mu sync.Mutex
|
||||
buffers map[bufferKey]*bufferEntry
|
||||
logger *zap.Logger
|
||||
flushTimeout time.Duration
|
||||
// inflight tracks dispatched flush goroutines so Shutdown can wait
|
||||
// for them to finish (or time out) before returning.
|
||||
inflight sync.WaitGroup
|
||||
}
|
||||
|
||||
// New creates an Aggregator. flushTimeout bounds the context passed to FlushFn.
|
||||
// 0 selects a sane default (60s).
|
||||
func New(logger *zap.Logger, flushTimeout time.Duration) *Aggregator {
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 60 * time.Second
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &Aggregator{
|
||||
buffers: map[bufferKey]*bufferEntry{},
|
||||
logger: logger.Named("aggregator"),
|
||||
flushTimeout: flushTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// BufferRequest carries everything needed to add an event.
|
||||
type BufferRequest struct {
|
||||
Namespace string
|
||||
FunctionID string
|
||||
TriggerID string
|
||||
WindowMs int
|
||||
MaxBatchSize int
|
||||
FlushFn FlushFn // simple public form; internally promoted to FlushFnWithReason
|
||||
Event Event
|
||||
}
|
||||
|
||||
// Buffer adds an event to the matching buffer. Returns immediately —
|
||||
// the function is invoked later, asynchronously, when the window or
|
||||
// size threshold fires.
|
||||
//
|
||||
// If WindowMs <= 0, this method panics with a programming-error message
|
||||
// to surface misuse: callers should not buffer non-aggregating triggers.
|
||||
func (a *Aggregator) Buffer(req BufferRequest) {
|
||||
if req.WindowMs <= 0 {
|
||||
// Aggregator should never be called for non-aggregating triggers.
|
||||
// Panicking here makes the caller bug obvious during development.
|
||||
panic("aggregator: Buffer called with WindowMs <= 0")
|
||||
}
|
||||
maxBatch := req.MaxBatchSize
|
||||
if maxBatch <= 0 {
|
||||
maxBatch = DefaultMaxBatchSize
|
||||
}
|
||||
|
||||
key := bufferKey{Namespace: req.Namespace, FunctionID: req.FunctionID, TriggerID: req.TriggerID}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
entry, ok := a.buffers[key]
|
||||
if !ok {
|
||||
// Promote the user-facing FlushFn into the reason-aware variant.
|
||||
// We capture req.FlushFn so subsequent Buffer calls keep using it.
|
||||
userFn := req.FlushFn
|
||||
entry = &bufferEntry{
|
||||
events: make([]Event, 0, maxBatch),
|
||||
windowMs: req.WindowMs,
|
||||
maxBatch: maxBatch,
|
||||
flushFn: func(ctx context.Context, payload []byte, reason FlushReason) {
|
||||
if userFn != nil {
|
||||
userFn(ctx, payload)
|
||||
}
|
||||
},
|
||||
}
|
||||
a.buffers[key] = entry
|
||||
}
|
||||
|
||||
entry.events = append(entry.events, req.Event)
|
||||
|
||||
// Start the window timer on the first event of a new window.
|
||||
if entry.timer == nil {
|
||||
// Capture key by value for the closure.
|
||||
k := key
|
||||
entry.timer = time.AfterFunc(time.Duration(entry.windowMs)*time.Millisecond, func() {
|
||||
a.flushByTimer(k)
|
||||
})
|
||||
}
|
||||
|
||||
// Size-triggered flush.
|
||||
if len(entry.events) >= entry.maxBatch {
|
||||
a.flushLocked(key, entry, FlushReasonSize)
|
||||
}
|
||||
}
|
||||
|
||||
// flushByTimer is invoked by time.AfterFunc; it acquires the lock then flushes.
|
||||
func (a *Aggregator) flushByTimer(key bufferKey) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
entry, ok := a.buffers[key]
|
||||
if !ok {
|
||||
// Buffer already flushed by size threshold and the bucket removed.
|
||||
return
|
||||
}
|
||||
if len(entry.events) == 0 {
|
||||
// Defensive: empty buffer — drop it so the map stays bounded.
|
||||
delete(a.buffers, key)
|
||||
return
|
||||
}
|
||||
a.flushLocked(key, entry, FlushReasonTimer)
|
||||
}
|
||||
|
||||
// flushLocked must be called with a.mu held. It snapshots the current
|
||||
// events, removes the bucket entry, then dispatches the flush in a
|
||||
// goroutine so the caller doesn't block on the function invocation.
|
||||
//
|
||||
// Removing the bucket on flush keeps the buffers map bounded over the
|
||||
// lifetime of the process. If a subsequent event arrives for the same
|
||||
// (namespace, function, trigger) tuple, Buffer recreates the entry.
|
||||
func (a *Aggregator) flushLocked(key bufferKey, entry *bufferEntry, reason FlushReason) {
|
||||
if entry.timer != nil {
|
||||
entry.timer.Stop()
|
||||
entry.timer = nil
|
||||
}
|
||||
if len(entry.events) == 0 {
|
||||
// Empty bucket — drop it so the map doesn't accumulate.
|
||||
delete(a.buffers, key)
|
||||
return
|
||||
}
|
||||
events := entry.events
|
||||
|
||||
payload, err := json.Marshal(BatchedPayload{Batched: true, Events: events})
|
||||
if err != nil {
|
||||
a.logger.Error("failed to marshal batched payload",
|
||||
zap.String("namespace", key.Namespace),
|
||||
zap.String("function_id", key.FunctionID),
|
||||
zap.String("trigger_id", key.TriggerID),
|
||||
zap.Int("batch_size", len(events)),
|
||||
zap.Error(err),
|
||||
)
|
||||
// Still drop the bucket — there's no point retrying with the same data.
|
||||
delete(a.buffers, key)
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Debug("aggregator flush",
|
||||
zap.String("namespace", key.Namespace),
|
||||
zap.String("function_id", key.FunctionID),
|
||||
zap.String("trigger_id", key.TriggerID),
|
||||
zap.Int("batch_size", len(events)),
|
||||
zap.String("reason", string(reason)),
|
||||
)
|
||||
|
||||
flushFn := entry.flushFn
|
||||
timeout := a.flushTimeout
|
||||
delete(a.buffers, key)
|
||||
|
||||
a.inflight.Add(1)
|
||||
go func() {
|
||||
defer a.inflight.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
flushFn(ctx, payload, reason)
|
||||
}()
|
||||
}
|
||||
|
||||
// Shutdown drains all non-empty buffers and waits for the resulting flush
|
||||
// invocations to finish, bounded by `wait`. Callers should pass a wait
|
||||
// long enough to cover one function invocation (e.g. 5–10 seconds) but
|
||||
// short enough that a misbehaving function can't delay process exit.
|
||||
//
|
||||
// Returns true if all in-flight flushes completed before the deadline,
|
||||
// false on timeout (in which case some events are effectively lost).
|
||||
func (a *Aggregator) Shutdown(wait time.Duration) bool {
|
||||
a.mu.Lock()
|
||||
keys := make([]bufferKey, 0, len(a.buffers))
|
||||
for k := range a.buffers {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
for _, k := range keys {
|
||||
entry := a.buffers[k]
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
if entry.timer != nil {
|
||||
entry.timer.Stop()
|
||||
entry.timer = nil
|
||||
}
|
||||
a.flushLocked(k, entry, FlushReasonShutdown)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if wait <= 0 {
|
||||
return true
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
a.inflight.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(wait):
|
||||
a.logger.Warn("aggregator shutdown timed out; some buffered events may be lost")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Stats reports the current number of buffered events across all keys.
|
||||
// Useful for metrics.
|
||||
func (a *Aggregator) Stats() (numBuffers, totalEvents int) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
numBuffers = len(a.buffers)
|
||||
for _, e := range a.buffers {
|
||||
totalEvents += len(e.events)
|
||||
}
|
||||
return
|
||||
}
|
||||
307
core/pkg/serverless/aggregator/aggregator_test.go
Normal file
307
core/pkg/serverless/aggregator/aggregator_test.go
Normal file
@ -0,0 +1,307 @@
|
||||
package aggregator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestBuffer_panics_on_zero_window(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic when WindowMs <= 0")
|
||||
}
|
||||
}()
|
||||
a := New(zap.NewNop(), time.Second)
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns",
|
||||
FunctionID: "fn",
|
||||
TriggerID: "tr",
|
||||
WindowMs: 0,
|
||||
FlushFn: func(ctx context.Context, payload []byte) {},
|
||||
Event: Event{Topic: "t"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuffer_flushes_on_timer(t *testing.T) {
|
||||
a := New(zap.NewNop(), 5*time.Second)
|
||||
|
||||
var (
|
||||
got []Event
|
||||
gotMu sync.Mutex
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
flush := func(ctx context.Context, payload []byte) {
|
||||
var p BatchedPayload
|
||||
if err := json.Unmarshal(payload, &p); err != nil {
|
||||
t.Errorf("unmarshal: %v", err)
|
||||
}
|
||||
gotMu.Lock()
|
||||
got = append(got, p.Events...)
|
||||
gotMu.Unlock()
|
||||
close(done)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns",
|
||||
FunctionID: "fn",
|
||||
TriggerID: "tr",
|
||||
WindowMs: 100, // short window so test runs fast
|
||||
FlushFn: flush,
|
||||
Event: Event{Topic: "presence:user", Data: json.RawMessage(`"e"`)},
|
||||
})
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("flush did not fire within 2s")
|
||||
}
|
||||
|
||||
gotMu.Lock()
|
||||
defer gotMu.Unlock()
|
||||
if len(got) != 3 {
|
||||
t.Errorf("expected 3 buffered events, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_flushes_on_max_batch_size(t *testing.T) {
|
||||
a := New(zap.NewNop(), 5*time.Second)
|
||||
|
||||
var (
|
||||
flushCount int32
|
||||
flushSize int32
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
flush := func(ctx context.Context, payload []byte) {
|
||||
var p BatchedPayload
|
||||
_ = json.Unmarshal(payload, &p)
|
||||
atomic.AddInt32(&flushCount, 1)
|
||||
atomic.StoreInt32(&flushSize, int32(len(p.Events)))
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns",
|
||||
FunctionID: "fn",
|
||||
TriggerID: "tr",
|
||||
WindowMs: 30_000, // long enough that the timer won't fire
|
||||
MaxBatchSize: 5,
|
||||
FlushFn: flush,
|
||||
Event: Event{Topic: "t"},
|
||||
})
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("max-batch flush did not fire")
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&flushCount) != 1 {
|
||||
t.Errorf("expected 1 flush, got %d", flushCount)
|
||||
}
|
||||
if atomic.LoadInt32(&flushSize) != 5 {
|
||||
t.Errorf("expected batch size 5, got %d", flushSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_separate_keys_independent(t *testing.T) {
|
||||
a := New(zap.NewNop(), 5*time.Second)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
counts = map[string]int{}
|
||||
flush = func(name string) FlushFn {
|
||||
return func(ctx context.Context, payload []byte) {
|
||||
var p BatchedPayload
|
||||
_ = json.Unmarshal(payload, &p)
|
||||
mu.Lock()
|
||||
counts[name] += len(p.Events)
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
|
||||
WindowMs: 100, FlushFn: flush("A"),
|
||||
Event: Event{Topic: "a"},
|
||||
})
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-B",
|
||||
WindowMs: 100, FlushFn: flush("B"),
|
||||
Event: Event{Topic: "b"},
|
||||
})
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
|
||||
WindowMs: 100, FlushFn: flush("A"),
|
||||
Event: Event{Topic: "a2"},
|
||||
})
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if counts["A"] != 2 {
|
||||
t.Errorf("A: expected 2 events, got %d", counts["A"])
|
||||
}
|
||||
if counts["B"] != 1 {
|
||||
t.Errorf("B: expected 1 event, got %d", counts["B"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown_flushes_all_buffers(t *testing.T) {
|
||||
a := New(zap.NewNop(), 2*time.Second)
|
||||
|
||||
var flushed int32
|
||||
flush := func(ctx context.Context, payload []byte) {
|
||||
atomic.AddInt32(&flushed, 1)
|
||||
}
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
|
||||
WindowMs: 30_000,
|
||||
FlushFn: flush,
|
||||
Event: Event{Topic: "t"},
|
||||
})
|
||||
}
|
||||
// Different trigger key — should produce a separate flush.
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "other",
|
||||
WindowMs: 30_000,
|
||||
FlushFn: flush,
|
||||
Event: Event{Topic: "t2"},
|
||||
})
|
||||
|
||||
a.Shutdown(2*time.Second)
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&flushed) < 2 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected 2 flushes from Shutdown, got %d", flushed)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown_skips_empty_buffers(t *testing.T) {
|
||||
a := New(zap.NewNop(), 2*time.Second)
|
||||
|
||||
var flushed int32
|
||||
flush := func(ctx context.Context, payload []byte) {
|
||||
atomic.AddInt32(&flushed, 1)
|
||||
}
|
||||
|
||||
// Add an event to create the buffer entry, then drain via size flush.
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
|
||||
WindowMs: 30_000, MaxBatchSize: 1,
|
||||
FlushFn: flush, Event: Event{Topic: "t"},
|
||||
})
|
||||
|
||||
// Wait for the size-triggered flush to drain.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&flushed) < 1 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("size flush didn't fire")
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now the buffer is empty. Shutdown should not flush again.
|
||||
a.Shutdown(2*time.Second)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if atomic.LoadInt32(&flushed) != 1 {
|
||||
t.Errorf("Shutdown flushed an empty buffer: total flushes %d", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStats_reports_buffered_state(t *testing.T) {
|
||||
a := New(zap.NewNop(), 2*time.Second)
|
||||
flush := func(ctx context.Context, payload []byte) {}
|
||||
|
||||
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
|
||||
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
|
||||
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "b", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
|
||||
|
||||
bufs, evs := a.Stats()
|
||||
if bufs != 2 {
|
||||
t.Errorf("expected 2 buffers, got %d", bufs)
|
||||
}
|
||||
if evs != 3 {
|
||||
t.Errorf("expected 3 buffered events, got %d", evs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_concurrent_writes_no_race(t *testing.T) {
|
||||
// Run with -race: this should not detect any data races.
|
||||
a := New(zap.NewNop(), 2*time.Second)
|
||||
flush := func(ctx context.Context, payload []byte) {}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 8; g++ {
|
||||
wg.Add(1)
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 50; i++ {
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns",
|
||||
FunctionID: "fn",
|
||||
TriggerID: "tr",
|
||||
WindowMs: 200,
|
||||
FlushFn: flush,
|
||||
Event: Event{Topic: "t"},
|
||||
})
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
// Drain.
|
||||
a.Shutdown(2*time.Second)
|
||||
}
|
||||
|
||||
func TestBuffer_payload_includes_batched_true_and_topic(t *testing.T) {
|
||||
a := New(zap.NewNop(), 2*time.Second)
|
||||
|
||||
got := make(chan BatchedPayload, 1)
|
||||
flush := func(ctx context.Context, payload []byte) {
|
||||
var p BatchedPayload
|
||||
if err := json.Unmarshal(payload, &p); err != nil {
|
||||
t.Errorf("unmarshal: %v", err)
|
||||
}
|
||||
got <- p
|
||||
}
|
||||
|
||||
a.Buffer(BufferRequest{
|
||||
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
|
||||
WindowMs: 50, FlushFn: flush,
|
||||
Event: Event{Topic: "presence:user-1", Data: json.RawMessage(`{"x":1}`)},
|
||||
})
|
||||
|
||||
select {
|
||||
case p := <-got:
|
||||
if !p.Batched {
|
||||
t.Error("payload should have Batched=true")
|
||||
}
|
||||
if len(p.Events) != 1 || p.Events[0].Topic != "presence:user-1" {
|
||||
t.Errorf("unexpected events: %+v", p.Events)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("flush did not fire")
|
||||
}
|
||||
}
|
||||
@ -328,6 +328,8 @@ func (e *Engine) registerHostModule(ctx context.Context) error {
|
||||
NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by").
|
||||
NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch").
|
||||
NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish").
|
||||
NewFunctionBuilder().WithFunc(e.hPubSubPublishBatch).Export("pubsub_publish_batch").
|
||||
NewFunctionBuilder().WithFunc(e.hPushSend).Export("push_send").
|
||||
NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info").
|
||||
NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error").
|
||||
Instantiate(ctx)
|
||||
@ -517,6 +519,46 @@ func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, t
|
||||
return 1 // Success
|
||||
}
|
||||
|
||||
// hPubSubPublishBatch is the WASM-callable wrapper for PubSubPublishBatch.
|
||||
// Input: pointer/length of a JSON array of {topic, data_base64}.
|
||||
// Returns 1 on success, 0 on error.
|
||||
func (e *Engine) hPubSubPublishBatch(ctx context.Context, mod api.Module, msgsPtr, msgsLen uint32) uint32 {
|
||||
msgsJSON, ok := e.executor.ReadFromGuest(mod, msgsPtr, msgsLen)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
if err := e.hostServices.PubSubPublishBatch(ctx, msgsJSON); err != nil {
|
||||
e.logger.Error("host function pubsub_publish_batch failed", zap.Error(err))
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// hPushSend is the WASM-callable wrapper for PushSend.
|
||||
// Inputs:
|
||||
// userIDPtr/userIDLen — UTF-8 user ID to push to (within the function's
|
||||
// own namespace; the namespace is server-side trusted)
|
||||
// msgPtr/msgLen — JSON payload matching hostfunctions.PushSendArgs
|
||||
// Returns 1 on success, 0 on error.
|
||||
func (e *Engine) hPushSend(ctx context.Context, mod api.Module,
|
||||
userIDPtr, userIDLen, msgPtr, msgLen uint32) uint32 {
|
||||
userID, ok := e.executor.ReadFromGuest(mod, userIDPtr, userIDLen)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
msgJSON, ok := e.executor.ReadFromGuest(mod, msgPtr, msgLen)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
if err := e.hostServices.PushSend(ctx, string(userID), msgJSON); err != nil {
|
||||
e.logger.Error("host function push_send failed",
|
||||
zap.String("user_id", string(userID)),
|
||||
zap.Error(err))
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) {
|
||||
msg, ok := e.executor.ReadFromGuest(mod, ptr, size)
|
||||
if ok {
|
||||
|
||||
@ -90,6 +90,14 @@ func (m *mockHostServices) PubSubPublish(ctx context.Context, topic string, data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
||||
@ -13,6 +14,10 @@ import (
|
||||
)
|
||||
|
||||
// NewHostFunctions creates a new HostFunctions instance.
|
||||
//
|
||||
// pushDispatcher may be nil when push isn't configured on this gateway —
|
||||
// in that case PushSend hostfunc returns nil (silent no-op) so functions
|
||||
// remain portable across deployments with/without push.
|
||||
func NewHostFunctions(
|
||||
db rqlite.Client,
|
||||
cacheClient olriclib.Client,
|
||||
@ -20,6 +25,7 @@ func NewHostFunctions(
|
||||
pubsubAdapter *pubsub.ClientAdapter,
|
||||
wsManager serverless.WebSocketManager,
|
||||
secrets serverless.SecretsManager,
|
||||
pushDispatcher *push.PushDispatcher,
|
||||
cfg HostFunctionsConfig,
|
||||
logger *zap.Logger,
|
||||
) *HostFunctions {
|
||||
@ -29,15 +35,16 @@ func NewHostFunctions(
|
||||
}
|
||||
|
||||
return &HostFunctions{
|
||||
db: db,
|
||||
cacheClient: cacheClient,
|
||||
storage: storage,
|
||||
ipfsAPIURL: cfg.IPFSAPIURL,
|
||||
pubsub: pubsubAdapter,
|
||||
wsManager: wsManager,
|
||||
secrets: secrets,
|
||||
httpClient: tlsutil.NewHTTPClient(httpTimeout),
|
||||
logger: logger,
|
||||
logs: make([]serverless.LogEntry, 0),
|
||||
db: db,
|
||||
cacheClient: cacheClient,
|
||||
storage: storage,
|
||||
ipfsAPIURL: cfg.IPFSAPIURL,
|
||||
pubsub: pubsubAdapter,
|
||||
wsManager: wsManager,
|
||||
secrets: secrets,
|
||||
pushDispatcher: pushDispatcher,
|
||||
httpClient: tlsutil.NewHTTPClient(httpTimeout),
|
||||
logger: logger,
|
||||
logs: make([]serverless.LogEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,8 +2,11 @@ package hostfunctions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
)
|
||||
|
||||
@ -21,6 +24,64 @@ func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data []
|
||||
return nil
|
||||
}
|
||||
|
||||
// pubSubBatchEntry mirrors the JSON shape accepted by PubSubPublishBatch.
|
||||
type pubSubBatchEntry struct {
|
||||
Topic string `json:"topic"`
|
||||
DataB64 string `json:"data_base64"`
|
||||
}
|
||||
|
||||
// PubSubPublishBatch publishes multiple messages in parallel.
|
||||
//
|
||||
// Input is JSON: [{"topic":"...","data_base64":"..."}, ...]
|
||||
// Up to pubsub.MaxBatchSize entries per call.
|
||||
//
|
||||
// Default behavior is fail-fast (first publish error is returned). The
|
||||
// host function does not currently expose a best-effort flag — WASM
|
||||
// callers that need it should call this function multiple times in
|
||||
// chunks they're willing to retry independently.
|
||||
func (h *HostFunctions) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
|
||||
if h.pubsub == nil {
|
||||
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("pubsub not available")}
|
||||
}
|
||||
|
||||
var entries []pubSubBatchEntry
|
||||
if err := json.Unmarshal(msgsJSON, &entries); err != nil {
|
||||
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("invalid json: %w", err)}
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(entries) > pubsub.MaxBatchSize {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "pubsub_publish_batch",
|
||||
Cause: fmt.Errorf("too many messages: max %d per batch", pubsub.MaxBatchSize),
|
||||
}
|
||||
}
|
||||
|
||||
msgs := make([]pubsub.TopicMessage, 0, len(entries))
|
||||
for i, e := range entries {
|
||||
if e.Topic == "" {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "pubsub_publish_batch",
|
||||
Cause: fmt.Errorf("entry %d: empty topic", i),
|
||||
}
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(e.DataB64)
|
||||
if err != nil {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "pubsub_publish_batch",
|
||||
Cause: fmt.Errorf("entry %d (topic %q): bad base64: %w", i, e.Topic, err),
|
||||
}
|
||||
}
|
||||
msgs = append(msgs, pubsub.TopicMessage{Topic: e.Topic, Data: data})
|
||||
}
|
||||
|
||||
if err := h.pubsub.PublishBatch(ctx, msgs, pubsub.PublishBatchOptions{}); err != nil {
|
||||
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WSSend sends data to a specific WebSocket client.
|
||||
func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error {
|
||||
if h.wsManager == nil {
|
||||
|
||||
103
core/pkg/serverless/hostfunctions/push.go
Normal file
103
core/pkg/serverless/hostfunctions/push.go
Normal file
@ -0,0 +1,103 @@
|
||||
package hostfunctions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
)
|
||||
|
||||
// PushSendArgs is the JSON payload format the WASM caller marshals into
|
||||
// the `msgJSON` argument of PushSend. Mirrors push.PushMessage minus the
|
||||
// device-token (which is filled in per-device by the dispatcher).
|
||||
type PushSendArgs struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Body string `json:"body,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Priority string `json:"priority,omitempty"` // "high" | "normal" | ""
|
||||
Badge int `json:"badge,omitempty"`
|
||||
Sound string `json:"sound,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// MaxPushSendArgsBytes caps the JSON arg size to a few KB. Push payloads
|
||||
// are small by construction (APNs caps at 4KB, ntfy/Expo similar).
|
||||
const MaxPushSendArgsBytes = 16 * 1024
|
||||
|
||||
// PushSend implements serverless.HostServices.PushSend.
|
||||
//
|
||||
// Sends a push notification to all devices the user has registered in the
|
||||
// function's namespace. The caller can only target users in their own
|
||||
// namespace — the dispatcher reads the namespace from the invocation
|
||||
// context (set by the engine before invoking).
|
||||
//
|
||||
// If push is not configured on this gateway (no dispatcher), this returns
|
||||
// nil (silent no-op) so functions remain portable across environments.
|
||||
func (h *HostFunctions) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
|
||||
if h.pushDispatcher == nil {
|
||||
// Silent no-op — push isn't configured on this gateway.
|
||||
return nil
|
||||
}
|
||||
if userID == "" {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "push_send",
|
||||
Cause: fmt.Errorf("user_id required"),
|
||||
}
|
||||
}
|
||||
if len(msgJSON) > MaxPushSendArgsBytes {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "push_send",
|
||||
Cause: fmt.Errorf("msg too large: max %d bytes", MaxPushSendArgsBytes),
|
||||
}
|
||||
}
|
||||
|
||||
var args PushSendArgs
|
||||
if err := json.Unmarshal(msgJSON, &args); err != nil {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "push_send",
|
||||
Cause: fmt.Errorf("invalid json: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve namespace from the current invocation context. A function
|
||||
// can NEVER push to another namespace's users — the namespace is
|
||||
// trusted server-side, not from the WASM input.
|
||||
h.invCtxLock.RLock()
|
||||
var namespace string
|
||||
if h.invCtx != nil {
|
||||
namespace = h.invCtx.Namespace
|
||||
}
|
||||
h.invCtxLock.RUnlock()
|
||||
|
||||
if namespace == "" {
|
||||
return &serverless.HostFunctionError{
|
||||
Function: "push_send",
|
||||
Cause: fmt.Errorf("no namespace in invocation context"),
|
||||
}
|
||||
}
|
||||
|
||||
priority := push.PriorityNormal
|
||||
switch args.Priority {
|
||||
case "high":
|
||||
priority = push.PriorityHigh
|
||||
case "normal", "":
|
||||
priority = push.PriorityNormal
|
||||
}
|
||||
|
||||
msg := push.PushMessage{
|
||||
Title: args.Title,
|
||||
Body: args.Body,
|
||||
Channel: args.Channel,
|
||||
Priority: priority,
|
||||
Badge: args.Badge,
|
||||
Sound: args.Sound,
|
||||
Data: args.Data,
|
||||
}
|
||||
|
||||
if err := h.pushDispatcher.SendToUser(ctx, namespace, userID, msg); err != nil {
|
||||
return &serverless.HostFunctionError{Function: "push_send", Cause: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
olriclib "github.com/olric-data/olric"
|
||||
@ -32,6 +33,10 @@ type HostFunctions struct {
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
|
||||
// pushDispatcher may be nil when push isn't configured for this gateway.
|
||||
// In that case PushSend returns nil silently — see hostfunctions/push.go.
|
||||
pushDispatcher *push.PushDispatcher
|
||||
|
||||
// Current invocation context (set per-execution)
|
||||
invCtx *serverless.InvocationContext
|
||||
invCtxLock sync.RWMutex
|
||||
|
||||
@ -176,6 +176,14 @@ func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -6,14 +6,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless/aggregator"
|
||||
olriclib "github.com/olric-data/olric"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// triggerCacheDMap is the Olric DMap name for caching trigger lookups.
|
||||
triggerCacheDMap = "pubsub_triggers"
|
||||
|
||||
// maxTriggerDepth prevents infinite loops when triggered functions publish
|
||||
// back to the same topic via the HTTP API.
|
||||
maxTriggerDepth = 5
|
||||
@ -37,6 +35,7 @@ type PubSubDispatcher struct {
|
||||
store *PubSubTriggerStore
|
||||
invoker *serverless.Invoker
|
||||
olricClient olriclib.Client // may be nil (cache disabled)
|
||||
aggregator *aggregator.Aggregator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
@ -51,10 +50,17 @@ func NewPubSubDispatcher(
|
||||
store: store,
|
||||
invoker: invoker,
|
||||
olricClient: olricClient,
|
||||
aggregator: aggregator.New(logger, dispatchTimeout),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregator exposes the underlying aggregator so callers (gateway lifecycle)
|
||||
// can flush pending buffers on shutdown.
|
||||
func (d *PubSubDispatcher) Aggregator() *aggregator.Aggregator {
|
||||
return d.aggregator
|
||||
}
|
||||
|
||||
// Dispatch looks up all triggers registered for the given topic+namespace and
|
||||
// invokes matching functions asynchronously. Each invocation runs in its own
|
||||
// goroutine and does not block the caller.
|
||||
@ -82,7 +88,7 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
|
||||
return
|
||||
}
|
||||
|
||||
// Build the event payload once for all invocations
|
||||
// Build the per-event payload once for non-aggregating dispatches.
|
||||
event := PubSubEvent{
|
||||
Topic: topic,
|
||||
Data: json.RawMessage(data),
|
||||
@ -90,11 +96,6 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
|
||||
TriggerDepth: depth + 1,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
d.logger.Error("Failed to marshal PubSub event", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
d.logger.Debug("Dispatching PubSub triggers",
|
||||
zap.String("namespace", namespace),
|
||||
@ -103,94 +104,81 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
|
||||
zap.Int("depth", depth),
|
||||
)
|
||||
|
||||
var (
|
||||
eventJSON []byte
|
||||
marshalErr error
|
||||
)
|
||||
|
||||
for _, match := range matches {
|
||||
if match.AggregationWindowMs > 0 {
|
||||
d.bufferEvent(match, event)
|
||||
continue
|
||||
}
|
||||
// Lazily marshal — non-aggregating triggers need eventJSON.
|
||||
if eventJSON == nil && marshalErr == nil {
|
||||
eventJSON, marshalErr = json.Marshal(event)
|
||||
if marshalErr != nil {
|
||||
d.logger.Error("Failed to marshal PubSub event", zap.Error(marshalErr))
|
||||
continue
|
||||
}
|
||||
}
|
||||
if marshalErr != nil {
|
||||
continue
|
||||
}
|
||||
go d.invokeFunction(match, eventJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache removes the cached trigger lookup for a namespace+topic.
|
||||
// Call this when triggers are added or removed.
|
||||
func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {
|
||||
if d.olricClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
|
||||
if err != nil {
|
||||
d.logger.Debug("Failed to get trigger cache DMap for invalidation", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
key := cacheKey(namespace, topic)
|
||||
if _, err := dm.Delete(ctx, key); err != nil {
|
||||
d.logger.Debug("Failed to invalidate trigger cache", zap.String("key", key), zap.Error(err))
|
||||
}
|
||||
// bufferEvent routes an event through the aggregator. The flush callback
|
||||
// invokes the function with the batched payload.
|
||||
func (d *PubSubDispatcher) bufferEvent(match TriggerMatch, event PubSubEvent) {
|
||||
d.aggregator.Buffer(aggregator.BufferRequest{
|
||||
Namespace: match.Namespace,
|
||||
FunctionID: match.FunctionID,
|
||||
TriggerID: match.TriggerID,
|
||||
WindowMs: match.AggregationWindowMs,
|
||||
MaxBatchSize: match.AggregationMaxBatchSize,
|
||||
Event: aggregator.Event{
|
||||
Topic: event.Topic,
|
||||
Data: event.Data,
|
||||
Namespace: event.Namespace,
|
||||
TriggerDepth: event.TriggerDepth,
|
||||
Timestamp: event.Timestamp,
|
||||
},
|
||||
FlushFn: func(ctx context.Context, payload []byte) {
|
||||
req := &serverless.InvokeRequest{
|
||||
Namespace: match.Namespace,
|
||||
FunctionName: match.FunctionName,
|
||||
Input: payload,
|
||||
TriggerType: serverless.TriggerTypePubSub,
|
||||
}
|
||||
if _, err := d.invoker.Invoke(ctx, req); err != nil {
|
||||
d.logger.Warn("Aggregated PubSub invocation failed",
|
||||
zap.String("function", match.FunctionName),
|
||||
zap.String("trigger_id", match.TriggerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// getMatches returns the trigger matches for a topic+namespace, using Olric cache when available.
|
||||
// InvalidateCache is now a no-op — the dispatcher no longer caches lookups.
|
||||
// Kept on the type so callers who used it still compile.
|
||||
func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {}
|
||||
|
||||
// getMatches returns the trigger matches for a topic+namespace.
|
||||
//
|
||||
// Caching note: an earlier revision cached results in Olric keyed by
|
||||
// (namespace, topic). With wildcard triggers the cache becomes
|
||||
// inconsistent — a single trigger Add/Remove invalidates an unbounded
|
||||
// number of resolved-topic keys. The cache was removed; re-introducing
|
||||
// it requires a generation-counter (or TTL) scheme that handles
|
||||
// wildcard pattern changes.
|
||||
func (d *PubSubDispatcher) getMatches(ctx context.Context, namespace, topic string) ([]TriggerMatch, error) {
|
||||
// Try cache first
|
||||
if d.olricClient != nil {
|
||||
if matches, ok := d.getCached(ctx, namespace, topic); ok {
|
||||
return matches, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss — query database
|
||||
matches, err := d.store.GetByTopicAndNamespace(ctx, topic, namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
if d.olricClient != nil && matches != nil {
|
||||
d.setCache(ctx, namespace, topic, matches)
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
return d.store.GetByTopicAndNamespace(ctx, topic, namespace)
|
||||
}
|
||||
|
||||
// getCached attempts to retrieve trigger matches from Olric cache.
|
||||
func (d *PubSubDispatcher) getCached(ctx context.Context, namespace, topic string) ([]TriggerMatch, bool) {
|
||||
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
key := cacheKey(namespace, topic)
|
||||
result, err := dm.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
data, err := result.Byte()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var matches []TriggerMatch
|
||||
if err := json.Unmarshal(data, &matches); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return matches, true
|
||||
}
|
||||
|
||||
// setCache stores trigger matches in Olric cache.
|
||||
func (d *PubSubDispatcher) setCache(ctx context.Context, namespace, topic string, matches []TriggerMatch) {
|
||||
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(matches)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := cacheKey(namespace, topic)
|
||||
_ = dm.Put(ctx, key, data)
|
||||
}
|
||||
|
||||
// invokeFunction invokes a single function for a trigger match.
|
||||
func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) {
|
||||
@ -224,7 +212,3 @@ func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte)
|
||||
)
|
||||
}
|
||||
|
||||
// cacheKey returns the Olric cache key for a namespace+topic pair.
|
||||
func cacheKey(namespace, topic string) string {
|
||||
return "triggers:" + namespace + ":" + topic
|
||||
}
|
||||
|
||||
138
core/pkg/serverless/triggers/pattern.go
Normal file
138
core/pkg/serverless/triggers/pattern.go
Normal file
@ -0,0 +1,138 @@
|
||||
package triggers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MaxPatternLength is the maximum allowed glob pattern length.
|
||||
const MaxPatternLength = 256
|
||||
|
||||
// ValidatePattern checks that a glob pattern is well-formed.
|
||||
// Empty patterns and unbalanced character classes return an error.
|
||||
// Patterns longer than MaxPatternLength are rejected to keep DB scans bounded.
|
||||
func ValidatePattern(p string) error {
|
||||
if p == "" {
|
||||
return fmt.Errorf("empty pattern")
|
||||
}
|
||||
if len(p) > MaxPatternLength {
|
||||
return fmt.Errorf("pattern too long: %d > %d", len(p), MaxPatternLength)
|
||||
}
|
||||
open := 0
|
||||
for i, c := range p {
|
||||
switch c {
|
||||
case '[':
|
||||
open++
|
||||
case ']':
|
||||
open--
|
||||
if open < 0 {
|
||||
return fmt.Errorf("unmatched ']' at position %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
if open != 0 {
|
||||
return fmt.Errorf("unmatched '['")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsWildcard reports whether the pattern contains any glob metacharacter.
|
||||
// Useful for choosing between exact-match cache keys and wildcard scans.
|
||||
func IsWildcard(p string) bool {
|
||||
return strings.ContainsAny(p, "*?[")
|
||||
}
|
||||
|
||||
// PatternMatches returns true when topic matches pattern under Orama's glob
|
||||
// semantics:
|
||||
// - '*' matches zero or more characters EXCEPT ':'
|
||||
// - '**' matches zero or more characters INCLUDING ':' (deep wildcard)
|
||||
// - '?' matches exactly one character (any)
|
||||
// - '[abc]' / '[!abc]' character classes
|
||||
//
|
||||
// SQLite's GLOB is the first-pass filter (in pubsub_store.go); this
|
||||
// post-filter enforces segment boundaries for single-'*' patterns since
|
||||
// SQLite GLOB treats '*' as "any chars including separators".
|
||||
func PatternMatches(pattern, topic string) bool {
|
||||
if strings.Contains(pattern, "**") {
|
||||
// Deep wildcards already accept across segment boundaries — SQLite GLOB
|
||||
// already accepted this row. No further filtering needed.
|
||||
return true
|
||||
}
|
||||
return strictGlobMatch(pattern, topic)
|
||||
}
|
||||
|
||||
// strictGlobMatch implements glob matching where '*' does NOT cross ':'.
|
||||
// Recursive backtracking matcher; bounded length keeps it cheap.
|
||||
func strictGlobMatch(pattern, s string) bool {
|
||||
pi, si := 0, 0
|
||||
starPi, starSi := -1, -1
|
||||
|
||||
for si < len(s) {
|
||||
if pi < len(pattern) {
|
||||
pc := pattern[pi]
|
||||
switch pc {
|
||||
case '?':
|
||||
pi++
|
||||
si++
|
||||
continue
|
||||
case '*':
|
||||
// Remember position so we can backtrack.
|
||||
starPi = pi
|
||||
starSi = si
|
||||
pi++
|
||||
continue
|
||||
case '[':
|
||||
end := strings.IndexByte(pattern[pi+1:], ']')
|
||||
if end < 0 {
|
||||
return false
|
||||
}
|
||||
class := pattern[pi+1 : pi+1+end]
|
||||
if matchClass(class, s[si]) {
|
||||
pi += end + 2
|
||||
si++
|
||||
continue
|
||||
}
|
||||
default:
|
||||
if pc == s[si] {
|
||||
pi++
|
||||
si++
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
// No match at this position — try to extend the last '*' if any,
|
||||
// but '*' must not cross a ':' segment separator.
|
||||
if starPi >= 0 && s[starSi] != ':' {
|
||||
starSi++
|
||||
pi = starPi + 1
|
||||
si = starSi
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Consume any trailing '*' in the pattern.
|
||||
for pi < len(pattern) && pattern[pi] == '*' {
|
||||
pi++
|
||||
}
|
||||
return pi == len(pattern)
|
||||
}
|
||||
|
||||
// matchClass reports whether c matches the SQLite-style character class body
|
||||
// (between '[' and ']'). Supports negation with leading '!'.
|
||||
func matchClass(class string, c byte) bool {
|
||||
if class == "" {
|
||||
return false
|
||||
}
|
||||
negate := false
|
||||
if class[0] == '!' {
|
||||
negate = true
|
||||
class = class[1:]
|
||||
}
|
||||
for i := 0; i < len(class); i++ {
|
||||
if class[i] == c {
|
||||
return !negate
|
||||
}
|
||||
}
|
||||
return negate
|
||||
}
|
||||
162
core/pkg/serverless/triggers/pattern_test.go
Normal file
162
core/pkg/serverless/triggers/pattern_test.go
Normal file
@ -0,0 +1,162 @@
|
||||
package triggers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidatePattern_empty_returns_error(t *testing.T) {
|
||||
if err := ValidatePattern(""); err == nil {
|
||||
t.Error("expected error for empty pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePattern_too_long_returns_error(t *testing.T) {
|
||||
long := strings.Repeat("a", MaxPatternLength+1)
|
||||
if err := ValidatePattern(long); err == nil {
|
||||
t.Error("expected error for over-long pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePattern_unbalanced_brackets_returns_error(t *testing.T) {
|
||||
cases := []string{"a[b", "a]b", "[a[b]", "a]"}
|
||||
for _, c := range cases {
|
||||
if err := ValidatePattern(c); err == nil {
|
||||
t.Errorf("expected error for %q", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePattern_valid_patterns_no_error(t *testing.T) {
|
||||
cases := []string{"foo", "foo:*", "foo:**", "*.bar", "[abc]xyz", "[!a]b", "?abc"}
|
||||
for _, c := range cases {
|
||||
if err := ValidatePattern(c); err != nil {
|
||||
t.Errorf("expected no error for %q, got: %v", c, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWildcard(t *testing.T) {
|
||||
cases := map[string]bool{
|
||||
"foo": false,
|
||||
"foo:bar": false,
|
||||
"foo:*": true,
|
||||
"foo?bar": true,
|
||||
"[abc]xyz": true,
|
||||
"foo:**": true,
|
||||
"a:b:c:d:e:f": false,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := IsWildcard(in); got != want {
|
||||
t.Errorf("IsWildcard(%q) = %v, want %v", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_exact(t *testing.T) {
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
{"foo", "foo", true},
|
||||
{"foo", "bar", false},
|
||||
{"foo:bar", "foo:bar", true},
|
||||
{"foo:bar", "foo:baz", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_single_star_segment_bounded(t *testing.T) {
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
// '*' matches within a single segment
|
||||
{"presence:*", "presence:user-1", true},
|
||||
{"presence:*", "presence:user-2", true},
|
||||
{"presence:*", "presence:", true},
|
||||
// '*' does NOT cross ':'
|
||||
{"presence:*", "presence:user:device", false},
|
||||
{"a:*:b", "a:x:b", true},
|
||||
{"a:*:b", "a:x:y:b", false},
|
||||
// Different prefix
|
||||
{"presence:*", "calls:invite", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_double_star_crosses_segments(t *testing.T) {
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
{"notify:**", "notify:user-1", true},
|
||||
{"notify:**", "notify:user:device:1", true},
|
||||
{"**", "anything:goes:here", true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_question_mark(t *testing.T) {
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
{"a?c", "abc", true},
|
||||
{"a?c", "axc", true},
|
||||
{"a?c", "ac", false},
|
||||
{"a?c", "abbc", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_character_class(t *testing.T) {
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
{"[abc]xyz", "axyz", true},
|
||||
{"[abc]xyz", "bxyz", true},
|
||||
{"[abc]xyz", "dxyz", false},
|
||||
{"[!a]bc", "xbc", true},
|
||||
{"[!a]bc", "abc", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternMatches_trailing_star_with_remaining_chars(t *testing.T) {
|
||||
// '*' can match zero characters at end.
|
||||
cases := []struct {
|
||||
pattern, topic string
|
||||
want bool
|
||||
}{
|
||||
{"foo*", "foo", true},
|
||||
{"foo*", "foobar", true},
|
||||
{"foo*", "foobar:baz", false}, // ':' breaks single '*'
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := PatternMatches(c.pattern, c.topic); got != c.want {
|
||||
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -16,30 +16,43 @@ import (
|
||||
|
||||
// TriggerMatch contains the fields needed to dispatch a trigger invocation.
|
||||
// It's the result of JOINing function_pubsub_triggers with functions.
|
||||
//
|
||||
// Topic is the *resolved* topic that the published message was sent to,
|
||||
// not the pattern stored in the trigger. This lets aggregating functions
|
||||
// see which concrete topic each event came from.
|
||||
//
|
||||
// AggregationWindowMs > 0 indicates the dispatcher should buffer events
|
||||
// instead of invoking the function per event.
|
||||
type TriggerMatch struct {
|
||||
TriggerID string
|
||||
FunctionID string
|
||||
FunctionName string
|
||||
Namespace string
|
||||
Topic string
|
||||
TriggerID string
|
||||
FunctionID string
|
||||
FunctionName string
|
||||
Namespace string
|
||||
Topic string
|
||||
AggregationWindowMs int
|
||||
AggregationMaxBatchSize int
|
||||
}
|
||||
|
||||
// triggerRow maps to the function_pubsub_triggers table for query scanning.
|
||||
type triggerRow struct {
|
||||
ID string
|
||||
FunctionID string
|
||||
Topic string
|
||||
Enabled bool
|
||||
CreatedAt time.Time
|
||||
ID string
|
||||
FunctionID string
|
||||
TopicPattern string
|
||||
Enabled bool
|
||||
CreatedAt time.Time
|
||||
AggregationWindowMs int
|
||||
AggregationMaxBatchSize int
|
||||
}
|
||||
|
||||
// triggerMatchRow maps to the JOIN query result for scanning.
|
||||
type triggerMatchRow struct {
|
||||
TriggerID string
|
||||
FunctionID string
|
||||
FunctionName string
|
||||
Namespace string
|
||||
Topic string
|
||||
TriggerID string
|
||||
FunctionID string
|
||||
FunctionName string
|
||||
Namespace string
|
||||
TopicPattern string
|
||||
AggregationWindowMs int
|
||||
AggregationMaxBatchSize int
|
||||
}
|
||||
|
||||
// PubSubTriggerStore manages PubSub trigger persistence in RQLite.
|
||||
@ -57,30 +70,60 @@ func NewPubSubTriggerStore(db rqlite.Client, logger *zap.Logger) *PubSubTriggerS
|
||||
}
|
||||
|
||||
// Add registers a new PubSub trigger for a function.
|
||||
// `topicPattern` may be an exact topic or a SQLite GLOB pattern (e.g. "presence:*").
|
||||
// Returns the trigger ID.
|
||||
func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topic string) (string, error) {
|
||||
//
|
||||
// For backward compatibility, aggregation defaults to disabled (windowMs=0).
|
||||
// Use AddWithAggregation to opt in.
|
||||
func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topicPattern string) (string, error) {
|
||||
return s.AddWithAggregation(ctx, functionID, topicPattern, 0, 0)
|
||||
}
|
||||
|
||||
// AddWithAggregation registers a trigger with optional aggregation.
|
||||
// - aggregationWindowMs = 0 disables aggregation (per-event invocation, default).
|
||||
// - aggregationMaxBatchSize = 0 uses the default (100) when aggregation is enabled.
|
||||
func (s *PubSubTriggerStore) AddWithAggregation(
|
||||
ctx context.Context,
|
||||
functionID, topicPattern string,
|
||||
aggregationWindowMs, aggregationMaxBatchSize int,
|
||||
) (string, error) {
|
||||
if functionID == "" {
|
||||
return "", fmt.Errorf("function ID required")
|
||||
}
|
||||
if topic == "" {
|
||||
return "", fmt.Errorf("topic required")
|
||||
if err := ValidatePattern(topicPattern); err != nil {
|
||||
return "", fmt.Errorf("invalid topic pattern: %w", err)
|
||||
}
|
||||
if aggregationWindowMs < 0 || aggregationWindowMs > 60_000 {
|
||||
return "", fmt.Errorf("aggregation_window_ms must be between 0 and 60000")
|
||||
}
|
||||
if aggregationMaxBatchSize < 0 || aggregationMaxBatchSize > 1000 {
|
||||
return "", fmt.Errorf("aggregation_max_batch_size must be between 0 and 1000")
|
||||
}
|
||||
if aggregationWindowMs > 0 && aggregationMaxBatchSize == 0 {
|
||||
aggregationMaxBatchSize = 100
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
// Write both `topic` (legacy) and `topic_pattern` (new). Keeping `topic`
|
||||
// populated lets old binaries running concurrently during a rolling
|
||||
// upgrade continue reading triggers. A future migration drops `topic`.
|
||||
query := `
|
||||
INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at)
|
||||
VALUES (?, ?, ?, TRUE, ?)
|
||||
INSERT INTO function_pubsub_triggers (id, function_id, topic, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size)
|
||||
VALUES (?, ?, ?, ?, TRUE, ?, ?, ?)
|
||||
`
|
||||
if _, err := s.db.Exec(ctx, query, id, functionID, topic, now); err != nil {
|
||||
if _, err := s.db.Exec(ctx, query, id, functionID, topicPattern, topicPattern, now, aggregationWindowMs, aggregationMaxBatchSize); err != nil {
|
||||
return "", fmt.Errorf("failed to add pubsub trigger: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("PubSub trigger added",
|
||||
zap.String("trigger_id", id),
|
||||
zap.String("function_id", functionID),
|
||||
zap.String("topic", topic),
|
||||
zap.String("topic_pattern", topicPattern),
|
||||
zap.Bool("wildcard", IsWildcard(topicPattern)),
|
||||
zap.Int("aggregation_window_ms", aggregationWindowMs),
|
||||
zap.Int("aggregation_max_batch_size", aggregationMaxBatchSize),
|
||||
)
|
||||
|
||||
return id, nil
|
||||
@ -129,7 +172,7 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, function_id, topic, enabled, created_at
|
||||
SELECT id, function_id, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size
|
||||
FROM function_pubsub_triggers
|
||||
WHERE function_id = ?
|
||||
`
|
||||
@ -142,18 +185,22 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
|
||||
triggers := make([]serverless.PubSubTrigger, len(rows))
|
||||
for i, row := range rows {
|
||||
triggers[i] = serverless.PubSubTrigger{
|
||||
ID: row.ID,
|
||||
FunctionID: row.FunctionID,
|
||||
Topic: row.Topic,
|
||||
Enabled: row.Enabled,
|
||||
ID: row.ID,
|
||||
FunctionID: row.FunctionID,
|
||||
Topic: row.TopicPattern,
|
||||
Enabled: row.Enabled,
|
||||
AggregationWindowMs: row.AggregationWindowMs,
|
||||
AggregationMaxBatchSize: row.AggregationMaxBatchSize,
|
||||
}
|
||||
}
|
||||
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
// GetByTopicAndNamespace returns all enabled triggers for a topic within a namespace.
|
||||
// Only returns triggers for active functions.
|
||||
// GetByTopicAndNamespace returns all enabled triggers whose topic_pattern
|
||||
// matches `topic` within the namespace. Patterns are SQLite GLOB; the
|
||||
// post-filter enforces stricter segment-aware semantics.
|
||||
// Only triggers for active functions are returned.
|
||||
func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, namespace string) ([]TriggerMatch, error) {
|
||||
if topic == "" || namespace == "" {
|
||||
return nil, nil
|
||||
@ -161,10 +208,12 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
|
||||
|
||||
query := `
|
||||
SELECT t.id AS trigger_id, t.function_id AS function_id,
|
||||
f.name AS function_name, f.namespace AS namespace, t.topic AS topic
|
||||
f.name AS function_name, f.namespace AS namespace, t.topic_pattern AS topic_pattern,
|
||||
t.aggregation_window_ms AS aggregation_window_ms,
|
||||
t.aggregation_max_batch_size AS aggregation_max_batch_size
|
||||
FROM function_pubsub_triggers t
|
||||
JOIN functions f ON t.function_id = f.id
|
||||
WHERE t.topic = ? AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
|
||||
WHERE ? GLOB t.topic_pattern AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
|
||||
`
|
||||
|
||||
var rows []triggerMatchRow
|
||||
@ -172,15 +221,21 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
|
||||
return nil, fmt.Errorf("failed to query triggers for topic: %w", err)
|
||||
}
|
||||
|
||||
matches := make([]TriggerMatch, len(rows))
|
||||
for i, row := range rows {
|
||||
matches[i] = TriggerMatch{
|
||||
TriggerID: row.TriggerID,
|
||||
FunctionID: row.FunctionID,
|
||||
FunctionName: row.FunctionName,
|
||||
Namespace: row.Namespace,
|
||||
Topic: row.Topic,
|
||||
matches := make([]TriggerMatch, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
// Post-filter to enforce strict segment boundaries on '*'.
|
||||
if !PatternMatches(row.TopicPattern, topic) {
|
||||
continue
|
||||
}
|
||||
matches = append(matches, TriggerMatch{
|
||||
TriggerID: row.TriggerID,
|
||||
FunctionID: row.FunctionID,
|
||||
FunctionName: row.FunctionName,
|
||||
Namespace: row.Namespace,
|
||||
Topic: topic, // resolved topic, not the pattern
|
||||
AggregationWindowMs: row.AggregationWindowMs,
|
||||
AggregationMaxBatchSize: row.AggregationMaxBatchSize,
|
||||
})
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
|
||||
@ -40,13 +40,6 @@ func TestDispatcher_DepthLimit(t *testing.T) {
|
||||
d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth+1)
|
||||
}
|
||||
|
||||
func TestCacheKey(t *testing.T) {
|
||||
key := cacheKey("my-namespace", "my-topic")
|
||||
if key != "triggers:my-namespace:my-topic" {
|
||||
t.Errorf("unexpected cache key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSubEvent_Marshal(t *testing.T) {
|
||||
event := PubSubEvent{
|
||||
Topic: "chat",
|
||||
|
||||
@ -288,11 +288,21 @@ type DBTrigger struct {
|
||||
}
|
||||
|
||||
// PubSubTrigger represents a pubsub trigger.
|
||||
//
|
||||
// Topic may be an exact topic name or a SQLite GLOB pattern (e.g.
|
||||
// "presence:*"). See pkg/serverless/triggers/pattern.go for matching rules.
|
||||
//
|
||||
// AggregationWindowMs > 0 enables event buffering: the dispatcher accumulates
|
||||
// events for at most that many milliseconds (or until AggregationMaxBatchSize
|
||||
// events have been collected, whichever comes first), then invokes the
|
||||
// function once with a batched payload of type BatchedPubSubEvent.
|
||||
type PubSubTrigger struct {
|
||||
ID string `json:"id"`
|
||||
FunctionID string `json:"function_id"`
|
||||
Topic string `json:"topic"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ID string `json:"id"`
|
||||
FunctionID string `json:"function_id"`
|
||||
Topic string `json:"topic"`
|
||||
Enabled bool `json:"enabled"`
|
||||
AggregationWindowMs int `json:"aggregation_window_ms,omitempty"`
|
||||
AggregationMaxBatchSize int `json:"aggregation_max_batch_size,omitempty"`
|
||||
}
|
||||
|
||||
// Timer represents a one-time scheduled execution.
|
||||
@ -337,6 +347,14 @@ type HostServices interface {
|
||||
|
||||
// PubSub operations
|
||||
PubSubPublish(ctx context.Context, topic string, data []byte) error
|
||||
PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error
|
||||
|
||||
// Push notifications. Sends to all of `userID`'s registered devices in
|
||||
// the function's namespace. `msgJSON` is the JSON-encoded PushSendArgs
|
||||
// shape (see hostfunctions.PushSend). Returns nil if push is not
|
||||
// configured (silent no-op) so functions can be portable across
|
||||
// namespaces with/without push enabled.
|
||||
PushSend(ctx context.Context, userID string, msgJSON []byte) error
|
||||
|
||||
// WebSocket operations (only valid in WS context)
|
||||
WSSend(ctx context.Context, clientID string, data []byte) error
|
||||
|
||||
106
core/pkg/sniproxy/router.go
Normal file
106
core/pkg/sniproxy/router.go
Normal file
@ -0,0 +1,106 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Backend describes where to forward a connection.
|
||||
type Backend struct {
|
||||
// Name is for logs/metrics only. Optional.
|
||||
Name string
|
||||
// Network is the dial network ("tcp", "tcp4", "tcp6"). Default "tcp".
|
||||
Network string
|
||||
// Addr is the dial target ("127.0.0.1:5349").
|
||||
Addr string
|
||||
}
|
||||
|
||||
// Route maps an SNI value (or wildcard pattern) to a Backend.
|
||||
//
|
||||
// Match semantics:
|
||||
// - "example.com" matches exactly "example.com"
|
||||
// - "*.example.com" matches any single-label subdomain ("a.example.com"
|
||||
// but not "a.b.example.com" — single-label like DNS wildcards)
|
||||
type Route struct {
|
||||
Match string
|
||||
Backend Backend
|
||||
}
|
||||
|
||||
// Router atomically swaps a routing table while concurrent reads are in
|
||||
// flight. Reads are lock-free after the slice is published.
|
||||
type Router struct {
|
||||
mu sync.RWMutex
|
||||
routes []Route
|
||||
fallback Backend
|
||||
}
|
||||
|
||||
// NewRouter creates a router with no routes and the given fallback.
|
||||
func NewRouter(fallback Backend) *Router {
|
||||
return &Router{fallback: fallback}
|
||||
}
|
||||
|
||||
// Pick returns the matching backend for an SNI value, or the fallback if
|
||||
// no route matches (or if sni is empty).
|
||||
func (r *Router) Pick(sni string) Backend {
|
||||
if sni == "" {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.fallback
|
||||
}
|
||||
sni = strings.ToLower(sni)
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
for _, route := range r.routes {
|
||||
if matchSNI(route.Match, sni) {
|
||||
return route.Backend
|
||||
}
|
||||
}
|
||||
return r.fallback
|
||||
}
|
||||
|
||||
// Replace atomically swaps the routing table. The new routes replace the
|
||||
// old ones in their entirety; partial updates are not supported.
|
||||
func (r *Router) Replace(routes []Route, fallback Backend) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.routes = routes
|
||||
r.fallback = fallback
|
||||
}
|
||||
|
||||
// Routes returns a defensive copy of the current routes. For introspection.
|
||||
func (r *Router) Routes() []Route {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]Route, len(r.routes))
|
||||
copy(out, r.routes)
|
||||
return out
|
||||
}
|
||||
|
||||
// Fallback returns the current fallback backend.
|
||||
func (r *Router) Fallback() Backend {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.fallback
|
||||
}
|
||||
|
||||
// matchSNI implements the Match semantics documented on Route.
|
||||
func matchSNI(pattern, sni string) bool {
|
||||
pattern = strings.ToLower(pattern)
|
||||
if pattern == sni {
|
||||
return true
|
||||
}
|
||||
// "*.example.com" matches "<single-label>.example.com".
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
suffix := pattern[1:] // ".example.com"
|
||||
if !strings.HasSuffix(sni, suffix) {
|
||||
return false
|
||||
}
|
||||
labelEnd := len(sni) - len(suffix)
|
||||
if labelEnd <= 0 {
|
||||
return false
|
||||
}
|
||||
// No additional dots in the wildcard label.
|
||||
return !strings.Contains(sni[:labelEnd], ".")
|
||||
}
|
||||
return false
|
||||
}
|
||||
113
core/pkg/sniproxy/router_test.go
Normal file
113
core/pkg/sniproxy/router_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRouter_pick_exact_match(t *testing.T) {
|
||||
fb := Backend{Name: "fallback", Addr: "127.0.0.1:9000"}
|
||||
r := NewRouter(fb)
|
||||
r.Replace([]Route{
|
||||
{Match: "turn.example.com", Backend: Backend{Name: "turn", Addr: "127.0.0.1:5349"}},
|
||||
}, fb)
|
||||
|
||||
got := r.Pick("turn.example.com")
|
||||
if got.Addr != "127.0.0.1:5349" {
|
||||
t.Errorf("expected turn backend, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_pick_unmatched_returns_fallback(t *testing.T) {
|
||||
fb := Backend{Name: "caddy", Addr: "127.0.0.1:8443"}
|
||||
r := NewRouter(fb)
|
||||
r.Replace([]Route{
|
||||
{Match: "turn.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}},
|
||||
}, fb)
|
||||
|
||||
if got := r.Pick("api.example.com"); got != fb {
|
||||
t.Errorf("expected fallback, got %+v", got)
|
||||
}
|
||||
if got := r.Pick(""); got != fb {
|
||||
t.Errorf("expected fallback for empty SNI, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_pick_case_insensitive(t *testing.T) {
|
||||
fb := Backend{Addr: "127.0.0.1:8443"}
|
||||
r := NewRouter(fb)
|
||||
r.Replace([]Route{
|
||||
{Match: "Turn.Example.Com", Backend: Backend{Addr: "127.0.0.1:5349"}},
|
||||
}, fb)
|
||||
|
||||
if got := r.Pick("turn.example.com"); got.Addr != "127.0.0.1:5349" {
|
||||
t.Errorf("expected case-insensitive match, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_pick_wildcard_subdomain(t *testing.T) {
|
||||
fb := Backend{Addr: "127.0.0.1:8443"}
|
||||
r := NewRouter(fb)
|
||||
r.Replace([]Route{
|
||||
{Match: "*.example.com", Backend: Backend{Name: "wild", Addr: "127.0.0.1:5349"}},
|
||||
}, fb)
|
||||
|
||||
cases := map[string]bool{
|
||||
"a.example.com": true,
|
||||
"foo.example.com": true,
|
||||
"a.b.example.com": false, // multi-label not allowed
|
||||
"example.com": false, // bare domain doesn't match *.example.com
|
||||
"other.com": false,
|
||||
}
|
||||
for sni, want := range cases {
|
||||
got := r.Pick(sni) == Backend{Name: "wild", Addr: "127.0.0.1:5349"}
|
||||
if got != want {
|
||||
t.Errorf("Pick(%q): want match=%v, got match=%v", sni, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_replace_atomic(t *testing.T) {
|
||||
// Many concurrent reads against many concurrent Replace calls — should
|
||||
// never observe partial state. Run with -race.
|
||||
fb := Backend{Addr: "fb"}
|
||||
r := NewRouter(fb)
|
||||
r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "1"}}}, fb)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
// Readers
|
||||
for i := 0; i < 4; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
_ = r.Pick("a.com")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Writers
|
||||
for i := 0; i < 200; i++ {
|
||||
r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "x"}}}, fb)
|
||||
}
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRouter_routes_returns_copy(t *testing.T) {
|
||||
r := NewRouter(Backend{})
|
||||
original := []Route{{Match: "a", Backend: Backend{Addr: "1"}}}
|
||||
r.Replace(original, Backend{})
|
||||
got := r.Routes()
|
||||
got[0].Match = "mutated"
|
||||
if r.Routes()[0].Match != "a" {
|
||||
t.Error("Routes() should return a defensive copy")
|
||||
}
|
||||
}
|
||||
177
core/pkg/sniproxy/server.go
Normal file
177
core/pkg/sniproxy/server.go
Normal file
@ -0,0 +1,177 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config tunes the proxy server.
|
||||
type Config struct {
|
||||
// ClientHelloTimeout bounds the wait for a parseable ClientHello.
|
||||
// 0 selects 5 seconds.
|
||||
ClientHelloTimeout time.Duration
|
||||
// BackendDialTimeout bounds backend connect time. 0 selects 5 seconds.
|
||||
BackendDialTimeout time.Duration
|
||||
// MaxConcurrentConns caps total in-flight connections to prevent
|
||||
// resource exhaustion. 0 selects 10000.
|
||||
MaxConcurrentConns int
|
||||
}
|
||||
|
||||
// Server is a TCP-level SNI router. Create via NewServer, then call
|
||||
// Serve(listener) in a goroutine. Close cancels in-flight connections.
|
||||
type Server struct {
|
||||
router *Router
|
||||
cfg Config
|
||||
logger *zap.Logger
|
||||
|
||||
gate chan struct{} // bounded semaphore for concurrent connections
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewServer constructs a Server with the given router and config.
|
||||
func NewServer(router *Router, cfg Config, logger *zap.Logger) *Server {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
if cfg.ClientHelloTimeout <= 0 {
|
||||
cfg.ClientHelloTimeout = 5 * time.Second
|
||||
}
|
||||
if cfg.BackendDialTimeout <= 0 {
|
||||
cfg.BackendDialTimeout = 5 * time.Second
|
||||
}
|
||||
if cfg.MaxConcurrentConns <= 0 {
|
||||
cfg.MaxConcurrentConns = 10000
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
router: router,
|
||||
cfg: cfg,
|
||||
logger: logger.Named("sniproxy"),
|
||||
gate: make(chan struct{}, cfg.MaxConcurrentConns),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Serve accepts connections from ln until ln.Accept returns a permanent
|
||||
// error or Close is called. Serve always returns a non-nil error.
|
||||
func (s *Server) Serve(ln net.Listener) error {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
// Check for shutdown via cancelled ctx.
|
||||
if s.ctx.Err() != nil {
|
||||
return s.ctx.Err()
|
||||
}
|
||||
// Net errors temporarily? Backoff briefly so we don't busy-loop.
|
||||
var ne net.Error
|
||||
if errors.As(err, &ne) && ne.Timeout() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case s.gate <- struct{}{}:
|
||||
default:
|
||||
s.logger.Warn("max concurrent connections reached, dropping",
|
||||
zap.Int("limit", s.cfg.MaxConcurrentConns),
|
||||
zap.String("remote", conn.RemoteAddr().String()),
|
||||
)
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer func() { <-s.gate }()
|
||||
s.handle(c)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Close cancels in-flight connections and waits for handlers to drain.
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
// handle processes a single accepted connection: peek SNI, dial backend,
|
||||
// replay peeked bytes, then bidirectional copy.
|
||||
func (s *Server) handle(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
sni, peeked, err := PeekClientHello(conn, s.cfg.ClientHelloTimeout)
|
||||
if err != nil {
|
||||
s.logger.Debug("ClientHello peek failed",
|
||||
zap.String("remote", conn.RemoteAddr().String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backend := s.router.Pick(sni)
|
||||
if backend.Addr == "" {
|
||||
s.logger.Warn("no backend for SNI",
|
||||
zap.String("sni", sni),
|
||||
zap.String("remote", conn.RemoteAddr().String()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
network := backend.Network
|
||||
if network == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
|
||||
upstream, err := net.DialTimeout(network, backend.Addr, s.cfg.BackendDialTimeout)
|
||||
if err != nil {
|
||||
s.logger.Warn("backend dial failed",
|
||||
zap.String("sni", sni),
|
||||
zap.String("backend", backend.Addr),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer upstream.Close()
|
||||
|
||||
// Replay peeked bytes (the ClientHello + anything else buffered).
|
||||
if len(peeked) > 0 {
|
||||
if _, err := upstream.Write(peeked); err != nil {
|
||||
s.logger.Debug("replay to backend failed",
|
||||
zap.String("sni", sni),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Bidirectional copy. We close both connections when either side
|
||||
// finishes OR when the server is shutting down, so handle() can't
|
||||
// hang forever on a half-stuck peer.
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
_, _ = io.Copy(upstream, conn)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(conn, upstream)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
// Force both sides closed; second copy will exit immediately.
|
||||
upstream.Close()
|
||||
conn.Close()
|
||||
<-done // drain the second goroutine
|
||||
}
|
||||
143
core/pkg/sniproxy/server_test.go
Normal file
143
core/pkg/sniproxy/server_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// startEchoBackend creates a TCP server that echoes the first 1024 bytes
|
||||
// it reads, then closes. Returns the listener and a chan that receives
|
||||
// the bytes the server saw.
|
||||
func startEchoBackend(t *testing.T) (net.Listener, <-chan []byte) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := make(chan []byte, 4)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
_ = c.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := c.Read(buf)
|
||||
got <- append([]byte(nil), buf[:n]...)
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
return ln, got
|
||||
}
|
||||
|
||||
func TestServer_routes_TLS_to_correct_backend(t *testing.T) {
|
||||
turnLn, turnGot := startEchoBackend(t)
|
||||
defer turnLn.Close()
|
||||
caddyLn, caddyGot := startEchoBackend(t)
|
||||
defer caddyLn.Close()
|
||||
|
||||
router := NewRouter(Backend{Network: "tcp", Addr: caddyLn.Addr().String()})
|
||||
router.Replace([]Route{
|
||||
{Match: "turn.example.com", Backend: Backend{Network: "tcp", Addr: turnLn.Addr().String()}},
|
||||
}, router.Fallback())
|
||||
|
||||
srv := NewServer(router, Config{}, zap.NewNop())
|
||||
defer srv.Close()
|
||||
|
||||
frontLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer frontLn.Close()
|
||||
|
||||
go func() { _ = srv.Serve(frontLn) }()
|
||||
|
||||
// Client A: SNI=turn.example.com -> goes to turnLn
|
||||
dialAndStartTLS(t, frontLn.Addr().String(), "turn.example.com")
|
||||
|
||||
// Client B: SNI=other.example.com -> falls through to caddyLn
|
||||
dialAndStartTLS(t, frontLn.Addr().String(), "other.example.com")
|
||||
|
||||
select {
|
||||
case b := <-turnGot:
|
||||
if len(b) == 0 {
|
||||
t.Error("turn backend received empty bytes")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("turn backend did not receive bytes")
|
||||
}
|
||||
|
||||
select {
|
||||
case b := <-caddyGot:
|
||||
if len(b) == 0 {
|
||||
t.Error("caddy fallback received empty bytes")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("caddy fallback did not receive bytes")
|
||||
}
|
||||
}
|
||||
|
||||
// dialAndStartTLS opens a TLS handshake (which produces a ClientHello)
|
||||
// against the given address with the given SNI. Returns immediately —
|
||||
// the test only needs the proxy to forward the bytes; it doesn't
|
||||
// require handshake completion.
|
||||
func dialAndStartTLS(t *testing.T, addr, sni string) {
|
||||
t.Helper()
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
c := tls.Client(conn, &tls.Config{ServerName: sni, InsecureSkipVerify: true})
|
||||
_ = c.Handshake() // expected to fail (echo backend isn't TLS)
|
||||
}()
|
||||
}
|
||||
|
||||
func TestServer_no_backend_drops_connection(t *testing.T) {
|
||||
router := NewRouter(Backend{}) // empty fallback, empty Addr -> dropped
|
||||
srv := NewServer(router, Config{}, zap.NewNop())
|
||||
defer srv.Close()
|
||||
|
||||
frontLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer frontLn.Close()
|
||||
|
||||
go func() { _ = srv.Serve(frontLn) }()
|
||||
|
||||
conn, err := net.Dial("tcp", frontLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
|
||||
// Handshake should fail because connection is closed by proxy.
|
||||
go func() { _ = c.Handshake() }()
|
||||
|
||||
// Reader should see EOF quickly.
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
br := bufio.NewReader(conn)
|
||||
_, err = br.ReadByte()
|
||||
if err == nil {
|
||||
t.Error("expected connection drop")
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
// "use of closed network connection" is also fine.
|
||||
t.Logf("acceptable read error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
235
core/pkg/sniproxy/sni.go
Normal file
235
core/pkg/sniproxy/sni.go
Normal file
@ -0,0 +1,235 @@
|
||||
// Package sniproxy provides a TCP-level Server Name Indication (SNI) router.
|
||||
//
|
||||
// The router peeks at the unencrypted TLS ClientHello on each accepted
|
||||
// connection, extracts the SNI host name, and forwards the raw stream to
|
||||
// a backend. It does NOT terminate TLS — encrypted bytes pass through
|
||||
// verbatim. This lets one TCP port serve multiple TLS-speaking backends
|
||||
// (HTTPS for the gateway, TURNS for stealth WebRTC, etc.) without
|
||||
// sharing private keys with the proxy.
|
||||
//
|
||||
// Design goals:
|
||||
// - Zero TLS material on the proxy
|
||||
// - Bounded ClientHello read (no slowloris)
|
||||
// - Backend dial timeout
|
||||
// - Per-IP rate limiting
|
||||
//
|
||||
// SNI parsing follows RFC 5246 §7.4.1.2 (TLS record + ClientHello) and
|
||||
// RFC 6066 §3 (server_name extension).
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNoSNI is returned when the ClientHello has no server_name extension.
|
||||
var ErrNoSNI = errors.New("sniproxy: ClientHello has no SNI")
|
||||
|
||||
// MaxClientHelloBytes bounds how many bytes we'll read while looking for
|
||||
// the SNI. TLS ClientHello records are typically 200–500 bytes; this is
|
||||
// a generous cap that still defends against memory abuse.
|
||||
const MaxClientHelloBytes = 16 * 1024
|
||||
|
||||
// PeekClientHello reads bytes from conn until a TLS ClientHello has
|
||||
// been parsed (or MaxClientHelloBytes is exceeded). Returns the SNI
|
||||
// hostname (lowercased), the bytes consumed (must be replayed to the
|
||||
// backend), and any error.
|
||||
//
|
||||
// readTimeout bounds the wait — slowloris-style stalls return an error
|
||||
// quickly without holding the goroutine indefinitely.
|
||||
func PeekClientHello(conn net.Conn, readTimeout time.Duration) (string, []byte, error) {
|
||||
if readTimeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
br := bufio.NewReaderSize(conn, MaxClientHelloBytes)
|
||||
|
||||
// Peek the TLS record header (5 bytes): content_type, version (2),
|
||||
// length (2). content_type for ClientHello is 22 (handshake).
|
||||
header, err := br.Peek(5)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("read tls record header: %w", err)
|
||||
}
|
||||
if header[0] != 22 {
|
||||
return "", nil, fmt.Errorf("not a TLS handshake record (type=%d)", header[0])
|
||||
}
|
||||
recLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recLen <= 0 || 5+recLen > MaxClientHelloBytes {
|
||||
return "", nil, fmt.Errorf("invalid record length %d", recLen)
|
||||
}
|
||||
|
||||
full, err := br.Peek(5 + recLen)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("read tls record body: %w", err)
|
||||
}
|
||||
|
||||
sni, err := parseSNI(full[5:])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// We've only peeked — drain the buffer to capture the bytes for replay.
|
||||
consumed := make([]byte, br.Buffered())
|
||||
if _, err := io.ReadFull(br, consumed); err != nil {
|
||||
return "", nil, fmt.Errorf("drain peeked bytes: %w", err)
|
||||
}
|
||||
|
||||
return sni, consumed, nil
|
||||
}
|
||||
|
||||
// parseSNI parses a TLS ClientHello body (without the 5-byte record
|
||||
// header) and returns the server_name extension value if present.
|
||||
func parseSNI(body []byte) (string, error) {
|
||||
r := newReader(body)
|
||||
|
||||
// Handshake type (1 byte) — must be 1 (ClientHello).
|
||||
hsType, err := r.readByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if hsType != 1 {
|
||||
return "", fmt.Errorf("not a ClientHello (handshake type %d)", hsType)
|
||||
}
|
||||
// Handshake length (3 bytes).
|
||||
if _, err := r.readBytes(3); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// client_version (2) + random (32).
|
||||
if _, err := r.readBytes(2 + 32); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// session_id.
|
||||
sidLen, err := r.readByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := r.readBytes(int(sidLen)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// cipher_suites length (2).
|
||||
csLen, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := r.readBytes(int(csLen)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// compression_methods length (1).
|
||||
cmLen, err := r.readByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := r.readBytes(int(cmLen)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Extensions length (2). Optional — TLS 1.0 ClientHello can skip it.
|
||||
if r.remaining() < 2 {
|
||||
return "", ErrNoSNI
|
||||
}
|
||||
extTotalLen, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if int(extTotalLen) > r.remaining() {
|
||||
return "", fmt.Errorf("extensions truncated")
|
||||
}
|
||||
|
||||
end := r.pos + int(extTotalLen)
|
||||
for r.pos < end {
|
||||
extType, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
extLen, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
extData, err := r.readBytes(int(extLen))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// server_name extension is type 0.
|
||||
if extType != 0 {
|
||||
continue
|
||||
}
|
||||
return parseServerName(extData)
|
||||
}
|
||||
return "", ErrNoSNI
|
||||
}
|
||||
|
||||
// parseServerName parses the body of a server_name extension and returns
|
||||
// the first host_name (type 0) entry.
|
||||
func parseServerName(data []byte) (string, error) {
|
||||
r := newReader(data)
|
||||
// server_name_list length (2).
|
||||
listLen, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if int(listLen) > r.remaining() {
|
||||
return "", fmt.Errorf("server_name list truncated")
|
||||
}
|
||||
end := r.pos + int(listLen)
|
||||
for r.pos < end {
|
||||
nameType, err := r.readByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nameLen, err := r.readUint16()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nameBytes, err := r.readBytes(int(nameLen))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if nameType == 0 { // host_name
|
||||
return strings.ToLower(string(nameBytes)), nil
|
||||
}
|
||||
}
|
||||
return "", ErrNoSNI
|
||||
}
|
||||
|
||||
// reader is a tiny byte-slice cursor used by parseSNI/parseServerName.
|
||||
type reader struct {
|
||||
buf []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func newReader(buf []byte) *reader { return &reader{buf: buf} }
|
||||
|
||||
func (r *reader) remaining() int { return len(r.buf) - r.pos }
|
||||
|
||||
func (r *reader) readByte() (byte, error) {
|
||||
if r.pos >= len(r.buf) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
b := r.buf[r.pos]
|
||||
r.pos++
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (r *reader) readUint16() (uint16, error) {
|
||||
if r.pos+2 > len(r.buf) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
v := binary.BigEndian.Uint16(r.buf[r.pos : r.pos+2])
|
||||
r.pos += 2
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *reader) readBytes(n int) ([]byte, error) {
|
||||
if r.pos+n > len(r.buf) {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
b := r.buf[r.pos : r.pos+n]
|
||||
r.pos += n
|
||||
return b, nil
|
||||
}
|
||||
173
core/pkg/sniproxy/sni_test.go
Normal file
173
core/pkg/sniproxy/sni_test.go
Normal file
@ -0,0 +1,173 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// dialAndPeek dials a TLS handshake to the given listener and returns
|
||||
// what PeekClientHello on the server side parsed.
|
||||
func dialAndPeek(t *testing.T, ln net.Listener, sni string) (string, []byte, error) {
|
||||
t.Helper()
|
||||
|
||||
type result struct {
|
||||
sni string
|
||||
peeked []byte
|
||||
err error
|
||||
}
|
||||
resCh := make(chan result, 1)
|
||||
|
||||
// Server side: accept once, peek SNI.
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
resCh <- result{err: err}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
s, p, err := PeekClientHello(conn, 2*time.Second)
|
||||
resCh <- result{sni: s, peeked: p, err: err}
|
||||
}()
|
||||
|
||||
// Client side: kick off a TLS handshake. We don't care if it
|
||||
// completes (no server cert) — we only need ClientHello to be sent.
|
||||
// Use a goroutine so the test doesn't deadlock waiting on Handshake.
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
c := tls.Client(conn, &tls.Config{
|
||||
ServerName: sni,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
_ = c.Handshake() // expected to fail; we only needed the ClientHello
|
||||
}()
|
||||
|
||||
select {
|
||||
case r := <-resCh:
|
||||
return r.sni, r.peeked, r.err
|
||||
case <-time.After(5 * time.Second):
|
||||
return "", nil, errors.New("test timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_returns_sni(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
sni, peeked, err := dialAndPeek(t, ln, "example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("PeekClientHello: %v", err)
|
||||
}
|
||||
if sni != "example.com" {
|
||||
t.Errorf("expected sni=example.com, got %q", sni)
|
||||
}
|
||||
if len(peeked) == 0 {
|
||||
t.Error("expected non-empty peeked bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_lowercases_sni(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
sni, _, err := dialAndPeek(t, ln, "Example.COM")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sni != "example.com" {
|
||||
t.Errorf("expected lowercase, got %q", sni)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_non_tls_returns_error(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
|
||||
go func() {
|
||||
// Send something that isn't a TLS handshake record.
|
||||
_, _ = a.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
_ = a.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(b, 1*time.Second)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-TLS bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_short_record_returns_error(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
|
||||
go func() {
|
||||
// One byte, then close — too short for record header.
|
||||
_, _ = a.Write([]byte{22})
|
||||
_ = a.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(b, 1*time.Second)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short record")
|
||||
}
|
||||
// EOF or read error is acceptable.
|
||||
if !errors.Is(err, io.EOF) && err.Error() == "" {
|
||||
t.Logf("error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_concurrent_safe(t *testing.T) {
|
||||
// Verify no shared state leaks between PeekClientHello calls.
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 4; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
|
||||
_ = c.Handshake()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sni, _, err := PeekClientHello(conn, 2*time.Second)
|
||||
conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("peek %d: %v", i, err)
|
||||
}
|
||||
if sni != "x.example.com" {
|
||||
t.Errorf("peek %d: got %q", i, sni)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
BIN
core/sni-router
Executable file
BIN
core/sni-router
Executable file
Binary file not shown.
38
core/systemd/orama-sni-router.service
Normal file
38
core/systemd/orama-sni-router.service
Normal file
@ -0,0 +1,38 @@
|
||||
[Unit]
|
||||
Description=Orama SNI Router (TLS-level :443 → backend forwarder)
|
||||
Documentation=https://github.com/DeBrosOfficial/network
|
||||
After=network.target
|
||||
Before=caddy.service
|
||||
PartOf=orama-node.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=/opt/orama
|
||||
EnvironmentFile=-/opt/orama/.orama/data/sni-router.env
|
||||
ExecStart=/opt/orama/bin/orama-sni-router --config sni-router.yaml
|
||||
|
||||
# Bind privileged ports (:80, :443) without running as root.
|
||||
AmbientCapabilities=CAP_NET_BIND_SERVICE
|
||||
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
|
||||
|
||||
User=orama
|
||||
Group=orama
|
||||
NoNewPrivileges=yes
|
||||
ProtectSystem=strict
|
||||
ProtectHome=yes
|
||||
PrivateTmp=yes
|
||||
LimitNOFILE=65536
|
||||
|
||||
TimeoutStopSec=15s
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=orama-sni-router
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
Loading…
x
Reference in New Issue
Block a user