anonpenguin23 3b8139802c feat: APNs silent-drop guard + persistent-WS mid-session JWT refresh
#348 - APNs silent-drop guard
Apple's APNs silently returns HTTP 200 for pushes with no visible
content (no title, no body, no badge, no sound, no
content-available=1) and then drops them — which looked to the WASM
caller like a successful delivery. Now rejected up-front with the new
push.ErrEmptyContent sentinel, and the APNs provider returns the
structured push.PushError shape (HTTPStatus, Reason, Unregistered,
Wrapped) so the dispatcher can branch on Unregistered to remove dead
tokens automatically. Legacy ErrDeviceUnregistered sentinel is
preserved for errors.Is compatibility (wrapped inside PushError).

Always logs APNs HTTP response (status, reason, apns_id, token prefix)
so future silent-drop classes show up in operator logs.

content-available is also now correctly mapped from snake_case
Data["content_available"] (any truthy variant) into Apple's
canonical "content-available": 1 inside the aps dictionary.

#321 - mid-session JWT refresh on persistent WS
Long-lived persistent WS connections used to have to close+reconnect
when the JWT rolled — losing per-instance state, message queues, and
subscriptions. The handler now accepts an "auth.refresh" control
frame: client sends the new token, the gateway re-verifies it via
the new JWTVerifier interface, updates the per-instance invCtx
in-place (persistent.Instance.UpdateInvCtx), and acks. No close, no
state loss.

JWTVerifier is optional — handlers set it via SetJWTVerifier at
gateway init. When unwired the handler nack's with a "not supported
on this gateway" response and clients fall back to the old
close+reconnect path, so older deploys don't break.

Other:
- push/dispatcher.go: SendToUserDetailed returns per-device PushError
  shape so callers can act on Unregistered / HTTPStatus / Reason.
- serverless/hostfunctions/push.go: WASM host functions for the new
  detailed-error shape.
- serverless/persistent/instance.go: UpdateInvCtx mid-session.

Tests:
- ws_persistent_control_test.go: auth.refresh ack/nack paths.
- apns_test.go: empty-content rejection, PushError shape on 410 +
  generic non-200, content-available mapping.
- dispatcher_detailed_test.go: SendToUserDetailed result shape.
- instance_update_invctx_test.go: invCtx update is per-instance, not
  cross-tenant.

VERSION bumped to 0.122.27.
2026-05-19 18:19:21 +03:00

361 lines
13 KiB
Go

