mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-06-16 21:54:14 +00:00
feat(sni-router): implement hot-reloading for route configuration
- Add `FileRouteReloader` to watch and atomically update routes from disk - Refactor `main` to support seamless configuration updates without restarts - Ensure existing routes are preserved if a reload encounters an error
This commit is contained in:
parent
32f7b3824e
commit
f8de4af704
@ -90,10 +90,29 @@ func main() {
|
||||
zap.String("version", version),
|
||||
zap.String("commit", commit))
|
||||
|
||||
cfg := parseConfig(logger)
|
||||
cfg, configPath := parseConfig(logger)
|
||||
|
||||
router := sniproxy.NewRouter(toBackend(cfg.Fallback))
|
||||
router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback))
|
||||
|
||||
// Hot-reload the route table from the config file so a namespace's
|
||||
// cdn/turn SNI routes can be added or removed without restarting the
|
||||
// router (Router.Replace swaps atomically under in-flight connections).
|
||||
reloader := sniproxy.NewFileRouteReloader(configPath,
|
||||
func() ([]sniproxy.Route, sniproxy.Backend, error) {
|
||||
y, err := loadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, sniproxy.Backend{}, err
|
||||
}
|
||||
return toRoutes(y.Routes), toBackend(y.Fallback), nil
|
||||
}, router, logger.Logger)
|
||||
if err := reloader.Apply(); err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to install initial routes",
|
||||
zap.Error(err))
|
||||
os.Exit(1)
|
||||
}
|
||||
routeStop := make(chan struct{})
|
||||
defer close(routeStop)
|
||||
go reloader.Watch(sniproxy.DefaultRouteReloadInterval, routeStop)
|
||||
|
||||
srv := sniproxy.NewServer(router, sniproxy.Config{
|
||||
ClientHelloTimeout: cfg.ClientHelloTimeout,
|
||||
@ -140,7 +159,7 @@ func main() {
|
||||
logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete")
|
||||
}
|
||||
|
||||
func parseConfig(logger *logging.ColoredLogger) yamlConfig {
|
||||
func parseConfig(logger *logging.ColoredLogger) (yamlConfig, string) {
|
||||
configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)")
|
||||
flag.Parse()
|
||||
|
||||
@ -166,28 +185,11 @@ func parseConfig(logger *logging.ColoredLogger) yamlConfig {
|
||||
}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
y, err := loadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Config file not found",
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to load SNI router config",
|
||||
zap.String("path", configPath), zap.Error(err))
|
||||
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var y yamlConfig
|
||||
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
|
||||
logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config",
|
||||
zap.Error(err))
|
||||
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if errs := validateConfig(&y); len(errs) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs))
|
||||
for _, e := range errs {
|
||||
fmt.Fprintf(os.Stderr, " - %s\n", e)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
|
||||
fmt.Fprintf(os.Stderr, "\nSNI router configuration error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -195,7 +197,25 @@ func parseConfig(logger *logging.ColoredLogger) yamlConfig {
|
||||
zap.String("path", configPath),
|
||||
)
|
||||
|
||||
return y
|
||||
return y, configPath
|
||||
}
|
||||
|
||||
// loadConfig reads, decodes, and validates the SNI router config file. Shared
|
||||
// by the initial parse and every hot-reload, so it returns an error instead of
|
||||
// exiting the process.
|
||||
func loadConfig(path string) (yamlConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return yamlConfig{}, fmt.Errorf("read config %s: %w", path, err)
|
||||
}
|
||||
var y yamlConfig
|
||||
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
|
||||
return yamlConfig{}, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
if errs := validateConfig(&y); len(errs) > 0 {
|
||||
return yamlConfig{}, fmt.Errorf("invalid config: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
return y, nil
|
||||
}
|
||||
|
||||
// validateConfig returns a non-empty slice of human-readable errors on misconfig.
|
||||
|
||||
93
core/pkg/sniproxy/reloader.go
Normal file
93
core/pkg/sniproxy/reloader.go
Normal file
@ -0,0 +1,93 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DefaultRouteReloadInterval is the default poll cadence for a FileRouteReloader.
|
||||
// SNI route changes (a namespace enabling/disabling the stealth-TURN path) are
|
||||
// infrequent, so 30s of detection latency is fine — and polling keeps the
|
||||
// dependency surface minimal (no fsnotify), matching the TURNS cert reloader.
|
||||
const DefaultRouteReloadInterval = 30 * time.Second
|
||||
|
||||
// RouteSource produces the current route table + fallback backend. It returns
|
||||
// an error when the underlying source (e.g. the YAML config file) is missing or
|
||||
// invalid; on error the reloader KEEPS the routes already installed in the
|
||||
// Router rather than dropping traffic for a bad edit.
|
||||
type RouteSource func() (routes []Route, fallback Backend, err error)
|
||||
|
||||
// FileRouteReloader watches a config file's mtime and re-applies its routes to
|
||||
// a Router when it changes — so the SNI route table can be updated (e.g. a new
|
||||
// namespace's cdn/turn routes added) WITHOUT restarting the router. The
|
||||
// Router's Replace swaps the table atomically while connections are in flight,
|
||||
// so reloads are seamless. Mirrors the TURNS cert hot-reload pattern.
|
||||
//
|
||||
// modTime is only ever touched by the goroutine running Watch (after the
|
||||
// synchronous startup Apply), so it needs no lock; the routes themselves live
|
||||
// behind the Router's own mutex.
|
||||
type FileRouteReloader struct {
|
||||
path string
|
||||
source RouteSource
|
||||
router *Router
|
||||
logger *zap.Logger
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
// NewFileRouteReloader creates a reloader. source must read/parse the file at
|
||||
// path; router receives the Replace calls.
|
||||
func NewFileRouteReloader(path string, source RouteSource, router *Router, logger *zap.Logger) *FileRouteReloader {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &FileRouteReloader{path: path, source: source, router: router, logger: logger}
|
||||
}
|
||||
|
||||
// Apply loads the routes from the source and atomically installs them in the
|
||||
// Router, recording the config file's mtime. On a source error it returns the
|
||||
// error and leaves the Router untouched.
|
||||
func (r *FileRouteReloader) Apply() error {
|
||||
routes, fallback, err := r.source()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.router.Replace(routes, fallback)
|
||||
if fi, statErr := os.Stat(r.path); statErr == nil {
|
||||
r.modTime = fi.ModTime()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Watch polls the config file's mtime every interval and re-applies the routes
|
||||
// when it advances. Blocks until stop is closed. A failed reload logs a warning
|
||||
// and keeps the currently-installed routes (a bad edit must not blackhole
|
||||
// traffic).
|
||||
func (r *FileRouteReloader) Watch(interval time.Duration, stop <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
fi, err := os.Stat(r.path)
|
||||
if err != nil {
|
||||
// File briefly absent during an atomic rename — retry next tick.
|
||||
continue
|
||||
}
|
||||
if !fi.ModTime().After(r.modTime) {
|
||||
continue
|
||||
}
|
||||
if err := r.Apply(); err != nil {
|
||||
r.logger.Warn("SNI route reload failed; keeping current routes",
|
||||
zap.String("config_path", r.path), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
r.logger.Info("SNI routes hot-reloaded",
|
||||
zap.String("config_path", r.path),
|
||||
zap.Int("routes", len(r.router.Routes())))
|
||||
}
|
||||
}
|
||||
}
|
||||
143
core/pkg/sniproxy/reloader_test.go
Normal file
143
core/pkg/sniproxy/reloader_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package sniproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// feat-41: the SNI router hot-reloads its route table from disk so a namespace's
|
||||
// cdn/turn routes can be added/removed without restarting the router. These pin
|
||||
// the initial apply, the hot-reload-on-change path, and the resilience contract
|
||||
// (a bad source keeps the currently-installed routes serving).
|
||||
|
||||
func writeFile(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(p, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("write %s: %v", p, err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestFileRouteReloader_appliesInitialRoutes(t *testing.T) {
|
||||
path := writeFile(t, t.TempDir(), "routes.yaml", "v1")
|
||||
source := func() ([]Route, Backend, error) {
|
||||
return []Route{
|
||||
{Match: "cdn.ns-a.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}},
|
||||
}, Backend{Addr: "127.0.0.1:8443"}, nil
|
||||
}
|
||||
router := NewRouter(Backend{Addr: "unset"})
|
||||
r := NewFileRouteReloader(path, source, router, zap.NewNop())
|
||||
|
||||
if err := r.Apply(); err != nil {
|
||||
t.Fatalf("Apply: %v", err)
|
||||
}
|
||||
if got := len(router.Routes()); got != 1 {
|
||||
t.Fatalf("want 1 route after initial apply, got %d", got)
|
||||
}
|
||||
if b := router.Pick("cdn.ns-a.example.com"); b.Addr != "127.0.0.1:5349" {
|
||||
t.Errorf("route not installed; Pick gave %q", b.Addr)
|
||||
}
|
||||
if router.Fallback().Addr != "127.0.0.1:8443" {
|
||||
t.Errorf("fallback not installed; got %q", router.Fallback().Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileRouteReloader_hotReloadsOnFileChange(t *testing.T) {
|
||||
path := writeFile(t, t.TempDir(), "routes.yaml", "v1")
|
||||
|
||||
var mu sync.Mutex
|
||||
version := 1
|
||||
source := func() ([]Route, Backend, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if version == 1 {
|
||||
return []Route{{Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}}},
|
||||
Backend{Addr: "fb:1"}, nil
|
||||
}
|
||||
return []Route{
|
||||
{Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}},
|
||||
{Match: "b.example.com", Backend: Backend{Addr: "127.0.0.1:2"}},
|
||||
}, Backend{Addr: "fb:2"}, nil
|
||||
}
|
||||
router := NewRouter(Backend{Addr: "unset"})
|
||||
r := NewFileRouteReloader(path, source, router, zap.NewNop())
|
||||
if err := r.Apply(); err != nil {
|
||||
t.Fatalf("initial Apply: %v", err)
|
||||
}
|
||||
if len(router.Routes()) != 1 {
|
||||
t.Fatalf("want 1 route initially, got %d", len(router.Routes()))
|
||||
}
|
||||
|
||||
// "Renew": flip the source to v2 and advance the file mtime so the watcher
|
||||
// detects the change regardless of filesystem timestamp granularity.
|
||||
mu.Lock()
|
||||
version = 2
|
||||
mu.Unlock()
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(path, future, future); err != nil {
|
||||
t.Fatalf("chtimes: %v", err)
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
go r.Watch(5*time.Millisecond, stop)
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if len(router.Routes()) == 2 && router.Fallback().Addr == "fb:2" {
|
||||
return // hot-reloaded
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("routes were not hot-reloaded (have %d routes, fallback %q)",
|
||||
len(router.Routes()), router.Fallback().Addr)
|
||||
}
|
||||
|
||||
func TestFileRouteReloader_keepsRoutesOnSourceError(t *testing.T) {
|
||||
path := writeFile(t, t.TempDir(), "routes.yaml", "v1")
|
||||
|
||||
var mu sync.Mutex
|
||||
fail := false
|
||||
source := func() ([]Route, Backend, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if fail {
|
||||
return nil, Backend{}, errors.New("invalid config")
|
||||
}
|
||||
return []Route{{Match: "a.example.com", Backend: Backend{Addr: "127.0.0.1:1"}}},
|
||||
Backend{Addr: "fb:1"}, nil
|
||||
}
|
||||
router := NewRouter(Backend{Addr: "unset"})
|
||||
r := NewFileRouteReloader(path, source, router, zap.NewNop())
|
||||
if err := r.Apply(); err != nil {
|
||||
t.Fatalf("initial Apply: %v", err)
|
||||
}
|
||||
|
||||
// Make the source fail, then trigger a reload via an mtime bump.
|
||||
mu.Lock()
|
||||
fail = true
|
||||
mu.Unlock()
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(path, future, future); err != nil {
|
||||
t.Fatalf("chtimes: %v", err)
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
go r.Watch(5*time.Millisecond, stop)
|
||||
time.Sleep(200 * time.Millisecond) // let it tick + hit the failing source
|
||||
close(stop)
|
||||
|
||||
if got := len(router.Routes()); got != 1 {
|
||||
t.Errorf("a failed reload must keep the previous routes; got %d routes", got)
|
||||
}
|
||||
if router.Fallback().Addr != "fb:1" {
|
||||
t.Errorf("a failed reload must keep the previous fallback; got %q", router.Fallback().Addr)
|
||||
}
|
||||
}
|
||||
105
core/pkg/turn/cert_reloader.go
Normal file
105
core/pkg/turn/cert_reloader.go
Normal file
@ -0,0 +1,105 @@
|
||||
package turn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// turnCertReloadInterval is how often the TURNS certificate file is polled for
|
||||
// changes. TLS cert renewals (Caddy DNS-01 for cdn.<base-domain>) happen on the
|
||||
// order of weeks, so a minute of detection latency is irrelevant; polling keeps
|
||||
// the dependency surface minimal (no fsnotify) and is robust across the
|
||||
// atomic-rename pattern certbot/Caddy use when writing a renewed cert.
|
||||
const turnCertReloadInterval = 60 * time.Second
|
||||
|
||||
// certReloader serves the current TURNS certificate through a tls.Config
|
||||
// GetCertificate callback and hot-reloads it when the cert file changes on
|
||||
// disk. This lets a Caddy-renewed certificate be picked up WITHOUT restarting
|
||||
// the TURN server — a restart would tear down every active relay (~30s RTC
|
||||
// drop for users mid-call). See plans/platform/04_STEALTH_TURN.md, the
|
||||
// "cert renewal during cutover" note.
|
||||
type certReloader struct {
|
||||
certPath string
|
||||
keyPath string
|
||||
logger *zap.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
// newCertReloader loads the initial cert/key pair. Returns an error if the
|
||||
// initial load fails — TURNS cannot start without a valid certificate.
|
||||
func newCertReloader(certPath, keyPath string, logger *zap.Logger) (*certReloader, error) {
|
||||
r := &certReloader{certPath: certPath, keyPath: keyPath, logger: logger}
|
||||
if err := r.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// reload reads the cert/key pair from disk and atomically swaps it in. On
|
||||
// failure it leaves the previously-loaded certificate in place: a renewal that
|
||||
// momentarily presents a half-written or mismatched cert/key file must never
|
||||
// take TURNS down — the old (still-valid) cert keeps serving until the next
|
||||
// successful reload.
|
||||
func (r *certReloader) reload() error {
|
||||
cert, err := tls.LoadX509KeyPair(r.certPath, r.keyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load TURNS cert/key (%s): %w", r.certPath, err)
|
||||
}
|
||||
var mod time.Time
|
||||
if fi, statErr := os.Stat(r.certPath); statErr == nil {
|
||||
mod = fi.ModTime()
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.cert = &cert
|
||||
r.modTime = mod
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificate is the tls.Config.GetCertificate callback. It always returns
|
||||
// the most recently loaded certificate, so every new TLS handshake uses the
|
||||
// current cert without the listener being recreated.
|
||||
func (r *certReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.cert, nil
|
||||
}
|
||||
|
||||
// watch polls the cert file's mtime every interval and reloads when it advances.
|
||||
// Blocks until stop is closed.
|
||||
func (r *certReloader) watch(interval time.Duration, stop <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
fi, err := os.Stat(r.certPath)
|
||||
if err != nil {
|
||||
// File briefly absent during an atomic rename — retry next tick.
|
||||
continue
|
||||
}
|
||||
r.mu.RLock()
|
||||
unchanged := !fi.ModTime().After(r.modTime)
|
||||
r.mu.RUnlock()
|
||||
if unchanged {
|
||||
continue
|
||||
}
|
||||
if err := r.reload(); err != nil {
|
||||
r.logger.Warn("TURNS cert reload failed; keeping previous certificate",
|
||||
zap.String("cert_path", r.certPath), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
r.logger.Info("TURNS cert hot-reloaded", zap.String("cert_path", r.certPath))
|
||||
}
|
||||
}
|
||||
}
|
||||
110
core/pkg/turn/cert_reloader_test.go
Normal file
110
core/pkg/turn/cert_reloader_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package turn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// feat-41: TURNS cert hot-reload lets a Caddy-renewed certificate be picked up
|
||||
// without restarting the TURN server (a restart drops every active relay). These
|
||||
// pin: initial load, in-process reload when the file changes, resilience (a bad
|
||||
// reload keeps the previous cert serving), and the missing-file failure.
|
||||
|
||||
func writeTestCert(t *testing.T, dir string) (certPath, keyPath string) {
|
||||
t.Helper()
|
||||
certPath = filepath.Join(dir, "cert.pem")
|
||||
keyPath = filepath.Join(dir, "key.pem")
|
||||
if err := GenerateSelfSignedCert(certPath, keyPath, "127.0.0.1"); err != nil {
|
||||
t.Fatalf("GenerateSelfSignedCert: %v", err)
|
||||
}
|
||||
return certPath, keyPath
|
||||
}
|
||||
|
||||
func leafDER(t *testing.T, r *certReloader) []byte {
|
||||
t.Helper()
|
||||
c, err := r.GetCertificate(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCertificate: %v", err)
|
||||
}
|
||||
if c == nil || len(c.Certificate) == 0 {
|
||||
t.Fatal("GetCertificate returned an empty certificate")
|
||||
}
|
||||
return c.Certificate[0]
|
||||
}
|
||||
|
||||
func TestNewCertReloader_failsOnMissingFiles(t *testing.T) {
|
||||
if _, err := newCertReloader("/no/such/cert.pem", "/no/such/key.pem", zap.NewNop()); err == nil {
|
||||
t.Fatal("expected an error when the cert/key files do not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertReloader_loadsAndServesCert(t *testing.T) {
|
||||
certPath, keyPath := writeTestCert(t, t.TempDir())
|
||||
r, err := newCertReloader(certPath, keyPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("newCertReloader: %v", err)
|
||||
}
|
||||
if got := leafDER(t, r); len(got) == 0 {
|
||||
t.Fatal("served certificate has no leaf")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertReloader_hotReloadsOnFileChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPath, keyPath := writeTestCert(t, dir)
|
||||
r, err := newCertReloader(certPath, keyPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("newCertReloader: %v", err)
|
||||
}
|
||||
before := leafDER(t, r)
|
||||
|
||||
// Renew: overwrite with a freshly-generated cert/key pair (different serial
|
||||
// + key → different leaf) and advance the mtime so the watcher detects it.
|
||||
if err := GenerateSelfSignedCert(certPath, keyPath, "127.0.0.1"); err != nil {
|
||||
t.Fatalf("regenerate cert: %v", err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(certPath, future, future); err != nil {
|
||||
t.Fatalf("chtimes: %v", err)
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
go r.watch(5*time.Millisecond, stop)
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if !bytes.Equal(leafDER(t, r), before) {
|
||||
return // hot-reloaded — the served cert changed
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("certificate was not hot-reloaded after the file changed")
|
||||
}
|
||||
|
||||
func TestCertReloader_keepsOldCertOnReloadError(t *testing.T) {
|
||||
certPath, keyPath := writeTestCert(t, t.TempDir())
|
||||
r, err := newCertReloader(certPath, keyPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("newCertReloader: %v", err)
|
||||
}
|
||||
before := leafDER(t, r)
|
||||
|
||||
// Corrupt the cert file (simulates a half-written renewal).
|
||||
if err := os.WriteFile(certPath, []byte("not a pem cert"), 0o644); err != nil {
|
||||
t.Fatalf("corrupt cert: %v", err)
|
||||
}
|
||||
if err := r.reload(); err == nil {
|
||||
t.Fatal("expected reload to fail on a corrupt cert file")
|
||||
}
|
||||
|
||||
// The previously-loaded cert must still be served (TURNS must not go down).
|
||||
if got := leafDER(t, r); !bytes.Equal(got, before) {
|
||||
t.Error("a failed reload must keep serving the previous certificate")
|
||||
}
|
||||
}
|
||||
@ -23,6 +23,9 @@ type Server struct {
|
||||
conn net.PacketConn // UDP listener on primary port (3478)
|
||||
tcpListener net.Listener // Plain TCP listener on primary port (3478)
|
||||
tlsListener net.Listener // TLS TCP listener for TURNS (port 5349)
|
||||
|
||||
certReloader *certReloader // hot-reloads the TURNS cert; nil when TURNS disabled
|
||||
certStop chan struct{} // closed to stop the cert-reload watcher goroutine
|
||||
}
|
||||
|
||||
// NewServer creates and starts a TURN server.
|
||||
@ -79,23 +82,31 @@ func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) {
|
||||
},
|
||||
})
|
||||
|
||||
// TURNS: TLS over TCP listener (port 5349) if configured
|
||||
// TURNS: TLS over TCP listener (port 5349) if configured.
|
||||
//
|
||||
// The cert is served via a hot-reloading GetCertificate callback rather
|
||||
// than a static Certificates slice, so a Caddy-renewed cert is picked up
|
||||
// in-process without restarting TURN (a restart drops every active relay
|
||||
// ~30s). See certReloader / plans/platform/04_STEALTH_TURN.md.
|
||||
if cfg.TURNSListenAddr != "" && cfg.TLSCertPath != "" && cfg.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCertPath, cfg.TLSKeyPath)
|
||||
reloader, err := newCertReloader(cfg.TLSCertPath, cfg.TLSKeyPath, s.logger)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.closeListeners()
|
||||
return nil, fmt.Errorf("failed to load TLS cert/key: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: reloader.GetCertificate,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsListener, err := tls.Listen("tcp", cfg.TURNSListenAddr, tlsConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.closeListeners()
|
||||
return nil, fmt.Errorf("failed to listen on %s: %w", cfg.TURNSListenAddr, err)
|
||||
}
|
||||
s.tlsListener = tlsListener
|
||||
s.certReloader = reloader
|
||||
s.certStop = make(chan struct{})
|
||||
go reloader.watch(turnCertReloadInterval, s.certStop)
|
||||
|
||||
listenerConfigs = append(listenerConfigs, pionTurn.ListenerConfig{
|
||||
Listener: tlsListener,
|
||||
@ -207,7 +218,15 @@ func (s *Server) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeListeners stops the cert watcher and closes all listeners. It is
|
||||
// idempotent (every field is nil-guarded and nil'd after use) but is NOT
|
||||
// mutex-protected — it relies on its call sites being single-threaded relative
|
||||
// to each other (sequential construction, plus a single Close() from main).
|
||||
func (s *Server) closeListeners() {
|
||||
if s.certStop != nil {
|
||||
close(s.certStop)
|
||||
s.certStop = nil
|
||||
}
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user