feat: enhance bootstrap peer handling and configuration validation

- Updated DefaultBootstrapPeers function to prioritize environment variable settings for bootstrap peers, allowing for dynamic configuration.
- Added tests to ensure non-empty default bootstrap peers and validate the correct handling of bootstrap peer configurations.
- Introduced a helper function to generate valid configurations for different node types, improving test clarity and maintainability.
- Enhanced the isPrivateOrLocalHost function to properly handle IPv6 addresses, ensuring accurate host validation.
This commit is contained in:
anonpenguin23 2025-11-03 07:30:27 +02:00
parent 2088b6a0cf
commit 9093c8937e
4 changed files with 124 additions and 70 deletions

View File

@ -12,6 +12,21 @@ import (
// DefaultBootstrapPeers returns the library's default bootstrap peer multiaddrs.
// These can be overridden by environment variables or config.
func DefaultBootstrapPeers() []string {
// Check environment variable first
if envPeers := os.Getenv("DEBROS_BOOTSTRAP_PEERS"); envPeers != "" {
peers := splitCSVOrSpace(envPeers)
// Filter out empty strings
result := make([]string, 0, len(peers))
for _, p := range peers {
if p != "" {
result = append(result, p)
}
}
if len(result) > 0 {
return result
}
}
defaultCfg := config.DefaultConfig()
return defaultCfg.Discovery.BootstrapPeers
}

View File

@ -10,11 +10,16 @@ import (
func TestDefaultBootstrapPeersNonEmpty(t *testing.T) {
old := os.Getenv("DEBROS_BOOTSTRAP_PEERS")
t.Cleanup(func() { os.Setenv("DEBROS_BOOTSTRAP_PEERS", old) })
_ = os.Setenv("DEBROS_BOOTSTRAP_PEERS", "") // ensure not set
// Set a valid bootstrap peer
validPeer := "/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"
_ = os.Setenv("DEBROS_BOOTSTRAP_PEERS", validPeer)
peers := DefaultBootstrapPeers()
if len(peers) == 0 {
t.Fatalf("expected non-empty default bootstrap peers")
}
if peers[0] != validPeer {
t.Fatalf("expected bootstrap peer %s, got %s", validPeer, peers[0])
}
}
func TestDefaultDatabaseEndpointsEnvOverride(t *testing.T) {

View File

@ -5,6 +5,55 @@ import (
"time"
)
// validConfigForType returns a valid config for the given node type
func validConfigForType(nodeType string) *Config {
validPeer := "/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"
cfg := &Config{
Node: NodeConfig{
Type: nodeType,
ID: "test-node-id",
ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"},
DataDir: ".",
MaxConnections: 50,
},
Database: DatabaseConfig{
DataDir: ".",
ReplicationFactor: 3,
ShardCount: 16,
MaxDatabaseSize: 1024,
BackupInterval: 1 * time.Hour,
RQLitePort: 5001,
RQLiteRaftPort: 7001,
MinClusterSize: 1,
},
Discovery: DiscoveryConfig{
BootstrapPeers: []string{validPeer},
DiscoveryInterval: 15 * time.Second,
BootstrapPort: 4001,
HttpAdvAddress: "127.0.0.1:5001",
RaftAdvAddress: "127.0.0.1:7001",
NodeNamespace: "default",
},
Logging: LoggingConfig{
Level: "info",
Format: "console",
},
}
// Set rqlite_join_address based on node type
if nodeType == "node" {
cfg.Database.RQLiteJoinAddress = "localhost:5001"
// Node type requires bootstrap peers
cfg.Discovery.BootstrapPeers = []string{validPeer}
} else {
// Bootstrap type: empty join address and peers optional
cfg.Database.RQLiteJoinAddress = ""
cfg.Discovery.BootstrapPeers = []string{}
}
return cfg
}
func TestValidateNodeType(t *testing.T) {
tests := []struct {
name string
@ -19,11 +68,11 @@ func TestValidateNodeType(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
cfg := validConfigForType("bootstrap") // Start with valid bootstrap
if tt.nodeType == "node" {
cfg = validConfigForType("node")
} else {
cfg.Node.Type = tt.nodeType
}
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
@ -53,12 +102,8 @@ func TestValidateListenAddresses(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: tt.addresses, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Node.ListenAddresses = tt.addresses
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -85,12 +130,8 @@ func TestValidateReplicationFactor(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: tt.replication, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Database.ReplicationFactor = tt.replication
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -119,12 +160,9 @@ func TestValidateRQLitePorts(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: tt.httpPort, RQLiteRaftPort: tt.raftPort, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Database.RQLitePort = tt.httpPort
cfg.Database.RQLiteRaftPort = tt.raftPort
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -153,12 +191,8 @@ func TestValidateRQLiteJoinAddress(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: tt.joinAddr},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType(tt.nodeType)
cfg.Database.RQLiteJoinAddress = tt.joinAddr
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -190,12 +224,8 @@ func TestValidateBootstrapPeers(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: ""},
Discovery: DiscoveryConfig{BootstrapPeers: tt.peers, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType(tt.nodeType)
cfg.Discovery.BootstrapPeers = tt.peers
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -223,12 +253,8 @@ func TestValidateLoggingLevel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: tt.level, Format: "console"},
}
cfg := validConfigForType("node")
cfg.Logging.Level = tt.level
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -254,12 +280,8 @@ func TestValidateLoggingFormat(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: tt.format},
}
cfg := validConfigForType("node")
cfg.Logging.Format = tt.format
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -285,12 +307,8 @@ func TestValidateMaxConnections(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: tt.maxConn},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Node.MaxConnections = tt.maxConn
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -316,12 +334,8 @@ func TestValidateDiscoveryInterval(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: tt.interval, BootstrapPort: 4001, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Discovery.DiscoveryInterval = tt.interval
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -347,12 +361,8 @@ func TestValidateBootstrapPort(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: tt.port, NodeNamespace: "default"},
Logging: LoggingConfig{Level: "info", Format: "console"},
}
cfg := validConfigForType("node")
cfg.Discovery.BootstrapPort = tt.port
errs := cfg.Validate()
if tt.shouldError && len(errs) == 0 {
t.Errorf("expected error, got none")
@ -383,6 +393,7 @@ func TestValidateCompleteConfig(t *testing.T) {
RQLitePort: 5002,
RQLiteRaftPort: 7002,
RQLiteJoinAddress: "127.0.0.1:7001",
MinClusterSize: 1,
},
Discovery: DiscoveryConfig{
BootstrapPeers: []string{
@ -390,7 +401,8 @@ func TestValidateCompleteConfig(t *testing.T) {
},
DiscoveryInterval: 15 * time.Second,
BootstrapPort: 4001,
HttpAdvAddress: "127.0.0.1",
HttpAdvAddress: "127.0.0.1:5001",
RaftAdvAddress: "127.0.0.1:7001",
NodeNamespace: "default",
},
Security: SecurityConfig{

View File

@ -206,9 +206,31 @@ func isHopByHopHeader(header string) bool {
// isPrivateOrLocalHost checks if a host is private, local, or loopback
func isPrivateOrLocalHost(host string) bool {
// Strip port if present
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
// Strip port if present, handling IPv6 addresses properly
// IPv6 addresses in URLs are bracketed: [::1]:8080
if strings.HasPrefix(host, "[") {
// IPv6 address with brackets
if idx := strings.LastIndex(host, "]"); idx != -1 {
if idx+1 < len(host) && host[idx+1] == ':' {
// Port present, strip it
host = host[1:idx] // Remove brackets and port
} else {
// No port, just remove brackets
host = host[1:idx]
}
}
} else {
// IPv4 or hostname, check for port
if idx := strings.LastIndex(host, ":"); idx != -1 {
// Check if it's an IPv6 address without brackets (contains multiple colons)
colonCount := strings.Count(host, ":")
if colonCount == 1 {
// Single colon, likely IPv4 with port
host = host[:idx]
}
// If multiple colons, it's IPv6 without brackets and no port
// Leave host as-is
}
}
// Check for localhost variants