package persistent
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/tetratelabs/wazero/api"
"go.uber.org/zap"
)
// WSOpenInput is the JSON payload passed to the WASM module's ws_open export.
// The WASM module unmarshals this from its input buffer.
type WSOpenInput struct {
ClientID string `json:"client_id"`
Wallet string `json:"wallet,omitempty"`
Namespace string `json:"namespace"`
Headers map[string]string `json:"headers,omitempty"` // filtered subset
}
// CloseReason explains why a connection is shutting down. Passed to ws_close.
type CloseReason string
const (
CloseReasonClientDisconnect CloseReason = "client_disconnect"
CloseReasonServerShutdown CloseReason = "server_shutdown"
CloseReasonIdleTimeout CloseReason = "idle_timeout"
CloseReasonHandlerError CloseReason = "handler_error"
CloseReasonRejected CloseReason = "rejected_by_open"
)
// SendFunc writes a frame to the underlying WebSocket. Returns an error if
// the connection is closed. The instance forwards ws_send hostfunc calls
// to this function.
type SendFunc func(data []byte) error
// Instance owns one WASM module bound to one WebSocket connection.
//
// Lifecycle:
//
// 1. NewInstance — wraps an already-instantiated module
// 2. Open(input) — calls ws_open; non-zero return = reject
// 3. Run(ctx) — drains the inbound channel, calling ws_frame per message;
// blocks until ctx cancelled or instance closed
// 4. Submit(frame) — enqueues a frame for Run to pick up; returns error if full
// 5. Close(reason) — calls ws_close; closes the wazero instance; idempotent
type Instance struct {
clientID string
functionName string
namespace string
module api.Module // wazero instance, owned by this struct
openFn api.Function // exported ws_open
frameFn api.Function // exported ws_frame
closeFn api.Function // exported ws_close
allocFn api.Function // orama_alloc / malloc — for input bytes
memory api.Memory
// Per-instance invocation context. Bound at NewInstance time and
// attached to every WASM-host call's ctx via
// hostfunctions.WithInvocationContext. This is what makes persistent
// WS function_invoke / GetCallerJWTSubject / GetSecret race-free
// across concurrent connections — each instance carries its own
// caller identity in the ctx, never reading the HostFunctions
// singleton field. See pkg/serverless/hostfunctions/invocation_context.go.
//
// MUTABLE: bug #321 added mid-session re-auth — the WS handler can
// swap invCtx via UpdateInvocationContext when the client rotates
// its JWT. invCtxMu guards reads/writes; withInvCtx() takes RLock.
invCtx *serverless.InvocationContext
invCtxMu sync.RWMutex
inbound chan []byte
logger *zap.Logger
// Per-frame timeout. Bounded by the function's TimeoutSeconds.
frameTimeout time.Duration
// Closed exactly once.
closeOnce sync.Once
closed atomic.Bool
}
// Config holds knobs for a persistent instance. Zero values use sensible
// defaults; the gateway populates these from the function's metadata.
type Config struct {
ClientID string
FunctionName string
Namespace string
FrameTimeoutSec int // 0 = 30s default
MaxInflightFrames int // 0 = 64 default
// InvocationContext is attached to every WASM-host call's ctx so the
// instance's caller identity (JWT subject, wallet, claims, ws client
// ID) is race-free across concurrent persistent WS connections.
//
// REQUIRED. NewInstance returns an error if nil — without it, host
// functions would fall back to the shared HostFunctions singleton
// field and re-open the cross-tenant identity leak this whole
// machinery exists to fix (see pkg/serverless/invocation_context.go).
InvocationContext *serverless.InvocationContext
}
// NewInstance wraps an already-instantiated wazero module as a persistent
// instance. Returns an error if any of the required exports
// (ws_open, ws_frame, ws_close) are missing.
//
// The caller retains ownership of the module's lifecycle outside of Close —
// that is, when Close is invoked here, the wazero instance is closed.
func NewInstance(module api.Module, cfg Config, logger *zap.Logger) (*Instance, error) {
// Reject nil invCtx loud and early. A persistent instance without
// per-call invCtx propagation falls back to the singleton field on
// every host call, which races across concurrent connections — the
// exact bug this design exists to prevent. Caller MUST populate.
if cfg.InvocationContext == nil {
return nil, fmt.Errorf("persistent: Config.InvocationContext is required (nil would re-open the cross-tenant identity-leak race; see pkg/serverless/invocation_context.go)")
}
openFn := module.ExportedFunction("ws_open")
if openFn == nil {
return nil, fmt.Errorf("persistent: module missing ws_open export")
}
frameFn := module.ExportedFunction("ws_frame")
if frameFn == nil {
return nil, fmt.Errorf("persistent: module missing ws_frame export")
}
closeFn := module.ExportedFunction("ws_close")
if closeFn == nil {
return nil, fmt.Errorf("persistent: module missing ws_close export")
}
allocFn := module.ExportedFunction("orama_alloc")
if allocFn == nil {
allocFn = module.ExportedFunction("malloc")
}
if allocFn == nil {
return nil, fmt.Errorf("persistent: module missing orama_alloc/malloc export (required to pass payload bytes)")
}
memory := module.Memory()
if memory == nil {
return nil, fmt.Errorf("persistent: module exports no memory")
}
frameTimeout := time.Duration(cfg.FrameTimeoutSec) * time.Second
if frameTimeout <= 0 {
frameTimeout = 30 * time.Second
}
maxInflight := cfg.MaxInflightFrames
if maxInflight <= 0 {
maxInflight = 64
}
return &Instance{
clientID: cfg.ClientID,
functionName: cfg.FunctionName,
namespace: cfg.Namespace,
module: module,
openFn: openFn,
frameFn: frameFn,
closeFn: closeFn,
allocFn: allocFn,
memory: memory,
invCtx: cfg.InvocationContext,
inbound: make(chan []byte, maxInflight),
logger: logger,
frameTimeout: frameTimeout,
}, nil
}
// withInvCtx returns a derived ctx carrying this instance's invocation
// context. Used by every export call so host functions read identity from
// the per-instance ctx instead of the shared HostFunctions singleton.
//
// Returns ctx unchanged when invCtx is nil — preserves backwards-compat
// for callers that didn't populate Config.InvocationContext.
func (i *Instance) withInvCtx(ctx context.Context) context.Context {
i.invCtxMu.RLock()
cur := i.invCtx
i.invCtxMu.RUnlock()
if cur == nil {
return ctx
}
return serverless.WithInvocationContext(ctx, cur)
}
// UpdateInvocationContext atomically swaps the per-instance invocation
// context. Used by the WS handler to apply a mid-session JWT rotation
// (bugboard #321 — `__orama:auth.refresh` control frame) so the
// client's new JWT subject / wallet / claims propagate to every
// subsequent host call WITHOUT tearing down the WS.
//
// Thread-safe: callers can call this from the WS read loop while the
// frame-processing goroutine is concurrently reading the field via
// withInvCtx. The swap is a single pointer-write under a write lock;
// in-flight host calls that already wrapped their ctx with the OLD
// invCtx keep using the old identity until they return — that's
// correct (an in-flight invocation should complete under the identity
// it started with, not get swapped mid-call).
//
// Rejects nil to preserve the "invCtx is required" invariant baked in
// at NewInstance. A nil swap would silently re-open the cross-tenant
// race documented in pkg/serverless/invocation_context.go.
func (i *Instance) UpdateInvocationContext(newInvCtx *serverless.InvocationContext) error {
if newInvCtx == nil {
return fmt.Errorf("persistent: UpdateInvocationContext: nil invCtx (would re-open the cross-tenant identity-leak race)")
}
i.invCtxMu.Lock()
i.invCtx = newInvCtx
i.invCtxMu.Unlock()
return nil
}
// CurrentInvocationContext returns the per-instance invocation context
// snapshot (the same pointer withInvCtx would attach to the next host
// call's ctx). Used by the WS handler to audit identity transitions on
// mid-session JWT refresh (bug #321) without re-reading from the lock.
// May return nil if the instance was constructed without an invCtx.
func (i *Instance) CurrentInvocationContext() *serverless.InvocationContext {
i.invCtxMu.RLock()
defer i.invCtxMu.RUnlock()
return i.invCtx
}
// ClientID returns the WebSocket client ID this instance serves.
func (i *Instance) ClientID() string { return i.clientID }
// Open invokes ws_open with the given input. Returns an error if the WASM
// returns non-zero (connection rejected) or the call traps.
func (i *Instance) Open(ctx context.Context, input WSOpenInput) error {
payload, err := json.Marshal(input)
if err != nil {
return fmt.Errorf("persistent.Open: marshal input: %w", err)
}
ctx, cancel := context.WithTimeout(i.withInvCtx(ctx), i.frameTimeout)
defer cancel()
rc, err := i.callExport(ctx, i.openFn, payload)
if err != nil {
return fmt.Errorf("persistent.Open: ws_open trap: %w", err)
}
if rc != 0 {
return fmt.Errorf("persistent.Open: ws_open rejected (rc=%d)", rc)
}
return nil
}
// Submit enqueues a frame for Run to process. Non-blocking: returns an
// error if the inbound channel is full (caller should drop the connection).
func (i *Instance) Submit(frame []byte) error {
if i.closed.Load() {
return fmt.Errorf("persistent.Submit: instance closed")
}
select {
case i.inbound <- frame:
return nil
default:
return fmt.Errorf("persistent.Submit: inbound queue full (max %d)", cap(i.inbound))
}
}
// Run drains the inbound channel, invoking ws_frame per message. Blocks
// until ctx is cancelled, Close is called, or a frame returns "close".
//
// Frames are processed serially — wazero instances are not goroutine-safe.
func (i *Instance) Run(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case frame, ok := <-i.inbound:
if !ok {
return
}
if err := i.handleFrame(ctx, frame); err != nil {
i.logger.Warn("persistent ws_frame error",
zap.String("client_id", i.clientID),
zap.String("function", i.functionName),
zap.Error(err),
)
// Frame errors trigger close — caller should disconnect WS.
return
}
}
}
}
func (i *Instance) handleFrame(ctx context.Context, frame []byte) error {
frameCtx, cancel := context.WithTimeout(i.withInvCtx(ctx), i.frameTimeout)
defer cancel()
rc, err := i.callExport(frameCtx, i.frameFn, frame)
if err != nil {
return fmt.Errorf("ws_frame trap: %w", err)
}
if rc == 1 {
return fmt.Errorf("ws_frame requested close")
}
return nil
}
// Close invokes ws_close (best-effort, bounded by frameTimeout) then
// shuts down the wazero instance. Idempotent and safe to call multiple times.
func (i *Instance) Close(ctx context.Context, reason CloseReason) {
i.closeOnce.Do(func() {
i.closed.Store(true)
// Drain channel to unblock any blocked Submit on the way out.
go func() {
for range i.inbound {
}
}()
// Best-effort ws_close — don't propagate errors; we're shutting down.
closeCtx, cancel := context.WithTimeout(i.withInvCtx(ctx), i.frameTimeout)
defer cancel()
if _, err := i.callExport(closeCtx, i.closeFn, []byte(reason)); err != nil {
i.logger.Debug("persistent ws_close ignored error",
zap.String("client_id", i.clientID),
zap.Error(err),
)
}
close(i.inbound)
if err := i.module.Close(context.Background()); err != nil {
i.logger.Debug("persistent module.Close error",
zap.String("client_id", i.clientID),
zap.Error(err))
}
})
}
// callExport allocates space in the guest, copies `payload` into it, calls
// the export with (ptr, len), then returns the export's first return value
// (or 0 if void).
//
// It does NOT free the allocated memory — the WASM module is short-lived
// at the per-instance level, and we rely on the module being closed at
// session end. For long-running instances with very high frame rates,
// memory growth is bounded by the function's memory_limit_mb.
func (i *Instance) callExport(ctx context.Context, fn api.Function, payload []byte) (uint32, error) {
var ptr uint64
if len(payload) > 0 {
results, err := i.allocFn.Call(ctx, uint64(len(payload)))
if err != nil {
return 0, fmt.Errorf("alloc: %w", err)
}
ptr = results[0]
if !i.memory.Write(uint32(ptr), payload) {
return 0, fmt.Errorf("memory write failed (oom?)")
}
}
results, err := fn.Call(ctx, ptr, uint64(len(payload)))
if err != nil {
return 0, err
}
if len(results) == 0 {
return 0, nil
}
return uint32(results[0]), nil
}