orama/pkg/gateway/rate_limiter.go
anonpenguin23 fd87eec476 feat(security): add manifest signing, TLS TOFU, refresh token migration
- Invalidate plaintext refresh tokens (migration 019)
- Add `--sign` flag to `orama build` for rootwallet manifest signing
- Add `--ca-fingerprint` TOFU verification for production joins/invites
- Save cluster secrets from join (RQLite auth, Olric key, IPFS peers)
- Add RQLite auth config fields
2026-02-28 15:40:43 +02:00

194 lines
4.8 KiB
Go

package gateway
import (
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/auth"
)
// wireGuardNet is the WireGuard mesh subnet, parsed once at init.
var wireGuardNet *net.IPNet
func init() {
_, wireGuardNet, _ = net.ParseCIDR(auth.WireGuardSubnet)
}
// RateLimiter implements a token-bucket rate limiter per client IP.
type RateLimiter struct {
mu sync.Mutex
clients map[string]*bucket
rate float64 // tokens per second
burst int // max tokens (burst capacity)
stopCh chan struct{}
}
type bucket struct {
tokens float64
lastCheck time.Time
}
// NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate;
// burst is the maximum number of requests that can be made in a short window.
func NewRateLimiter(ratePerMinute, burst int) *RateLimiter {
return &RateLimiter{
clients: make(map[string]*bucket),
rate: float64(ratePerMinute) / 60.0,
burst: burst,
}
}
// Allow checks if a request from the given IP should be allowed.
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
b, exists := rl.clients[ip]
if !exists {
rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now}
return true
}
// Refill tokens based on elapsed time
elapsed := now.Sub(b.lastCheck).Seconds()
b.tokens += elapsed * rl.rate
if b.tokens > float64(rl.burst) {
b.tokens = float64(rl.burst)
}
b.lastCheck = now
if b.tokens >= 1 {
b.tokens--
return true
}
return false
}
// Cleanup removes stale entries older than the given duration.
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
rl.mu.Lock()
defer rl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
for ip, b := range rl.clients {
if b.lastCheck.Before(cutoff) {
delete(rl.clients, ip)
}
}
}
// StartCleanup runs periodic cleanup in a goroutine. Call Stop() to terminate it.
func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) {
rl.stopCh = make(chan struct{})
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.Cleanup(maxAge)
case <-rl.stopCh:
return
}
}
}()
}
// Stop terminates the background cleanup goroutine.
func (rl *RateLimiter) Stop() {
if rl.stopCh != nil {
close(rl.stopCh)
}
}
// NamespaceRateLimiter provides per-namespace rate limiting using a sync.Map
// for better concurrent performance than a single mutex.
type NamespaceRateLimiter struct {
limiters sync.Map // namespace -> *RateLimiter
rate int // per-minute rate per namespace
burst int
}
// NewNamespaceRateLimiter creates a per-namespace rate limiter.
func NewNamespaceRateLimiter(ratePerMinute, burst int) *NamespaceRateLimiter {
return &NamespaceRateLimiter{rate: ratePerMinute, burst: burst}
}
// Allow checks if a request for the given namespace should be allowed.
func (nrl *NamespaceRateLimiter) Allow(namespace string) bool {
if namespace == "" {
return true
}
val, _ := nrl.limiters.LoadOrStore(namespace, NewRateLimiter(nrl.rate, nrl.burst))
return val.(*RateLimiter).Allow(namespace)
}
// rateLimitMiddleware returns 429 when a client exceeds the rate limit.
// Internal traffic from the WireGuard subnet is exempt.
func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler {
if g.rateLimiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
// Exempt internal cluster traffic (WireGuard subnet)
if isInternalIP(ip) {
next.ServeHTTP(w, r)
return
}
if !g.rateLimiter.Allow(ip) {
w.Header().Set("Retry-After", "5")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// namespaceRateLimitMiddleware enforces per-namespace rate limits.
// It runs after auth middleware so the namespace is available in context.
func (g *Gateway) namespaceRateLimitMiddleware(next http.Handler) http.Handler {
if g.namespaceRateLimiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract namespace from context (set by auth middleware)
if v := r.Context().Value(CtxKeyNamespaceOverride); v != nil {
if ns, ok := v.(string); ok && ns != "" {
if !g.namespaceRateLimiter.Allow(ns) {
w.Header().Set("Retry-After", "60")
http.Error(w, "namespace rate limit exceeded", http.StatusTooManyRequests)
return
}
}
}
next.ServeHTTP(w, r)
})
}
// isInternalIP returns true if the IP is in the WireGuard subnet
// or is a loopback address.
func isInternalIP(ipStr string) bool {
// Strip port if present
if strings.Contains(ipStr, ":") {
host, _, err := net.SplitHostPort(ipStr)
if err == nil {
ipStr = host
}
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.IsLoopback() {
return true
}
return wireGuardNet.Contains(ip)
}