package gateway import ( "net" "net/http" "strings" "sync" "time" "github.com/DeBrosOfficial/network/pkg/auth" "github.com/DeBrosOfficial/network/pkg/httputil" ) // wireGuardNet is the WireGuard mesh subnet, parsed once at init. var wireGuardNet *net.IPNet func init() { _, wireGuardNet, _ = net.ParseCIDR(auth.WireGuardSubnet) } // RateLimiter implements a token-bucket rate limiter per client IP. type RateLimiter struct { mu sync.Mutex clients map[string]*bucket rate float64 // tokens per second burst int // max tokens (burst capacity) stopCh chan struct{} } type bucket struct { tokens float64 lastCheck time.Time } // NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate; // burst is the maximum number of requests that can be made in a short window. func NewRateLimiter(ratePerMinute, burst int) *RateLimiter { return &RateLimiter{ clients: make(map[string]*bucket), rate: float64(ratePerMinute) / 60.0, burst: burst, } } // Allow checks if a request from the given IP should be allowed. func (rl *RateLimiter) Allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() b, exists := rl.clients[ip] if !exists { rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now} return true } // Refill tokens based on elapsed time elapsed := now.Sub(b.lastCheck).Seconds() b.tokens += elapsed * rl.rate if b.tokens > float64(rl.burst) { b.tokens = float64(rl.burst) } b.lastCheck = now if b.tokens >= 1 { b.tokens-- return true } return false } // Cleanup removes stale entries older than the given duration. func (rl *RateLimiter) Cleanup(maxAge time.Duration) { rl.mu.Lock() defer rl.mu.Unlock() cutoff := time.Now().Add(-maxAge) for ip, b := range rl.clients { if b.lastCheck.Before(cutoff) { delete(rl.clients, ip) } } } // StartCleanup runs periodic cleanup in a goroutine. Call Stop() to terminate it. func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) { rl.stopCh = make(chan struct{}) go func() { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: rl.Cleanup(maxAge) case <-rl.stopCh: return } } }() } // Stop terminates the background cleanup goroutine. func (rl *RateLimiter) Stop() { if rl.stopCh != nil { close(rl.stopCh) } } // NamespaceRateLimiter provides per-namespace rate limiting using a sync.Map // for better concurrent performance than a single mutex. type NamespaceRateLimiter struct { limiters sync.Map // namespace -> *RateLimiter rate int // per-minute rate per namespace burst int } // NewNamespaceRateLimiter creates a per-namespace rate limiter. func NewNamespaceRateLimiter(ratePerMinute, burst int) *NamespaceRateLimiter { return &NamespaceRateLimiter{rate: ratePerMinute, burst: burst} } // Allow checks if a request for the given namespace should be allowed. func (nrl *NamespaceRateLimiter) Allow(namespace string) bool { if namespace == "" { return true } val, _ := nrl.limiters.LoadOrStore(namespace, NewRateLimiter(nrl.rate, nrl.burst)) return val.(*RateLimiter).Allow(namespace) } // rateLimitMiddleware returns 429 when a client exceeds the rate limit. // Internal traffic from the WireGuard subnet is exempt. func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler { if g.rateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := getClientIP(r) // Exempt internal cluster traffic (WireGuard subnet) if isInternalIP(ip) { next.ServeHTTP(w, r) return } if !g.rateLimiter.Allow(ip) { w.Header().Set("Retry-After", "5") http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // namespaceRateLimitMiddleware enforces per-namespace rate limits. // It runs after auth middleware so the namespace is available in context. // // Feature #69: when g.rateLimitManager is set (production wiring), it's // preferred — supports per-namespace overrides via /v1/namespace/rate-limit // and emits the canonical RPC error envelope on 429 (so SDK clients see // a structured error code instead of plain text). The legacy // g.namespaceRateLimiter remains as a fallback for code paths that // haven't wired the manager yet. func (g *Gateway) namespaceRateLimitMiddleware(next http.Handler) http.Handler { if g.rateLimitManager == nil && g.namespaceRateLimiter == nil { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v := r.Context().Value(CtxKeyNamespaceOverride) ns, _ := v.(string) if ns == "" { next.ServeHTTP(w, r) return } allowed := true if g.rateLimitManager != nil { allowed = g.rateLimitManager.Allow(r.Context(), ns) } else if g.namespaceRateLimiter != nil { allowed = g.namespaceRateLimiter.Allow(ns) } if !allowed { // Canonical RPC error envelope (bug #212 contract) so SDKs // parse the rate-limit hit instead of seeing plain text. The // 60s retry hint maps to both the HTTP Retry-After header // and the envelope's retry_after field. httputil.WriteRPCError(w, http.StatusTooManyRequests, httputil.ErrCodeRateLimited, "namespace rate limit exceeded — back off and retry in a few seconds", httputil.WithRetryable(), httputil.WithRetryAfter(60)) return } next.ServeHTTP(w, r) }) } // isInternalIP returns true if the IP is in the WireGuard subnet // or is a loopback address. func isInternalIP(ipStr string) bool { // Strip port if present if strings.Contains(ipStr, ":") { host, _, err := net.SplitHostPort(ipStr) if err == nil { ipStr = host } } ip := net.ParseIP(ipStr) if ip == nil { return false } if ip.IsLoopback() { return true } return wireGuardNet.Contains(ip) }