feat(gateway): enforce jwt expiry on persistent websockets

- implement `wsJWTExpired` to validate token lifetime with a grace period
- capture jwt expiry at connection upgrade and update via auth.refresh
- close connections with custom code 4401 when tokens expire to force re-auth
- add unit tests to verify expiry logic and state transitions
This commit is contained in:
anonpenguin23 2026-06-12 10:12:21 +03:00
parent d113b75497
commit 4d700aed54
3 changed files with 251 additions and 2 deletions

View File

@ -157,6 +157,24 @@ func (h *ServerlessHandlers) getJWTSubjectFromRequest(r *http.Request) string {
return strings.TrimSpace(claims.Sub) return strings.TrimSpace(claims.Sub)
} }
// getJWTExpiryFromRequest returns the Bearer JWT's `exp` claim (unix seconds)
// if the request was JWT-authenticated, or 0 otherwise (e.g. API-key auth, or
// a token without an exp). Persistent WS connections capture this at upgrade
// to enforce mid-session expiry — a long-lived socket must stop serving RPCs
// once its authorizing token expires, unless refreshed via the #321
// auth.refresh control frame. Bugboard #868.
func (h *ServerlessHandlers) getJWTExpiryFromRequest(r *http.Request) int64 {
v := r.Context().Value(ctxkeys.JWT)
if v == nil {
return 0
}
claims, ok := v.(*auth.JWTClaims)
if !ok || claims == nil {
return 0
}
return claims.Exp
}
// getWalletFromRequest extracts wallet address from JWT. // getWalletFromRequest extracts wallet address from JWT.
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string { func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
// Import strings package functions inline to avoid circular dependencies // Import strings package functions inline to avoid circular dependencies

View File

@ -0,0 +1,152 @@
package serverless
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
)
// TestWSJWTExpired is the core security regression guard for bugboard #868: a
// persistent WS authenticates ONCE at upgrade, and the read loop must stop
// serving application frames once the authorizing JWT is past exp+grace.
//
// If wsJWTExpired starts returning false for a clearly-expired token (or true
// for a still-valid one), an expired token regains full RPC access — including
// turn.credentials minting — for the socket's lifetime.
func TestWSJWTExpired(t *testing.T) {
// Fixed reference instant so the table is deterministic (the read loop
// uses time.Now() in production; the pure function takes `now` for tests).
now := time.Unix(1_700_000_000, 0)
grace := 120 * time.Second
cases := []struct {
name string
expUnix int64
now time.Time
want bool
}{
{
name: "no expiry to enforce (API-key auth, exp=0) never expires",
expUnix: 0,
now: now,
want: false,
},
{
name: "negative exp treated as no-expiry (defensive)",
expUnix: -5,
now: now,
want: false,
},
{
name: "token valid, well before exp",
expUnix: now.Add(10 * time.Minute).Unix(),
now: now,
want: false,
},
{
name: "token just past exp but inside grace window — still allowed",
expUnix: now.Add(-30 * time.Second).Unix(),
now: now,
want: false,
},
{
name: "token exactly at exp+grace boundary — not yet expired (After is strict)",
expUnix: now.Add(-grace).Unix(),
now: now,
want: false,
},
{
name: "token past exp+grace — expired, must reject",
expUnix: now.Add(-(grace + time.Second)).Unix(),
now: now,
want: true,
},
{
name: "token long expired — expired",
expUnix: now.Add(-24 * time.Hour).Unix(),
now: now,
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := wsJWTExpired(tc.expUnix, tc.now, grace)
if got != tc.want {
t.Errorf("wsJWTExpired(exp=%d, now=%d, grace=%s) = %v; want %v",
tc.expUnix, tc.now.Unix(), grace, got, tc.want)
}
})
}
}
// TestGetJWTExpiryFromRequest verifies the gateway reads the authorizing JWT's
// exp off the request context at upgrade. This is the value the read loop
// enforces for the socket's lifetime (#868); if it silently returns 0 for a
// JWT-authenticated request, expiry enforcement is disabled and the bug
// re-opens.
func TestGetJWTExpiryFromRequest(t *testing.T) {
h := newTestHandlers(nil)
t.Run("JWT with exp returns exp", func(t *testing.T) {
claims := &auth.JWTClaims{Sub: "alice", Exp: 1_700_000_123}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkeys.JWT, claims))
if got := h.getJWTExpiryFromRequest(req); got != 1_700_000_123 {
t.Errorf("getJWTExpiryFromRequest = %d; want 1700000123", got)
}
})
t.Run("no JWT on context returns 0 (API-key / unauthenticated)", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if got := h.getJWTExpiryFromRequest(req); got != 0 {
t.Errorf("getJWTExpiryFromRequest = %d; want 0", got)
}
})
t.Run("nil claims under key returns 0", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
var nilClaims *auth.JWTClaims
req = req.WithContext(context.WithValue(req.Context(), ctxkeys.JWT, nilClaims))
if got := h.getJWTExpiryFromRequest(req); got != 0 {
t.Errorf("getJWTExpiryFromRequest = %d; want 0", got)
}
})
}
// TestWSAuthState_refreshExtendsExpiry documents the auth.refresh contract that
// the read loop relies on (#868 + #321): a successful auth.refresh moves the
// enforced expiry forward to the new token's exp, so a socket that refreshes
// before its grace window closes keeps serving RPCs uninterrupted.
//
// We assert the state-transition directly (the full handler needs a live WS
// conn for the ack write; that path is exercised by integration tests). The
// invariant: after refresh, a `now` that WOULD have expired the old token no
// longer expires the socket.
func TestWSAuthState_refreshExtendsExpiry(t *testing.T) {
now := time.Unix(1_700_000_000, 0)
grace := 120 * time.Second
oldExp := now.Add(-(grace + time.Minute)).Unix() // already past grace → expired
state := &wsAuthState{expUnix: oldExp}
if !wsJWTExpired(state.expUnix, now, grace) {
t.Fatalf("precondition: old token should be expired at now")
}
// Simulate what handleAuthRefresh does on success: adopt the new token's
// exp.
newExp := now.Add(15 * time.Minute).Unix()
state.expUnix = newExp
if wsJWTExpired(state.expUnix, now, grace) {
t.Errorf("after refresh the socket must NOT be expired (exp=%d, now=%d)",
state.expUnix, now.Unix())
}
}

