From 7690b22c0ac0ef09a7496817753edc0dc0574df4 Mon Sep 17 00:00:00 2001 From: anonpenguin23 Date: Fri, 6 Feb 2026 08:30:11 +0200 Subject: [PATCH] Improved performance on request journey with cache and some tricks --- pkg/gateway/gateway.go | 12 ++ pkg/gateway/middleware.go | 159 +++++++++++++++++------- pkg/gateway/middleware_cache.go | 121 +++++++++++++++++++ pkg/gateway/request_log_batcher.go | 188 +++++++++++++++++++++++++++++ 4 files changed, 439 insertions(+), 41 deletions(-) create mode 100644 pkg/gateway/middleware_cache.go create mode 100644 pkg/gateway/request_log_batcher.go diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 0bb8d8b..2b78e40 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -105,6 +105,12 @@ type Gateway struct { processManager *process.Manager healthChecker *health.HealthChecker + // Middleware cache for auth/routing lookups (eliminates redundant DB queries) + mwCache *middlewareCache + + // Request log batcher (aggregates writes instead of per-request inserts) + logBatcher *requestLogBatcher + // Rate limiter rateLimiter *RateLimiter @@ -298,6 +304,12 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { ) } + // Initialize middleware cache (60s TTL for auth/routing lookups) + gw.mwCache = newMiddlewareCache(60 * time.Second) + + // Initialize request log batcher (flush every 5 seconds) + gw.logBatcher = newRequestLogBatcher(gw, 5*time.Second, 100) + // Initialize rate limiter (300 req/min, burst 50) gw.rateLimiter = NewRateLimiter(300, 50) gw.rateLimiter.StartCleanup(5*time.Minute, 10*time.Minute) diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index 2e58317..c8d8423 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -64,7 +64,14 @@ func (g *Gateway) validateAuthForNamespaceProxy(r *http.Request) (namespace stri return "", "" // No credentials provided } - // Look up API key in main cluster RQLite + // Check middleware cache first + if g.mwCache != nil { + if cachedNS, ok := g.mwCache.GetAPIKeyNamespace(key); ok { + return cachedNS, "" + } + } + + // Cache miss — look up API key in main cluster RQLite db := g.client.Database() internalCtx := client.WithInternalAuth(r.Context()) q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1" @@ -86,6 +93,11 @@ func (g *Gateway) validateAuthForNamespaceProxy(r *http.Request) (namespace stri return "", "invalid API key" } + // Cache the result + if g.mwCache != nil { + g.mwCache.SetAPIKeyNamespace(key, ns) + } + return ns, "" } @@ -208,8 +220,24 @@ func (g *Gateway) loggingMiddleware(next http.Handler) http.Handler { zap.String("duration", dur.String()), ) - // Persist request log asynchronously (best-effort) - go g.persistRequestLog(r, srw, dur) + // Enqueue log entry for batched persistence (replaces per-request DB writes) + if g.logBatcher != nil { + apiKey := "" + if v := r.Context().Value(ctxKeyAPIKey); v != nil { + if s, ok := v.(string); ok { + apiKey = s + } + } + g.logBatcher.Add(requestLogEntry{ + method: r.Method, + path: r.URL.Path, + statusCode: srw.status, + bytesOut: srw.bytes, + durationMs: dur.Milliseconds(), + ip: getClientIP(r), + apiKey: apiKey, + }) + } }) } @@ -278,7 +306,17 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler { return } - // Look up API key in DB and derive namespace + // Check middleware cache first for API key → namespace mapping + if g.mwCache != nil { + if cachedNS, ok := g.mwCache.GetAPIKeyNamespace(key); ok { + reqCtx := context.WithValue(r.Context(), ctxKeyAPIKey, key) + reqCtx = context.WithValue(reqCtx, CtxKeyNamespaceOverride, cachedNS) + next.ServeHTTP(w, r.WithContext(reqCtx)) + return + } + } + + // Cache miss — look up API key in DB and derive namespace // Use authClient for namespace gateways (validates against global RQLite) // Otherwise use regular client for global gateways authClient := g.client @@ -319,6 +357,11 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler { return } + // Cache the result for subsequent requests + if g.mwCache != nil { + g.mwCache.SetAPIKeyNamespace(key, ns) + } + // Attach auth metadata to context for downstream use reqCtx := context.WithValue(r.Context(), ctxKeyAPIKey, key) reqCtx = context.WithValue(reqCtx, CtxKeyNamespaceOverride, ns) @@ -441,6 +484,18 @@ func (g *Gateway) authorizationMiddleware(next http.Handler) http.Handler { return } + // Skip ownership checks for requests pre-authenticated by the main gateway. + // The main gateway already validated the API key and resolved the namespace + // before proxying, so re-checking ownership against the namespace RQLite is + // redundant and adds ~300ms of unnecessary latency (3 DB round-trips). + if r.Header.Get(HeaderInternalAuthValidated) == "true" { + clientIP := getClientIP(r) + if isInternalIP(clientIP) { + next.ServeHTTP(w, r) + return + } + } + // Cross-namespace access control for namespace gateways // The gateway's ClientNamespace determines which namespace this gateway serves gatewayNamespace := "default" @@ -779,50 +834,72 @@ func (g *Gateway) handleNamespaceGatewayRequest(w http.ResponseWriter, r *http.R return } - // Look up namespace cluster gateway using internal (WireGuard) IPs for inter-node proxying - db := g.client.Database() - internalCtx := client.WithInternalAuth(r.Context()) - - // Query all ready namespace gateways and choose a stable target. - // Random selection causes WS subscribe and publish calls to hit different - // nodes, which makes pubsub delivery flaky for short-lived subscriptions. - query := ` - SELECT COALESCE(dn.internal_ip, dn.ip_address), npa.gateway_http_port - FROM namespace_port_allocations npa - JOIN namespace_clusters nc ON npa.namespace_cluster_id = nc.id - JOIN dns_nodes dn ON npa.node_id = dn.id - WHERE nc.namespace_name = ? AND nc.status = 'ready' - ` - result, err := db.Query(internalCtx, query, namespaceName) - if err != nil || result == nil || len(result.Rows) == 0 { - g.logger.ComponentWarn(logging.ComponentGeneral, "namespace gateway not found", - zap.String("namespace", namespaceName), - ) - http.Error(w, "Namespace gateway not found", http.StatusNotFound) - return - } - + // Check middleware cache for namespace gateway targets type namespaceGatewayTarget struct { ip string port int } - targets := make([]namespaceGatewayTarget, 0, len(result.Rows)) - for _, row := range result.Rows { - if len(row) == 0 { - continue - } - ip := getString(row[0]) - if ip == "" { - continue - } - port := 10004 - if len(row) > 1 { - if p := getInt(row[1]); p > 0 { - port = p + var targets []namespaceGatewayTarget + + if g.mwCache != nil { + if cached, ok := g.mwCache.GetNamespaceTargets(namespaceName); ok { + for _, t := range cached { + targets = append(targets, namespaceGatewayTarget{ip: t.ip, port: t.port}) } } - targets = append(targets, namespaceGatewayTarget{ip: ip, port: port}) } + + // Cache miss — look up namespace cluster gateway from DB + if len(targets) == 0 { + db := g.client.Database() + internalCtx := client.WithInternalAuth(r.Context()) + + // Query all ready namespace gateways and choose a stable target. + // Random selection causes WS subscribe and publish calls to hit different + // nodes, which makes pubsub delivery flaky for short-lived subscriptions. + query := ` + SELECT COALESCE(dn.internal_ip, dn.ip_address), npa.gateway_http_port + FROM namespace_port_allocations npa + JOIN namespace_clusters nc ON npa.namespace_cluster_id = nc.id + JOIN dns_nodes dn ON npa.node_id = dn.id + WHERE nc.namespace_name = ? AND nc.status = 'ready' + ` + result, err := db.Query(internalCtx, query, namespaceName) + if err != nil || result == nil || len(result.Rows) == 0 { + g.logger.ComponentWarn(logging.ComponentGeneral, "namespace gateway not found", + zap.String("namespace", namespaceName), + ) + http.Error(w, "Namespace gateway not found", http.StatusNotFound) + return + } + + for _, row := range result.Rows { + if len(row) == 0 { + continue + } + ip := getString(row[0]) + if ip == "" { + continue + } + port := 10004 + if len(row) > 1 { + if p := getInt(row[1]); p > 0 { + port = p + } + } + targets = append(targets, namespaceGatewayTarget{ip: ip, port: port}) + } + + // Cache the result for subsequent requests + if g.mwCache != nil && len(targets) > 0 { + cacheTargets := make([]gatewayTarget, len(targets)) + for i, t := range targets { + cacheTargets[i] = gatewayTarget{ip: t.ip, port: t.port} + } + g.mwCache.SetNamespaceTargets(namespaceName, cacheTargets) + } + } + if len(targets) == 0 { http.Error(w, "Namespace gateway not available", http.StatusServiceUnavailable) return diff --git a/pkg/gateway/middleware_cache.go b/pkg/gateway/middleware_cache.go new file mode 100644 index 0000000..fab8bcb --- /dev/null +++ b/pkg/gateway/middleware_cache.go @@ -0,0 +1,121 @@ +package gateway + +import ( + "sync" + "time" +) + +// middlewareCache provides in-memory TTL caching for frequently-queried middleware +// data that rarely changes. This eliminates redundant RQLite round-trips for: +// - API key → namespace lookups (authMiddleware, validateAuthForNamespaceProxy) +// - Namespace → gateway targets (handleNamespaceGatewayRequest) +type middlewareCache struct { + // apiKeyToNamespace caches API key → namespace name mappings. + // These rarely change and are looked up on every authenticated request. + apiKeyNS map[string]*cachedValue + apiKeyNSMu sync.RWMutex + + // nsGatewayTargets caches namespace → []gatewayTarget for namespace routing. + // Updated infrequently (only when namespace clusters change). + nsTargets map[string]*cachedGatewayTargets + nsTargetsMu sync.RWMutex + + ttl time.Duration +} + +type cachedValue struct { + value string + expiresAt time.Time +} + +type gatewayTarget struct { + ip string + port int +} + +type cachedGatewayTargets struct { + targets []gatewayTarget + expiresAt time.Time +} + +func newMiddlewareCache(ttl time.Duration) *middlewareCache { + mc := &middlewareCache{ + apiKeyNS: make(map[string]*cachedValue), + nsTargets: make(map[string]*cachedGatewayTargets), + ttl: ttl, + } + go mc.cleanup() + return mc +} + +// GetAPIKeyNamespace returns the cached namespace for an API key, or "" if not cached/expired. +func (mc *middlewareCache) GetAPIKeyNamespace(apiKey string) (string, bool) { + mc.apiKeyNSMu.RLock() + defer mc.apiKeyNSMu.RUnlock() + + entry, ok := mc.apiKeyNS[apiKey] + if !ok || time.Now().After(entry.expiresAt) { + return "", false + } + return entry.value, true +} + +// SetAPIKeyNamespace caches an API key → namespace mapping. +func (mc *middlewareCache) SetAPIKeyNamespace(apiKey, namespace string) { + mc.apiKeyNSMu.Lock() + defer mc.apiKeyNSMu.Unlock() + + mc.apiKeyNS[apiKey] = &cachedValue{ + value: namespace, + expiresAt: time.Now().Add(mc.ttl), + } +} + +// GetNamespaceTargets returns cached gateway targets for a namespace, or nil if not cached/expired. +func (mc *middlewareCache) GetNamespaceTargets(namespace string) ([]gatewayTarget, bool) { + mc.nsTargetsMu.RLock() + defer mc.nsTargetsMu.RUnlock() + + entry, ok := mc.nsTargets[namespace] + if !ok || time.Now().After(entry.expiresAt) { + return nil, false + } + return entry.targets, true +} + +// SetNamespaceTargets caches namespace gateway targets. +func (mc *middlewareCache) SetNamespaceTargets(namespace string, targets []gatewayTarget) { + mc.nsTargetsMu.Lock() + defer mc.nsTargetsMu.Unlock() + + mc.nsTargets[namespace] = &cachedGatewayTargets{ + targets: targets, + expiresAt: time.Now().Add(mc.ttl), + } +} + +// cleanup periodically removes expired entries to prevent memory leaks. +func (mc *middlewareCache) cleanup() { + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + + mc.apiKeyNSMu.Lock() + for k, v := range mc.apiKeyNS { + if now.After(v.expiresAt) { + delete(mc.apiKeyNS, k) + } + } + mc.apiKeyNSMu.Unlock() + + mc.nsTargetsMu.Lock() + for k, v := range mc.nsTargets { + if now.After(v.expiresAt) { + delete(mc.nsTargets, k) + } + } + mc.nsTargetsMu.Unlock() + } +} diff --git a/pkg/gateway/request_log_batcher.go b/pkg/gateway/request_log_batcher.go new file mode 100644 index 0000000..5c12382 --- /dev/null +++ b/pkg/gateway/request_log_batcher.go @@ -0,0 +1,188 @@ +package gateway + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// requestLogEntry holds a single request log to be batched. +type requestLogEntry struct { + method string + path string + statusCode int + bytesOut int + durationMs int64 + ip string + apiKey string // raw API key (resolved to ID at flush time in batch) +} + +// requestLogBatcher aggregates request logs and flushes them to RQLite in bulk +// instead of issuing 3 DB writes per request (INSERT log + SELECT api_key_id + UPDATE last_used). +type requestLogBatcher struct { + gw *Gateway + entries []requestLogEntry + mu sync.Mutex + interval time.Duration + maxBatch int + stopCh chan struct{} +} + +func newRequestLogBatcher(gw *Gateway, interval time.Duration, maxBatch int) *requestLogBatcher { + b := &requestLogBatcher{ + gw: gw, + entries: make([]requestLogEntry, 0, maxBatch), + interval: interval, + maxBatch: maxBatch, + stopCh: make(chan struct{}), + } + go b.run() + return b +} + +// Add enqueues a log entry. If the buffer is full, it triggers an early flush. +func (b *requestLogBatcher) Add(entry requestLogEntry) { + b.mu.Lock() + b.entries = append(b.entries, entry) + needsFlush := len(b.entries) >= b.maxBatch + b.mu.Unlock() + + if needsFlush { + go b.flush() + } +} + +// run is the background loop that flushes logs periodically. +func (b *requestLogBatcher) run() { + ticker := time.NewTicker(b.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + b.flush() + case <-b.stopCh: + b.flush() // final flush on stop + return + } + } +} + +// flush writes all buffered log entries to RQLite in a single batch. +func (b *requestLogBatcher) flush() { + b.mu.Lock() + if len(b.entries) == 0 { + b.mu.Unlock() + return + } + batch := b.entries + b.entries = make([]requestLogEntry, 0, b.maxBatch) + b.mu.Unlock() + + if b.gw.client == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + db := b.gw.client.Database() + + // Collect unique API keys that need ID resolution + apiKeySet := make(map[string]struct{}) + for _, e := range batch { + if e.apiKey != "" { + apiKeySet[e.apiKey] = struct{}{} + } + } + + // Batch-resolve API key IDs in a single query + apiKeyIDs := make(map[string]int64) + if len(apiKeySet) > 0 { + keys := make([]string, 0, len(apiKeySet)) + for k := range apiKeySet { + keys = append(keys, k) + } + + placeholders := make([]string, len(keys)) + args := make([]interface{}, len(keys)) + for i, k := range keys { + placeholders[i] = "?" + args[i] = k + } + + q := fmt.Sprintf("SELECT id, key FROM api_keys WHERE key IN (%s)", strings.Join(placeholders, ",")) + res, err := db.Query(client.WithInternalAuth(ctx), q, args...) + if err == nil && res != nil { + for _, row := range res.Rows { + if len(row) >= 2 { + var id int64 + switch v := row[0].(type) { + case float64: + id = int64(v) + case int64: + id = v + } + if key, ok := row[1].(string); ok && id > 0 { + apiKeyIDs[key] = id + } + } + } + } + } + + // Build batch INSERT for request_logs + if len(batch) > 0 { + var sb strings.Builder + sb.WriteString("INSERT INTO request_logs (method, path, status_code, bytes_out, duration_ms, ip, api_key_id) VALUES ") + + args := make([]interface{}, 0, len(batch)*7) + for i, e := range batch { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString("(?, ?, ?, ?, ?, ?, ?)") + + var apiKeyID interface{} = nil + if e.apiKey != "" { + if id, ok := apiKeyIDs[e.apiKey]; ok { + apiKeyID = id + } + } + args = append(args, e.method, e.path, e.statusCode, e.bytesOut, e.durationMs, e.ip, apiKeyID) + } + + _, _ = db.Query(client.WithInternalAuth(ctx), sb.String(), args...) + } + + // Batch UPDATE last_used_at for all API keys seen in this batch + if len(apiKeyIDs) > 0 { + ids := make([]string, 0, len(apiKeyIDs)) + args := make([]interface{}, 0, len(apiKeyIDs)) + for _, id := range apiKeyIDs { + ids = append(ids, "?") + args = append(args, id) + } + + q := fmt.Sprintf("UPDATE api_keys SET last_used_at = CURRENT_TIMESTAMP WHERE id IN (%s)", strings.Join(ids, ",")) + _, _ = db.Query(client.WithInternalAuth(ctx), q, args...) + } + + if b.gw.logger != nil { + b.gw.logger.ComponentDebug(logging.ComponentGeneral, "request logs flushed", + zap.Int("count", len(batch)), + zap.Int("api_keys", len(apiKeyIDs)), + ) + } +} + +// Stop signals the batcher to stop and flush remaining entries. +func (b *requestLogBatcher) Stop() { + close(b.stopCh) +}