anonpenguin23 0379dc39f1 feat(core): implement sni-router for stealth turn
- add `orama-sni-router` binary to build process
- introduce `cmd/sni-router` for TLS-level SNI routing
- add documentation for stealth turn deployment architecture
2026-05-03 18:20:21 +03:00

236 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package sniproxy provides a TCP-level Server Name Indication (SNI) router.
//
// The router peeks at the unencrypted TLS ClientHello on each accepted
// connection, extracts the SNI host name, and forwards the raw stream to
// a backend. It does NOT terminate TLS — encrypted bytes pass through
// verbatim. This lets one TCP port serve multiple TLS-speaking backends
// (HTTPS for the gateway, TURNS for stealth WebRTC, etc.) without
// sharing private keys with the proxy.
//
// Design goals:
// - Zero TLS material on the proxy
// - Bounded ClientHello read (no slowloris)
// - Backend dial timeout
// - Per-IP rate limiting
//
// SNI parsing follows RFC 5246 §7.4.1.2 (TLS record + ClientHello) and
// RFC 6066 §3 (server_name extension).
package sniproxy
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
)
// ErrNoSNI is returned when the ClientHello has no server_name extension.
var ErrNoSNI = errors.New("sniproxy: ClientHello has no SNI")
// MaxClientHelloBytes bounds how many bytes we'll read while looking for
// the SNI. TLS ClientHello records are typically 200500 bytes; this is
// a generous cap that still defends against memory abuse.
const MaxClientHelloBytes = 16 * 1024
// PeekClientHello reads bytes from conn until a TLS ClientHello has
// been parsed (or MaxClientHelloBytes is exceeded). Returns the SNI
// hostname (lowercased), the bytes consumed (must be replayed to the
// backend), and any error.
//
// readTimeout bounds the wait — slowloris-style stalls return an error
// quickly without holding the goroutine indefinitely.
func PeekClientHello(conn net.Conn, readTimeout time.Duration) (string, []byte, error) {
if readTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(readTimeout))
defer conn.SetReadDeadline(time.Time{})
}
br := bufio.NewReaderSize(conn, MaxClientHelloBytes)
// Peek the TLS record header (5 bytes): content_type, version (2),
// length (2). content_type for ClientHello is 22 (handshake).
header, err := br.Peek(5)
if err != nil {
return "", nil, fmt.Errorf("read tls record header: %w", err)
}
if header[0] != 22 {
return "", nil, fmt.Errorf("not a TLS handshake record (type=%d)", header[0])
}
recLen := int(binary.BigEndian.Uint16(header[3:5]))
if recLen <= 0 || 5+recLen > MaxClientHelloBytes {
return "", nil, fmt.Errorf("invalid record length %d", recLen)
}
full, err := br.Peek(5 + recLen)
if err != nil {
return "", nil, fmt.Errorf("read tls record body: %w", err)
}
sni, err := parseSNI(full[5:])
if err != nil {
return "", nil, err
}
// We've only peeked — drain the buffer to capture the bytes for replay.
consumed := make([]byte, br.Buffered())
if _, err := io.ReadFull(br, consumed); err != nil {
return "", nil, fmt.Errorf("drain peeked bytes: %w", err)
}
return sni, consumed, nil
}
// parseSNI parses a TLS ClientHello body (without the 5-byte record
// header) and returns the server_name extension value if present.
func parseSNI(body []byte) (string, error) {
r := newReader(body)
// Handshake type (1 byte) — must be 1 (ClientHello).
hsType, err := r.readByte()
if err != nil {
return "", err
}
if hsType != 1 {
return "", fmt.Errorf("not a ClientHello (handshake type %d)", hsType)
}
// Handshake length (3 bytes).
if _, err := r.readBytes(3); err != nil {
return "", err
}
// client_version (2) + random (32).
if _, err := r.readBytes(2 + 32); err != nil {
return "", err
}
// session_id.
sidLen, err := r.readByte()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(sidLen)); err != nil {
return "", err
}
// cipher_suites length (2).
csLen, err := r.readUint16()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(csLen)); err != nil {
return "", err
}
// compression_methods length (1).
cmLen, err := r.readByte()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(cmLen)); err != nil {
return "", err
}
// Extensions length (2). Optional — TLS 1.0 ClientHello can skip it.
if r.remaining() < 2 {
return "", ErrNoSNI
}
extTotalLen, err := r.readUint16()
if err != nil {
return "", err
}
if int(extTotalLen) > r.remaining() {
return "", fmt.Errorf("extensions truncated")
}
end := r.pos + int(extTotalLen)
for r.pos < end {
extType, err := r.readUint16()
if err != nil {
return "", err
}
extLen, err := r.readUint16()
if err != nil {
return "", err
}
extData, err := r.readBytes(int(extLen))
if err != nil {
return "", err
}
// server_name extension is type 0.
if extType != 0 {
continue
}
return parseServerName(extData)
}
return "", ErrNoSNI
}
// parseServerName parses the body of a server_name extension and returns
// the first host_name (type 0) entry.
func parseServerName(data []byte) (string, error) {
r := newReader(data)
// server_name_list length (2).
listLen, err := r.readUint16()
if err != nil {
return "", err
}
if int(listLen) > r.remaining() {
return "", fmt.Errorf("server_name list truncated")
}
end := r.pos + int(listLen)
for r.pos < end {
nameType, err := r.readByte()
if err != nil {
return "", err
}
nameLen, err := r.readUint16()
if err != nil {
return "", err
}
nameBytes, err := r.readBytes(int(nameLen))
if err != nil {
return "", err
}
if nameType == 0 { // host_name
return strings.ToLower(string(nameBytes)), nil
}
}
return "", ErrNoSNI
}
// reader is a tiny byte-slice cursor used by parseSNI/parseServerName.
type reader struct {
buf []byte
pos int
}
func newReader(buf []byte) *reader { return &reader{buf: buf} }
func (r *reader) remaining() int { return len(r.buf) - r.pos }
func (r *reader) readByte() (byte, error) {
if r.pos >= len(r.buf) {
return 0, io.ErrUnexpectedEOF
}
b := r.buf[r.pos]
r.pos++
return b, nil
}
func (r *reader) readUint16() (uint16, error) {
if r.pos+2 > len(r.buf) {
return 0, io.ErrUnexpectedEOF
}
v := binary.BigEndian.Uint16(r.buf[r.pos : r.pos+2])
r.pos += 2
return v, nil
}
func (r *reader) readBytes(n int) ([]byte, error) {
if r.pos+n > len(r.buf) {
return nil, io.ErrUnexpectedEOF
}
b := r.buf[r.pos : r.pos+n]
r.pos += n
return b, nil
}