mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 11:46:58 +00:00
193 lines
4.8 KiB
Go
193 lines
4.8 KiB
Go
package gateway
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// wireGuardNet is the WireGuard mesh subnet, parsed once at init.
|
|
var wireGuardNet *net.IPNet
|
|
|
|
func init() {
|
|
_, wireGuardNet, _ = net.ParseCIDR("10.0.0.0/8")
|
|
}
|
|
|
|
// RateLimiter implements a token-bucket rate limiter per client IP.
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
clients map[string]*bucket
|
|
rate float64 // tokens per second
|
|
burst int // max tokens (burst capacity)
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
type bucket struct {
|
|
tokens float64
|
|
lastCheck time.Time
|
|
}
|
|
|
|
// NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate;
|
|
// burst is the maximum number of requests that can be made in a short window.
|
|
func NewRateLimiter(ratePerMinute, burst int) *RateLimiter {
|
|
return &RateLimiter{
|
|
clients: make(map[string]*bucket),
|
|
rate: float64(ratePerMinute) / 60.0,
|
|
burst: burst,
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request from the given IP should be allowed.
|
|
func (rl *RateLimiter) Allow(ip string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
b, exists := rl.clients[ip]
|
|
if !exists {
|
|
rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now}
|
|
return true
|
|
}
|
|
|
|
// Refill tokens based on elapsed time
|
|
elapsed := now.Sub(b.lastCheck).Seconds()
|
|
b.tokens += elapsed * rl.rate
|
|
if b.tokens > float64(rl.burst) {
|
|
b.tokens = float64(rl.burst)
|
|
}
|
|
b.lastCheck = now
|
|
|
|
if b.tokens >= 1 {
|
|
b.tokens--
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Cleanup removes stale entries older than the given duration.
|
|
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
cutoff := time.Now().Add(-maxAge)
|
|
for ip, b := range rl.clients {
|
|
if b.lastCheck.Before(cutoff) {
|
|
delete(rl.clients, ip)
|
|
}
|
|
}
|
|
}
|
|
|
|
// StartCleanup runs periodic cleanup in a goroutine. Call Stop() to terminate it.
|
|
func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) {
|
|
rl.stopCh = make(chan struct{})
|
|
go func() {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
rl.Cleanup(maxAge)
|
|
case <-rl.stopCh:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Stop terminates the background cleanup goroutine.
|
|
func (rl *RateLimiter) Stop() {
|
|
if rl.stopCh != nil {
|
|
close(rl.stopCh)
|
|
}
|
|
}
|
|
|
|
// NamespaceRateLimiter provides per-namespace rate limiting using a sync.Map
|
|
// for better concurrent performance than a single mutex.
|
|
type NamespaceRateLimiter struct {
|
|
limiters sync.Map // namespace -> *RateLimiter
|
|
rate int // per-minute rate per namespace
|
|
burst int
|
|
}
|
|
|
|
// NewNamespaceRateLimiter creates a per-namespace rate limiter.
|
|
func NewNamespaceRateLimiter(ratePerMinute, burst int) *NamespaceRateLimiter {
|
|
return &NamespaceRateLimiter{rate: ratePerMinute, burst: burst}
|
|
}
|
|
|
|
// Allow checks if a request for the given namespace should be allowed.
|
|
func (nrl *NamespaceRateLimiter) Allow(namespace string) bool {
|
|
if namespace == "" {
|
|
return true
|
|
}
|
|
val, _ := nrl.limiters.LoadOrStore(namespace, NewRateLimiter(nrl.rate, nrl.burst))
|
|
return val.(*RateLimiter).Allow(namespace)
|
|
}
|
|
|
|
// rateLimitMiddleware returns 429 when a client exceeds the rate limit.
|
|
// Internal traffic from the WireGuard subnet (10.0.0.0/8) is exempt.
|
|
func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler {
|
|
if g.rateLimiter == nil {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := getClientIP(r)
|
|
|
|
// Exempt internal cluster traffic (WireGuard subnet)
|
|
if isInternalIP(ip) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if !g.rateLimiter.Allow(ip) {
|
|
w.Header().Set("Retry-After", "5")
|
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// namespaceRateLimitMiddleware enforces per-namespace rate limits.
|
|
// It runs after auth middleware so the namespace is available in context.
|
|
func (g *Gateway) namespaceRateLimitMiddleware(next http.Handler) http.Handler {
|
|
if g.namespaceRateLimiter == nil {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract namespace from context (set by auth middleware)
|
|
if v := r.Context().Value(CtxKeyNamespaceOverride); v != nil {
|
|
if ns, ok := v.(string); ok && ns != "" {
|
|
if !g.namespaceRateLimiter.Allow(ns) {
|
|
w.Header().Set("Retry-After", "60")
|
|
http.Error(w, "namespace rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// isInternalIP returns true if the IP is in the WireGuard 10.0.0.0/8 subnet
|
|
// or is a loopback address.
|
|
func isInternalIP(ipStr string) bool {
|
|
// Strip port if present
|
|
if strings.Contains(ipStr, ":") {
|
|
host, _, err := net.SplitHostPort(ipStr)
|
|
if err == nil {
|
|
ipStr = host
|
|
}
|
|
}
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
// 10.0.0.0/8 — WireGuard mesh
|
|
return wireGuardNet.Contains(ip)
|
|
}
|