orama/core/pkg/sniproxy/server.go
anonpenguin23 0379dc39f1 feat(core): implement sni-router for stealth turn
- add `orama-sni-router` binary to build process
- introduce `cmd/sni-router` for TLS-level SNI routing
- add documentation for stealth turn deployment architecture
2026-05-03 18:20:21 +03:00

178 lines
4.2 KiB
Go

package sniproxy
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
"go.uber.org/zap"
)
// Config tunes the proxy server.
type Config struct {
// ClientHelloTimeout bounds the wait for a parseable ClientHello.
// 0 selects 5 seconds.
ClientHelloTimeout time.Duration
// BackendDialTimeout bounds backend connect time. 0 selects 5 seconds.
BackendDialTimeout time.Duration
// MaxConcurrentConns caps total in-flight connections to prevent
// resource exhaustion. 0 selects 10000.
MaxConcurrentConns int
}
// Server is a TCP-level SNI router. Create via NewServer, then call
// Serve(listener) in a goroutine. Close cancels in-flight connections.
type Server struct {
router *Router
cfg Config
logger *zap.Logger
gate chan struct{} // bounded semaphore for concurrent connections
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewServer constructs a Server with the given router and config.
func NewServer(router *Router, cfg Config, logger *zap.Logger) *Server {
if logger == nil {
logger = zap.NewNop()
}
if cfg.ClientHelloTimeout <= 0 {
cfg.ClientHelloTimeout = 5 * time.Second
}
if cfg.BackendDialTimeout <= 0 {
cfg.BackendDialTimeout = 5 * time.Second
}
if cfg.MaxConcurrentConns <= 0 {
cfg.MaxConcurrentConns = 10000
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
router: router,
cfg: cfg,
logger: logger.Named("sniproxy"),
gate: make(chan struct{}, cfg.MaxConcurrentConns),
ctx: ctx,
cancel: cancel,
}
}
// Serve accepts connections from ln until ln.Accept returns a permanent
// error or Close is called. Serve always returns a non-nil error.
func (s *Server) Serve(ln net.Listener) error {
for {
conn, err := ln.Accept()
if err != nil {
// Check for shutdown via cancelled ctx.
if s.ctx.Err() != nil {
return s.ctx.Err()
}
// Net errors temporarily? Backoff briefly so we don't busy-loop.
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
time.Sleep(50 * time.Millisecond)
continue
}
return err
}
select {
case s.gate <- struct{}{}:
default:
s.logger.Warn("max concurrent connections reached, dropping",
zap.Int("limit", s.cfg.MaxConcurrentConns),
zap.String("remote", conn.RemoteAddr().String()),
)
conn.Close()
continue
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
defer func() { <-s.gate }()
s.handle(c)
}(conn)
}
}
// Close cancels in-flight connections and waits for handlers to drain.
func (s *Server) Close() {
s.cancel()
s.wg.Wait()
}
// handle processes a single accepted connection: peek SNI, dial backend,
// replay peeked bytes, then bidirectional copy.
func (s *Server) handle(conn net.Conn) {
defer conn.Close()
sni, peeked, err := PeekClientHello(conn, s.cfg.ClientHelloTimeout)
if err != nil {
s.logger.Debug("ClientHello peek failed",
zap.String("remote", conn.RemoteAddr().String()),
zap.Error(err),
)
return
}
backend := s.router.Pick(sni)
if backend.Addr == "" {
s.logger.Warn("no backend for SNI",
zap.String("sni", sni),
zap.String("remote", conn.RemoteAddr().String()),
)
return
}
network := backend.Network
if network == "" {
network = "tcp"
}
upstream, err := net.DialTimeout(network, backend.Addr, s.cfg.BackendDialTimeout)
if err != nil {
s.logger.Warn("backend dial failed",
zap.String("sni", sni),
zap.String("backend", backend.Addr),
zap.Error(err),
)
return
}
defer upstream.Close()
// Replay peeked bytes (the ClientHello + anything else buffered).
if len(peeked) > 0 {
if _, err := upstream.Write(peeked); err != nil {
s.logger.Debug("replay to backend failed",
zap.String("sni", sni),
zap.Error(err),
)
return
}
}
// Bidirectional copy. We close both connections when either side
// finishes OR when the server is shutting down, so handle() can't
// hang forever on a half-stuck peer.
done := make(chan struct{}, 2)
go func() {
_, _ = io.Copy(upstream, conn)
done <- struct{}{}
}()
go func() {
_, _ = io.Copy(conn, upstream)
done <- struct{}{}
}()
select {
case <-done:
case <-s.ctx.Done():
}
// Force both sides closed; second copy will exit immediately.
upstream.Close()
conn.Close()
<-done // drain the second goroutine
}