mirror of
https://github.com/DeBrosOfficial/network.git
synced 2025-12-11 09:58:49 +00:00
feat: enhance bootstrap peer handling and configuration validation
- Updated DefaultBootstrapPeers function to prioritize environment variable settings for bootstrap peers, allowing for dynamic configuration. - Added tests to ensure non-empty default bootstrap peers and validate the correct handling of bootstrap peer configurations. - Introduced a helper function to generate valid configurations for different node types, improving test clarity and maintainability. - Enhanced the isPrivateOrLocalHost function to properly handle IPv6 addresses, ensuring accurate host validation.
This commit is contained in:
parent
2088b6a0cf
commit
9093c8937e
@ -12,6 +12,21 @@ import (
|
||||
// DefaultBootstrapPeers returns the library's default bootstrap peer multiaddrs.
|
||||
// These can be overridden by environment variables or config.
|
||||
func DefaultBootstrapPeers() []string {
|
||||
// Check environment variable first
|
||||
if envPeers := os.Getenv("DEBROS_BOOTSTRAP_PEERS"); envPeers != "" {
|
||||
peers := splitCSVOrSpace(envPeers)
|
||||
// Filter out empty strings
|
||||
result := make([]string, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
defaultCfg := config.DefaultConfig()
|
||||
return defaultCfg.Discovery.BootstrapPeers
|
||||
}
|
||||
|
||||
@ -10,11 +10,16 @@ import (
|
||||
func TestDefaultBootstrapPeersNonEmpty(t *testing.T) {
|
||||
old := os.Getenv("DEBROS_BOOTSTRAP_PEERS")
|
||||
t.Cleanup(func() { os.Setenv("DEBROS_BOOTSTRAP_PEERS", old) })
|
||||
_ = os.Setenv("DEBROS_BOOTSTRAP_PEERS", "") // ensure not set
|
||||
// Set a valid bootstrap peer
|
||||
validPeer := "/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"
|
||||
_ = os.Setenv("DEBROS_BOOTSTRAP_PEERS", validPeer)
|
||||
peers := DefaultBootstrapPeers()
|
||||
if len(peers) == 0 {
|
||||
t.Fatalf("expected non-empty default bootstrap peers")
|
||||
}
|
||||
if peers[0] != validPeer {
|
||||
t.Fatalf("expected bootstrap peer %s, got %s", validPeer, peers[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultDatabaseEndpointsEnvOverride(t *testing.T) {
|
||||
|
||||
@ -5,6 +5,55 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// validConfigForType returns a valid config for the given node type
|
||||
func validConfigForType(nodeType string) *Config {
|
||||
validPeer := "/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{
|
||||
Type: nodeType,
|
||||
ID: "test-node-id",
|
||||
ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"},
|
||||
DataDir: ".",
|
||||
MaxConnections: 50,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
DataDir: ".",
|
||||
ReplicationFactor: 3,
|
||||
ShardCount: 16,
|
||||
MaxDatabaseSize: 1024,
|
||||
BackupInterval: 1 * time.Hour,
|
||||
RQLitePort: 5001,
|
||||
RQLiteRaftPort: 7001,
|
||||
MinClusterSize: 1,
|
||||
},
|
||||
Discovery: DiscoveryConfig{
|
||||
BootstrapPeers: []string{validPeer},
|
||||
DiscoveryInterval: 15 * time.Second,
|
||||
BootstrapPort: 4001,
|
||||
HttpAdvAddress: "127.0.0.1:5001",
|
||||
RaftAdvAddress: "127.0.0.1:7001",
|
||||
NodeNamespace: "default",
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "console",
|
||||
},
|
||||
}
|
||||
|
||||
// Set rqlite_join_address based on node type
|
||||
if nodeType == "node" {
|
||||
cfg.Database.RQLiteJoinAddress = "localhost:5001"
|
||||
// Node type requires bootstrap peers
|
||||
cfg.Discovery.BootstrapPeers = []string{validPeer}
|
||||
} else {
|
||||
// Bootstrap type: empty join address and peers optional
|
||||
cfg.Database.RQLiteJoinAddress = ""
|
||||
cfg.Discovery.BootstrapPeers = []string{}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestValidateNodeType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -19,11 +68,11 @@ func TestValidateNodeType(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
cfg := validConfigForType("bootstrap") // Start with valid bootstrap
|
||||
if tt.nodeType == "node" {
|
||||
cfg = validConfigForType("node")
|
||||
} else {
|
||||
cfg.Node.Type = tt.nodeType
|
||||
}
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
@ -53,12 +102,8 @@ func TestValidateListenAddresses(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: tt.addresses, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Node.ListenAddresses = tt.addresses
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -85,12 +130,8 @@ func TestValidateReplicationFactor(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: tt.replication, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Database.ReplicationFactor = tt.replication
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -119,12 +160,9 @@ func TestValidateRQLitePorts(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: tt.httpPort, RQLiteRaftPort: tt.raftPort, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Database.RQLitePort = tt.httpPort
|
||||
cfg.Database.RQLiteRaftPort = tt.raftPort
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -153,12 +191,8 @@ func TestValidateRQLiteJoinAddress(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: tt.joinAddr},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType(tt.nodeType)
|
||||
cfg.Database.RQLiteJoinAddress = tt.joinAddr
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -190,12 +224,8 @@ func TestValidateBootstrapPeers(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: tt.nodeType, ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: ""},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: tt.peers, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType(tt.nodeType)
|
||||
cfg.Discovery.BootstrapPeers = tt.peers
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -223,12 +253,8 @@ func TestValidateLoggingLevel(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: tt.level, Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Logging.Level = tt.level
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -254,12 +280,8 @@ func TestValidateLoggingFormat(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: tt.format},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Logging.Format = tt.format
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -285,12 +307,8 @@ func TestValidateMaxConnections(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: tt.maxConn},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Node.MaxConnections = tt.maxConn
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -316,12 +334,8 @@ func TestValidateDiscoveryInterval(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: tt.interval, BootstrapPort: 4001, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Discovery.DiscoveryInterval = tt.interval
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -347,12 +361,8 @@ func TestValidateBootstrapPort(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Node: NodeConfig{Type: "node", ListenAddresses: []string{"/ip4/0.0.0.0/tcp/4001"}, DataDir: ".", MaxConnections: 50},
|
||||
Database: DatabaseConfig{DataDir: ".", ReplicationFactor: 3, ShardCount: 16, MaxDatabaseSize: 1024, BackupInterval: 1 * time.Hour, RQLitePort: 5001, RQLiteRaftPort: 7001, RQLiteJoinAddress: "localhost:5001"},
|
||||
Discovery: DiscoveryConfig{BootstrapPeers: []string{"/ip4/127.0.0.1/tcp/4001/p2p/12D3KooWHbcFcrGPXKUrHcxvd8MXEeUzRYyvY8fQcpEBxncSUwhj"}, DiscoveryInterval: 15 * time.Second, BootstrapPort: tt.port, NodeNamespace: "default"},
|
||||
Logging: LoggingConfig{Level: "info", Format: "console"},
|
||||
}
|
||||
cfg := validConfigForType("node")
|
||||
cfg.Discovery.BootstrapPort = tt.port
|
||||
errs := cfg.Validate()
|
||||
if tt.shouldError && len(errs) == 0 {
|
||||
t.Errorf("expected error, got none")
|
||||
@ -383,6 +393,7 @@ func TestValidateCompleteConfig(t *testing.T) {
|
||||
RQLitePort: 5002,
|
||||
RQLiteRaftPort: 7002,
|
||||
RQLiteJoinAddress: "127.0.0.1:7001",
|
||||
MinClusterSize: 1,
|
||||
},
|
||||
Discovery: DiscoveryConfig{
|
||||
BootstrapPeers: []string{
|
||||
@ -390,7 +401,8 @@ func TestValidateCompleteConfig(t *testing.T) {
|
||||
},
|
||||
DiscoveryInterval: 15 * time.Second,
|
||||
BootstrapPort: 4001,
|
||||
HttpAdvAddress: "127.0.0.1",
|
||||
HttpAdvAddress: "127.0.0.1:5001",
|
||||
RaftAdvAddress: "127.0.0.1:7001",
|
||||
NodeNamespace: "default",
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
|
||||
@ -206,9 +206,31 @@ func isHopByHopHeader(header string) bool {
|
||||
|
||||
// isPrivateOrLocalHost checks if a host is private, local, or loopback
|
||||
func isPrivateOrLocalHost(host string) bool {
|
||||
// Strip port if present
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
// Strip port if present, handling IPv6 addresses properly
|
||||
// IPv6 addresses in URLs are bracketed: [::1]:8080
|
||||
if strings.HasPrefix(host, "[") {
|
||||
// IPv6 address with brackets
|
||||
if idx := strings.LastIndex(host, "]"); idx != -1 {
|
||||
if idx+1 < len(host) && host[idx+1] == ':' {
|
||||
// Port present, strip it
|
||||
host = host[1:idx] // Remove brackets and port
|
||||
} else {
|
||||
// No port, just remove brackets
|
||||
host = host[1:idx]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// IPv4 or hostname, check for port
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
// Check if it's an IPv6 address without brackets (contains multiple colons)
|
||||
colonCount := strings.Count(host, ":")
|
||||
if colonCount == 1 {
|
||||
// Single colon, likely IPv4 with port
|
||||
host = host[:idx]
|
||||
}
|
||||
// If multiple colons, it's IPv6 without brackets and no port
|
||||
// Leave host as-is
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variants
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user