From f8de4af704386f07631e9919ccd700e36b30b33f Mon Sep 17 00:00:00 2001 From: anonpenguin23 Date: Tue, 9 Jun 2026 09:23:54 +0300 Subject: [PATCH] feat(sni-router): implement hot-reloading for route configuration - Add `FileRouteReloader` to watch and atomically update routes from disk - Refactor `main` to support seamless configuration updates without restarts - Ensure existing routes are preserved if a reload encounters an error --- core/cmd/sni-router/main.go | 68 ++++++++----- core/pkg/sniproxy/reloader.go | 93 ++++++++++++++++++ core/pkg/sniproxy/reloader_test.go | 143 ++++++++++++++++++++++++++++ core/pkg/turn/cert_reloader.go | 105 ++++++++++++++++++++ core/pkg/turn/cert_reloader_test.go | 110 +++++++++++++++++++++ core/pkg/turn/server.go | 31 ++++-- 6 files changed, 520 insertions(+), 30 deletions(-) create mode 100644 core/pkg/sniproxy/reloader.go create mode 100644 core/pkg/sniproxy/reloader_test.go create mode 100644 core/pkg/turn/cert_reloader.go create mode 100644 core/pkg/turn/cert_reloader_test.go diff --git a/core/cmd/sni-router/main.go b/core/cmd/sni-router/main.go index cc727df..32ae84b 100644 --- a/core/cmd/sni-router/main.go +++ b/core/cmd/sni-router/main.go @@ -90,10 +90,29 @@ func main() { zap.String("version", version), zap.String("commit", commit)) - cfg := parseConfig(logger) + cfg, configPath := parseConfig(logger) router := sniproxy.NewRouter(toBackend(cfg.Fallback)) - router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback)) + + // Hot-reload the route table from the config file so a namespace's + // cdn/turn SNI routes can be added or removed without restarting the + // router (Router.Replace swaps atomically under in-flight connections). + reloader := sniproxy.NewFileRouteReloader(configPath, + func() ([]sniproxy.Route, sniproxy.Backend, error) { + y, err := loadConfig(configPath) + if err != nil { + return nil, sniproxy.Backend{}, err + } + return toRoutes(y.Routes), toBackend(y.Fallback), nil + }, router, logger.Logger) + if err := reloader.Apply(); err != nil { + logger.ComponentError(logging.ComponentSNI, "Failed to install initial routes", + zap.Error(err)) + os.Exit(1) + } + routeStop := make(chan struct{}) + defer close(routeStop) + go reloader.Watch(sniproxy.DefaultRouteReloadInterval, routeStop) srv := sniproxy.NewServer(router, sniproxy.Config{ ClientHelloTimeout: cfg.ClientHelloTimeout, @@ -140,7 +159,7 @@ func main() { logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete") } -func parseConfig(logger *logging.ColoredLogger) yamlConfig { +func parseConfig(logger *logging.ColoredLogger) (yamlConfig, string) { configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)") flag.Parse() @@ -166,28 +185,11 @@ func parseConfig(logger *logging.ColoredLogger) yamlConfig { } } - data, err := os.ReadFile(configPath) + y, err := loadConfig(configPath) if err != nil { - logger.ComponentError(logging.ComponentSNI, "Config file not found", + logger.ComponentError(logging.ComponentSNI, "Failed to load SNI router config", zap.String("path", configPath), zap.Error(err)) - fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath) - os.Exit(1) - } - - var y yamlConfig - if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil { - logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config", - zap.Error(err)) - fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err) - os.Exit(1) - } - - if errs := validateConfig(&y); len(errs) > 0 { - fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs)) - for _, e := range errs { - fmt.Fprintf(os.Stderr, " - %s\n", e) - } - fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n") + fmt.Fprintf(os.Stderr, "\nSNI router configuration error: %v\n", err) os.Exit(1) } @@ -195,7 +197,25 @@ func parseConfig(logger *logging.ColoredLogger) yamlConfig { zap.String("path", configPath), ) - return y + return y, configPath +} + +// loadConfig reads, decodes, and validates the SNI router config file. Shared +// by the initial parse and every hot-reload, so it returns an error instead of +// exiting the process. +func loadConfig(path string) (yamlConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return yamlConfig{}, fmt.Errorf("read config %s: %w", path, err) + } + var y yamlConfig + if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil { + return yamlConfig{}, fmt.Errorf("parse config: %w", err) + } + if errs := validateConfig(&y); len(errs) > 0 { + return yamlConfig{}, fmt.Errorf("invalid config: %s", strings.Join(errs, "; ")) + } + return y, nil } // validateConfig returns a non-empty slice of human-readable errors on misconfig. diff --git a/core/pkg/sniproxy/reloader.go b/core/pkg/sniproxy/reloader.go new file mode 100644 index 0000000..8cabec4 --- /dev/null +++ b/core/pkg/sniproxy/reloader.go @@ -0,0 +1,93 @@ +package sniproxy + +import ( + "os" + "time" + + "go.uber.org/zap" +) + +// DefaultRouteReloadInterval is the default poll cadence for a FileRouteReloader. +// SNI route changes (a namespace enabling/disabling the stealth-TURN path) are +// infrequent, so 30s of detection latency is fine — and polling keeps the +// dependency surface minimal (no fsnotify), matching the TURNS cert reloader. +const DefaultRouteReloadInterval = 30 * time.Second + +// RouteSource produces the current route table + fallback backend. It returns +// an error when the underlying source (e.g. the YAML config file) is missing or +// invalid; on error the reloader KEEPS the routes already installed in the +// Router rather than dropping traffic for a bad edit. +type RouteSource func() (routes []Route, fallback Backend, err error) + +// FileRouteReloader watches a config file's mtime and re-applies its routes to +// a Router when it changes — so the SNI route table can be updated (e.g. a new +// namespace's cdn/turn routes added) WITHOUT restarting the router. The +// Router's Replace swaps the table atomically while connections are in flight, +// so reloads are seamless. Mirrors the TURNS cert hot-reload pattern. +// +// modTime is only ever touched by the goroutine running Watch (after the +// synchronous startup Apply), so it needs no lock; the routes themselves live +// behind the Router's own mutex. +type FileRouteReloader struct { + path string + source RouteSource + router *Router + logger *zap.Logger + modTime time.Time +} + +// NewFileRouteReloader creates a reloader. source must read/parse the file at +// path; router receives the Replace calls. +func NewFileRouteReloader(path string, source RouteSource, router *Router, logger *zap.Logger) *FileRouteReloader { + if logger == nil { + logger = zap.NewNop() + } + return &FileRouteReloader{path: path, source: source, router: router, logger: logger} +} + +// Apply loads the routes from the source and atomically installs them in the +// Router, recording the config file's mtime. On a source error it returns the +// error and leaves the Router untouched. +func (r *FileRouteReloader) Apply() error { + routes, fallback, err := r.source() + if err != nil { + return err + } + r.router.Replace(routes, fallback) + if fi, statErr := os.Stat(r.path); statErr == nil { + r.modTime = fi.ModTime() + } + return nil +} + +// Watch polls the config file's mtime every interval and re-applies the routes +// when it advances. Blocks until stop is closed. A failed reload logs a warning +// and keeps the currently-installed routes (a bad edit must not blackhole +// traffic). +func (r *FileRouteReloader) Watch(interval time.Duration, stop <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + fi, err := os.Stat(r.path) + if err != nil { + // File briefly absent during an atomic rename — retry next tick. + continue + } + if !fi.ModTime().After(r.modTime) { + continue + } + if err := r.Apply(); err != nil { + r.logger.Warn("SNI route reload failed; keeping current routes", + zap.String("config_path", r.path), zap.Error(err)) + continue + } + r.logger.Info("SNI routes hot-reloaded", + zap.String("config_path", r.path), + zap.Int("routes", len(r.router.Routes()))) + } + } +} diff --git a/core/pkg/sniproxy/reloader_test.go b/core/pkg/sniproxy/reloader_test.go new file mode 100644 index 0000000..d7dc1d6 --- /dev/null +++ b/core/pkg/sniproxy/reloader_test.go @@ -0,0 +1,143 @@ +package sniproxy + +import ( + "errors" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "go.uber.org/zap" +) + +// feat-41: the SNI router hot-reloads its route table from disk so a namespace's +// cdn/turn routes can be added/removed without restarting the router. These pin +// the initial apply, the hot-reload-on-change path, and the resilience contract +// (a bad source keeps the currently-installed routes serving). + +func writeFile(t *testing.T, dir, name, content string) string { + t.Helper() + p := filepath.Join(dir, name) + if err := os.WriteFile(p, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", p, err) + } + return p +} + +func TestFileRouteReloader_appliesInitialRoutes(t *testing.T) { + path := writeFile(t, t.TempDir(), "routes.yaml", "v1") + source := func() ([]Route, Backend, error) { + return []Route{ + {Match: "cdn.ns-a.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}}, + }, Backend{Addr: "127.0.0.1:8443"}, nil + } + router := NewRouter(Backend{Addr: "unset"}) + r := NewFileRouteReloader(path, source, router, zap.NewNop()) + + if err := r.Apply(); err != nil { + t.Fatalf("Apply: %v", err) + } + if got := len(router.Routes()); got != 1 { + t.Fatalf("want 1 route after initial apply, got %d", got) + } + if b := router.Pick("cdn.ns-a.example.com"); b.Addr != "127.0.0.1:5349" { + t.Errorf("route not installed; Pick gave %q", b.Addr) + } + if router.Fallback().Addr != "127.0.0.1:8443" { + t.Errorf("fallback not installed; got %q", router.Fallback().Addr) + } +} + +func TestFileRouteReloader_hotReloadsOnFileChange(t *testing.T) { + path := writeFile(t, t.TempDir(), "routes.yaml", "v1") + + var mu sync.Mutex + version := 1 + source := func() ([]Route, Backend, error) { + mu.Lock() + defer mu.Unlock() + if version == 1 { + return []Route{{Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}}}, + Backend{Addr: "fb:1"}, nil + } + return []Route{ + {Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}}, + {Match: "b.example.com", Backend: Backend{Addr: "127.0.0.1:2"}}, + }, Backend{Addr: "fb:2"}, nil + } + router := NewRouter(Backend{Addr: "unset"}) + r := NewFileRouteReloader(path, source, router, zap.NewNop()) + if err := r.Apply(); err != nil { + t.Fatalf("initial Apply: %v", err) + } + if len(router.Routes()) != 1 { + t.Fatalf("want 1 route initially, got %d", len(router.Routes())) + } + + // "Renew": flip the source to v2 and advance the file mtime so the watcher + // detects the change regardless of filesystem timestamp granularity. + mu.Lock() + version = 2 + mu.Unlock() + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(path, future, future); err != nil { + t.Fatalf("chtimes: %v", err) + } + + stop := make(chan struct{}) + defer close(stop) + go r.Watch(5*time.Millisecond, stop) + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + if len(router.Routes()) == 2 && router.Fallback().Addr == "fb:2" { + return // hot-reloaded + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("routes were not hot-reloaded (have %d routes, fallback %q)", + len(router.Routes()), router.Fallback().Addr) +} + +func TestFileRouteReloader_keepsRoutesOnSourceError(t *testing.T) { + path := writeFile(t, t.TempDir(), "routes.yaml", "v1") + + var mu sync.Mutex + fail := false + source := func() ([]Route, Backend, error) { + mu.Lock() + defer mu.Unlock() + if fail { + return nil, Backend{}, errors.New("invalid config") + } + return []Route{{Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}}}, + Backend{Addr: "fb:1"}, nil + } + router := NewRouter(Backend{Addr: "unset"}) + r := NewFileRouteReloader(path, source, router, zap.NewNop()) + if err := r.Apply(); err != nil { + t.Fatalf("initial Apply: %v", err) + } + + // Make the source fail, then trigger a reload via an mtime bump. + mu.Lock() + fail = true + mu.Unlock() + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(path, future, future); err != nil { + t.Fatalf("chtimes: %v", err) + } + + stop := make(chan struct{}) + go r.Watch(5*time.Millisecond, stop) + time.Sleep(200 * time.Millisecond) // let it tick + hit the failing source + close(stop) + + if got := len(router.Routes()); got != 1 { + t.Errorf("a failed reload must keep the previous routes; got %d routes", got) + } + if router.Fallback().Addr != "fb:1" { + t.Errorf("a failed reload must keep the previous fallback; got %q", router.Fallback().Addr) + } +} diff --git a/core/pkg/turn/cert_reloader.go b/core/pkg/turn/cert_reloader.go new file mode 100644 index 0000000..6602144 --- /dev/null +++ b/core/pkg/turn/cert_reloader.go @@ -0,0 +1,105 @@ +package turn + +import ( + "crypto/tls" + "fmt" + "os" + "sync" + "time" + + "go.uber.org/zap" +) + +// turnCertReloadInterval is how often the TURNS certificate file is polled for +// changes. TLS cert renewals (Caddy DNS-01 for cdn.) happen on the +// order of weeks, so a minute of detection latency is irrelevant; polling keeps +// the dependency surface minimal (no fsnotify) and is robust across the +// atomic-rename pattern certbot/Caddy use when writing a renewed cert. +const turnCertReloadInterval = 60 * time.Second + +// certReloader serves the current TURNS certificate through a tls.Config +// GetCertificate callback and hot-reloads it when the cert file changes on +// disk. This lets a Caddy-renewed certificate be picked up WITHOUT restarting +// the TURN server — a restart would tear down every active relay (~30s RTC +// drop for users mid-call). See plans/platform/04_STEALTH_TURN.md, the +// "cert renewal during cutover" note. +type certReloader struct { + certPath string + keyPath string + logger *zap.Logger + + mu sync.RWMutex + cert *tls.Certificate + modTime time.Time +} + +// newCertReloader loads the initial cert/key pair. Returns an error if the +// initial load fails — TURNS cannot start without a valid certificate. +func newCertReloader(certPath, keyPath string, logger *zap.Logger) (*certReloader, error) { + r := &certReloader{certPath: certPath, keyPath: keyPath, logger: logger} + if err := r.reload(); err != nil { + return nil, err + } + return r, nil +} + +// reload reads the cert/key pair from disk and atomically swaps it in. On +// failure it leaves the previously-loaded certificate in place: a renewal that +// momentarily presents a half-written or mismatched cert/key file must never +// take TURNS down — the old (still-valid) cert keeps serving until the next +// successful reload. +func (r *certReloader) reload() error { + cert, err := tls.LoadX509KeyPair(r.certPath, r.keyPath) + if err != nil { + return fmt.Errorf("load TURNS cert/key (%s): %w", r.certPath, err) + } + var mod time.Time + if fi, statErr := os.Stat(r.certPath); statErr == nil { + mod = fi.ModTime() + } + r.mu.Lock() + r.cert = &cert + r.modTime = mod + r.mu.Unlock() + return nil +} + +// GetCertificate is the tls.Config.GetCertificate callback. It always returns +// the most recently loaded certificate, so every new TLS handshake uses the +// current cert without the listener being recreated. +func (r *certReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + r.mu.RLock() + defer r.mu.RUnlock() + return r.cert, nil +} + +// watch polls the cert file's mtime every interval and reloads when it advances. +// Blocks until stop is closed. +func (r *certReloader) watch(interval time.Duration, stop <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + fi, err := os.Stat(r.certPath) + if err != nil { + // File briefly absent during an atomic rename — retry next tick. + continue + } + r.mu.RLock() + unchanged := !fi.ModTime().After(r.modTime) + r.mu.RUnlock() + if unchanged { + continue + } + if err := r.reload(); err != nil { + r.logger.Warn("TURNS cert reload failed; keeping previous certificate", + zap.String("cert_path", r.certPath), zap.Error(err)) + continue + } + r.logger.Info("TURNS cert hot-reloaded", zap.String("cert_path", r.certPath)) + } + } +} diff --git a/core/pkg/turn/cert_reloader_test.go b/core/pkg/turn/cert_reloader_test.go new file mode 100644 index 0000000..0b4d4f9 --- /dev/null +++ b/core/pkg/turn/cert_reloader_test.go @@ -0,0 +1,110 @@ +package turn + +import ( + "bytes" + "os" + "path/filepath" + "testing" + "time" + + "go.uber.org/zap" +) + +// feat-41: TURNS cert hot-reload lets a Caddy-renewed certificate be picked up +// without restarting the TURN server (a restart drops every active relay). These +// pin: initial load, in-process reload when the file changes, resilience (a bad +// reload keeps the previous cert serving), and the missing-file failure. + +func writeTestCert(t *testing.T, dir string) (certPath, keyPath string) { + t.Helper() + certPath = filepath.Join(dir, "cert.pem") + keyPath = filepath.Join(dir, "key.pem") + if err := GenerateSelfSignedCert(certPath, keyPath, "127.0.0.1"); err != nil { + t.Fatalf("GenerateSelfSignedCert: %v", err) + } + return certPath, keyPath +} + +func leafDER(t *testing.T, r *certReloader) []byte { + t.Helper() + c, err := r.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if c == nil || len(c.Certificate) == 0 { + t.Fatal("GetCertificate returned an empty certificate") + } + return c.Certificate[0] +} + +func TestNewCertReloader_failsOnMissingFiles(t *testing.T) { + if _, err := newCertReloader("/no/such/cert.pem", "/no/such/key.pem", zap.NewNop()); err == nil { + t.Fatal("expected an error when the cert/key files do not exist") + } +} + +func TestCertReloader_loadsAndServesCert(t *testing.T) { + certPath, keyPath := writeTestCert(t, t.TempDir()) + r, err := newCertReloader(certPath, keyPath, zap.NewNop()) + if err != nil { + t.Fatalf("newCertReloader: %v", err) + } + if got := leafDER(t, r); len(got) == 0 { + t.Fatal("served certificate has no leaf") + } +} + +func TestCertReloader_hotReloadsOnFileChange(t *testing.T) { + dir := t.TempDir() + certPath, keyPath := writeTestCert(t, dir) + r, err := newCertReloader(certPath, keyPath, zap.NewNop()) + if err != nil { + t.Fatalf("newCertReloader: %v", err) + } + before := leafDER(t, r) + + // Renew: overwrite with a freshly-generated cert/key pair (different serial + // + key → different leaf) and advance the mtime so the watcher detects it. + if err := GenerateSelfSignedCert(certPath, keyPath, "127.0.0.1"); err != nil { + t.Fatalf("regenerate cert: %v", err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(certPath, future, future); err != nil { + t.Fatalf("chtimes: %v", err) + } + + stop := make(chan struct{}) + defer close(stop) + go r.watch(5*time.Millisecond, stop) + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + if !bytes.Equal(leafDER(t, r), before) { + return // hot-reloaded — the served cert changed + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("certificate was not hot-reloaded after the file changed") +} + +func TestCertReloader_keepsOldCertOnReloadError(t *testing.T) { + certPath, keyPath := writeTestCert(t, t.TempDir()) + r, err := newCertReloader(certPath, keyPath, zap.NewNop()) + if err != nil { + t.Fatalf("newCertReloader: %v", err) + } + before := leafDER(t, r) + + // Corrupt the cert file (simulates a half-written renewal). + if err := os.WriteFile(certPath, []byte("not a pem cert"), 0o644); err != nil { + t.Fatalf("corrupt cert: %v", err) + } + if err := r.reload(); err == nil { + t.Fatal("expected reload to fail on a corrupt cert file") + } + + // The previously-loaded cert must still be served (TURNS must not go down). + if got := leafDER(t, r); !bytes.Equal(got, before) { + t.Error("a failed reload must keep serving the previous certificate") + } +} diff --git a/core/pkg/turn/server.go b/core/pkg/turn/server.go index c80a2f9..d80f361 100644 --- a/core/pkg/turn/server.go +++ b/core/pkg/turn/server.go @@ -23,6 +23,9 @@ type Server struct { conn net.PacketConn // UDP listener on primary port (3478) tcpListener net.Listener // Plain TCP listener on primary port (3478) tlsListener net.Listener // TLS TCP listener for TURNS (port 5349) + + certReloader *certReloader // hot-reloads the TURNS cert; nil when TURNS disabled + certStop chan struct{} // closed to stop the cert-reload watcher goroutine } // NewServer creates and starts a TURN server. @@ -79,23 +82,31 @@ func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) { }, }) - // TURNS: TLS over TCP listener (port 5349) if configured + // TURNS: TLS over TCP listener (port 5349) if configured. + // + // The cert is served via a hot-reloading GetCertificate callback rather + // than a static Certificates slice, so a Caddy-renewed cert is picked up + // in-process without restarting TURN (a restart drops every active relay + // ~30s). See certReloader / plans/platform/04_STEALTH_TURN.md. if cfg.TURNSListenAddr != "" && cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" { - cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath) + reloader, err := newCertReloader(cfg.TLSCertPath, cfg.TLSKeyPath, s.logger) if err != nil { - conn.Close() + s.closeListeners() return nil, fmt.Errorf("failed to load TLS cert/key: %w", err) } tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, + GetCertificate: reloader.GetCertificate, + MinVersion: tls.VersionTLS12, } tlsListener, err := tls.Listen("tcp", cfg.TURNSListenAddr, tlsConfig) if err != nil { - conn.Close() + s.closeListeners() return nil, fmt.Errorf("failed to listen on %s: %w", cfg.TURNSListenAddr, err) } s.tlsListener = tlsListener + s.certReloader = reloader + s.certStop = make(chan struct{}) + go reloader.watch(turnCertReloadInterval, s.certStop) listenerConfigs = append(listenerConfigs, pionTurn.ListenerConfig{ Listener: tlsListener, @@ -207,7 +218,15 @@ func (s *Server) Close() error { return nil } +// closeListeners stops the cert watcher and closes all listeners. It is +// idempotent (every field is nil-guarded and nil'd after use) but is NOT +// mutex-protected — it relies on its call sites being single-threaded relative +// to each other (sequential construction, plus a single Close() from main). func (s *Server) closeListeners() { + if s.certStop != nil { + close(s.certStop) + s.certStop = nil + } if s.conn != nil { s.conn.Close() s.conn = nil