mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-06-17 01:34:13 +00:00
feat(gateway): enforce jwt expiry on persistent websockets
- implement `wsJWTExpired` to validate token lifetime with a grace period - capture jwt expiry at connection upgrade and update via auth.refresh - close connections with custom code 4401 when tokens expire to force re-auth - add unit tests to verify expiry logic and state transitions
This commit is contained in:
parent
d113b75497
commit
4d700aed54
@ -157,6 +157,24 @@ func (h *ServerlessHandlers) getJWTSubjectFromRequest(r *http.Request) string {
|
|||||||
return strings.TrimSpace(claims.Sub)
|
return strings.TrimSpace(claims.Sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getJWTExpiryFromRequest returns the Bearer JWT's `exp` claim (unix seconds)
|
||||||
|
// if the request was JWT-authenticated, or 0 otherwise (e.g. API-key auth, or
|
||||||
|
// a token without an exp). Persistent WS connections capture this at upgrade
|
||||||
|
// to enforce mid-session expiry — a long-lived socket must stop serving RPCs
|
||||||
|
// once its authorizing token expires, unless refreshed via the #321
|
||||||
|
// auth.refresh control frame. Bugboard #868.
|
||||||
|
func (h *ServerlessHandlers) getJWTExpiryFromRequest(r *http.Request) int64 {
|
||||||
|
v := r.Context().Value(ctxkeys.JWT)
|
||||||
|
if v == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
claims, ok := v.(*auth.JWTClaims)
|
||||||
|
if !ok || claims == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return claims.Exp
|
||||||
|
}
|
||||||
|
|
||||||
// getWalletFromRequest extracts wallet address from JWT.
|
// getWalletFromRequest extracts wallet address from JWT.
|
||||||
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
|
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
|
||||||
// Import strings package functions inline to avoid circular dependencies
|
// Import strings package functions inline to avoid circular dependencies
|
||||||
|
|||||||
@ -0,0 +1,152 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestWSJWTExpired is the core security regression guard for bugboard #868: a
|
||||||
|
// persistent WS authenticates ONCE at upgrade, and the read loop must stop
|
||||||
|
// serving application frames once the authorizing JWT is past exp+grace.
|
||||||
|
//
|
||||||
|
// If wsJWTExpired starts returning false for a clearly-expired token (or true
|
||||||
|
// for a still-valid one), an expired token regains full RPC access — including
|
||||||
|
// turn.credentials minting — for the socket's lifetime.
|
||||||
|
func TestWSJWTExpired(t *testing.T) {
|
||||||
|
// Fixed reference instant so the table is deterministic (the read loop
|
||||||
|
// uses time.Now() in production; the pure function takes `now` for tests).
|
||||||
|
now := time.Unix(1_700_000_000, 0)
|
||||||
|
grace := 120 * time.Second
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
expUnix int64
|
||||||
|
now time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no expiry to enforce (API-key auth, exp=0) never expires",
|
||||||
|
expUnix: 0,
|
||||||
|
now: now,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative exp treated as no-expiry (defensive)",
|
||||||
|
expUnix: -5,
|
||||||
|
now: now,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token valid, well before exp",
|
||||||
|
expUnix: now.Add(10 * time.Minute).Unix(),
|
||||||
|
now: now,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token just past exp but inside grace window — still allowed",
|
||||||
|
expUnix: now.Add(-30 * time.Second).Unix(),
|
||||||
|
now: now,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token exactly at exp+grace boundary — not yet expired (After is strict)",
|
||||||
|
expUnix: now.Add(-grace).Unix(),
|
||||||
|
now: now,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token past exp+grace — expired, must reject",
|
||||||
|
expUnix: now.Add(-(grace + time.Second)).Unix(),
|
||||||
|
now: now,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token long expired — expired",
|
||||||
|
expUnix: now.Add(-24 * time.Hour).Unix(),
|
||||||
|
now: now,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := wsJWTExpired(tc.expUnix, tc.now, grace)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("wsJWTExpired(exp=%d, now=%d, grace=%s) = %v; want %v",
|
||||||
|
tc.expUnix, tc.now.Unix(), grace, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetJWTExpiryFromRequest verifies the gateway reads the authorizing JWT's
|
||||||
|
// exp off the request context at upgrade. This is the value the read loop
|
||||||
|
// enforces for the socket's lifetime (#868); if it silently returns 0 for a
|
||||||
|
// JWT-authenticated request, expiry enforcement is disabled and the bug
|
||||||
|
// re-opens.
|
||||||
|
func TestGetJWTExpiryFromRequest(t *testing.T) {
|
||||||
|
h := newTestHandlers(nil)
|
||||||
|
|
||||||
|
t.Run("JWT with exp returns exp", func(t *testing.T) {
|
||||||
|
claims := &auth.JWTClaims{Sub: "alice", Exp: 1_700_000_123}
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkeys.JWT, claims))
|
||||||
|
|
||||||
|
if got := h.getJWTExpiryFromRequest(req); got != 1_700_000_123 {
|
||||||
|
t.Errorf("getJWTExpiryFromRequest = %d; want 1700000123", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no JWT on context returns 0 (API-key / unauthenticated)", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if got := h.getJWTExpiryFromRequest(req); got != 0 {
|
||||||
|
t.Errorf("getJWTExpiryFromRequest = %d; want 0", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil claims under key returns 0", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
var nilClaims *auth.JWTClaims
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkeys.JWT, nilClaims))
|
||||||
|
if got := h.getJWTExpiryFromRequest(req); got != 0 {
|
||||||
|
t.Errorf("getJWTExpiryFromRequest = %d; want 0", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWSAuthState_refreshExtendsExpiry documents the auth.refresh contract that
|
||||||
|
// the read loop relies on (#868 + #321): a successful auth.refresh moves the
|
||||||
|
// enforced expiry forward to the new token's exp, so a socket that refreshes
|
||||||
|
// before its grace window closes keeps serving RPCs uninterrupted.
|
||||||
|
//
|
||||||
|
// We assert the state-transition directly (the full handler needs a live WS
|
||||||
|
// conn for the ack write; that path is exercised by integration tests). The
|
||||||
|
// invariant: after refresh, a `now` that WOULD have expired the old token no
|
||||||
|
// longer expires the socket.
|
||||||
|
func TestWSAuthState_refreshExtendsExpiry(t *testing.T) {
|
||||||
|
now := time.Unix(1_700_000_000, 0)
|
||||||
|
grace := 120 * time.Second
|
||||||
|
|
||||||
|
oldExp := now.Add(-(grace + time.Minute)).Unix() // already past grace → expired
|
||||||
|
state := &wsAuthState{expUnix: oldExp}
|
||||||
|
|
||||||
|
if !wsJWTExpired(state.expUnix, now, grace) {
|
||||||
|
t.Fatalf("precondition: old token should be expired at now")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate what handleAuthRefresh does on success: adopt the new token's
|
||||||
|
// exp.
|
||||||
|
newExp := now.Add(15 * time.Minute).Unix()
|
||||||
|
state.expUnix = newExp
|
||||||
|
|
||||||
|
if wsJWTExpired(state.expUnix, now, grace) {
|
||||||
|
t.Errorf("after refresh the socket must NOT be expired (exp=%d, now=%d)",
|
||||||
|
state.expUnix, now.Unix())
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -22,6 +22,51 @@ import (
|
|||||||
// application traffic that goes straight to WASM. Bugboard #321.
|
// application traffic that goes straight to WASM. Bugboard #321.
|
||||||
var oramaControlFramePrefix = []byte(`"__orama"`)
|
var oramaControlFramePrefix = []byte(`"__orama"`)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// wsJWTExpiryGrace is the slack past a JWT's `exp` before the gateway
|
||||||
|
// stops serving application frames on a persistent WS. It covers clock
|
||||||
|
// skew between the gateway and the issuing path plus the client's
|
||||||
|
// refresh round-trip (the #321 auth.refresh control frame). Bugboard
|
||||||
|
// #868: without this, a socket authenticated ONCE at upgrade keeps full
|
||||||
|
// RPC access — including turn.credentials minting — for the socket's
|
||||||
|
// entire lifetime even after the token expires.
|
||||||
|
//
|
||||||
|
// Note: on the auth.refresh path ParseAndVerifyJWT independently allows
|
||||||
|
// its own ±60s exp skew, so worst-case service-past-exp is this grace
|
||||||
|
// plus that skew (~180s), not 120s flat. Both bounds are deliberate and
|
||||||
|
// the socket is force-closed once they elapse.
|
||||||
|
wsJWTExpiryGrace = 120 * time.Second
|
||||||
|
|
||||||
|
// wsCloseJWTExpired is the application-specific WS close code sent when a
|
||||||
|
// persistent socket is torn down for serving past its JWT expiry. It sits
|
||||||
|
// in the private-use range (4000-4999) and is distinct from protocol
|
||||||
|
// codes so clients can special-case it as "reconnect with a fresh token".
|
||||||
|
// Bugboard #868.
|
||||||
|
wsCloseJWTExpired = 4401
|
||||||
|
)
|
||||||
|
|
||||||
|
// wsAuthState carries the live JWT expiry for a persistent WS across the read
|
||||||
|
// loop and the auth.refresh control handler. Both run in the SAME goroutine —
|
||||||
|
// control frames are handled inline in the read loop before any frame reaches
|
||||||
|
// WASM — so the field needs no synchronization. Bugboard #868.
|
||||||
|
type wsAuthState struct {
|
||||||
|
// expUnix is the `exp` (unix seconds) of the JWT currently authorizing
|
||||||
|
// this socket. 0 means "no expiry to enforce" (e.g. API-key auth or a
|
||||||
|
// token without exp) — such sockets are exempt from mid-session expiry.
|
||||||
|
expUnix int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsJWTExpired reports whether a persistent WS authorized by a JWT expiring at
|
||||||
|
// expUnix (unix seconds) has passed its enforcement deadline at time now,
|
||||||
|
// allowing grace for clock skew + refresh round-trip. expUnix <= 0 means there
|
||||||
|
// is no expiry to enforce and is never considered expired. Bugboard #868.
|
||||||
|
func wsJWTExpired(expUnix int64, now time.Time, grace time.Duration) bool {
|
||||||
|
if expUnix <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return now.After(time.Unix(expUnix, 0).Add(grace))
|
||||||
|
}
|
||||||
|
|
||||||
// oramaControlFrame is the wire shape for gateway-handled control
|
// oramaControlFrame is the wire shape for gateway-handled control
|
||||||
// frames on a persistent WS. The single Type field discriminates;
|
// frames on a persistent WS. The single Type field discriminates;
|
||||||
// payload fields specific to each Type ride alongside.
|
// payload fields specific to each Type ride alongside.
|
||||||
@ -97,6 +142,12 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
|
|||||||
invCtx := h.buildPersistentInvocationContext(r, fn, clientID)
|
invCtx := h.buildPersistentInvocationContext(r, fn, clientID)
|
||||||
callerWallet := invCtx.CallerWallet
|
callerWallet := invCtx.CallerWallet
|
||||||
|
|
||||||
|
// Capture the authorizing JWT's expiry so the read loop can enforce it
|
||||||
|
// for the socket's lifetime (bugboard #868). A successful auth.refresh
|
||||||
|
// control frame updates this in place; 0 (non-JWT auth) disables the
|
||||||
|
// check.
|
||||||
|
authState := &wsAuthState{expUnix: h.getJWTExpiryFromRequest(r)}
|
||||||
|
|
||||||
// Instantiate the persistent module. This compiles once (cached) and
|
// Instantiate the persistent module. This compiles once (cached) and
|
||||||
// creates one wazero instance bound to this connection.
|
// creates one wazero instance bound to this connection.
|
||||||
module, err := h.engine.InstantiatePersistent(r.Context(), fn, invCtx)
|
module, err := h.engine.InstantiatePersistent(r.Context(), fn, invCtx)
|
||||||
@ -196,7 +247,7 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
|
|||||||
// avoids json.Unmarshal for every application frame. Only
|
// avoids json.Unmarshal for every application frame. Only
|
||||||
// frames carrying the `"__orama"` key get parsed.
|
// frames carrying the `"__orama"` key get parsed.
|
||||||
if bytes.Contains(frame, oramaControlFramePrefix) {
|
if bytes.Contains(frame, oramaControlFramePrefix) {
|
||||||
handled, ackErr := h.handleOramaControlFrame(frame, fn, inst, namespace, clientID, conn)
|
handled, ackErr := h.handleOramaControlFrame(frame, fn, inst, authState, namespace, clientID, conn)
|
||||||
if ackErr != nil {
|
if ackErr != nil {
|
||||||
h.logger.Warn("persistent WS: control-frame ack write failed",
|
h.logger.Warn("persistent WS: control-frame ack write failed",
|
||||||
zap.String("client_id", clientID),
|
zap.String("client_id", clientID),
|
||||||
@ -213,6 +264,26 @@ func (h *ServerlessHandlers) handlePersistentWebSocket(
|
|||||||
// application frame.
|
// application frame.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bugboard #868: a persistent WS authenticates ONCE at upgrade.
|
||||||
|
// Before handing an application frame to WASM, reject it once the
|
||||||
|
// authorizing JWT is past exp+grace — otherwise an expired token
|
||||||
|
// keeps serving RPCs (incl. turn.credentials minting) indefinitely.
|
||||||
|
// The client keeps the socket alive by sending an
|
||||||
|
// {"__orama":"auth.refresh"} control frame (handled above, which
|
||||||
|
// bypasses this check) before the token expires. The check runs
|
||||||
|
// only on application frames so an expired client can still recover
|
||||||
|
// via auth.refresh rather than being locked out.
|
||||||
|
if wsJWTExpired(authState.expUnix, time.Now(), wsJWTExpiryGrace) {
|
||||||
|
h.logger.Info("persistent WS: closing — JWT expired without refresh",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.Int64("jwt_exp", authState.expUnix))
|
||||||
|
_ = conn.WriteControl(websocket.CloseMessage,
|
||||||
|
websocket.FormatCloseMessage(wsCloseJWTExpired, "jwt expired; reconnect with a fresh token"),
|
||||||
|
time.Now().Add(time.Second))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
if err := inst.Submit(frame); err != nil {
|
if err := inst.Submit(frame); err != nil {
|
||||||
h.logger.Warn("persistent WS submit failed (queue full?)",
|
h.logger.Warn("persistent WS submit failed (queue full?)",
|
||||||
zap.String("client_id", clientID),
|
zap.String("client_id", clientID),
|
||||||
@ -276,6 +347,7 @@ func (h *ServerlessHandlers) handleOramaControlFrame(
|
|||||||
frame []byte,
|
frame []byte,
|
||||||
fn *serverless.Function,
|
fn *serverless.Function,
|
||||||
inst *persistent.Instance,
|
inst *persistent.Instance,
|
||||||
|
authState *wsAuthState,
|
||||||
namespace, clientID string,
|
namespace, clientID string,
|
||||||
conn *websocket.Conn,
|
conn *websocket.Conn,
|
||||||
) (handled bool, ackErr error) {
|
) (handled bool, ackErr error) {
|
||||||
@ -291,7 +363,7 @@ func (h *ServerlessHandlers) handleOramaControlFrame(
|
|||||||
|
|
||||||
switch ctrl.Type {
|
switch ctrl.Type {
|
||||||
case "auth.refresh":
|
case "auth.refresh":
|
||||||
return true, h.handleAuthRefresh(ctrl, fn, inst, namespace, clientID, conn)
|
return true, h.handleAuthRefresh(ctrl, fn, inst, authState, namespace, clientID, conn)
|
||||||
default:
|
default:
|
||||||
// Unknown control type — ack with an error so the client knows
|
// Unknown control type — ack with an error so the client knows
|
||||||
// the frame was seen but ignored. Treat as handled (don't
|
// the frame was seen but ignored. Treat as handled (don't
|
||||||
@ -312,6 +384,7 @@ func (h *ServerlessHandlers) handleAuthRefresh(
|
|||||||
ctrl oramaControlFrame,
|
ctrl oramaControlFrame,
|
||||||
fn *serverless.Function,
|
fn *serverless.Function,
|
||||||
inst *persistent.Instance,
|
inst *persistent.Instance,
|
||||||
|
authState *wsAuthState,
|
||||||
namespace, clientID string,
|
namespace, clientID string,
|
||||||
conn *websocket.Conn,
|
conn *websocket.Conn,
|
||||||
) error {
|
) error {
|
||||||
@ -407,6 +480,12 @@ func (h *ServerlessHandlers) handleAuthRefresh(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extend the socket's expiry enforcement to the new token's exp so the
|
||||||
|
// read loop keeps serving RPCs past the old deadline (bugboard #868).
|
||||||
|
// authState and the read loop share this goroutine, so the write is
|
||||||
|
// race-free.
|
||||||
|
authState.expUnix = claims.Exp
|
||||||
|
|
||||||
h.logger.Info("persistent WS: auth.refresh applied",
|
h.logger.Info("persistent WS: auth.refresh applied",
|
||||||
zap.String("client_id", clientID),
|
zap.String("client_id", clientID),
|
||||||
zap.String("namespace", namespace),
|
zap.String("namespace", namespace),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user