package gateway import ( "net" "net/http" "strings" "sync" "time" "github.com/DeBrosOfficial/network/pkg/auth" ) // wireGuardNet is the WireGuard mesh subnet, parsed once at init. var wireGuardNet *net.IPNet func init() { _, wireGuardNet, _ = net.ParseCIDR(auth.WireGuardSubnet) } // RateLimiter implements a token-bucket rate limiter per client IP. type RateLimiter struct { mu sync.Mutex clients map[string]*bucket rate float64 // tokens per second burst int // max tokens (burst capacity) stopCh chan struct{} } type bucket struct { tokens float64 lastCheck time.Time } // NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate; // burst is the maximum number of requests that can be made in a short window. func NewRateLimiter(ratePerMinute, burst int) *RateLimiter { return &RateLimiter{ clients: make(map[string]*bucket), rate: float64(ratePerMinute) / 60.0, burst: burst, } } // Allow checks if a request from the given IP should be allowed. func (rl *RateLimiter) Allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() b, exists := rl.clients[ip] if !exists { rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now} return true } // Refill tokens based on elapsed time elapsed := now.Sub(b.lastCheck).Seconds() b.tokens += elapsed * rl.rate if b.tokens > float64(rl.burst) { b.tokens = float64(rl.burst) } b.lastCheck = now if b.tokens >= 1 { b.tokens-- return true } return false } // Cleanup removes stale entries older than the given duration. func (rl *RateLimiter) Cleanup(maxAge time.Duration) { rl.mu.Lock() defer rl.mu.Unlock() cutoff := time.Now().Add(-maxAge) for ip, b := range rl.clients { if b.lastCheck.Before(cutoff) { delete(rl.clients, ip) } } } // StartCleanup runs periodic cleanup in a goroutine. Call Stop() to terminate it. func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) { rl.stopCh = make(chan struct{}) go func() { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: rl.Cleanup(maxAge) case <-rl.stopCh: return } } }() } // Stop terminates the background cleanup goroutine. func (rl *RateLimiter) Stop() { if rl.stopCh != nil { close(rl.stopCh) } } // NamespaceRateLimiter provides per-namespace rate limiting using a sync.Map // for better concurrent performance than a single mutex. type NamespaceRateLimiter struct { limiters sync.Map // namespace -> *RateLimiter rate int // per-minute rate per namespace burst int } // NewNamespaceRateLimiter creates a per-namespace rate limiter. func NewNamespaceRateLimiter(ratePerMinute, burst int) *NamespaceRateLimiter { return &NamespaceRateLimiter{rate: ratePerMinute, burst: burst} } // Allow checks if a request for the given namespace should be allowed. func (nrl *NamespaceRateLimiter) Allow(namespace string) bool { if namespace == "" { return true } val, _ := nrl.limiters.LoadOrStore(namespace, NewRateLimiter(nrl.rate, nrl.burst)) return val.(*RateLimiter).Allow(namespace) } // rateLimitMiddleware returns 429 when a client exceeds the rate limit. // Internal traffic from the WireGuard subnet is exempt. func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler { if g.rateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := getClientIP(r) // Exempt internal cluster traffic (WireGuard subnet) if isInternalIP(ip) { next.ServeHTTP(w, r) return } if !g.rateLimiter.Allow(ip) { w.Header().Set("Retry-After", "5") http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // namespaceRateLimitMiddleware enforces per-namespace rate limits. // It runs after auth middleware so the namespace is available in context. func (g *Gateway) namespaceRateLimitMiddleware(next http.Handler) http.Handler { if g.namespaceRateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract namespace from context (set by auth middleware) if v := r.Context().Value(CtxKeyNamespaceOverride); v != nil { if ns, ok := v.(string); ok && ns != "" { if !g.namespaceRateLimiter.Allow(ns) { w.Header().Set("Retry-After", "60") http.Error(w, "namespace rate limit exceeded", http.StatusTooManyRequests) return } } } next.ServeHTTP(w, r) }) } // isInternalIP returns true if the IP is in the WireGuard subnet // or is a loopback address. func isInternalIP(ipStr string) bool { // Strip port if present if strings.Contains(ipStr, ":") { host, _, err := net.SplitHostPort(ipStr) if err == nil { ipStr = host } } ip := net.ParseIP(ipStr) if ip == nil { return false } if ip.IsLoopback() { return true } return wireGuardNet.Contains(ip) }