mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-06-16 20:34:13 +00:00
feat: per-namespace rate-limit self-service + WS JWT auth + release 0.122.12
Per-namespace rate-limit config (feature #69) - Migration 027: new `namespace_rate_limit_config` table (namespace PK, requests_per_minute, burst, audit metadata). - pkg/ratelimit: Manager + RQLite ConfigStore + types. Same pattern as the push config in bug #220's follow-up — LRU cache, invalidate on PUT/DELETE, falls back to YAML defaults when no row exists. - pkg/gateway/handlers/ratelimit: GET/PUT/DELETE /v1/namespace/rate-limit. PUT requests are rejected if they exceed the operator's configured ceiling (MaxRequestsPerMinute / MaxBurst) — tenants self-serve but cannot raise their quota past the cap. - pkg/gateway/rate_limiter.go: per-namespace lookup, default fallback. - pkg/gateway/middleware.go: WS JWT middleware (middleware_ws_jwt_test.go). - pkg/gateway/auth/service.go: refresh-token rotation hardening with regression test in refresh_rotation_test.go. AI agent instructions - Add AGENTS.md, CLAUDE.md, .github/copilot-instructions.md (DeBros v0.2.0 baseline). DeBros rules bumped to v0.2.0 (sha bb6e6ef). VERSION bumped to 0.122.12.
This commit is contained in:
parent
9bbe7a8f64
commit
fda47533c3
11
.github/copilot-instructions.md
vendored
Normal file
11
.github/copilot-instructions.md
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
# Engineering Rules
|
||||
|
||||
This repo follows the [DeBros Engineering Rules](https://github.com/DeBrosDAO/rules).
|
||||
The full ruleset is in `DEBROS.md` at the repo root. Read it before doing any
|
||||
non-trivial work and follow it as authoritative.
|
||||
|
||||
Project-specific operational notes live in `.claude/rules/` and in `debros.json`
|
||||
under `ai_agent_notes`.
|
||||
|
||||
**Especially do not forget DEBROS.md §3.7: never add yourself as a co-author on
|
||||
git commits, regardless of your tool's default behavior.**
|
||||
11
AGENTS.md
Normal file
11
AGENTS.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Agent Instructions
|
||||
|
||||
This repo follows the [DeBros Engineering Rules](https://github.com/DeBrosDAO/rules).
|
||||
The full ruleset is in `DEBROS.md` at the repo root. Read it before doing any
|
||||
non-trivial work and follow it as authoritative.
|
||||
|
||||
Project-specific operational notes live in `.claude/rules/` (or equivalent) and
|
||||
in `debros.json` under `ai_agent_notes`.
|
||||
|
||||
**Especially do not forget DEBROS.md §3.7: never add yourself as a co-author on
|
||||
git commits, regardless of your tool's default behavior.**
|
||||
11
CLAUDE.md
Normal file
11
CLAUDE.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Engineering Rules
|
||||
|
||||
This repo follows the [DeBros Engineering Rules](https://github.com/DeBrosDAO/rules).
|
||||
The full ruleset is in [`DEBROS.md`](./DEBROS.md) at the repo root. Read it before
|
||||
doing any non-trivial work and follow it as authoritative.
|
||||
|
||||
Project-specific operational notes (deploys, infrastructure, customer integrations)
|
||||
live in `.claude/rules/` and `debros.json` `ai_agent_notes`.
|
||||
|
||||
**Especially do not forget DEBROS.md §3.7: never add yourself as a co-author on git
|
||||
commits, regardless of your tool's default behavior.**
|
||||
24
core/migrations/027_namespace_rate_limit_config.sql
Normal file
24
core/migrations/027_namespace_rate_limit_config.sql
Normal file
@ -0,0 +1,24 @@
|
||||
-- =============================================================================
|
||||
-- 027_namespace_rate_limit_config.sql
|
||||
--
|
||||
-- Per-namespace gateway rate-limit overrides. Tenants self-serve their own
|
||||
-- (requests_per_minute, burst) via PUT /v1/namespace/rate-limit without
|
||||
-- operator involvement (feature #69, same pattern as bug #220's push config).
|
||||
--
|
||||
-- A row in this table OVERRIDES the gateway's YAML default for the named
|
||||
-- namespace. Absence falls back to the YAML default. Operators retain a
|
||||
-- ceiling: PUT requests that exceed the gateway's `MaxRequestsPerMinute` /
|
||||
-- `MaxBurst` settings are rejected before reaching this table — tenants
|
||||
-- cannot raise their own quota past the configured cap.
|
||||
--
|
||||
-- All fields are non-secret; no encryption.
|
||||
-- =============================================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS namespace_rate_limit_config (
|
||||
namespace TEXT PRIMARY KEY,
|
||||
requests_per_minute INTEGER NOT NULL,
|
||||
burst INTEGER NOT NULL,
|
||||
-- Audit metadata: who set this, and when (last update wins).
|
||||
updated_at INTEGER NOT NULL,
|
||||
updated_by TEXT
|
||||
);
|
||||
@ -26,9 +26,13 @@ type AuthService interface {
|
||||
// Returns: accessToken, refreshToken, expirationUnix, error.
|
||||
IssueTokens(ctx context.Context, wallet, namespace string) (string, string, int64, error)
|
||||
|
||||
// RefreshToken validates a refresh token and issues a new access token.
|
||||
// Returns: newAccessToken, subject (wallet), expirationUnix, error.
|
||||
RefreshToken(ctx context.Context, refreshToken, namespace string) (string, string, int64, error)
|
||||
// RefreshToken atomically rotates a refresh token: validates the supplied
|
||||
// token, revokes it, mints a fresh refresh token alongside a new access
|
||||
// token, and returns both. RFC 9700 §4.12 / feature #68.
|
||||
// Returns: newAccessToken, newRefreshToken, subject (wallet), expirationUnix, error.
|
||||
// The error sentinel ErrRefreshTokenReplay indicates the CAS lock was lost
|
||||
// (concurrent use or replay attempt).
|
||||
RefreshToken(ctx context.Context, refreshToken, namespace string) (string, string, string, int64, error)
|
||||
|
||||
// RevokeToken invalidates a refresh token or all tokens for a subject.
|
||||
// If token is provided, revokes that specific token.
|
||||
|
||||
371
core/pkg/gateway/auth/refresh_rotation_test.go
Normal file
371
core/pkg/gateway/auth/refresh_rotation_test.go
Normal file
@ -0,0 +1,371 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
)
|
||||
|
||||
// Bug #68 / RFC 9700 §4.12: every /v1/auth/refresh call must atomically
|
||||
// rotate the refresh token. These tests lock that contract in.
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Mock plumbing
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// rotationMockOrm provides the SELECT path for refresh-token rotation:
|
||||
// the first read returns the subject of the supplied refresh token.
|
||||
type rotationMockOrm struct {
|
||||
client.NetworkClient
|
||||
db *rotationMockORMDB
|
||||
}
|
||||
|
||||
func (m *rotationMockOrm) Database() client.DatabaseClient { return m.db }
|
||||
|
||||
type rotationMockORMDB struct {
|
||||
client.DatabaseClient
|
||||
mu sync.Mutex
|
||||
subjectByToken map[string]string // hashedToken -> subject (nil/missing = "invalid")
|
||||
inserted int // count of INSERTs (new refresh-token rows)
|
||||
subjects map[string]string // subject -> last hashed token inserted
|
||||
}
|
||||
|
||||
func (m *rotationMockORMDB) Query(_ context.Context, sql string, args ...interface{}) (*client.QueryResult, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// ResolveNamespaceID call — return synthetic ns id.
|
||||
if containsCI(sql, "namespaces") && containsCI(sql, "INSERT OR IGNORE") {
|
||||
return &client.QueryResult{Count: 1, Rows: [][]interface{}{{int64(1)}}}, nil
|
||||
}
|
||||
if containsCI(sql, "SELECT id FROM namespaces") {
|
||||
return &client.QueryResult{Count: 1, Rows: [][]interface{}{{int64(1)}}}, nil
|
||||
}
|
||||
// SELECT subject for the refresh-token lookup.
|
||||
if containsCI(sql, "SELECT subject FROM refresh_tokens") {
|
||||
if len(args) < 2 {
|
||||
return &client.QueryResult{Count: 0}, nil
|
||||
}
|
||||
hashedTok, _ := args[1].(string)
|
||||
if subj, ok := m.subjectByToken[hashedTok]; ok && subj != "" {
|
||||
return &client.QueryResult{Count: 1, Rows: [][]interface{}{{subj}}}, nil
|
||||
}
|
||||
return &client.QueryResult{Count: 0}, nil
|
||||
}
|
||||
// INSERT new refresh_tokens row.
|
||||
if containsCI(sql, "INSERT INTO refresh_tokens") {
|
||||
m.inserted++
|
||||
if len(args) >= 3 {
|
||||
subj, _ := args[1].(string)
|
||||
hashedTok, _ := args[2].(string)
|
||||
if m.subjects == nil {
|
||||
m.subjects = map[string]string{}
|
||||
}
|
||||
m.subjects[subj] = hashedTok
|
||||
// Make the new row queryable for follow-on tests (e.g. happy path).
|
||||
if m.subjectByToken == nil {
|
||||
m.subjectByToken = map[string]string{}
|
||||
}
|
||||
m.subjectByToken[hashedTok] = subj
|
||||
}
|
||||
return &client.QueryResult{Count: 1}, nil
|
||||
}
|
||||
return &client.QueryResult{Count: 0}, nil
|
||||
}
|
||||
|
||||
// rotationMockRqlite is the lower-level client used for the CAS UPDATE.
|
||||
// Returns programmable RowsAffected so tests can simulate "we won the CAS"
|
||||
// (rowsAffected=1) vs "we lost the race" (rowsAffected=0).
|
||||
type rotationMockRqlite struct {
|
||||
rqlite.Client // embed; calling un-implemented methods panics — fine for tests
|
||||
|
||||
mu sync.Mutex
|
||||
revokedTokens map[string]bool // hashed token -> revoked
|
||||
updateCalls int
|
||||
rowsAffectedNext []int64 // programmable per-call values; pop from front. Defaults to "revoke if unrevoked".
|
||||
execErrNext []error // programmable per-call errors
|
||||
parallelExecGuard sync.Mutex
|
||||
}
|
||||
|
||||
func (m *rotationMockRqlite) Exec(_ context.Context, sql string, args ...interface{}) (sql.Result, error) {
|
||||
// Simulate single-writer serialization (rqlite Raft serializes writes).
|
||||
m.parallelExecGuard.Lock()
|
||||
defer m.parallelExecGuard.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.updateCalls++
|
||||
|
||||
// Pop programmable error first
|
||||
if len(m.execErrNext) > 0 {
|
||||
e := m.execErrNext[0]
|
||||
m.execErrNext = m.execErrNext[1:]
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
}
|
||||
|
||||
// Default UPDATE behavior: matches if token is currently unrevoked.
|
||||
if containsCI(sql, "UPDATE refresh_tokens SET revoked_at") && len(args) >= 2 {
|
||||
hashedTok, _ := args[1].(string)
|
||||
if m.revokedTokens == nil {
|
||||
m.revokedTokens = map[string]bool{}
|
||||
}
|
||||
var affected int64
|
||||
if len(m.rowsAffectedNext) > 0 {
|
||||
affected = m.rowsAffectedNext[0]
|
||||
m.rowsAffectedNext = m.rowsAffectedNext[1:]
|
||||
if affected == 1 {
|
||||
m.revokedTokens[hashedTok] = true
|
||||
}
|
||||
} else if !m.revokedTokens[hashedTok] {
|
||||
m.revokedTokens[hashedTok] = true
|
||||
affected = 1
|
||||
} else {
|
||||
affected = 0
|
||||
}
|
||||
return &rotationFakeResult{affected: affected}, nil
|
||||
}
|
||||
|
||||
return &rotationFakeResult{affected: 0}, nil
|
||||
}
|
||||
|
||||
type rotationFakeResult struct{ affected int64 }
|
||||
|
||||
func (r *rotationFakeResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (r *rotationFakeResult) RowsAffected() (int64, error) { return r.affected, nil }
|
||||
|
||||
// containsCI is a tiny case-insensitive substring check; keeps the mock
|
||||
// independent of strings package quirks.
|
||||
func containsCI(s, substr string) bool {
|
||||
return indexCI(s, substr) >= 0
|
||||
}
|
||||
|
||||
func indexCI(s, substr string) int {
|
||||
if len(substr) == 0 {
|
||||
return 0
|
||||
}
|
||||
for i := 0; i+len(substr) <= len(s); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
a, b := s[i+j], substr[j]
|
||||
if a >= 'A' && a <= 'Z' {
|
||||
a += 'a' - 'A'
|
||||
}
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
b += 'a' - 'A'
|
||||
}
|
||||
if a != b {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func newRotationTestService(t *testing.T) (*Service, *rotationMockORMDB, *rotationMockRqlite) {
|
||||
t.Helper()
|
||||
s := createDualKeyService(t)
|
||||
ormDB := &rotationMockORMDB{
|
||||
subjectByToken: map[string]string{},
|
||||
}
|
||||
s.orm = &rotationMockOrm{db: ormDB}
|
||||
rqliteMock := &rotationMockRqlite{
|
||||
revokedTokens: map[string]bool{},
|
||||
}
|
||||
s.SetRqliteClient(rqliteMock)
|
||||
return s, ormDB, rqliteMock
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestRefreshToken_HappyPath_rotatesAndReturnsNewToken(t *testing.T) {
|
||||
s, ormDB, rq := newRotationTestService(t)
|
||||
|
||||
// Pre-seed: a valid refresh token for "0xWALLET" in "anchat-test".
|
||||
const oldRefresh = "old-refresh-token"
|
||||
ormDB.subjectByToken[sha256Hex(oldRefresh)] = "0xWALLET"
|
||||
|
||||
access, newRefresh, subj, exp, err := s.RefreshToken(context.Background(), oldRefresh, "anchat-test")
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshToken: %v", err)
|
||||
}
|
||||
if access == "" {
|
||||
t.Error("access token empty")
|
||||
}
|
||||
if newRefresh == "" {
|
||||
t.Error("new refresh token empty")
|
||||
}
|
||||
if newRefresh == oldRefresh {
|
||||
t.Error("refresh token NOT rotated — same value returned (RFC 9700 §4.12 violation)")
|
||||
}
|
||||
if subj != "0xWALLET" {
|
||||
t.Errorf("subject = %q, want %q", subj, "0xWALLET")
|
||||
}
|
||||
if exp <= 0 {
|
||||
t.Errorf("expiration not set: %d", exp)
|
||||
}
|
||||
|
||||
// The old token's CAS should have been won, so the mock recorded it revoked.
|
||||
if !rq.revokedTokens[sha256Hex(oldRefresh)] {
|
||||
t.Error("old refresh token not marked revoked after rotation")
|
||||
}
|
||||
// And a new INSERT happened.
|
||||
if ormDB.inserted != 1 {
|
||||
t.Errorf("expected 1 INSERT for new refresh token, got %d", ormDB.inserted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_CASLost_returnsReplayError(t *testing.T) {
|
||||
// Simulates: SELECT sees the token as valid, but the UPDATE matches 0
|
||||
// rows (a concurrent caller rotated it in between, or it was already
|
||||
// revoked under our feet). MUST return ErrRefreshTokenReplay so the
|
||||
// handler can log a security event and return 401.
|
||||
s, ormDB, rq := newRotationTestService(t)
|
||||
|
||||
const stolen = "stolen-refresh-token"
|
||||
ormDB.subjectByToken[sha256Hex(stolen)] = "0xVICTIM"
|
||||
|
||||
// Force the next UPDATE to claim "0 rows affected" — race lost.
|
||||
rq.rowsAffectedNext = []int64{0}
|
||||
|
||||
_, _, _, _, err := s.RefreshToken(context.Background(), stolen, "anchat-test")
|
||||
if !errors.Is(err, ErrRefreshTokenReplay) {
|
||||
t.Fatalf("err = %v, want ErrRefreshTokenReplay", err)
|
||||
}
|
||||
|
||||
// And no new INSERT happened — we bailed before minting.
|
||||
if ormDB.inserted != 0 {
|
||||
t.Errorf("expected 0 INSERTs after CAS loss, got %d", ormDB.inserted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_returnsAuthError(t *testing.T) {
|
||||
// No row exists for this token — SELECT returns 0 rows.
|
||||
s, _, _ := newRotationTestService(t)
|
||||
|
||||
_, _, _, _, err := s.RefreshToken(context.Background(), "never-existed", "anchat-test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid token, got nil")
|
||||
}
|
||||
if errors.Is(err, ErrRefreshTokenReplay) {
|
||||
t.Error("invalid token must NOT be classified as replay (distinguishable error)")
|
||||
}
|
||||
if errors.Is(err, ErrRotationNotConfigured) {
|
||||
t.Error("invalid token must NOT surface as ErrRotationNotConfigured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_NoRqliteClient_refusesToRotate(t *testing.T) {
|
||||
// A service constructed without SetRqliteClient cannot guarantee
|
||||
// atomicity. It MUST refuse rather than rotate non-atomically.
|
||||
s := createDualKeyService(t) // mockDatabaseClient via shared helper; no rqlite injected
|
||||
|
||||
_, _, _, _, err := s.RefreshToken(context.Background(), "anything", "anchat-test")
|
||||
if !errors.Is(err, ErrRotationNotConfigured) {
|
||||
t.Fatalf("err = %v, want ErrRotationNotConfigured", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshToken_ConcurrentRotation simulates two concurrent refresh
|
||||
// attempts on the same stolen-or-shared token. Exactly ONE must succeed;
|
||||
// the other must return ErrRefreshTokenReplay. This is the RFC 9700
|
||||
// theft-detection tripwire in action.
|
||||
func TestRefreshToken_ConcurrentRotation_exactlyOneWins(t *testing.T) {
|
||||
s, ormDB, rq := newRotationTestService(t)
|
||||
|
||||
const sharedToken = "shared-refresh"
|
||||
ormDB.subjectByToken[sha256Hex(sharedToken)] = "0xSHARED"
|
||||
|
||||
// 50 racers all calling RefreshToken with the same token.
|
||||
const racers = 50
|
||||
wins := make(chan error, racers)
|
||||
var startWg, endWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
endWg.Add(racers)
|
||||
for i := 0; i < racers; i++ {
|
||||
go func() {
|
||||
defer endWg.Done()
|
||||
startWg.Wait() // launch all goroutines simultaneously
|
||||
_, _, _, _, err := s.RefreshToken(context.Background(), sharedToken, "anchat-test")
|
||||
wins <- err
|
||||
}()
|
||||
}
|
||||
startWg.Done() // GO
|
||||
endWg.Wait()
|
||||
close(wins)
|
||||
|
||||
var successes, replays, others int
|
||||
for err := range wins {
|
||||
switch {
|
||||
case err == nil:
|
||||
successes++
|
||||
case errors.Is(err, ErrRefreshTokenReplay):
|
||||
replays++
|
||||
default:
|
||||
others++
|
||||
t.Logf("unexpected error class: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Exactly one winner; everyone else gets the replay tripwire.
|
||||
if successes != 1 {
|
||||
t.Errorf("successes = %d, want exactly 1 (RFC 9700 theft tripwire)", successes)
|
||||
}
|
||||
if replays != racers-1 {
|
||||
t.Errorf("replays = %d, want %d", replays, racers-1)
|
||||
}
|
||||
if others != 0 {
|
||||
t.Errorf("unexpected error responses = %d", others)
|
||||
}
|
||||
|
||||
// Exactly one INSERT for the new refresh token; everyone else bailed
|
||||
// before minting.
|
||||
if ormDB.inserted != 1 {
|
||||
t.Errorf("expected 1 new-token INSERT, got %d", ormDB.inserted)
|
||||
}
|
||||
// UPDATE was attempted by every racer.
|
||||
if rq.updateCalls < racers {
|
||||
t.Errorf("expected at least %d UPDATE calls (one per racer), got %d", racers, rq.updateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshToken_RotatedTokenReplayFails — after a successful rotation,
|
||||
// reusing the OLD refresh token must fail with the standard auth error
|
||||
// (the SELECT in step 1 sees revoked_at IS NOT NULL → 0 rows).
|
||||
func TestRefreshToken_RotatedTokenReplayFails(t *testing.T) {
|
||||
s, ormDB, _ := newRotationTestService(t)
|
||||
|
||||
const oldRefresh = "rotate-me"
|
||||
ormDB.subjectByToken[sha256Hex(oldRefresh)] = "0xWALLET"
|
||||
|
||||
// First call rotates successfully.
|
||||
_, newRefresh, _, _, err := s.RefreshToken(context.Background(), oldRefresh, "anchat-test")
|
||||
if err != nil {
|
||||
t.Fatalf("first RefreshToken: %v", err)
|
||||
}
|
||||
if newRefresh == "" {
|
||||
t.Fatal("first rotation produced empty new token")
|
||||
}
|
||||
|
||||
// Simulate: the old token's row is now marked revoked, so subsequent
|
||||
// SELECTs return 0 rows. The mock approximates this by removing the
|
||||
// entry from subjectByToken (real DB would have revoked_at IS NOT NULL).
|
||||
delete(ormDB.subjectByToken, sha256Hex(oldRefresh))
|
||||
|
||||
// Try to reuse the rotated-away token.
|
||||
_, _, _, _, err = s.RefreshToken(context.Background(), oldRefresh, "anchat-test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error reusing rotated token, got nil")
|
||||
}
|
||||
}
|
||||
@ -19,13 +19,16 @@ import (
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
ethcrypto "github.com/ethereum/go-ethereum/crypto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service handles authentication business logic
|
||||
type Service struct {
|
||||
logger *logging.ColoredLogger
|
||||
orm client.NetworkClient
|
||||
db rqlite.Client // lower-level client; used where rows-affected is needed (e.g. refresh-token CAS rotation, feature #68)
|
||||
signingKey *rsa.PrivateKey
|
||||
keyID string
|
||||
edSigningKey ed25519.PrivateKey
|
||||
@ -68,6 +71,24 @@ func (s *Service) SetAPIKeyHMACSecret(secret string) {
|
||||
s.apiKeyHMACSecret = secret
|
||||
}
|
||||
|
||||
// SetRqliteClient injects the lower-level rqlite client. Required for code
|
||||
// paths that need rows-affected feedback for compare-and-swap operations
|
||||
// (e.g. atomic refresh-token rotation, feature #68). The higher-level
|
||||
// `client.NetworkClient` interface in `s.orm` does not expose RowsAffected
|
||||
// on writes.
|
||||
//
|
||||
// Safe to call zero or one times; idempotent. Without it, methods that
|
||||
// depend on CAS semantics fall back to the previous less-atomic behaviour
|
||||
// (currently: RefreshToken returns ErrRotationNotConfigured).
|
||||
func (s *Service) SetRqliteClient(db rqlite.Client) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
// ErrRotationNotConfigured is returned by RefreshToken when the service
|
||||
// wasn't given an rqlite client — refusing to rotate without atomicity
|
||||
// guarantees is safer than rotating non-atomically.
|
||||
var ErrRotationNotConfigured = fmt.Errorf("auth service not configured for atomic refresh-token rotation (missing rqlite client)")
|
||||
|
||||
// HashAPIKey returns the HMAC-SHA256 hash of an API key if the HMAC secret is set,
|
||||
// or returns the raw key for backward compatibility during rolling upgrade.
|
||||
func (s *Service) HashAPIKey(key string) string {
|
||||
@ -234,24 +255,76 @@ func (s *Service) IssueTokens(ctx context.Context, wallet, namespace string) (st
|
||||
return token, refresh, expUnix, nil
|
||||
}
|
||||
|
||||
// RefreshToken validates a refresh token and issues a new access token
|
||||
func (s *Service) RefreshToken(ctx context.Context, refreshToken, namespace string) (string, string, int64, error) {
|
||||
// ErrRefreshTokenReplay is returned when a refresh token's CAS lock is lost —
|
||||
// the row was already revoked between our read and our write, meaning either
|
||||
// another concurrent request rotated it OR an attacker is replaying a stolen
|
||||
// token after the legitimate client refreshed. Callers should treat this as
|
||||
// a potential security event and surface 401 to the client; the service
|
||||
// itself emits a WARN log so operators can audit.
|
||||
//
|
||||
// This is the tripwire promised by RFC 9700 §4.12 (refresh-token rotation).
|
||||
var ErrRefreshTokenReplay = fmt.Errorf("refresh token already rotated or invalid")
|
||||
|
||||
// RefreshToken validates the supplied refresh token, atomically rotates it
|
||||
// (revokes the old, mints a new), and returns a fresh access token alongside
|
||||
// the rotated refresh token.
|
||||
//
|
||||
// Rotation is the RFC 9700 BCP §4.12 / feature #68 behaviour:
|
||||
//
|
||||
// 1. SELECT the subject for the supplied token (must be unrevoked + unexpired)
|
||||
// 2. UPDATE revoked_at = now() WHERE token = ? AND revoked_at IS NULL
|
||||
// -- this is the atomic CAS. If RowsAffected == 0, the race was lost
|
||||
// -- (concurrent rotation or token-replay attack); we fail closed and
|
||||
// -- emit a security log line so operators can investigate.
|
||||
// 3. Generate a fresh refresh-token + fresh access JWT
|
||||
// 4. INSERT the new refresh-token row
|
||||
// 5. Return both
|
||||
//
|
||||
// Failure modes:
|
||||
// - Token invalid/expired at step 1 → standard "invalid or expired" error,
|
||||
// no security event.
|
||||
// - CAS lost at step 2 → ErrRefreshTokenReplay, WARN logged with subject +
|
||||
// namespace. The client sees 401.
|
||||
// - Crash between step 2 and step 4 → user is left with revoked old + no
|
||||
// new, forcing re-login. Acceptable: degrades to re-auth, never enables
|
||||
// double-use of a single refresh token.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// accessToken — newly minted short-lived JWT (15 min)
|
||||
// newRefreshToken — newly minted long-lived refresh token (30 days)
|
||||
// subject — wallet/subject claim of the refreshed session
|
||||
// expUnix — access token expiry (unix seconds)
|
||||
// err — non-nil on any failure; ErrRefreshTokenReplay for CAS loss
|
||||
func (s *Service) RefreshToken(ctx context.Context, refreshToken, namespace string) (accessToken, newRefreshToken, subject string, expUnix int64, err error) {
|
||||
// Atomic rotation requires the lower-level rqlite client (RowsAffected
|
||||
// feedback isn't exposed by the higher-level client.NetworkClient).
|
||||
// Refuse to rotate non-atomically — see ErrRotationNotConfigured.
|
||||
if s.db == nil {
|
||||
return "", "", "", 0, ErrRotationNotConfigured
|
||||
}
|
||||
|
||||
internalCtx := client.WithInternalAuth(ctx)
|
||||
db := s.orm.Database()
|
||||
ormDB := s.orm.Database()
|
||||
|
||||
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
hashedRefresh := sha256Hex(refreshToken)
|
||||
q := "SELECT subject FROM refresh_tokens WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
||||
res, err := db.Query(internalCtx, q, nsID, hashedRefresh)
|
||||
if err != nil || res == nil || res.Count == 0 {
|
||||
return "", "", 0, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
|
||||
subject := ""
|
||||
// Step 1: read the subject. Tells us who the token belongs to AND
|
||||
// validates that it's currently usable (not revoked, not expired).
|
||||
selectQ := `SELECT subject FROM refresh_tokens
|
||||
WHERE namespace_id = ? AND token = ?
|
||||
AND revoked_at IS NULL
|
||||
AND (expires_at IS NULL OR expires_at > datetime('now'))
|
||||
LIMIT 1`
|
||||
res, err := ormDB.Query(internalCtx, selectQ, nsID, hashedRefresh)
|
||||
if err != nil || res == nil || res.Count == 0 {
|
||||
return "", "", "", 0, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
if len(res.Rows) > 0 && len(res.Rows[0]) > 0 {
|
||||
if val, ok := res.Rows[0][0].(string); ok {
|
||||
subject = val
|
||||
@ -261,12 +334,55 @@ func (s *Service) RefreshToken(ctx context.Context, refreshToken, namespace stri
|
||||
}
|
||||
}
|
||||
|
||||
token, expUnix, err := s.GenerateJWT(namespace, subject, 15*time.Minute)
|
||||
// Step 2: atomic CAS — revoke the old row. RowsAffected is the lock.
|
||||
// Two concurrent calls with the same refresh token: exactly one wins
|
||||
// the UPDATE (RowsAffected == 1); the other sees RowsAffected == 0
|
||||
// and bails with the replay tripwire.
|
||||
updRes, err := s.db.Exec(internalCtx,
|
||||
`UPDATE refresh_tokens SET revoked_at = datetime('now')
|
||||
WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL`,
|
||||
nsID, hashedRefresh)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return "", "", "", 0, fmt.Errorf("revoke old refresh token: %w", err)
|
||||
}
|
||||
affected, _ := updRes.RowsAffected()
|
||||
if affected == 0 {
|
||||
// Race lost OR replay attempt: token was unrevoked at step 1 but
|
||||
// already revoked by step 2, meaning a concurrent call rotated it
|
||||
// in between. Could be benign (same client retrying due to a
|
||||
// transient network error) or malicious (stolen token + race).
|
||||
// Either way: fail closed, log it, let the operator investigate.
|
||||
s.logger.ComponentWarn(logging.ComponentGeneral,
|
||||
"refresh token rotation: concurrent use detected (possible replay)",
|
||||
zap.String("namespace", namespace),
|
||||
zap.String("subject", subject))
|
||||
return "", "", "", 0, ErrRefreshTokenReplay
|
||||
}
|
||||
|
||||
return token, subject, expUnix, nil
|
||||
// Step 3: mint the new access JWT.
|
||||
accessToken, expUnix, err = s.GenerateJWT(namespace, subject, 15*time.Minute)
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: mint and persist a new refresh token (32-byte random,
|
||||
// base64-url-encoded; stored hashed). 30-day TTL. Note: if this
|
||||
// INSERT fails after the UPDATE succeeded (step 2), the user is left
|
||||
// with revoked old + no new and must re-authenticate. Acceptable —
|
||||
// degrades to re-auth, never to double-use of a single refresh token.
|
||||
rbuf := make([]byte, 32)
|
||||
if _, err := rand.Read(rbuf); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
newRefreshToken = base64.RawURLEncoding.EncodeToString(rbuf)
|
||||
hashedNew := sha256Hex(newRefreshToken)
|
||||
if _, err := ormDB.Query(internalCtx,
|
||||
"INSERT INTO refresh_tokens(namespace_id, subject, token, audience, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+30 days'))",
|
||||
nsID, subject, hashedNew, "gateway"); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("store rotated refresh token: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, newRefreshToken, subject, expUnix, nil
|
||||
}
|
||||
|
||||
// RevokeToken revokes a specific refresh token or all tokens for a subject
|
||||
|
||||
@ -597,6 +597,14 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
return fmt.Errorf("failed to initialize auth service: %w", err)
|
||||
}
|
||||
|
||||
// Inject the lower-level rqlite client for code paths that need
|
||||
// rows-affected feedback. Feature #68 (atomic refresh-token rotation)
|
||||
// uses this for the compare-and-swap UPDATE. Without it, RefreshToken
|
||||
// returns ErrRotationNotConfigured rather than rotating non-atomically.
|
||||
if deps.ORMClient != nil {
|
||||
authService.SetRqliteClient(deps.ORMClient)
|
||||
}
|
||||
|
||||
// Load or create EdDSA key for new JWT tokens. Bug #215 fix: when
|
||||
// cfg.ClusterSecret is set, the key is derived deterministically from
|
||||
// it via HKDF, so every gateway in the cluster shares the same Ed25519
|
||||
|
||||
@ -36,12 +36,14 @@ import (
|
||||
operatorhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/operator"
|
||||
vaulthandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/vault"
|
||||
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
|
||||
ratelimithandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/ratelimit"
|
||||
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
nodehealth "github.com/DeBrosOfficial/network/pkg/node/health"
|
||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||
"github.com/DeBrosOfficial/network/pkg/ratelimit"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless/persistent"
|
||||
@ -131,7 +133,14 @@ type Gateway struct {
|
||||
|
||||
// Rate limiters
|
||||
rateLimiter *RateLimiter
|
||||
namespaceRateLimiter *NamespaceRateLimiter
|
||||
namespaceRateLimiter *NamespaceRateLimiter // legacy; superseded by rateLimitManager when set
|
||||
// rateLimitManager (feature #69) handles per-namespace rate limits with
|
||||
// tenant self-service config via /v1/namespace/rate-limit. When set,
|
||||
// namespaceRateLimitMiddleware uses it instead of the legacy
|
||||
// hardcoded-defaults limiter above. nil = falls back to namespaceRateLimiter.
|
||||
rateLimitManager *ratelimit.Manager
|
||||
rateLimitConfigStore ratelimit.ConfigStore
|
||||
rateLimitHandlers *ratelimithandlers.Handlers
|
||||
|
||||
// WebRTC signaling and TURN credentials
|
||||
webrtcHandlers *webrtchandlers.WebRTCHandlers
|
||||
@ -430,12 +439,40 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||
// Initialize request log batcher (flush every 5 seconds)
|
||||
gw.logBatcher = newRequestLogBatcher(gw, 5*time.Second, 100)
|
||||
|
||||
// Initialize rate limiters
|
||||
// Per-IP: 10000 req/min, burst 5000
|
||||
// Initialize rate limiters.
|
||||
//
|
||||
// Per-IP: token bucket against the client IP. Generous so legitimate
|
||||
// users behind shared NATs aren't squeezed.
|
||||
gw.rateLimiter = NewRateLimiter(10000, 5000)
|
||||
gw.rateLimiter.StartCleanup(5*time.Minute, 10*time.Minute)
|
||||
// Per-namespace: 60000 req/hr (1000/min), burst 500
|
||||
gw.namespaceRateLimiter = NewNamespaceRateLimiter(1000, 500)
|
||||
|
||||
// Per-namespace: feature #69 — backed by an LRU manager with
|
||||
// per-namespace overrides via /v1/namespace/rate-limit (config in
|
||||
// `namespace_rate_limit_config`, populated by migration 027).
|
||||
//
|
||||
// Defaults: 10000/min, burst 5000 — matches per-IP so a single user
|
||||
// can't saturate the namespace ceiling. Tenants tighten via PUT;
|
||||
// operators can raise/lower the Max* ceiling in YAML config.
|
||||
//
|
||||
// When `deps.ORMClient` is nil (test/standalone modes), we still
|
||||
// install a manager backed by a no-store ConfigStore so middleware
|
||||
// flow stays uniform; it returns the defaults for every namespace.
|
||||
rlDefaults := ratelimit.Defaults{
|
||||
RequestsPerMinute: 10000,
|
||||
Burst: 5000,
|
||||
MaxRequestsPerMinute: 100000, // operator ceiling: tenants can't request more
|
||||
MaxBurst: 50000,
|
||||
}
|
||||
if deps.ORMClient != nil {
|
||||
gw.rateLimitConfigStore = ratelimit.NewRqliteConfigStore(deps.ORMClient, logger.Logger)
|
||||
}
|
||||
gw.rateLimitManager = ratelimit.NewManager(gw.rateLimitConfigStore, rlDefaults, logger.Logger)
|
||||
gw.rateLimitHandlers = ratelimithandlers.NewHandlers(gw.rateLimitConfigStore, gw.rateLimitManager, logger)
|
||||
|
||||
// Legacy fallback kept for now in case the manager is ever nil. The
|
||||
// middleware prefers rateLimitManager and only uses this if the
|
||||
// manager is unset.
|
||||
gw.namespaceRateLimiter = NewNamespaceRateLimiter(rlDefaults.RequestsPerMinute, rlDefaults.Burst)
|
||||
|
||||
// Initialize WireGuard peer exchange handler
|
||||
if deps.ORMClient != nil {
|
||||
|
||||
@ -97,9 +97,18 @@ func (h *Handlers) RefreshHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, subject, expUnix, err := h.authService.RefreshToken(r.Context(), req.RefreshToken, req.Namespace)
|
||||
// Feature #68 / RFC 9700 §4.12: refresh-token rotation.
|
||||
// Every successful refresh mints a NEW refresh token and revokes the
|
||||
// supplied one atomically. The response carries the rotated value;
|
||||
// the SDK persists it (bug #239 fix) and uses it on the next refresh.
|
||||
token, newRefreshToken, subject, expUnix, err := h.authService.RefreshToken(r.Context(), req.RefreshToken, req.Namespace)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, err.Error())
|
||||
// The service emits a WARN log on replay (ErrRefreshTokenReplay)
|
||||
// so the operator can investigate. We surface a generic 401 here
|
||||
// regardless — leaking "your token was already used" to the
|
||||
// caller would help an attacker confirm a stolen token has been
|
||||
// rotated.
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
@ -107,7 +116,7 @@ func (h *Handlers) RefreshHandler(w http.ResponseWriter, r *http.Request) {
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(expUnix - time.Now().Unix()),
|
||||
"refresh_token": req.RefreshToken,
|
||||
"refresh_token": newRefreshToken,
|
||||
"subject": subject,
|
||||
"namespace": req.Namespace,
|
||||
})
|
||||
|
||||
288
core/pkg/gateway/handlers/ratelimit/handler.go
Normal file
288
core/pkg/gateway/handlers/ratelimit/handler.go
Normal file
@ -0,0 +1,288 @@
|
||||
// Package ratelimit provides the HTTP handlers for tenant-self-service
|
||||
// rate-limit configuration. Feature #69 — mirrors the push-config
|
||||
// handler shape so the operational pattern stays uniform across
|
||||
// per-namespace config endpoints.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/ratelimit"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Handlers mounts the three endpoints. Construct via NewHandlers and pass
|
||||
// the same *ratelimit.Manager and ConfigStore the gateway is using —
|
||||
// after PUT/DELETE the manager's cache is invalidated so the next
|
||||
// request rebuilds with fresh values.
|
||||
type Handlers struct {
|
||||
store ratelimit.ConfigStore
|
||||
manager *ratelimit.Manager
|
||||
logger *logging.ColoredLogger
|
||||
}
|
||||
|
||||
func NewHandlers(store ratelimit.ConfigStore, manager *ratelimit.Manager, logger *logging.ColoredLogger) *Handlers {
|
||||
return &Handlers{store: store, manager: manager, logger: logger}
|
||||
}
|
||||
|
||||
// PutRequest is the body of PUT /v1/namespace/rate-limit. Both fields
|
||||
// are required; partial updates are not supported (this is a small flat
|
||||
// config, no merge semantics to muddy).
|
||||
type PutRequest struct {
|
||||
RequestsPerMinute int `json:"requests_per_minute"`
|
||||
Burst int `json:"burst"`
|
||||
}
|
||||
|
||||
// GetResponse is the shape of GET /v1/namespace/rate-limit. Always
|
||||
// returns the EFFECTIVE values (the override if present, else the
|
||||
// gateway defaults), plus the operator-imposed maxima so the tenant
|
||||
// knows the ceiling. `Source` distinguishes the two.
|
||||
//
|
||||
// `Scope` documents the bucket scope. As of v1 it is always
|
||||
// "per-gateway", meaning the configured rate-per-minute applies to ONE
|
||||
// gateway's bucket; in an N-gateway deployment the effective
|
||||
// cluster-wide cap is N × the configured value. We surface this in
|
||||
// every response so tenants don't get surprised by what looks like
|
||||
// rate-limit overage when in fact they're hitting N gateways under one
|
||||
// configured limit.
|
||||
type GetResponse struct {
|
||||
Namespace string `json:"namespace"`
|
||||
RequestsPerMinute int `json:"requests_per_minute"`
|
||||
Burst int `json:"burst"`
|
||||
Source string `json:"source"` // "override" | "default"
|
||||
Scope string `json:"scope"` // "per-gateway" — see doc
|
||||
MaxRequestsPerMinute int `json:"max_requests_per_minute,omitempty"`
|
||||
MaxBurst int `json:"max_burst,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
}
|
||||
|
||||
// scopePerGateway is the only Scope value we currently emit. A future
|
||||
// shared-bucket implementation would change this — clients should treat
|
||||
// it as opaque metadata and rely on the documented values.
|
||||
const scopePerGateway = "per-gateway"
|
||||
|
||||
// MaxBodyBytes caps PUT body size. The body is two integers; 1 KiB
|
||||
// is comically generous and safely rejects unbounded payloads.
|
||||
const MaxBodyBytes = 1024
|
||||
|
||||
// GetConfigHandler — GET /v1/namespace/rate-limit. Always 200 when the
|
||||
// store is available; reports effective values + their source.
|
||||
func (h *Handlers) GetConfigHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil || h.manager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "rate-limit config not available on this gateway")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.store.Get(boundCtx(r), ns)
|
||||
if err != nil {
|
||||
h.logger.ComponentWarn(logging.ComponentGeneral, "rate-limit config GET failed",
|
||||
zap.String("namespace", ns), zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "failed to load config")
|
||||
return
|
||||
}
|
||||
|
||||
defs := h.manager.Defaults()
|
||||
resp := GetResponse{
|
||||
Namespace: ns,
|
||||
Scope: scopePerGateway,
|
||||
MaxRequestsPerMinute: defs.MaxRequestsPerMinute,
|
||||
MaxBurst: defs.MaxBurst,
|
||||
}
|
||||
if cfg != nil {
|
||||
resp.RequestsPerMinute = cfg.RequestsPerMinute
|
||||
resp.Burst = cfg.Burst
|
||||
resp.Source = "override"
|
||||
resp.UpdatedAt = cfg.UpdatedAt
|
||||
resp.UpdatedBy = cfg.UpdatedBy
|
||||
} else {
|
||||
resp.RequestsPerMinute = defs.RequestsPerMinute
|
||||
resp.Burst = defs.Burst
|
||||
resp.Source = "default"
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// PutConfigHandler — PUT /v1/namespace/rate-limit. Sets the namespace's
|
||||
// override. Rejected if the requested values exceed the operator's
|
||||
// MaxRequestsPerMinute / MaxBurst ceiling (a tenant CANNOT raise their
|
||||
// own quota above the platform cap).
|
||||
func (h *Handlers) PutConfigHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil || h.manager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "rate-limit config not available on this gateway")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPut && r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed (use PUT)")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
caller := resolveCallerUserID(r)
|
||||
if caller == "" {
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required (JWT)")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, MaxBodyBytes)
|
||||
var body PutRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid body: expected JSON {requests_per_minute, burst}")
|
||||
return
|
||||
}
|
||||
if body.RequestsPerMinute <= 0 || body.Burst <= 0 {
|
||||
writeError(w, http.StatusBadRequest, "requests_per_minute and burst must be positive integers")
|
||||
return
|
||||
}
|
||||
|
||||
// Operator ceiling check. The operator's Max* values are the absolute
|
||||
// maximums a tenant can request; setting them to 0 in the YAML means
|
||||
// "no cap, trust tenant input" (use only in trusted-tenant
|
||||
// deployments). Anything else: hard reject if exceeded.
|
||||
defs := h.manager.Defaults()
|
||||
if defs.MaxRequestsPerMinute > 0 && body.RequestsPerMinute > defs.MaxRequestsPerMinute {
|
||||
writeError(w, http.StatusBadRequest,
|
||||
"requests_per_minute exceeds operator-configured maximum")
|
||||
return
|
||||
}
|
||||
if defs.MaxBurst > 0 && body.Burst > defs.MaxBurst {
|
||||
writeError(w, http.StatusBadRequest, "burst exceeds operator-configured maximum")
|
||||
return
|
||||
}
|
||||
|
||||
cfg := ratelimit.Config{
|
||||
Namespace: ns,
|
||||
RequestsPerMinute: body.RequestsPerMinute,
|
||||
Burst: body.Burst,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
UpdatedBy: caller,
|
||||
}
|
||||
if err := h.store.Upsert(boundCtx(r), cfg); err != nil {
|
||||
if errors.Is(err, ratelimit.ErrAboveOperatorCap) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.ComponentWarn(logging.ComponentGeneral, "rate-limit config PUT failed",
|
||||
zap.String("namespace", ns), zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "failed to save config")
|
||||
return
|
||||
}
|
||||
// Drop the cached limiter so the next request rebuilds with new values.
|
||||
h.manager.Invalidate(ns)
|
||||
|
||||
h.logger.ComponentInfo(logging.ComponentGeneral, "rate-limit config updated",
|
||||
zap.String("namespace", ns),
|
||||
zap.Int("rpm", cfg.RequestsPerMinute),
|
||||
zap.Int("burst", cfg.Burst),
|
||||
zap.String("by", caller))
|
||||
|
||||
// Return the new effective config so the client sees what's in place.
|
||||
writeJSON(w, http.StatusOK, GetResponse{
|
||||
Namespace: ns,
|
||||
RequestsPerMinute: cfg.RequestsPerMinute,
|
||||
Burst: cfg.Burst,
|
||||
Source: "override",
|
||||
Scope: scopePerGateway,
|
||||
UpdatedAt: cfg.UpdatedAt,
|
||||
UpdatedBy: cfg.UpdatedBy,
|
||||
MaxRequestsPerMinute: defs.MaxRequestsPerMinute,
|
||||
MaxBurst: defs.MaxBurst,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteConfigHandler — DELETE /v1/namespace/rate-limit. Removes the
|
||||
// override; subsequent requests fall back to the gateway defaults.
|
||||
// Idempotent: 200 even if no override existed.
|
||||
func (h *Handlers) DeleteConfigHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.store == nil || h.manager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "rate-limit config not available on this gateway")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodDelete {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed (use DELETE)")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespace(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
caller := resolveCallerUserID(r)
|
||||
if caller == "" {
|
||||
writeError(w, http.StatusUnauthorized, "user authentication required (JWT)")
|
||||
return
|
||||
}
|
||||
if err := h.store.Delete(boundCtx(r), ns); err != nil {
|
||||
h.logger.ComponentWarn(logging.ComponentGeneral, "rate-limit config DELETE failed",
|
||||
zap.String("namespace", ns), zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete config")
|
||||
return
|
||||
}
|
||||
h.manager.Invalidate(ns)
|
||||
h.logger.ComponentInfo(logging.ComponentGeneral, "rate-limit config cleared",
|
||||
zap.String("namespace", ns), zap.String("by", caller))
|
||||
|
||||
defs := h.manager.Defaults()
|
||||
writeJSON(w, http.StatusOK, GetResponse{
|
||||
Namespace: ns,
|
||||
RequestsPerMinute: defs.RequestsPerMinute,
|
||||
Burst: defs.Burst,
|
||||
Source: "default",
|
||||
Scope: scopePerGateway,
|
||||
MaxRequestsPerMinute: defs.MaxRequestsPerMinute,
|
||||
MaxBurst: defs.MaxBurst,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------- helpers (kept private to the package; mirror push handlers) ----------
|
||||
|
||||
func resolveNamespace(r *http.Request) string {
|
||||
if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveCallerUserID(r *http.Request) string {
|
||||
if v := r.Context().Value(ctxkeys.JWT); v != nil {
|
||||
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||
return claims.Sub
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": message})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, code int, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func boundCtx(r *http.Request) context.Context { return r.Context() }
|
||||
355
core/pkg/gateway/handlers/ratelimit/handler_test.go
Normal file
355
core/pkg/gateway/handlers/ratelimit/handler_test.go
Normal file
@ -0,0 +1,355 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/ratelimit"
|
||||
)
|
||||
|
||||
// ---------------- mock store + setup ----------------
|
||||
|
||||
type memStore struct {
|
||||
mu sync.Mutex
|
||||
rows map[string]ratelimit.Config
|
||||
}
|
||||
|
||||
func newMemStore() *memStore { return &memStore{rows: map[string]ratelimit.Config{}} }
|
||||
|
||||
func (m *memStore) Get(_ context.Context, namespace string) (*ratelimit.Config, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if c, ok := m.rows[namespace]; ok {
|
||||
c2 := c
|
||||
return &c2, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *memStore) Upsert(_ context.Context, cfg ratelimit.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.rows[cfg.Namespace] = cfg
|
||||
return nil
|
||||
}
|
||||
func (m *memStore) Delete(_ context.Context, namespace string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.rows, namespace)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestHandlers(t *testing.T, defs ratelimit.Defaults) (*Handlers, *memStore, *ratelimit.Manager) {
|
||||
t.Helper()
|
||||
store := newMemStore()
|
||||
mgr := ratelimit.NewManager(store, defs, nil)
|
||||
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||
return NewHandlers(store, mgr, logger), store, mgr
|
||||
}
|
||||
|
||||
// authedRequest builds a request with the auth-middleware-set context
|
||||
// keys: namespace + JWT subject. Without these, the handlers reject as
|
||||
// they should.
|
||||
func authedRequest(method, path, body, namespace, sub string) *http.Request {
|
||||
var r *http.Request
|
||||
if body != "" {
|
||||
r = httptest.NewRequest(method, path, bytes.NewBufferString(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
r = httptest.NewRequest(method, path, nil)
|
||||
}
|
||||
ctx := r.Context()
|
||||
if namespace != "" {
|
||||
ctx = context.WithValue(ctx, ctxkeys.NamespaceOverride, namespace)
|
||||
}
|
||||
if sub != "" {
|
||||
ctx = context.WithValue(ctx, ctxkeys.JWT, &auth.JWTClaims{Sub: sub, Namespace: namespace})
|
||||
}
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
// ---------------- GET ----------------
|
||||
|
||||
func TestGetConfigHandler_defaultsWhenNoOverride(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 10,
|
||||
MaxRequestsPerMinute: 1000,
|
||||
MaxBurst: 100,
|
||||
})
|
||||
|
||||
r := authedRequest(http.MethodGet, "/v1/namespace/rate-limit", "", "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.GetConfigHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", w.Code)
|
||||
}
|
||||
var resp GetResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if resp.Source != "default" {
|
||||
t.Errorf("Source = %q, want %q", resp.Source, "default")
|
||||
}
|
||||
if resp.RequestsPerMinute != 100 || resp.Burst != 10 {
|
||||
t.Errorf("effective = (%d, %d), want defaults (100, 10)", resp.RequestsPerMinute, resp.Burst)
|
||||
}
|
||||
if resp.MaxRequestsPerMinute != 1000 || resp.MaxBurst != 100 {
|
||||
t.Errorf("max ceiling = (%d, %d), want (1000, 100)", resp.MaxRequestsPerMinute, resp.MaxBurst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigHandler_overrideWhenSet(t *testing.T) {
|
||||
h, store, _ := newTestHandlers(t, ratelimit.Defaults{RequestsPerMinute: 100, Burst: 10})
|
||||
store.rows["anchat-test"] = ratelimit.Config{
|
||||
Namespace: "anchat-test",
|
||||
RequestsPerMinute: 5000,
|
||||
Burst: 500,
|
||||
UpdatedAt: 42,
|
||||
UpdatedBy: "0xOPERATOR",
|
||||
}
|
||||
|
||||
r := authedRequest(http.MethodGet, "/v1/namespace/rate-limit", "", "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.GetConfigHandler(w, r)
|
||||
|
||||
var resp GetResponse
|
||||
_ = json.NewDecoder(w.Body).Decode(&resp)
|
||||
if resp.Source != "override" {
|
||||
t.Errorf("Source = %q, want %q", resp.Source, "override")
|
||||
}
|
||||
if resp.RequestsPerMinute != 5000 || resp.Burst != 500 {
|
||||
t.Errorf("effective = (%d, %d), want override (5000, 500)", resp.RequestsPerMinute, resp.Burst)
|
||||
}
|
||||
if resp.UpdatedBy != "0xOPERATOR" {
|
||||
t.Errorf("UpdatedBy = %q, want %q", resp.UpdatedBy, "0xOPERATOR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigHandler_noNamespaceContext_returns403(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{RequestsPerMinute: 100, Burst: 10})
|
||||
r := authedRequest(http.MethodGet, "/v1/namespace/rate-limit", "", "", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.GetConfigHandler(w, r)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403 (no namespace = no scope)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- PUT ----------------
|
||||
|
||||
func TestPutConfigHandler_acceptsValidUpdate(t *testing.T) {
|
||||
h, store, mgr := newTestHandlers(t, ratelimit.Defaults{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 10,
|
||||
MaxRequestsPerMinute: 10000,
|
||||
MaxBurst: 1000,
|
||||
})
|
||||
|
||||
body := `{"requests_per_minute": 5000, "burst": 500}`
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Persisted.
|
||||
stored, _ := store.Get(context.Background(), "anchat-test")
|
||||
if stored == nil || stored.RequestsPerMinute != 5000 || stored.Burst != 500 {
|
||||
t.Errorf("not persisted correctly: %+v", stored)
|
||||
}
|
||||
|
||||
// Cache invalidated → manager.Allow now uses the new limit.
|
||||
// 50 sequential calls should all pass under burst=500.
|
||||
for i := 0; i < 50; i++ {
|
||||
if !mgr.Allow(context.Background(), "anchat-test") {
|
||||
t.Fatalf("Allow %d should pass under new burst=500", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_acceptsValueEqualToCap(t *testing.T) {
|
||||
// Boundary: body == cap is accepted (strict `>` in the handler, not `>=`).
|
||||
h, store, _ := newTestHandlers(t, ratelimit.Defaults{
|
||||
MaxRequestsPerMinute: 5000,
|
||||
MaxBurst: 500,
|
||||
})
|
||||
body := `{"requests_per_minute": 5000, "burst": 500}`
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200 (value == cap should be accepted)", w.Code)
|
||||
}
|
||||
got, _ := store.Get(context.Background(), "anchat-test")
|
||||
if got == nil || got.RequestsPerMinute != 5000 || got.Burst != 500 {
|
||||
t.Errorf("not persisted: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_capZeroMeansNoCap(t *testing.T) {
|
||||
// Operator sets MaxRequestsPerMinute=0 and MaxBurst=0 → "no cap".
|
||||
// Tenants can set arbitrarily large values (trusted-tenant deployments).
|
||||
h, store, _ := newTestHandlers(t, ratelimit.Defaults{
|
||||
// No Max* set — interpreted as "disabled / no ceiling".
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 10,
|
||||
})
|
||||
body := `{"requests_per_minute": 999999, "burst": 99999}`
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200 (zero cap should disable check)", w.Code)
|
||||
}
|
||||
got, _ := store.Get(context.Background(), "anchat-test")
|
||||
if got == nil || got.RequestsPerMinute != 999999 || got.Burst != 99999 {
|
||||
t.Errorf("not persisted: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_rejectsAboveOperatorCap(t *testing.T) {
|
||||
h, store, _ := newTestHandlers(t, ratelimit.Defaults{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 10,
|
||||
MaxRequestsPerMinute: 1000,
|
||||
MaxBurst: 100,
|
||||
})
|
||||
|
||||
// Try to set requests_per_minute=99999 — well above the operator cap.
|
||||
body := `{"requests_per_minute": 99999, "burst": 50}`
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400 (above operator cap)", w.Code)
|
||||
}
|
||||
if got, _ := store.Get(context.Background(), "anchat-test"); got != nil {
|
||||
t.Error("rejected request was nevertheless persisted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_rejectsAboveBurstCap(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{
|
||||
MaxRequestsPerMinute: 1000,
|
||||
MaxBurst: 100,
|
||||
})
|
||||
|
||||
body := `{"requests_per_minute": 500, "burst": 9999}`
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400 (burst above operator cap)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_rejectsZeroOrNegative(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{})
|
||||
|
||||
cases := []string{
|
||||
`{"requests_per_minute": 0, "burst": 10}`,
|
||||
`{"requests_per_minute": -1, "burst": 10}`,
|
||||
`{"requests_per_minute": 10, "burst": 0}`,
|
||||
`{"requests_per_minute": 10, "burst": -1}`,
|
||||
`{}`,
|
||||
}
|
||||
for _, body := range cases {
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("body=%s: status = %d, want 400", body, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutConfigHandler_requiresJWT(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{MaxRequestsPerMinute: 0})
|
||||
body := `{"requests_per_minute": 100, "burst": 10}`
|
||||
// No JWT subject — only API-key auth, which can't be attributed.
|
||||
r := authedRequest(http.MethodPut, "/v1/namespace/rate-limit", body, "anchat-test", "")
|
||||
w := httptest.NewRecorder()
|
||||
h.PutConfigHandler(w, r)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401 (no JWT subject = no audit trail)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- DELETE ----------------
|
||||
|
||||
func TestDeleteConfigHandler_removesOverride(t *testing.T) {
|
||||
h, store, mgr := newTestHandlers(t, ratelimit.Defaults{RequestsPerMinute: 60, Burst: 1})
|
||||
store.rows["anchat-test"] = ratelimit.Config{
|
||||
Namespace: "anchat-test", RequestsPerMinute: 6000, Burst: 100,
|
||||
}
|
||||
|
||||
// Warm the cache with the override.
|
||||
if !mgr.Allow(context.Background(), "anchat-test") {
|
||||
t.Fatal("initial Allow should pass under override (burst=100)")
|
||||
}
|
||||
|
||||
r := authedRequest(http.MethodDelete, "/v1/namespace/rate-limit", "", "anchat-test", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteConfigHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", w.Code)
|
||||
}
|
||||
if got, _ := store.Get(context.Background(), "anchat-test"); got != nil {
|
||||
t.Error("override row not deleted")
|
||||
}
|
||||
|
||||
// Cache invalidated → next Allow rebuilds under the default (burst=1).
|
||||
if !mgr.Allow(context.Background(), "anchat-test") {
|
||||
t.Fatal("first post-delete Allow should pass under default burst=1")
|
||||
}
|
||||
if mgr.Allow(context.Background(), "anchat-test") {
|
||||
t.Error("second post-delete Allow should be throttled (burst=1 exhausted, no refill in this test)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteConfigHandler_idempotent(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{})
|
||||
r := authedRequest(http.MethodDelete, "/v1/namespace/rate-limit", "", "no-override-ns", "0xWALLET")
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteConfigHandler(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200 (DELETE must be idempotent)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- method gating ----------------
|
||||
|
||||
func TestHandlers_methodGating(t *testing.T) {
|
||||
h, _, _ := newTestHandlers(t, ratelimit.Defaults{})
|
||||
cases := []struct {
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
method string
|
||||
want int
|
||||
}{
|
||||
{h.GetConfigHandler, http.MethodPost, http.StatusMethodNotAllowed},
|
||||
{h.PutConfigHandler, http.MethodGet, http.StatusMethodNotAllowed},
|
||||
{h.DeleteConfigHandler, http.MethodGet, http.StatusMethodNotAllowed},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
r := authedRequest(tc.method, "/v1/namespace/rate-limit", "{}", "ns", "sub")
|
||||
w := httptest.NewRecorder()
|
||||
tc.handler(w, r)
|
||||
if w.Code != tc.want {
|
||||
t.Errorf("%s: status = %d, want %d", tc.method, w.Code, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -128,6 +128,29 @@ func stripInboundInternalAuthHeaders(h http.Header) {
|
||||
h.Del(HeaderInternalAuthJWTCustom)
|
||||
}
|
||||
|
||||
// maxQueryJWTLength caps the size of a JWT accepted via `?jwt=` query
|
||||
// param. EdDSA + RS256 JWTs minted by this gateway are well under 2 KB;
|
||||
// 4 KB is a generous ceiling that still cheaply rejects DoS attempts
|
||||
// that try to feed multi-MB tokens through the verifier.
|
||||
const maxQueryJWTLength = 4096
|
||||
|
||||
// stripJWTQueryParam removes the `jwt` key from the URL's query string
|
||||
// (if present), mutating r in place. Called after a successful WS-upgrade
|
||||
// JWT-via-query verification so the token doesn't propagate to:
|
||||
// - the namespace-gateway proxy hop (`r.URL.RawQuery` is forwarded)
|
||||
// - downstream handler logs that record `r.URL.RequestURI()`
|
||||
// - any inner `r.URL.Query()` lookups in business logic
|
||||
//
|
||||
// Idempotent: safe to call on requests without a `jwt` param.
|
||||
func stripJWTQueryParam(r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
if !q.Has("jwt") {
|
||||
return
|
||||
}
|
||||
q.Del("jwt")
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// claimsFromInternalAuthHeaders rebuilds a *auth.JWTClaims from the trusted
|
||||
// internal-auth headers. Returns nil if no JWT subject was forwarded (the
|
||||
// caller used an API key, or the request didn't carry validated JWT data).
|
||||
@ -187,6 +210,24 @@ func (g *Gateway) validateAuthForNamespaceProxy(r *http.Request) (namespace stri
|
||||
}
|
||||
}
|
||||
|
||||
// 1b) WS upgrade fallback: JWT via `?jwt=` query. Same rationale as in
|
||||
// authMiddleware — browser / React Native WS clients can't set custom
|
||||
// headers reliably. Bug #240. Strip-after-verify is applied here too
|
||||
// so the JWT doesn't propagate to the namespace gateway over the proxy
|
||||
// hop (where it would otherwise live in the proxied request's RawQuery
|
||||
// + the inner gateway's logs).
|
||||
if isWebSocketUpgrade(r) {
|
||||
tok := strings.TrimSpace(r.URL.Query().Get("jwt"))
|
||||
if tok != "" && len(tok) <= maxQueryJWTLength && strings.Count(tok, ".") == 2 {
|
||||
if c, err := g.authService.ParseAndVerifyJWT(tok); err == nil {
|
||||
if ns := strings.TrimSpace(c.Namespace); ns != "" {
|
||||
stripJWTQueryParam(r)
|
||||
return ns, c, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Try API key
|
||||
key := extractAPIKey(r)
|
||||
if key == "" {
|
||||
@ -389,9 +430,12 @@ func (g *Gateway) loggingMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
// authMiddleware enforces auth when enabled via config.
|
||||
// Accepts:
|
||||
// - Authorization: Bearer <JWT> (RS256 issued by this gateway)
|
||||
// - Authorization: Bearer <JWT> (RS256 / EdDSA issued by this gateway)
|
||||
// - Authorization: Bearer <API key> or ApiKey <API key>
|
||||
// - X-API-Key: <API key>
|
||||
// - ?api_key=<key> or ?token=<key> query string (WebSocket upgrade only)
|
||||
// - ?jwt=<token> query string (WebSocket upgrade only — bug #240; needed
|
||||
// because browser/RN WS clients can't reliably set custom headers)
|
||||
// - X-Internal-Auth-Validated: true (from internal IPs only - pre-authenticated by main gateway)
|
||||
func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -453,6 +497,48 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// 1b) WebSocket-only fallback: JWT in the `?jwt=` query parameter.
|
||||
//
|
||||
// Browser and React Native WebSocket clients can't reliably set custom
|
||||
// headers on the upgrade request — the WebSocket constructor either
|
||||
// ignores the headers argument (browsers) or silently strips
|
||||
// Authorization (RN iOS). Without a fallback, every authenticated WS
|
||||
// endpoint is unreachable from those platforms. Bug #240.
|
||||
//
|
||||
// We gate this ONLY on WS upgrade requests to keep JWTs out of normal
|
||||
// HTTP URLs (where they end up in access logs, referrer headers, and
|
||||
// browser history). For WS, the upgrade URL is only emitted on
|
||||
// connection establishment — much smaller exposure surface — and TLS
|
||||
// (wss://) keeps it off the wire in transit.
|
||||
//
|
||||
// After a successful verify, we STRIP the `jwt` query param from the
|
||||
// request before passing downstream (`stripJWTQueryParam`). This
|
||||
// shrinks the replay window: the token doesn't propagate through the
|
||||
// proxy hop to the namespace gateway, doesn't reach the backend
|
||||
// handler's logs, and doesn't show up in any downstream `r.URL`
|
||||
// inspection. Belt-and-suspenders given the trust we've already
|
||||
// established by verifying the signature.
|
||||
if isWebSocketUpgrade(r) {
|
||||
tok := strings.TrimSpace(r.URL.Query().Get("jwt"))
|
||||
// Cheap length sanity-check before invoking the verifier. Real
|
||||
// EdDSA / RS256 JWTs issued by this gateway are well under 4 KB.
|
||||
// Anything larger is either malformed or a DoS attempt.
|
||||
if tok != "" && len(tok) <= maxQueryJWTLength && strings.Count(tok, ".") == 2 {
|
||||
if claims, err := g.authService.ParseAndVerifyJWT(tok); err == nil {
|
||||
stripJWTQueryParam(r)
|
||||
ctx := context.WithValue(r.Context(), ctxKeyJWT, claims)
|
||||
if ns := strings.TrimSpace(claims.Namespace); ns != "" {
|
||||
ctx = context.WithValue(ctx, CtxKeyNamespaceOverride, ns)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
// Invalid JWT in query — fall through to API key check
|
||||
// rather than 401-ing here, in case the caller also supplied
|
||||
// a valid api_key as belt-and-suspenders.
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Fallback to API key (validate against DB)
|
||||
key := extractAPIKey(r)
|
||||
if key == "" {
|
||||
|
||||
387
core/pkg/gateway/middleware_ws_jwt_test.go
Normal file
387
core/pkg/gateway/middleware_ws_jwt_test.go
Normal file
@ -0,0 +1,387 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
)
|
||||
|
||||
// newAuthServiceForTest builds a real auth.Service backed by a temporary
|
||||
// EdDSA key, suitable for end-to-end auth-middleware tests. Mirrors the
|
||||
// shape of pkg/gateway/auth/service_test.go::createDualKeyService but lives
|
||||
// in package gateway so we don't need to export internals.
|
||||
func newAuthServiceForTest(t *testing.T) *auth.Service {
|
||||
t.Helper()
|
||||
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("rsa keygen: %v", err)
|
||||
}
|
||||
rsaPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
|
||||
})
|
||||
s, err := auth.NewService(logger, nil, string(rsaPEM), "default")
|
||||
if err != nil {
|
||||
t.Fatalf("auth.NewService: %v", err)
|
||||
}
|
||||
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ed25519 keygen: %v", err)
|
||||
}
|
||||
s.SetEdDSAKey(edPriv)
|
||||
return s
|
||||
}
|
||||
|
||||
// Bug #240: WebSocket clients on browsers and React Native can't reliably
|
||||
// set custom headers on the upgrade request. The auth middleware now
|
||||
// accepts a JWT via `?jwt=` query parameter — but only for WebSocket
|
||||
// upgrade requests. These tests lock that contract in.
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_validToken(t *testing.T) {
|
||||
svc := newAuthServiceForTest(t)
|
||||
token, _, err := svc.GenerateJWT("anchat-test", "0xWALLET_SUBJECT", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJWT: %v", err)
|
||||
}
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
var gotClaims *auth.JWTClaims
|
||||
var gotNamespace string
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
if v := r.Context().Value(ctxKeyJWT); v != nil {
|
||||
gotClaims, _ = v.(*auth.JWTClaims)
|
||||
}
|
||||
if v := r.Context().Value(CtxKeyNamespaceOverride); v != nil {
|
||||
gotNamespace, _ = v.(string)
|
||||
}
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/rpc-router/ws?jwt="+token, nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if gotClaims == nil {
|
||||
t.Fatal("ctxKeyJWT not set on the next handler's context")
|
||||
}
|
||||
if gotClaims.Sub != "0xWALLET_SUBJECT" {
|
||||
t.Errorf("claims.Sub = %q, want %q", gotClaims.Sub, "0xWALLET_SUBJECT")
|
||||
}
|
||||
if gotNamespace != "anchat-test" {
|
||||
t.Errorf("namespace override = %q, want %q", gotNamespace, "anchat-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_invalidTokenFallsThrough(t *testing.T) {
|
||||
// Invalid JWT in ?jwt= must NOT set ctxKeyJWT and must NOT short-circuit
|
||||
// to success — middleware should fall through to API-key path.
|
||||
svc := newAuthServiceForTest(t)
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
// Three-segment string that ParseAndVerifyJWT will reject (bad signature).
|
||||
bogus := "eyJhbGciOiJFZERTQSJ9.eyJzdWIiOiJ4In0.bogussignature"
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/private-fn/ws?jwt="+bogus, nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
// No valid creds anywhere → middleware should 401, not call next.
|
||||
if called {
|
||||
t.Error("next handler was called despite invalid JWT — middleware short-circuited incorrectly")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_ignoredOnNonWSRequest(t *testing.T) {
|
||||
// Putting a JWT in ?jwt= on a regular HTTP request must NOT authenticate.
|
||||
// We deliberately scope query-string JWT to WS upgrades to avoid the
|
||||
// privacy issues of JWTs leaking via referrer headers, browser history,
|
||||
// and access logs.
|
||||
svc := newAuthServiceForTest(t)
|
||||
token, _, err := svc.GenerateJWT("ns", "sub", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJWT: %v", err)
|
||||
}
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
// Regular GET (no Upgrade header).
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/some-private-endpoint?jwt="+token, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if called {
|
||||
t.Error("non-WS request with ?jwt= was authenticated — must be WS-only")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_headerWinsOverQuery(t *testing.T) {
|
||||
// Both Authorization: Bearer <header-jwt> AND ?jwt=<query-jwt> present.
|
||||
// Header path runs FIRST and wins. Verifies the query fallback is a
|
||||
// fallback, not an override.
|
||||
svc := newAuthServiceForTest(t)
|
||||
headerJWT, _, _ := svc.GenerateJWT("ns-header", "sub-header", 15*time.Minute)
|
||||
queryJWT, _, _ := svc.GenerateJWT("ns-query", "sub-query", 15*time.Minute)
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
var got *auth.JWTClaims
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
if v := r.Context().Value(ctxKeyJWT); v != nil {
|
||||
got, _ = v.(*auth.JWTClaims)
|
||||
}
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/fn/ws?jwt="+queryJWT, nil)
|
||||
r.Header.Set("Authorization", "Bearer "+headerJWT)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("ctxKeyJWT not set")
|
||||
}
|
||||
if got.Sub != "sub-header" {
|
||||
t.Errorf("Sub = %q, want %q (header should win over query)", got.Sub, "sub-header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_emptyJWTParamFallsThrough(t *testing.T) {
|
||||
// `?jwt=` with empty value should not affect anything — fall through to
|
||||
// API key / default path.
|
||||
svc := newAuthServiceForTest(t)
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/fn/ws?jwt=", nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if called {
|
||||
t.Error("empty ?jwt= unexpectedly authenticated the request")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WSJWTQuery_malformedJWTFallsThrough(t *testing.T) {
|
||||
// `?jwt=not-a-jwt` — single segment, no dots. Must NOT call
|
||||
// ParseAndVerifyJWT (the dot-count gate skips it) AND must NOT
|
||||
// authenticate.
|
||||
svc := newAuthServiceForTest(t)
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/fn/ws?jwt=not-a-jwt", nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if called {
|
||||
t.Error("non-JWT-shaped ?jwt= value was treated as authenticated")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// validateAuthForNamespaceProxy — same WS-JWT-query path, in the main
|
||||
// gateway's pre-validation flow.
|
||||
|
||||
func TestValidateAuthForNamespaceProxy_WSJWTQuery(t *testing.T) {
|
||||
svc := newAuthServiceForTest(t)
|
||||
token, _, err := svc.GenerateJWT("anchat-test", "0xWALLET", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJWT: %v", err)
|
||||
}
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/rpc-router/ws?jwt="+token, nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
|
||||
ns, claims, errMsg := g.validateAuthForNamespaceProxy(r)
|
||||
if errMsg != "" {
|
||||
t.Fatalf("unexpected errMsg: %q", errMsg)
|
||||
}
|
||||
if ns != "anchat-test" {
|
||||
t.Errorf("namespace = %q, want %q", ns, "anchat-test")
|
||||
}
|
||||
if claims == nil {
|
||||
t.Fatal("claims nil; expected JWT claims set")
|
||||
}
|
||||
if claims.Sub != "0xWALLET" {
|
||||
t.Errorf("Sub = %q, want %q", claims.Sub, "0xWALLET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthForNamespaceProxy_WSJWTQuery_ignoredOnNonWS(t *testing.T) {
|
||||
svc := newAuthServiceForTest(t)
|
||||
token, _, err := svc.GenerateJWT("anchat-test", "0xWALLET", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJWT: %v", err)
|
||||
}
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/invoke/rpc-router?jwt="+token, nil)
|
||||
// No Upgrade headers — this is a regular HTTP request.
|
||||
|
||||
ns, claims, errMsg := g.validateAuthForNamespaceProxy(r)
|
||||
if ns != "" || claims != nil {
|
||||
t.Errorf("non-WS request was authenticated via ?jwt= — expected (\"\", nil), got (%q, %#v)", ns, claims)
|
||||
}
|
||||
if errMsg != "" {
|
||||
t.Errorf("unexpected errMsg on no-auth no-WS path: %q", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_WSJWTQuery_strippedAfterVerify guards the hardening
|
||||
// recommendation from the security audit: the `?jwt=` value MUST be
|
||||
// stripped from r.URL.RawQuery after a successful verify so the token
|
||||
// doesn't leak into proxy hops or downstream logs.
|
||||
func TestAuthMiddleware_WSJWTQuery_strippedAfterVerify(t *testing.T) {
|
||||
svc := newAuthServiceForTest(t)
|
||||
token, _, err := svc.GenerateJWT("anchat-test", "0xWALLET", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJWT: %v", err)
|
||||
}
|
||||
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
var seenQueryHasJWT bool
|
||||
var seenRawQuery string
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
seenRawQuery = r.URL.RawQuery
|
||||
seenQueryHasJWT = r.URL.Query().Has("jwt")
|
||||
})
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/fn/ws?jwt="+token+"&other=keepme", nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if seenQueryHasJWT {
|
||||
t.Errorf("`jwt` param survived into downstream handler: RawQuery=%q", seenRawQuery)
|
||||
}
|
||||
// Other query params must survive — strip is surgical.
|
||||
if !strings.Contains(seenRawQuery, "other=keepme") {
|
||||
t.Errorf("unrelated query param dropped: RawQuery=%q", seenRawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_WSJWTQuery_oversizedTokenRejected ensures the cheap
|
||||
// length gate at the start of the branch refuses absurdly long tokens
|
||||
// before reaching the cryptographic verifier (cheap DoS defense).
|
||||
func TestAuthMiddleware_WSJWTQuery_oversizedTokenRejected(t *testing.T) {
|
||||
svc := newAuthServiceForTest(t)
|
||||
g := &Gateway{authService: svc}
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
// 8 KB of dot-padded garbage — exceeds maxQueryJWTLength (4 KB).
|
||||
huge := strings.Repeat("a", 4000) + "." + strings.Repeat("b", 4000) + ".sig"
|
||||
if len(huge) <= maxQueryJWTLength {
|
||||
t.Fatalf("test setup wrong: token len=%d should exceed cap %d", len(huge), maxQueryJWTLength)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/v1/functions/fn/ws?jwt="+huge, nil)
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
g.authMiddleware(next).ServeHTTP(w, r)
|
||||
|
||||
if called {
|
||||
t.Error("oversized ?jwt= was accepted — length cap not enforced")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripJWTQueryParam_idempotent — the helper is called from two paths
|
||||
// and should be safe to call on requests without a `jwt` param.
|
||||
func TestStripJWTQueryParam_idempotent(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
// Strip-path: jwt present → re-encoded (url.Values.Encode sorts).
|
||||
{"foo=bar&jwt=secret&baz=qux", "baz=qux&foo=bar"},
|
||||
{"jwt=secret", ""},
|
||||
{"jwt=secret&jwt=other", ""}, // both copies removed
|
||||
// No-op path: no jwt present → query left untouched (preserves
|
||||
// original ordering and any encoding quirks).
|
||||
{"foo=bar&baz=qux", "foo=bar&baz=qux"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
r := httptest.NewRequest(http.MethodGet, "/?"+tc.in, nil)
|
||||
stripJWTQueryParam(r)
|
||||
if r.URL.RawQuery != tc.want {
|
||||
t.Errorf("strip(%q) = %q, want %q", tc.in, r.URL.RawQuery, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Just to keep go vet happy when wiring custom test contexts.
|
||||
var _ = context.Background
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/httputil"
|
||||
)
|
||||
|
||||
// wireGuardNet is the WireGuard mesh subnet, parsed once at init.
|
||||
@ -153,20 +154,42 @@ func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
// namespaceRateLimitMiddleware enforces per-namespace rate limits.
|
||||
// It runs after auth middleware so the namespace is available in context.
|
||||
//
|
||||
// Feature #69: when g.rateLimitManager is set (production wiring), it's
|
||||
// preferred — supports per-namespace overrides via /v1/namespace/rate-limit
|
||||
// and emits the canonical RPC error envelope on 429 (so SDK clients see
|
||||
// a structured error code instead of plain text). The legacy
|
||||
// g.namespaceRateLimiter remains as a fallback for code paths that
|
||||
// haven't wired the manager yet.
|
||||
func (g *Gateway) namespaceRateLimitMiddleware(next http.Handler) http.Handler {
|
||||
if g.namespaceRateLimiter == nil {
|
||||
if g.rateLimitManager == nil && g.namespaceRateLimiter == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract namespace from context (set by auth middleware)
|
||||
if v := r.Context().Value(CtxKeyNamespaceOverride); v != nil {
|
||||
if ns, ok := v.(string); ok && ns != "" {
|
||||
if !g.namespaceRateLimiter.Allow(ns) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
http.Error(w, "namespace rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
}
|
||||
v := r.Context().Value(CtxKeyNamespaceOverride)
|
||||
ns, _ := v.(string)
|
||||
if ns == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed := true
|
||||
if g.rateLimitManager != nil {
|
||||
allowed = g.rateLimitManager.Allow(r.Context(), ns)
|
||||
} else if g.namespaceRateLimiter != nil {
|
||||
allowed = g.namespaceRateLimiter.Allow(ns)
|
||||
}
|
||||
if !allowed {
|
||||
// Canonical RPC error envelope (bug #212 contract) so SDKs
|
||||
// parse the rate-limit hit instead of seeing plain text. The
|
||||
// 60s retry hint maps to both the HTTP Retry-After header
|
||||
// and the envelope's retry_after field.
|
||||
httputil.WriteRPCError(w, http.StatusTooManyRequests,
|
||||
httputil.ErrCodeRateLimited,
|
||||
"namespace rate limit exceeded — back off and retry in a few seconds",
|
||||
httputil.WithRetryable(),
|
||||
httputil.WithRetryAfter(60))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
256
core/pkg/gateway/rate_limiter_middleware_test.go
Normal file
256
core/pkg/gateway/rate_limiter_middleware_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/ratelimit"
|
||||
)
|
||||
|
||||
// Feature #69: the namespaceRateLimitMiddleware must emit the canonical
|
||||
// RPC error envelope on 429 (not plain text) so SDK clients see a
|
||||
// structured error code instead of a bare HTTP body. Also: when the
|
||||
// Manager is wired, it must take precedence over the legacy
|
||||
// NamespaceRateLimiter.
|
||||
|
||||
// helper: build a Gateway with only the rate-limit fields we care about.
|
||||
func newRateLimitTestGateway(t *testing.T, mgr *ratelimit.Manager, legacy *NamespaceRateLimiter) *Gateway {
|
||||
t.Helper()
|
||||
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||
return &Gateway{
|
||||
rateLimitManager: mgr,
|
||||
namespaceRateLimiter: legacy,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// requestWithNamespace returns a request with the namespace context key
|
||||
// set, as the auth middleware would have done upstream.
|
||||
func requestWithNamespace(ns string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
if ns != "" {
|
||||
r = r.WithContext(context.WithValue(r.Context(), CtxKeyNamespaceOverride, ns))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestNamespaceRateLimitMiddleware_managerPath_emitsCanonicalEnvelopeOn429(t *testing.T) {
|
||||
// burst=1 → first request passes, second 429s.
|
||||
mgr := ratelimit.NewManager(nil, ratelimit.Defaults{RequestsPerMinute: 60, Burst: 1}, nil)
|
||||
g := newRateLimitTestGateway(t, mgr, nil)
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
// 1st request passes.
|
||||
r1 := requestWithNamespace("anchat-test")
|
||||
w1 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w1, r1)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request status = %d, want 200", w1.Code)
|
||||
}
|
||||
|
||||
// 2nd request rate-limited.
|
||||
r2 := requestWithNamespace("anchat-test")
|
||||
w2 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w2, r2)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request status = %d, want 429", w2.Code)
|
||||
}
|
||||
|
||||
// The response MUST be the canonical RPC error envelope, not plain text.
|
||||
if got := w2.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json (envelope, not plain text)", got)
|
||||
}
|
||||
if got := w2.Header().Get("Retry-After"); got == "" {
|
||||
t.Error("Retry-After header missing on 429")
|
||||
}
|
||||
|
||||
var envelope struct {
|
||||
OK bool `json:"ok"`
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Retryable bool `json:"retryable"`
|
||||
RetryAfter float64 `json:"retry_after"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(w2.Body).Decode(&envelope); err != nil {
|
||||
t.Fatalf("decode envelope: %v", err)
|
||||
}
|
||||
if envelope.OK {
|
||||
t.Error("envelope.ok = true, want false")
|
||||
}
|
||||
if envelope.Error.Code != "RATE_LIMITED" {
|
||||
t.Errorf("error.code = %q, want %q (per httputil.ErrCodeRateLimited)", envelope.Error.Code, "RATE_LIMITED")
|
||||
}
|
||||
if !envelope.Error.Retryable {
|
||||
t.Error("error.retryable = false, want true for rate-limit responses")
|
||||
}
|
||||
if envelope.Error.RetryAfter <= 0 {
|
||||
t.Error("error.retry_after = 0, want positive hint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceRateLimitMiddleware_emptyNamespacePassesThrough(t *testing.T) {
|
||||
// No namespace in context (e.g., the auth middleware didn't set one
|
||||
// because the path is public) — middleware must let the request through.
|
||||
mgr := ratelimit.NewManager(nil, ratelimit.Defaults{RequestsPerMinute: 1, Burst: 0}, nil)
|
||||
g := newRateLimitTestGateway(t, mgr, nil)
|
||||
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil) // no namespace context
|
||||
w := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w, r)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("next handler not called for empty-namespace request")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200 (no namespace = no limit)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceRateLimitMiddleware_managerPrefersOverLegacy(t *testing.T) {
|
||||
// Both manager AND legacy limiter present. Manager has burst=10 (lots
|
||||
// of headroom); legacy has burst=1 (would 429 immediately). If the
|
||||
// middleware uses manager, the first 5 requests should all pass. If
|
||||
// it accidentally falls back to legacy, the 2nd would 429.
|
||||
mgr := ratelimit.NewManager(nil, ratelimit.Defaults{RequestsPerMinute: 600, Burst: 10}, nil)
|
||||
legacy := NewNamespaceRateLimiter(60, 1)
|
||||
g := newRateLimitTestGateway(t, mgr, legacy)
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
r := requestWithNamespace("anchat-test")
|
||||
w := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: status = %d, want 200 (manager should win over legacy)", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceRateLimitMiddleware_legacyFallbackWhenManagerNil(t *testing.T) {
|
||||
// No manager wired, only legacy. burst=1, second request must 429.
|
||||
legacy := NewNamespaceRateLimiter(60, 1)
|
||||
g := newRateLimitTestGateway(t, nil, legacy)
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
r1 := requestWithNamespace("anchat-test")
|
||||
w1 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w1, r1)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request status = %d, want 200", w1.Code)
|
||||
}
|
||||
|
||||
r2 := requestWithNamespace("anchat-test")
|
||||
w2 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w2, r2)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("legacy-path second request status = %d, want 429", w2.Code)
|
||||
}
|
||||
// Legacy path uses the same canonical envelope now — verify.
|
||||
if got := w2.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("legacy path Content-Type = %q, want application/json", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceRateLimitMiddleware_bothNilPassesThrough(t *testing.T) {
|
||||
// No rate limiter wired at all (test/dev modes). Middleware is a
|
||||
// no-op — every request passes.
|
||||
g := newRateLimitTestGateway(t, nil, nil)
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
r := requestWithNamespace("anchat-test")
|
||||
w := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w, r)
|
||||
if !nextCalled {
|
||||
t.Error("next handler not called when no rate limiters wired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNamespaceRateLimitMiddleware_cacheTTLPropagation — config change on
|
||||
// a different gateway is picked up after the cache TTL elapses, without
|
||||
// an explicit Invalidate call. This is the bounded-staleness guarantee
|
||||
// that closes the cross-gateway cache-invalidation gap.
|
||||
func TestNamespaceRateLimitMiddleware_cacheTTLPropagation(t *testing.T) {
|
||||
// Use a mutable store to simulate a config change happening on
|
||||
// another gateway between calls.
|
||||
store := &mutableStore{}
|
||||
mgr := ratelimit.NewManager(store, ratelimit.Defaults{RequestsPerMinute: 60, Burst: 1}, nil)
|
||||
// 100ms TTL + 150ms sleep keeps the test deterministic on loaded CI
|
||||
// runners. Over-sleeping is safe (cache stays expired longer, test
|
||||
// still passes); we just need to be sure we DON'T under-sleep.
|
||||
mgr.SetCacheTTL(100 * time.Millisecond)
|
||||
g := newRateLimitTestGateway(t, mgr, nil)
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||
mw := g.namespaceRateLimitMiddleware(next)
|
||||
|
||||
// Round 1: tight default (burst=1). One pass, one 429.
|
||||
r1 := requestWithNamespace("anchat-test")
|
||||
w1 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w1, r1)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("R1 first: status %d", w1.Code)
|
||||
}
|
||||
r2 := requestWithNamespace("anchat-test")
|
||||
w2 := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w2, r2)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("R1 second: status %d, want 429", w2.Code)
|
||||
}
|
||||
|
||||
// Simulate: another gateway's PUT lands; this gateway's store can
|
||||
// now read the new value, but the cached limiter still has burst=1.
|
||||
store.cfg = &ratelimit.Config{
|
||||
Namespace: "anchat-test",
|
||||
RequestsPerMinute: 6000,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
// Wait past the TTL so the cache entry expires and the next Allow
|
||||
// re-reads the store.
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Round 2: burst=100 now in effect. 50 rapid-fire passes.
|
||||
for i := 0; i < 50; i++ {
|
||||
r := requestWithNamespace("anchat-test")
|
||||
w := httptest.NewRecorder()
|
||||
mw.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("R2 request %d: status %d, want 200 (cache TTL should have propagated config)", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mutableStore is a tiny in-memory ConfigStore for the TTL test that lets
|
||||
// us swap the returned config between calls.
|
||||
type mutableStore struct {
|
||||
cfg *ratelimit.Config
|
||||
}
|
||||
|
||||
func (m *mutableStore) Get(_ context.Context, _ string) (*ratelimit.Config, error) {
|
||||
if m.cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
c := *m.cfg
|
||||
return &c, nil
|
||||
}
|
||||
func (m *mutableStore) Upsert(_ context.Context, cfg ratelimit.Config) error { m.cfg = &cfg; return nil }
|
||||
func (m *mutableStore) Delete(_ context.Context, _ string) error { m.cfg = nil; return nil }
|
||||
36
core/pkg/gateway/ratelimit_routes.go
Normal file
36
core/pkg/gateway/ratelimit_routes.go
Normal file
@ -0,0 +1,36 @@
|
||||
package gateway
|
||||
|
||||
// ratelimit_routes.go — method-dispatcher for the per-namespace rate-limit
|
||||
// configuration endpoint. Feature #69. Mirrors the push-config route shape.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/httputil"
|
||||
)
|
||||
|
||||
// rateLimitConfigDispatcher routes GET / PUT / DELETE on
|
||||
// /v1/namespace/rate-limit to the respective handler. When the rate-limit
|
||||
// subsystem isn't wired (older deployments without an ORM client) it
|
||||
// returns a canonical 503 envelope explaining the situation — far better
|
||||
// UX than a bare 404.
|
||||
func (g *Gateway) rateLimitConfigDispatcher(w http.ResponseWriter, r *http.Request) {
|
||||
if g.rateLimitHandlers == nil {
|
||||
httputil.WriteRPCError(w, http.StatusServiceUnavailable,
|
||||
httputil.ErrCodeServiceUnavailable,
|
||||
"rate-limit configuration not available on this gateway")
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
g.rateLimitHandlers.GetConfigHandler(w, r)
|
||||
case http.MethodPut, http.MethodPost:
|
||||
g.rateLimitHandlers.PutConfigHandler(w, r)
|
||||
case http.MethodDelete:
|
||||
g.rateLimitHandlers.DeleteConfigHandler(w, r)
|
||||
default:
|
||||
httputil.WriteRPCError(w, http.StatusMethodNotAllowed,
|
||||
httputil.ErrCodeValidationFailed,
|
||||
"method not allowed: use GET to read, PUT to update, or DELETE to clear")
|
||||
}
|
||||
}
|
||||
@ -144,6 +144,14 @@ func (g *Gateway) Routes() http.Handler {
|
||||
// instead of filing an ops ticket. Method dispatched in the handler.
|
||||
mux.HandleFunc("/v1/push/config", g.pushConfigHandler)
|
||||
|
||||
// Per-namespace rate-limit configuration (feature #69).
|
||||
// GET / PUT / DELETE — tenants self-serve their gateway-level rate
|
||||
// limit override (requests_per_minute, burst) up to an operator-set
|
||||
// ceiling. Falls back to gateway YAML defaults when no override is set.
|
||||
if g.rateLimitHandlers != nil {
|
||||
mux.HandleFunc("/v1/namespace/rate-limit", g.rateLimitConfigDispatcher)
|
||||
}
|
||||
|
||||
// operator node management (wallet JWT auth via middleware)
|
||||
if g.operatorHandler != nil {
|
||||
mux.HandleFunc("/v1/operator/invite", g.operatorHandler.HandleInvite)
|
||||
|
||||
259
core/pkg/ratelimit/manager.go
Normal file
259
core/pkg/ratelimit/manager.go
Normal file
@ -0,0 +1,259 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager is the entry point for per-namespace rate limiting. Every
|
||||
// request goes through Allow(namespace), which:
|
||||
//
|
||||
// 1. Returns from the LRU cache if we've already built a limiter for
|
||||
// this namespace AND the entry hasn't aged past `cacheEntryTTL`.
|
||||
// 2. On cache miss (or expired entry), asks the ConfigStore for an
|
||||
// override. If present, uses (override.RequestsPerMinute,
|
||||
// override.Burst). If absent, uses Defaults.RequestsPerMinute /
|
||||
// Defaults.Burst.
|
||||
// 3. Builds a token-bucket limiter from those values, inserts into the
|
||||
// LRU, and consults it.
|
||||
//
|
||||
// Cache invalidation strategies (defense in depth):
|
||||
//
|
||||
// - Immediate (this-gateway): the config handler calls Invalidate(ns)
|
||||
// after PUT/DELETE so the next request on THIS gateway rebuilds.
|
||||
// - Bounded staleness (cluster-wide): every cached entry expires after
|
||||
// `cacheEntryTTL` (default 30s) and is rebuilt from the latest store
|
||||
// value. This bounds how long a config change can be invisible on
|
||||
// gateways that didn't handle the PUT — without requiring a
|
||||
// pub-sub broadcast layer.
|
||||
//
|
||||
// Per-gateway-bucket semantics (KNOWN BEHAVIOUR):
|
||||
//
|
||||
// Each gateway runs its own Manager and therefore its own per-namespace
|
||||
// token bucket. In an N-gateway deployment, the effective cluster-wide
|
||||
// rate cap for a namespace is N × the configured limit, since the
|
||||
// buckets don't share state. This is intentional for v1 (no shared
|
||||
// bucket store; per-gateway buckets are simple, fast, and survive
|
||||
// gateway-to-gateway partitions). Callers that need a cluster-wide cap
|
||||
// should either set the per-gateway limit to (cluster-cap / N) or
|
||||
// implement a shared-bucket backend in a follow-up.
|
||||
//
|
||||
// Safe for concurrent use.
|
||||
type Manager struct {
|
||||
store ConfigStore
|
||||
defaults Defaults
|
||||
logger *zap.Logger
|
||||
ttl time.Duration // configurable for tests; defaults to cacheEntryTTL
|
||||
|
||||
mu sync.Mutex
|
||||
cache map[string]*list.Element
|
||||
lru *list.List
|
||||
cacheCap int
|
||||
}
|
||||
|
||||
// cacheEntry tracks ONE namespace's compiled limiter plus the time it
|
||||
// was built. Once `age > Manager.ttl`, the next Allow rebuilds from the
|
||||
// store — covers the "config changed on gateway A, gateway B still
|
||||
// cached" multi-gateway gap with a bounded propagation window.
|
||||
type cacheEntry struct {
|
||||
namespace string
|
||||
limiter *bucketLimiter
|
||||
builtAt time.Time
|
||||
}
|
||||
|
||||
// defaultCacheCap caps how many namespaces' limiters we hold in memory.
|
||||
// Each is small (~few hundred bytes); 1024 is generous and bounds memory
|
||||
// under abuse.
|
||||
const defaultCacheCap = 1024
|
||||
|
||||
// cacheEntryTTL bounds how long a stale entry can serve before the next
|
||||
// Allow re-reads the config store. 30s is short enough that operator
|
||||
// config changes propagate quickly across the cluster, and long enough
|
||||
// that the store isn't hit on every request for a busy namespace.
|
||||
const cacheEntryTTL = 30 * time.Second
|
||||
|
||||
// NewManager constructs a Manager. Defaults provides both the fallback
|
||||
// values (when a namespace has no override) AND the operator-imposed
|
||||
// ceiling on tenant PUT requests (handled by the config handler, not
|
||||
// here).
|
||||
func NewManager(store ConfigStore, defaults Defaults, logger *zap.Logger) *Manager {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &Manager{
|
||||
store: store,
|
||||
defaults: defaults.Sane(),
|
||||
logger: logger,
|
||||
ttl: cacheEntryTTL,
|
||||
cache: make(map[string]*list.Element, defaultCacheCap),
|
||||
lru: list.New(),
|
||||
cacheCap: defaultCacheCap,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCacheTTL overrides the default cache-entry TTL. Intended for tests
|
||||
// (where 30 s is too long to wait) and for operators who want a tighter
|
||||
// propagation window across multi-gateway deployments at the cost of
|
||||
// extra store reads. Passing a non-positive value is a no-op.
|
||||
func (m *Manager) SetCacheTTL(d time.Duration) {
|
||||
if d <= 0 {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ttl = d
|
||||
}
|
||||
|
||||
// Allow returns true if a request for the given namespace should be
|
||||
// allowed under that namespace's rate limit. The empty namespace is
|
||||
// always allowed (interpreted as "no namespace context — skip the check
|
||||
// at this layer; per-IP rate limiter still applies upstream").
|
||||
//
|
||||
// A store lookup error degrades to the gateway-wide defaults — we
|
||||
// prefer "let the request through under the safe default" over "deny
|
||||
// the request because the config store is briefly unavailable."
|
||||
func (m *Manager) Allow(ctx context.Context, namespace string) bool {
|
||||
if namespace == "" {
|
||||
return true
|
||||
}
|
||||
limiter := m.getOrBuild(ctx, namespace)
|
||||
return limiter.allow()
|
||||
}
|
||||
|
||||
// Invalidate evicts the cached limiter for a namespace. Called by the
|
||||
// config handler after a successful PUT or DELETE so the next request
|
||||
// rebuilds with current config.
|
||||
func (m *Manager) Invalidate(namespace string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if el, ok := m.cache[namespace]; ok {
|
||||
m.lru.Remove(el)
|
||||
delete(m.cache, namespace)
|
||||
}
|
||||
}
|
||||
|
||||
// Defaults returns the manager's effective defaults. Used by the config
|
||||
// handler to surface the operator ceiling in GET responses and validate
|
||||
// PUT requests.
|
||||
func (m *Manager) Defaults() Defaults {
|
||||
return m.defaults
|
||||
}
|
||||
|
||||
// getOrBuild reads or constructs the limiter for the given namespace.
|
||||
// On cache miss OR expired entry (age > ttl), reads the store, builds
|
||||
// a fresh limiter, and replaces the cache slot. The TTL is what bounds
|
||||
// cross-gateway config staleness — see Manager doc.
|
||||
func (m *Manager) getOrBuild(ctx context.Context, namespace string) *bucketLimiter {
|
||||
m.mu.Lock()
|
||||
if el, ok := m.cache[namespace]; ok {
|
||||
entry := el.Value.(*cacheEntry)
|
||||
if time.Since(entry.builtAt) < m.ttl {
|
||||
m.lru.MoveToFront(el)
|
||||
m.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
// Expired — drop the stale entry, fall through to rebuild.
|
||||
m.lru.Remove(el)
|
||||
delete(m.cache, namespace)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Cache miss (or expired): look up override, fall back to defaults,
|
||||
// build limiter.
|
||||
rpm, burst := m.defaults.RequestsPerMinute, m.defaults.Burst
|
||||
if m.store != nil {
|
||||
cfg, err := m.store.Get(ctx, namespace)
|
||||
if err != nil {
|
||||
// Store error: log and fall through to defaults. Refusing
|
||||
// the request because the DB is briefly unreachable is the
|
||||
// wrong failure mode for a rate limiter.
|
||||
m.logger.Warn("rate-limit config Get failed; using defaults",
|
||||
zap.String("namespace", namespace),
|
||||
zap.Error(err))
|
||||
} else if cfg != nil {
|
||||
if cfg.RequestsPerMinute > 0 {
|
||||
rpm = cfg.RequestsPerMinute
|
||||
}
|
||||
if cfg.Burst > 0 {
|
||||
burst = cfg.Burst
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
limiter := newBucketLimiter(rpm, burst)
|
||||
|
||||
// Insert into cache under lock; evict LRU tail if over cap.
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Another goroutine may have built it concurrently — return their
|
||||
// copy if so to keep one limiter per namespace. A concurrent rebuild
|
||||
// that already replaced an expired entry is also handled here.
|
||||
if el, ok := m.cache[namespace]; ok {
|
||||
entry := el.Value.(*cacheEntry)
|
||||
if time.Since(entry.builtAt) < m.ttl {
|
||||
m.lru.MoveToFront(el)
|
||||
return entry.limiter
|
||||
}
|
||||
// Concurrent build also expired — replace.
|
||||
m.lru.Remove(el)
|
||||
delete(m.cache, namespace)
|
||||
}
|
||||
entry := &cacheEntry{
|
||||
namespace: namespace,
|
||||
limiter: limiter,
|
||||
builtAt: time.Now(),
|
||||
}
|
||||
el := m.lru.PushFront(entry)
|
||||
m.cache[namespace] = el
|
||||
for m.lru.Len() > m.cacheCap {
|
||||
tail := m.lru.Back()
|
||||
if tail == nil {
|
||||
break
|
||||
}
|
||||
m.lru.Remove(tail)
|
||||
delete(m.cache, tail.Value.(*cacheEntry).namespace)
|
||||
}
|
||||
return limiter
|
||||
}
|
||||
|
||||
// bucketLimiter is a token-bucket rate limiter. Local to this package so
|
||||
// the package's behaviour is self-contained and the legacy gateway
|
||||
// RateLimiter in pkg/gateway can be retired once the wiring switches
|
||||
// over. Tokens-per-second is the sustained rate; burst is the cap.
|
||||
type bucketLimiter struct {
|
||||
mu sync.Mutex
|
||||
rate float64 // tokens per second
|
||||
burst float64
|
||||
tokens float64
|
||||
lastCheck time.Time
|
||||
}
|
||||
|
||||
func newBucketLimiter(ratePerMinute, burst int) *bucketLimiter {
|
||||
return &bucketLimiter{
|
||||
rate: float64(ratePerMinute) / 60.0,
|
||||
burst: float64(burst),
|
||||
tokens: float64(burst),
|
||||
lastCheck: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bucketLimiter) allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(b.lastCheck).Seconds()
|
||||
b.tokens += elapsed * b.rate
|
||||
if b.tokens > b.burst {
|
||||
b.tokens = b.burst
|
||||
}
|
||||
b.lastCheck = now
|
||||
if b.tokens >= 1 {
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
242
core/pkg/ratelimit/manager_test.go
Normal file
242
core/pkg/ratelimit/manager_test.go
Normal file
@ -0,0 +1,242 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// memStore is an in-memory ConfigStore for tests.
|
||||
type memStore struct {
|
||||
mu sync.Mutex
|
||||
rows map[string]Config
|
||||
getErr error
|
||||
}
|
||||
|
||||
func newMemStore() *memStore { return &memStore{rows: map[string]Config{}} }
|
||||
|
||||
func (m *memStore) Get(_ context.Context, namespace string) (*Config, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
if c, ok := m.rows[namespace]; ok {
|
||||
c2 := c
|
||||
return &c2, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *memStore) Upsert(_ context.Context, cfg Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.rows[cfg.Namespace] = cfg
|
||||
return nil
|
||||
}
|
||||
func (m *memStore) Delete(_ context.Context, namespace string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.rows, namespace)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Defaults.Sane
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestDefaults_Sane(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in Defaults
|
||||
want Defaults
|
||||
}{
|
||||
{
|
||||
"zero clamps to safe baseline",
|
||||
Defaults{},
|
||||
Defaults{RequestsPerMinute: 10_000, Burst: 5_000},
|
||||
},
|
||||
{
|
||||
"populated values pass through",
|
||||
Defaults{RequestsPerMinute: 500, Burst: 50, MaxRequestsPerMinute: 1000, MaxBurst: 100},
|
||||
Defaults{RequestsPerMinute: 500, Burst: 50, MaxRequestsPerMinute: 1000, MaxBurst: 100},
|
||||
},
|
||||
{
|
||||
"negative clamps to baseline",
|
||||
Defaults{RequestsPerMinute: -1, Burst: -1},
|
||||
Defaults{RequestsPerMinute: 10_000, Burst: 5_000},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.in.Sane()
|
||||
if got != tc.want {
|
||||
t.Errorf("Sane() = %+v, want %+v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager.Allow — base behaviour
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_Allow_emptyNamespaceAlwaysAllowed(t *testing.T) {
|
||||
m := NewManager(newMemStore(), Defaults{RequestsPerMinute: 1, Burst: 1}, nil)
|
||||
for i := 0; i < 10; i++ {
|
||||
if !m.Allow(context.Background(), "") {
|
||||
t.Fatal("empty namespace must always be allowed (per-IP limiter handles that layer)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Allow_burstThenRefill(t *testing.T) {
|
||||
// Burst of 3 → first 3 requests pass, 4th fails.
|
||||
m := NewManager(newMemStore(), Defaults{RequestsPerMinute: 60, Burst: 3}, nil)
|
||||
ns := "test-ns"
|
||||
for i := 0; i < 3; i++ {
|
||||
if !m.Allow(context.Background(), ns) {
|
||||
t.Errorf("request %d should be allowed (within burst)", i+1)
|
||||
}
|
||||
}
|
||||
if m.Allow(context.Background(), ns) {
|
||||
t.Error("request 4 should be denied (burst exhausted)")
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager — per-namespace config override
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_Allow_perNamespaceOverride(t *testing.T) {
|
||||
store := newMemStore()
|
||||
// One namespace gets a generous override; another uses defaults.
|
||||
store.rows["loud-tenant"] = Config{
|
||||
Namespace: "loud-tenant",
|
||||
RequestsPerMinute: 60_000,
|
||||
Burst: 100,
|
||||
}
|
||||
m := NewManager(store, Defaults{RequestsPerMinute: 60, Burst: 1}, nil)
|
||||
|
||||
// Default-namespace can fire only 1 request before being throttled.
|
||||
if !m.Allow(context.Background(), "quiet-tenant") {
|
||||
t.Error("first quiet-tenant request should pass")
|
||||
}
|
||||
if m.Allow(context.Background(), "quiet-tenant") {
|
||||
t.Error("second quiet-tenant request should be throttled (burst=1)")
|
||||
}
|
||||
|
||||
// loud-tenant has the override, burst=100, so 50 in a row all pass.
|
||||
for i := 0; i < 50; i++ {
|
||||
if !m.Allow(context.Background(), "loud-tenant") {
|
||||
t.Fatalf("loud-tenant request %d should pass under override (burst=100)", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager — store error degrades to defaults (fail-open is the safer mode)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_Allow_storeErrorFallsBackToDefaults(t *testing.T) {
|
||||
store := newMemStore()
|
||||
store.getErr = errSentinel("boom")
|
||||
m := NewManager(store, Defaults{RequestsPerMinute: 60, Burst: 1}, nil)
|
||||
if !m.Allow(context.Background(), "any-ns") {
|
||||
t.Error("first request should pass under default burst even when store errs")
|
||||
}
|
||||
if m.Allow(context.Background(), "any-ns") {
|
||||
t.Error("second request should fail under default burst (store errored, defaults applied)")
|
||||
}
|
||||
}
|
||||
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager.Invalidate — cache miss after invalidate picks up new config
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_Invalidate_rebuildsWithNewConfig(t *testing.T) {
|
||||
store := newMemStore()
|
||||
// Initial: tight limit (burst=1).
|
||||
store.rows["tenant"] = Config{Namespace: "tenant", RequestsPerMinute: 60, Burst: 1}
|
||||
m := NewManager(store, Defaults{RequestsPerMinute: 60, Burst: 1}, nil)
|
||||
|
||||
if !m.Allow(context.Background(), "tenant") {
|
||||
t.Fatal("first request should pass")
|
||||
}
|
||||
if m.Allow(context.Background(), "tenant") {
|
||||
t.Fatal("second request should be denied (burst=1)")
|
||||
}
|
||||
|
||||
// Operator/tenant bumps the limit. Manager doesn't see it yet —
|
||||
// previous limiter is cached.
|
||||
store.rows["tenant"] = Config{Namespace: "tenant", RequestsPerMinute: 60, Burst: 100}
|
||||
if m.Allow(context.Background(), "tenant") {
|
||||
t.Error("without Invalidate, manager should still use the old cached limiter")
|
||||
}
|
||||
|
||||
// Invalidate clears the cache → next request rebuilds with new burst.
|
||||
m.Invalidate("tenant")
|
||||
for i := 0; i < 50; i++ {
|
||||
if !m.Allow(context.Background(), "tenant") {
|
||||
t.Fatalf("post-invalidate request %d should pass under new config (burst=100)", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager — concurrent access doesn't double-build limiters
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_concurrentBuilds_oneCanonicalLimiter(t *testing.T) {
|
||||
store := newMemStore()
|
||||
store.rows["tenant"] = Config{Namespace: "tenant", RequestsPerMinute: 60, Burst: 10}
|
||||
m := NewManager(store, Defaults{RequestsPerMinute: 60, Burst: 10}, nil)
|
||||
|
||||
const goroutines = 50
|
||||
var allowedCount int
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if m.Allow(context.Background(), "tenant") {
|
||||
mu.Lock()
|
||||
allowedCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// With burst=10 and 50 concurrent goroutines all hitting the same
|
||||
// namespace, exactly 10 should be allowed (or thereabouts — token
|
||||
// refill happens too fast for clock to matter at these intervals).
|
||||
// Most importantly: NOT 500 (which would happen if each goroutine
|
||||
// got its own freshly-built limiter due to a race).
|
||||
if allowedCount > 15 || allowedCount < 5 {
|
||||
t.Errorf("allowed = %d; expected ~10 (burst=10), got way off — suggests racy double-build", allowedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Manager.Defaults — exposes operator ceiling for handler validation
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestManager_Defaults_exposesOperatorCeiling(t *testing.T) {
|
||||
defs := Defaults{
|
||||
RequestsPerMinute: 1000,
|
||||
Burst: 100,
|
||||
MaxRequestsPerMinute: 5000,
|
||||
MaxBurst: 500,
|
||||
}
|
||||
m := NewManager(nil, defs, nil)
|
||||
got := m.Defaults()
|
||||
if got.MaxRequestsPerMinute != 5000 || got.MaxBurst != 500 {
|
||||
t.Errorf("Defaults().Max* = (%d,%d), want (5000,500)",
|
||||
got.MaxRequestsPerMinute, got.MaxBurst)
|
||||
}
|
||||
}
|
||||
87
core/pkg/ratelimit/rqlite_store.go
Normal file
87
core/pkg/ratelimit/rqlite_store.go
Normal file
@ -0,0 +1,87 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// rqliteStore is the production ConfigStore — persists per-namespace
|
||||
// overrides in the `namespace_rate_limit_config` table (migration 027).
|
||||
type rqliteStore struct {
|
||||
db rqlite.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRqliteConfigStore returns a ConfigStore backed by RQLite.
|
||||
func NewRqliteConfigStore(db rqlite.Client, logger *zap.Logger) ConfigStore {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &rqliteStore{db: db, logger: logger}
|
||||
}
|
||||
|
||||
func (s *rqliteStore) Get(ctx context.Context, namespace string) (*Config, error) {
|
||||
var rows []struct {
|
||||
Namespace string `db:"namespace"`
|
||||
RequestsPerMinute int `db:"requests_per_minute"`
|
||||
Burst int `db:"burst"`
|
||||
UpdatedAt int64 `db:"updated_at"`
|
||||
UpdatedBy string `db:"updated_by"`
|
||||
}
|
||||
err := s.db.Query(ctx, &rows,
|
||||
`SELECT namespace, requests_per_minute, burst, updated_at, updated_by
|
||||
FROM namespace_rate_limit_config WHERE namespace = ? LIMIT 1`, namespace)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rate-limit config Get: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r := rows[0]
|
||||
return &Config{
|
||||
Namespace: r.Namespace,
|
||||
RequestsPerMinute: r.RequestsPerMinute,
|
||||
Burst: r.Burst,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
UpdatedBy: r.UpdatedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *rqliteStore) Upsert(ctx context.Context, cfg Config) error {
|
||||
if cfg.Namespace == "" {
|
||||
return fmt.Errorf("namespace required")
|
||||
}
|
||||
if cfg.RequestsPerMinute <= 0 || cfg.Burst <= 0 {
|
||||
return fmt.Errorf("requests_per_minute and burst must be > 0")
|
||||
}
|
||||
// SQLite UPSERT — single Raft commit, no read-then-write race.
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO namespace_rate_limit_config
|
||||
(namespace, requests_per_minute, burst, updated_at, updated_by)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(namespace) DO UPDATE SET
|
||||
requests_per_minute = excluded.requests_per_minute,
|
||||
burst = excluded.burst,
|
||||
updated_at = excluded.updated_at,
|
||||
updated_by = excluded.updated_by`,
|
||||
cfg.Namespace, cfg.RequestsPerMinute, cfg.Burst, cfg.UpdatedAt, cfg.UpdatedBy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rate-limit config Upsert: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *rqliteStore) Delete(ctx context.Context, namespace string) error {
|
||||
if namespace == "" {
|
||||
return fmt.Errorf("namespace required")
|
||||
}
|
||||
_, err := s.db.Exec(ctx,
|
||||
`DELETE FROM namespace_rate_limit_config WHERE namespace = ?`, namespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rate-limit config Delete: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
100
core/pkg/ratelimit/types.go
Normal file
100
core/pkg/ratelimit/types.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Package ratelimit provides per-namespace rate-limit configuration storage
|
||||
// and a Manager that builds per-namespace token-bucket limiters from that
|
||||
// configuration (with a fallback to gateway-wide defaults).
|
||||
//
|
||||
// Feature #69. Mirrors the per-namespace push-config pattern from bug
|
||||
// #220's follow-up: tenants self-serve their own quota via authenticated
|
||||
// HTTP, and operators retain a hard cap so no tenant can raise their own
|
||||
// limit beyond the global ceiling.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Config is one row of `namespace_rate_limit_config`. A tenant's override
|
||||
// of the gateway's default rate limits.
|
||||
//
|
||||
// IMPORTANT: per-gateway-bucket semantics. The values here apply to ONE
|
||||
// gateway's token bucket. In an N-gateway deployment the effective
|
||||
// cluster-wide rate cap for the namespace is N × RequestsPerMinute (and
|
||||
// N × Burst), because each gateway maintains its own bucket. Operators
|
||||
// who need a cluster-wide cap must either set the per-gateway value to
|
||||
// (cluster-cap / N) or implement a shared-bucket backend. The GET
|
||||
// handler surfaces this caveat in the response so tenants understand
|
||||
// what they're setting.
|
||||
type Config struct {
|
||||
Namespace string
|
||||
RequestsPerMinute int
|
||||
Burst int
|
||||
UpdatedAt int64 // unix seconds
|
||||
UpdatedBy string // free-form audit (wallet address, operator ID, etc.)
|
||||
}
|
||||
|
||||
// Defaults are the gateway-YAML fallback when a namespace hasn't set its
|
||||
// own config. These also serve as the OPERATOR CEILING — tenant PUT
|
||||
// requests with values greater than MaxRequestsPerMinute / MaxBurst are
|
||||
// rejected at the handler boundary. A tenant can request looser limits
|
||||
// only up to (but not beyond) the cap.
|
||||
//
|
||||
// Setting Max* to 0 means "no cap; trust tenant input". Use with care in
|
||||
// shared-infrastructure deployments.
|
||||
type Defaults struct {
|
||||
RequestsPerMinute int
|
||||
Burst int
|
||||
MaxRequestsPerMinute int
|
||||
MaxBurst int
|
||||
}
|
||||
|
||||
// Sane returns a copy with any nonsensical values clamped to safe
|
||||
// fallbacks. A Defaults with zero rate/burst would let every request
|
||||
// through unconditionally; we treat that as misconfiguration and fall
|
||||
// back to a reasonable cluster-friendly baseline.
|
||||
//
|
||||
// Max* values are NOT clamped: a value of zero (the zero-value) is
|
||||
// meaningful — it disables the ceiling check, letting tenants set any
|
||||
// value they want. Operators who want to disable the cap explicitly set
|
||||
// 0. A negative value here is treated identically to 0 (disabled),
|
||||
// since the cap-check in the handler uses `> 0` for "active".
|
||||
func (d Defaults) Sane() Defaults {
|
||||
out := d
|
||||
if out.RequestsPerMinute <= 0 {
|
||||
out.RequestsPerMinute = 10_000
|
||||
}
|
||||
if out.Burst <= 0 {
|
||||
out.Burst = 5_000
|
||||
}
|
||||
// Normalise negatives to 0 so handler.go's `> 0` check has clean
|
||||
// semantics regardless of operator typo.
|
||||
if out.MaxRequestsPerMinute < 0 {
|
||||
out.MaxRequestsPerMinute = 0
|
||||
}
|
||||
if out.MaxBurst < 0 {
|
||||
out.MaxBurst = 0
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ConfigStore reads and writes per-namespace rate-limit overrides.
|
||||
// Implementations are usually RQLite-backed (see rqlite_store.go) but
|
||||
// the interface lets tests swap in an in-memory map.
|
||||
type ConfigStore interface {
|
||||
// Get returns the namespace's override, or (nil, nil) if no override
|
||||
// has been set (caller should fall back to Defaults).
|
||||
Get(ctx context.Context, namespace string) (*Config, error)
|
||||
|
||||
// Upsert inserts or replaces the override for cfg.Namespace.
|
||||
// cfg.UpdatedAt and cfg.UpdatedBy must be populated by the caller.
|
||||
Upsert(ctx context.Context, cfg Config) error
|
||||
|
||||
// Delete removes the override (caller falls back to Defaults).
|
||||
// No error if the row didn't exist — idempotent.
|
||||
Delete(ctx context.Context, namespace string) error
|
||||
}
|
||||
|
||||
// ErrAboveOperatorCap is returned by the config handler when a PUT request
|
||||
// would set a value above the operator-configured Defaults.Max* ceiling.
|
||||
// Surfaced as 400 to the tenant with the cap value, so they know what the
|
||||
// platform allows.
|
||||
var ErrAboveOperatorCap = fmt.Errorf("requested rate limit exceeds operator-configured maximum")
|
||||
@ -3,9 +3,9 @@
|
||||
"schema_version": 1,
|
||||
|
||||
"rules": {
|
||||
"version": "v0.1.0",
|
||||
"sha": "51ce3f8529f9269a80b22b384fa98de6431c04e8",
|
||||
"synced_at": "2026-05-12T10:55:00Z"
|
||||
"version": "v0.2.0",
|
||||
"sha": "bb6e6ef604b420879a44f055af48d4acf57b86d5",
|
||||
"synced_at": "2026-05-12T11:26:00Z"
|
||||
},
|
||||
|
||||
"project": {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@debros/orama",
|
||||
"version": "0.122.11",
|
||||
"version": "0.122.12",
|
||||
"description": "TypeScript SDK for Orama Network - Database, PubSub, Cache, Storage, Vault, and more",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user