View File

@ -22,6 +22,51 @@ import (
// application traffic that goes straight to WASM. Bugboard #321. // application traffic that goes straight to WASM. Bugboard #321.
var oramaControlFramePrefix = []byte(`"__orama"`) var oramaControlFramePrefix = []byte(`"__orama"`)
const (
// wsJWTExpiryGrace is the slack past a JWT's `exp` before the gateway
// stops serving application frames on a persistent WS. It covers clock
// skew between the gateway and the issuing path plus the client's
// refresh round-trip (the #321 auth.refresh control frame). Bugboard
// #868: without this, a socket authenticated ONCE at upgrade keeps full
// RPC access — including turn.credentials minting — for the socket's
// entire lifetime even after the token expires.
//
// Note: on the auth.refresh path ParseAndVerifyJWT independently allows
// its own ±60s exp skew, so worst-case service-past-exp is this grace
// plus that skew (~180s), not 120s flat. Both bounds are deliberate and
// the socket is force-closed once they elapse.
wsJWTExpiryGrace = 120 * time.Second
// wsCloseJWTExpired is the application-specific WS close code sent when a
// persistent socket is torn down for serving past its JWT expiry. It sits
// in the private-use range (4000-4999) and is distinct from protocol
// codes so clients can special-case it as "reconnect with a fresh token".
// Bugboard #868.
wsCloseJWTExpired = 4401
)
// wsAuthState carries the live JWT expiry for a persistent WS across the read
// loop and the auth.refresh control handler. Both run in the SAME goroutine —
// control frames are handled inline in the read loop before any frame reaches
// WASM — so the field needs no synchronization. Bugboard #868.
type wsAuthState struct {
// expUnix is the `exp` (unix seconds) of the JWT currently authorizing
// this socket. 0 means "no expiry to enforce" (e.g. API-key auth or a
// token without exp) — such sockets are exempt from mid-session expiry.
expUnix int64
}
// wsJWTExpired reports whether a persistent WS authorized by a JWT expiring at
// expUnix (unix seconds) has passed its enforcement deadline at time now,
// allowing grace for clock skew + refresh round-trip. expUnix <= 0 means there
// is no expiry to enforce and is never considered expired. Bugboard #868.
func wsJWTExpired(expUnix int64, now time.Time, grace time.Duration) bool {
if expUnix <= 0 {
return false
}
return now.After(time.Unix(expUnix, 0).Add(grace))
}
// oramaControlFrame is the wire shape for gateway-handled control // oramaControlFrame is the wire shape for gateway-handled control
// frames on a persistent WS. The single Type field discriminates; // frames on a persistent WS. The single Type field discriminates;
// payload fields specific to each Type ride alongside. // payload fields specific to each Type ride alongside.
@ -97,6 +142,12 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
invCtx := h.buildPersistentInvocationContext(r, fn, clientID) invCtx := h.buildPersistentInvocationContext(r, fn, clientID)
callerWallet := invCtx.CallerWallet callerWallet := invCtx.CallerWallet
// Capture the authorizing JWT's expiry so the read loop can enforce it
// for the socket's lifetime (bugboard #868). A successful auth.refresh
// control frame updates this in place; 0 (non-JWT auth) disables the
// check.
authState := &wsAuthState{expUnix: h.getJWTExpiryFromRequest(r)}
// Instantiate the persistent module. This compiles once (cached) and // Instantiate the persistent module. This compiles once (cached) and
// creates one wazero instance bound to this connection. // creates one wazero instance bound to this connection.
module, err := h.engine.InstantiatePersistent(r.Context(), fn, invCtx) module, err := h.engine.InstantiatePersistent(r.Context(), fn, invCtx)
@ -196,7 +247,7 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
// avoids json.Unmarshal for every application frame. Only // avoids json.Unmarshal for every application frame. Only
// frames carrying the `"__orama"` key get parsed. // frames carrying the `"__orama"` key get parsed.
if bytes.Contains(frame, oramaControlFramePrefix) { if bytes.Contains(frame, oramaControlFramePrefix) {
handled, ackErr := h.handleOramaControlFrame(frame, fn, inst, namespace, clientID, conn) handled, ackErr := h.handleOramaControlFrame(frame, fn, inst, authState, namespace, clientID, conn)
if ackErr != nil { if ackErr != nil {
h.logger.Warn("persistent WS: control-frame ack write failed", h.logger.Warn("persistent WS: control-frame ack write failed",
zap.String("client_id", clientID), zap.String("client_id", clientID),
@ -213,6 +264,26 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
// application frame. // application frame.
} }
// Bugboard #868: a persistent WS authenticates ONCE at upgrade.
// Before handing an application frame to WASM, reject it once the
// authorizing JWT is past exp+grace — otherwise an expired token
// keeps serving RPCs (incl. turn.credentials minting) indefinitely.
// The client keeps the socket alive by sending an
// {"__orama":"auth.refresh"} control frame (handled above, which
// bypasses this check) before the token expires. The check runs
// only on application frames so an expired client can still recover
// via auth.refresh rather than being locked out.
if wsJWTExpired(authState.expUnix, time.Now(), wsJWTExpiryGrace) {
h.logger.Info("persistent WS: closing — JWT expired without refresh",
zap.String("client_id", clientID),
zap.String("namespace", namespace),
zap.Int64("jwt_exp", authState.expUnix))
_ = conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(wsCloseJWTExpired, "jwt expired; reconnect with a fresh token"),
time.Now().Add(time.Second))
break
}
if err := inst.Submit(frame); err != nil { if err := inst.Submit(frame); err != nil {
h.logger.Warn("persistent WS submit failed (queue full?)", h.logger.Warn("persistent WS submit failed (queue full?)",
zap.String("client_id", clientID), zap.String("client_id", clientID),
@ -276,6 +347,7 @@ func (h *ServerlessHandlers) handleOramaControlFrame(
frame []byte, frame []byte,
fn *serverless.Function, fn *serverless.Function,
inst *persistent.Instance, inst *persistent.Instance,
authState *wsAuthState,
namespace, clientID string, namespace, clientID string,
conn *websocket.Conn, conn *websocket.Conn,
) (handled bool, ackErr error) { ) (handled bool, ackErr error) {
@ -291,7 +363,7 @@ func (h *ServerlessHandlers) handleOramaControlFrame(
switch ctrl.Type { switch ctrl.Type {
case "auth.refresh": case "auth.refresh":
return true, h.handleAuthRefresh(ctrl, fn, inst, namespace, clientID, conn) return true, h.handleAuthRefresh(ctrl, fn, inst, authState, namespace, clientID, conn)
default: default:
// Unknown control type — ack with an error so the client knows // Unknown control type — ack with an error so the client knows
// the frame was seen but ignored. Treat as handled (don't // the frame was seen but ignored. Treat as handled (don't
@ -312,6 +384,7 @@ func (h *ServerlessHandlers) handleAuthRefresh(
ctrl oramaControlFrame, ctrl oramaControlFrame,
fn *serverless.Function, fn *serverless.Function,
inst *persistent.Instance, inst *persistent.Instance,
authState *wsAuthState,
namespace, clientID string, namespace, clientID string,
conn *websocket.Conn, conn *websocket.Conn,
) error { ) error {
@ -407,6 +480,12 @@ func (h *ServerlessHandlers) handleAuthRefresh(
}) })
} }
// Extend the socket's expiry enforcement to the new token's exp so the
// read loop keeps serving RPCs past the old deadline (bugboard #868).
// authState and the read loop share this goroutine, so the write is
// race-free.
authState.expUnix = claims.Exp
h.logger.Info("persistent WS: auth.refresh applied", h.logger.Info("persistent WS: auth.refresh applied",
zap.String("client_id", clientID), zap.String("client_id", clientID),
zap.String("namespace", namespace), zap.String("namespace", namespace),