diff --git a/Makefile b/Makefile index a6874ce..3f9d161 100644 --- a/Makefile +++ b/Makefile @@ -63,7 +63,7 @@ test-e2e-quick: .PHONY: build clean test deps tidy fmt vet lint install-hooks redeploy-devnet redeploy-testnet release health -VERSION := 0.104.0 +VERSION := 0.105.0 COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ) LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)' diff --git a/pkg/auth/phantom.go b/pkg/auth/phantom.go index 6e6a1e3..856245d 100644 --- a/pkg/auth/phantom.go +++ b/pkg/auth/phantom.go @@ -16,8 +16,17 @@ import ( qrterminal "github.com/mdp/qrterminal/v3" ) -// Hardcoded Phantom auth React app URL (deployed on Orama devnet) -const phantomAuthURL = "https://phantom-auth-y0w9aa.orama-devnet.network" +// defaultPhantomAuthURL is the default Phantom auth React app URL (deployed on Orama devnet). +// Override with ORAMA_PHANTOM_AUTH_URL environment variable. +const defaultPhantomAuthURL = "https://phantom-auth-y0w9aa.orama-devnet.network" + +// phantomAuthURL returns the Phantom auth URL, preferring the environment variable. +func phantomAuthURL() string { + if u := os.Getenv("ORAMA_PHANTOM_AUTH_URL"); u != "" { + return strings.TrimRight(u, "/") + } + return defaultPhantomAuthURL +} // PhantomSession represents a phantom auth session from the gateway. type PhantomSession struct { @@ -76,7 +85,7 @@ func PerformPhantomAuthentication(gatewayURL, namespace string) (*Credentials, e // 2. Build auth URL and display QR code authURL := fmt.Sprintf("%s/?session=%s&gateway=%s&namespace=%s", - phantomAuthURL, session.SessionID, url.QueryEscape(gatewayURL), url.QueryEscape(namespace)) + phantomAuthURL(), session.SessionID, url.QueryEscape(gatewayURL), url.QueryEscape(namespace)) fmt.Println("\nScan this QR code with your phone to authenticate:") fmt.Println() diff --git a/pkg/auth/simple_auth.go b/pkg/auth/simple_auth.go index 3b5f7b5..5e54fb3 100644 --- a/pkg/auth/simple_auth.go +++ b/pkg/auth/simple_auth.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "strings" "time" @@ -336,22 +337,15 @@ func retryAPIKeyRequest(gatewayURL string, client *http.Client, wallet, namespac return apiKey, nil } -// extractDomainFromURL extracts the domain from a URL -// Removes protocol (https://, http://), path, and port components -func extractDomainFromURL(url string) string { - // Remove protocol prefixes - url = strings.TrimPrefix(url, "https://") - url = strings.TrimPrefix(url, "http://") - - // Remove path component - if idx := strings.Index(url, "/"); idx != -1 { - url = url[:idx] +// extractDomainFromURL extracts the hostname from a URL, stripping scheme, port, and path. +func extractDomainFromURL(rawURL string) string { + // Ensure the URL has a scheme so net/url.Parse works correctly + if !strings.Contains(rawURL, "://") { + rawURL = "https://" + rawURL } - - // Remove port component - if idx := strings.Index(url, ":"); idx != -1 { - url = url[:idx] + u, err := url.Parse(rawURL) + if err != nil { + return "" } - - return url + return u.Hostname() } diff --git a/pkg/auth/wallet.go b/pkg/auth/wallet.go index 0a9344d..5be457c 100644 --- a/pkg/auth/wallet.go +++ b/pkg/auth/wallet.go @@ -168,7 +168,7 @@ func (as *AuthServer) handleCallback(w http.ResponseWriter, r *http.Request) { return } - // Send success response to browser + // Send success response to browser (API key is never exposed in HTML) w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, ` @@ -181,30 +181,25 @@ func (as *AuthServer) handleCallback(w http.ResponseWriter, r *http.Request) { .container { background: white; padding: 30px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); max-width: 500px; margin: 0 auto; } .success { color: #4CAF50; font-size: 48px; margin-bottom: 20px; } .details { background: #f8f9fa; padding: 20px; border-radius: 5px; margin: 20px 0; text-align: left; } - .key { font-family: monospace; background: #e9ecef; padding: 10px; border-radius: 3px; word-break: break-all; }
-
+

Authentication Successful!

You have successfully authenticated with your wallet.

-

🔑 Your Credentials:

-

API Key:

-
%s

Namespace: %s

Wallet: %s

%s
-

Your credentials have been saved securely to ~/.orama/credentials.json

-

You can now close this browser window and return to your terminal.

+

Your credentials have been saved securely. Return to your terminal to continue.

+

You can now close this browser window.

`, - result.APIKey, result.Namespace, result.Wallet, func() string { diff --git a/pkg/cli/production/invite/command.go b/pkg/cli/production/invite/command.go index 57ca730..d2240cd 100644 --- a/pkg/cli/production/invite/command.go +++ b/pkg/cli/production/invite/command.go @@ -1,12 +1,13 @@ package invite import ( + "bytes" "crypto/rand" "encoding/hex" + "encoding/json" "fmt" "net/http" "os" - "strings" "time" "gopkg.in/yaml.v3" @@ -89,12 +90,18 @@ func readNodeDomain() (string, error) { return config.Node.Domain, nil } -// insertToken inserts an invite token into RQLite via HTTP API +// insertToken inserts an invite token into RQLite via HTTP API using parameterized queries func insertToken(token, createdBy, expiresAt string) error { - body := fmt.Sprintf(`[["INSERT INTO invite_tokens (token, created_by, expires_at) VALUES ('%s', '%s', '%s')"]]`, - token, createdBy, expiresAt) + stmt := []interface{}{ + "INSERT INTO invite_tokens (token, created_by, expires_at) VALUES (?, ?, ?)", + token, createdBy, expiresAt, + } + payload, err := json.Marshal([]interface{}{stmt}) + if err != nil { + return fmt.Errorf("failed to marshal query: %w", err) + } - req, err := http.NewRequest("POST", "http://localhost:5001/db/execute", strings.NewReader(body)) + req, err := http.NewRequest("POST", "http://localhost:5001/db/execute", bytes.NewReader(payload)) if err != nil { return err } diff --git a/pkg/client/client.go b/pkg/client/client.go index 6b75f80..cdc92cc 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -195,7 +195,7 @@ func (c *Client) Connect() error { c.logger.Info("App namespace retrieved", zap.String("namespace", namespace)) c.logger.Info("Calling pubsub.NewClientAdapter...") - adapter := pubsub.NewClientAdapter(c.libp2pPS, namespace) + adapter := pubsub.NewClientAdapter(c.libp2pPS, namespace, c.logger) c.logger.Info("pubsub.NewClientAdapter completed successfully") c.logger.Info("Creating pubSubBridge...") diff --git a/pkg/client/network_client.go b/pkg/client/network_client.go index 029125e..ce3f1c6 100644 --- a/pkg/client/network_client.go +++ b/pkg/client/network_client.go @@ -28,7 +28,9 @@ func (n *NetworkInfoImpl) GetPeers(ctx context.Context) ([]PeerInfo, error) { } // Get peers from LibP2P host + n.client.mu.RLock() host := n.client.host + n.client.mu.RUnlock() if host == nil { return nil, fmt.Errorf("no host available") } @@ -87,7 +89,10 @@ func (n *NetworkInfoImpl) GetStatus(ctx context.Context) (*NetworkStatus, error) return nil, fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err) } + n.client.mu.RLock() host := n.client.host + dbClient := n.client.database + n.client.mu.RUnlock() if host == nil { return nil, fmt.Errorf("no host available") } @@ -97,7 +102,6 @@ func (n *NetworkInfoImpl) GetStatus(ctx context.Context) (*NetworkStatus, error) // Try to get database size from RQLite (optional - don't fail if unavailable) var dbSize int64 = 0 - dbClient := n.client.database if conn, err := dbClient.getRQLiteConnection(); err == nil { // Query database size (rough estimate) if result, err := conn.QueryOne("SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()"); err == nil { diff --git a/pkg/config/paths.go b/pkg/config/paths.go index 4335c77..d092770 100644 --- a/pkg/config/paths.go +++ b/pkg/config/paths.go @@ -4,8 +4,22 @@ import ( "fmt" "os" "path/filepath" + "strings" ) +// ExpandPath expands environment variables and ~ in a path. +func ExpandPath(path string) (string, error) { + path = os.ExpandEnv(path) + if strings.HasPrefix(path, "~") { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to determine home directory: %w", err) + } + path = filepath.Join(home, path[1:]) + } + return path, nil +} + // ConfigDir returns the path to the DeBros config directory (~/.orama). func ConfigDir() (string, error) { home, err := os.UserHomeDir() diff --git a/pkg/config/validate/database.go b/pkg/config/validate/database.go index b74e957..ade58ed 100644 --- a/pkg/config/validate/database.go +++ b/pkg/config/validate/database.go @@ -45,9 +45,11 @@ func ValidateDatabase(dc DatabaseConfig) []error { Message: fmt.Sprintf("must be >= 1; got %d", dc.ReplicationFactor), }) } else if dc.ReplicationFactor%2 == 0 { - // Warn about even replication factor (Raft best practice: odd) - // For now we log a note but don't error - _ = fmt.Sprintf("note: database.replication_factor %d is even; Raft recommends odd numbers for quorum", dc.ReplicationFactor) + errs = append(errs, ValidationError{ + Path: "database.replication_factor", + Message: fmt.Sprintf("value %d is even; Raft recommends odd numbers for quorum", dc.ReplicationFactor), + Hint: "use 1, 3, or 5 for proper Raft consensus", + }) } // Validate shard_count diff --git a/pkg/config/validate/validators.go b/pkg/config/validate/validators.go index 8195ec9..fbab893 100644 --- a/pkg/config/validate/validators.go +++ b/pkg/config/validate/validators.go @@ -34,7 +34,7 @@ func ValidateDataDir(path string) error { if strings.HasPrefix(expandedPath, "~") { home, err := os.UserHomeDir() if err != nil { - return fmt.Errorf("cannot determine home directory: %v", err) + return fmt.Errorf("cannot determine home directory: %w", err) } expandedPath = filepath.Join(home, expandedPath[1:]) } @@ -47,7 +47,7 @@ func ValidateDataDir(path string) error { // Try to write a test file to check permissions testFile := filepath.Join(expandedPath, ".write_test") if err := os.WriteFile(testFile, []byte(""), 0644); err != nil { - return fmt.Errorf("directory not writable: %v", err) + return fmt.Errorf("directory not writable: %w", err) } os.Remove(testFile) } else if os.IsNotExist(err) { @@ -59,7 +59,7 @@ func ValidateDataDir(path string) error { // Allow parent not existing - it will be created at runtime if info, err := os.Stat(parent); err != nil { if !os.IsNotExist(err) { - return fmt.Errorf("parent directory not accessible: %v", err) + return fmt.Errorf("parent directory not accessible: %w", err) } // Parent doesn't exist either - that's ok, will be created } else if !info.IsDir() { @@ -67,11 +67,11 @@ func ValidateDataDir(path string) error { } else { // Parent exists, check if writable if err := ValidateDirWritable(parent); err != nil { - return fmt.Errorf("parent directory not writable: %v", err) + return fmt.Errorf("parent directory not writable: %w", err) } } } else { - return fmt.Errorf("cannot access path: %v", err) + return fmt.Errorf("cannot access path: %w", err) } return nil @@ -81,7 +81,7 @@ func ValidateDataDir(path string) error { func ValidateDirWritable(path string) error { info, err := os.Stat(path) if err != nil { - return fmt.Errorf("cannot access directory: %v", err) + return fmt.Errorf("cannot access directory: %w", err) } if !info.IsDir() { return fmt.Errorf("path is not a directory") @@ -90,7 +90,7 @@ func ValidateDirWritable(path string) error { // Try to write a test file testFile := filepath.Join(path, ".write_test") if err := os.WriteFile(testFile, []byte(""), 0644); err != nil { - return fmt.Errorf("directory not writable: %v", err) + return fmt.Errorf("directory not writable: %w", err) } os.Remove(testFile) @@ -101,7 +101,7 @@ func ValidateDirWritable(path string) error { func ValidateFileReadable(path string) error { _, err := os.Stat(path) if err != nil { - return fmt.Errorf("cannot read file: %v", err) + return fmt.Errorf("cannot read file: %w", err) } return nil } diff --git a/pkg/deployments/health/checker.go b/pkg/deployments/health/checker.go index 69c0c0f..51cef1b 100644 --- a/pkg/deployments/health/checker.go +++ b/pkg/deployments/health/checker.go @@ -224,7 +224,12 @@ func (hc *HealthChecker) checkConsecutiveFailures(ctx context.Context, deploymen INSERT INTO deployment_events (deployment_id, event_type, message, created_at) VALUES (?, 'health_failed', 'Deployment marked as failed after 3 consecutive health check failures', ?) ` - hc.db.Exec(ctx, eventQuery, deploymentID, time.Now()) + if _, err := hc.db.Exec(ctx, eventQuery, deploymentID, time.Now()); err != nil { + hc.logger.Error("Failed to record health_failed event", + zap.String("deployment", deploymentID), + zap.Error(err), + ) + } } } } diff --git a/pkg/deployments/process/manager.go b/pkg/deployments/process/manager.go index aeceb9f..de9076d 100644 --- a/pkg/deployments/process/manager.go +++ b/pkg/deployments/process/manager.go @@ -194,9 +194,9 @@ func (m *Manager) Stop(ctx context.Context, deployment *deployments.Deployment) // stopDirect stops a directly spawned process func (m *Manager) stopDirect(serviceName string) error { m.processesMu.Lock() - cmd, exists := m.processes[serviceName] - m.processesMu.Unlock() + defer m.processesMu.Unlock() + cmd, exists := m.processes[serviceName] if !exists || cmd.Process == nil { return nil // Already stopped } @@ -511,11 +511,10 @@ func (m *Manager) GetStats(ctx context.Context, deployment *deployments.Deployme // Direct mode (macOS) — only disk, no /proc serviceName := m.getServiceName(deployment) m.processesMu.RLock() - cmd, exists := m.processes[serviceName] - m.processesMu.RUnlock() - if exists && cmd.Process != nil { + if cmd, exists := m.processes[serviceName]; exists && cmd.Process != nil { stats.PID = cmd.Process.Pid } + m.processesMu.RUnlock() return stats, nil } diff --git a/pkg/discovery/discovery.go b/pkg/discovery/discovery.go index 1d1ec60..e3e06b3 100644 --- a/pkg/discovery/discovery.go +++ b/pkg/discovery/discovery.go @@ -19,6 +19,42 @@ import ( // Protocol ID for peer exchange const PeerExchangeProtocol = "/debros/peer-exchange/1.0.0" +// libp2pPort is the standard port used for libp2p peer connections. +// Filtering on this port prevents cross-connecting with IPFS (4101) or IPFS Cluster (9096/9098). +const libp2pPort = 4001 + +// filterLibp2pAddrs returns only multiaddrs with TCP port 4001 (standard libp2p port). +func filterLibp2pAddrs(addrs []multiaddr.Multiaddr) []multiaddr.Multiaddr { + filtered := make([]multiaddr.Multiaddr, 0, len(addrs)) + for _, addr := range addrs { + port, err := addr.ValueForProtocol(multiaddr.P_TCP) + if err != nil { + continue + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum != libp2pPort { + continue + } + filtered = append(filtered, addr) + } + return filtered +} + +// hasLibp2pAddr returns true if any of the peer's addresses use the standard libp2p port. +func hasLibp2pAddr(addrs []multiaddr.Multiaddr) bool { + for _, addr := range addrs { + port, err := addr.ValueForProtocol(multiaddr.P_TCP) + if err != nil { + continue + } + portNum, err := strconv.Atoi(port) + if err == nil && portNum == libp2pPort { + return true + } + } + return false +} + // PeerExchangeRequest represents a request for peer information type PeerExchangeRequest struct { Limit int `json:"limit"` @@ -116,38 +152,11 @@ func (d *Manager) handlePeerExchangeStream(s network.Stream) { continue } - // Filter addresses to only include port 4001 (standard libp2p port) - // This prevents including non-libp2p service ports (like RQLite ports) in peer exchange - const libp2pPort = 4001 - filteredAddrs := make([]multiaddr.Multiaddr, 0) - filteredCount := 0 - for _, addr := range addrs { - // Extract TCP port from multiaddr - port, err := addr.ValueForProtocol(multiaddr.P_TCP) - if err == nil { - portNum, err := strconv.Atoi(port) - if err == nil { - // Only include addresses with port 4001 - if portNum == libp2pPort { - filteredAddrs = append(filteredAddrs, addr) - } else { - filteredCount++ - } - } - // Skip addresses with unparseable ports - } else { - // Skip non-TCP addresses (libp2p uses TCP) - filteredCount++ - } - } - - // If no addresses remain after filtering, skip this peer - // (Filtering is routine - no need to log every occurrence) + filteredAddrs := filterLibp2pAddrs(addrs) if len(filteredAddrs) == 0 { continue } - // Convert addresses to strings addrStrs := make([]string, len(filteredAddrs)) for i, addr := range filteredAddrs { addrStrs[i] = addr.String() @@ -253,38 +262,20 @@ func (d *Manager) discoverViaPeerstore(ctx context.Context, maxConnections int) // Iterate over peerstore known peers peers := d.host.Peerstore().Peers() - // Only connect to peers on our standard LibP2P port to avoid cross-connecting - // with IPFS/IPFS Cluster instances that use different ports - const libp2pPort = 4001 - for _, pid := range peers { if connected >= maxConnections { break } - // Skip self if pid == d.host.ID() { continue } - // Skip already connected peers if d.host.Network().Connectedness(pid) != network.NotConnected { continue } - // Filter peers to only include those with addresses on our port (4001) - // This prevents attempting to connect to IPFS (port 4101) or IPFS Cluster (port 9096/9098) + // Only connect to peers with addresses on the standard libp2p port peerInfo := d.host.Peerstore().PeerInfo(pid) - hasValidPort := false - for _, addr := range peerInfo.Addrs { - if port, err := addr.ValueForProtocol(multiaddr.P_TCP); err == nil { - if portNum, err := strconv.Atoi(port); err == nil && portNum == libp2pPort { - hasValidPort = true - break - } - } - } - - // Skip peers without valid port 4001 addresses - if !hasValidPort { + if !hasLibp2pAddr(peerInfo.Addrs) { continue } @@ -356,28 +347,17 @@ func (d *Manager) discoverViaPeerExchange(ctx context.Context, maxConnections in } // Parse and filter addresses to only include port 4001 (standard libp2p port) - const libp2pPort = 4001 - addrs := make([]multiaddr.Multiaddr, 0, len(peerInfo.Addrs)) + parsedAddrs := make([]multiaddr.Multiaddr, 0, len(peerInfo.Addrs)) for _, addrStr := range peerInfo.Addrs { ma, err := multiaddr.NewMultiaddr(addrStr) if err != nil { d.logger.Debug("Failed to parse multiaddr", zap.Error(err)) continue } - // Only include addresses with port 4001 - port, err := ma.ValueForProtocol(multiaddr.P_TCP) - if err == nil { - portNum, err := strconv.Atoi(port) - if err == nil && portNum == libp2pPort { - addrs = append(addrs, ma) - } - // Skip addresses with wrong ports - } - // Skip non-TCP addresses + parsedAddrs = append(parsedAddrs, ma) } - + addrs := filterLibp2pAddrs(parsedAddrs) if len(addrs) == 0 { - // Skip peers without valid addresses - no need to log every occurrence continue } diff --git a/pkg/environments/production/installers/ipfs.go b/pkg/environments/production/installers/ipfs.go index d8fa906..0a0439e 100644 --- a/pkg/environments/production/installers/ipfs.go +++ b/pkg/environments/production/installers/ipfs.go @@ -96,7 +96,9 @@ func (ii *IPFSInstaller) Install() error { found = true // Ensure it's executable if info.Mode()&0111 == 0 { - os.Chmod(loc, 0755) + if err := os.Chmod(loc, 0755); err != nil { + return fmt.Errorf("failed to make ipfs executable at %s: %w", loc, err) + } } break } diff --git a/pkg/errors/http.go b/pkg/errors/http.go index d1b90eb..9223865 100644 --- a/pkg/errors/http.go +++ b/pkg/errors/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "net/http" + "strconv" ) // HTTPError represents an HTTP error response. @@ -211,7 +212,7 @@ func ToHTTPError(err error, traceID string) *HTTPError { } case errors.As(err, &rateLimitErr): if rateLimitErr.RetryAfter > 0 { - httpErr.Details["retry_after"] = string(rune(rateLimitErr.RetryAfter)) + httpErr.Details["retry_after"] = strconv.Itoa(rateLimitErr.RetryAfter) } case errors.As(err, &serviceErr): if serviceErr.Service != "" { @@ -234,7 +235,7 @@ func WriteHTTPError(w http.ResponseWriter, err error, traceID string) { // Add retry-after header for rate limit errors var rateLimitErr *RateLimitError if errors.As(err, &rateLimitErr) && rateLimitErr.RetryAfter > 0 { - w.Header().Set("Retry-After", string(rune(rateLimitErr.RetryAfter))) + w.Header().Set("Retry-After", strconv.Itoa(rateLimitErr.RetryAfter)) } // Add WWW-Authenticate header for unauthorized errors diff --git a/pkg/gateway/config_validate.go b/pkg/gateway/config_validate.go index baae7be..e4e086d 100644 --- a/pkg/gateway/config_validate.go +++ b/pkg/gateway/config_validate.go @@ -20,7 +20,7 @@ func (c *Config) ValidateConfig() []error { errs = append(errs, fmt.Errorf("gateway.listen_addr: must not be empty")) } else { if err := validateListenAddr(c.ListenAddr); err != nil { - errs = append(errs, fmt.Errorf("gateway.listen_addr: %v", err)) + errs = append(errs, fmt.Errorf("gateway.listen_addr: %w", err)) } } @@ -36,7 +36,7 @@ func (c *Config) ValidateConfig() []error { _, err := multiaddr.NewMultiaddr(peer) if err != nil { - errs = append(errs, fmt.Errorf("%s: invalid multiaddr: %v; expected /ip{4,6}/.../tcp//p2p/", path, err)) + errs = append(errs, fmt.Errorf("%s: invalid multiaddr: %w", path, err)) continue } @@ -66,7 +66,7 @@ func (c *Config) ValidateConfig() []error { // Validate rqlite_dsn if provided if c.RQLiteDSN != "" { if err := validateRQLiteDSN(c.RQLiteDSN); err != nil { - errs = append(errs, fmt.Errorf("gateway.rqlite_dsn: %v", err)) + errs = append(errs, fmt.Errorf("gateway.rqlite_dsn: %w", err)) } } @@ -116,7 +116,7 @@ func validateListenAddr(addr string) error { func validateRQLiteDSN(dsn string) error { u, err := url.Parse(dsn) if err != nil { - return fmt.Errorf("invalid URL: %v", err) + return fmt.Errorf("invalid URL: %w", err) } if u.Scheme != "http" && u.Scheme != "https" { diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 9df7478..7194bb8 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -378,17 +378,18 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { gw.processManager = process.NewManager(logger.Logger) // Create deployment service + baseDomain := gw.cfg.BaseDomain + if baseDomain == "" { + baseDomain = "dbrs.space" + } gw.deploymentService = deploymentshandlers.NewDeploymentService( deps.ORMClient, gw.homeNodeManager, gw.portAllocator, gw.replicaManager, logger.Logger, + baseDomain, ) - // Set base domain from config - if gw.cfg.BaseDomain != "" { - gw.deploymentService.SetBaseDomain(gw.cfg.BaseDomain) - } // Set node peer ID so deployments run on the node that receives the request if gw.cfg.NodePeerID != "" { gw.deploymentService.SetNodePeerID(gw.cfg.NodePeerID) diff --git a/pkg/gateway/handlers/deployments/service.go b/pkg/gateway/handlers/deployments/service.go index 0388cf9..38af662 100644 --- a/pkg/gateway/handlers/deployments/service.go +++ b/pkg/gateway/handlers/deployments/service.go @@ -34,13 +34,15 @@ type DeploymentService struct { nodePeerID string // Current node's peer ID (deployments run on this node) } -// NewDeploymentService creates a new deployment service +// NewDeploymentService creates a new deployment service. +// baseDomain is required and sets the domain used for deployment URLs (e.g., "dbrs.space"). func NewDeploymentService( db rqlite.Client, homeNodeManager *deployments.HomeNodeManager, portAllocator *deployments.PortAllocator, replicaManager *deployments.ReplicaManager, logger *zap.Logger, + baseDomain string, ) *DeploymentService { return &DeploymentService{ db: db, @@ -48,7 +50,7 @@ func NewDeploymentService( portAllocator: portAllocator, replicaManager: replicaManager, logger: logger, - baseDomain: "dbrs.space", // default + baseDomain: baseDomain, } } @@ -65,11 +67,8 @@ func (s *DeploymentService) SetNodePeerID(peerID string) { s.nodePeerID = peerID } -// BaseDomain returns the configured base domain +// BaseDomain returns the configured base domain. func (s *DeploymentService) BaseDomain() string { - if s.baseDomain == "" { - return "dbrs.space" - } return s.baseDomain } diff --git a/pkg/gateway/instance_spawner.go b/pkg/gateway/instance_spawner.go index 6d0563e..77eea8f 100644 --- a/pkg/gateway/instance_spawner.go +++ b/pkg/gateway/instance_spawner.go @@ -63,11 +63,14 @@ type GatewayInstance struct { OlricServers []string // Connection to namespace Olric ConfigPath string PID int - Status InstanceNodeStatus StartedAt time.Time - LastHealthCheck time.Time cmd *exec.Cmd logger *zap.Logger + + // mu protects mutable state accessed concurrently by the monitor goroutine. + mu sync.RWMutex + Status InstanceNodeStatus + LastHealthCheck time.Time } // InstanceConfig holds configuration for spawning a Gateway instance @@ -130,13 +133,14 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig is.mu.Lock() if existing, ok := is.instances[key]; ok { - is.mu.Unlock() - // Instance already exists, return it if running - if existing.Status == InstanceStatusRunning { + existing.mu.RLock() + status := existing.Status + existing.mu.RUnlock() + if status == InstanceStatusRunning { + is.mu.Unlock() return existing, nil } // Otherwise, remove it and start fresh - is.mu.Lock() delete(is.instances, key) } is.mu.Unlock() @@ -261,8 +265,10 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig } } + instance.mu.Lock() instance.Status = InstanceStatusRunning instance.LastHealthCheck = time.Now() + instance.mu.Unlock() instance.logger.Info("Gateway instance started successfully", zap.Int("pid", instance.PID), @@ -356,7 +362,9 @@ func (is *InstanceSpawner) StopInstance(ctx context.Context, ns, nodeID string) } } + instance.mu.Lock() instance.Status = InstanceStatusStopped + instance.mu.Unlock() return nil } @@ -415,9 +423,9 @@ func (is *InstanceSpawner) HealthCheck(ctx context.Context, ns, nodeID string) ( healthy, err := instance.IsHealthy(ctx) if healthy { - is.mu.Lock() + instance.mu.Lock() instance.LastHealthCheck = time.Now() - is.mu.Unlock() + instance.mu.Unlock() } return healthy, err } @@ -474,7 +482,7 @@ func (is *InstanceSpawner) monitorInstance(instance *GatewayInstance) { healthy, _ := instance.IsHealthy(ctx) cancel() - is.mu.Lock() + instance.mu.Lock() if healthy { instance.Status = InstanceStatusRunning instance.LastHealthCheck = time.Now() @@ -482,13 +490,13 @@ func (is *InstanceSpawner) monitorInstance(instance *GatewayInstance) { instance.Status = InstanceStatusFailed instance.logger.Warn("Gateway instance health check failed") } - is.mu.Unlock() + instance.mu.Unlock() // Check if process is still running if instance.cmd != nil && instance.cmd.ProcessState != nil && instance.cmd.ProcessState.Exited() { - is.mu.Lock() + instance.mu.Lock() instance.Status = InstanceStatusStopped - is.mu.Unlock() + instance.mu.Unlock() instance.logger.Warn("Gateway instance process exited unexpectedly") return } diff --git a/pkg/gateway/lifecycle.go b/pkg/gateway/lifecycle.go index fd2ec4d..049336d 100644 --- a/pkg/gateway/lifecycle.go +++ b/pkg/gateway/lifecycle.go @@ -50,4 +50,12 @@ func (g *Gateway) Close() { g.logger.ComponentWarn(logging.ComponentGeneral, "error during IPFS client close", zap.Error(err)) } } + + // Stop background goroutines + if g.mwCache != nil { + g.mwCache.Stop() + } + if g.rateLimiter != nil { + g.rateLimiter.Stop() + } } diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index c4ff731..ba2b386 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -2,7 +2,7 @@ package gateway import ( "context" - "encoding/json" + "fmt" "hash/fnv" "io" "net" @@ -64,41 +64,40 @@ func (g *Gateway) validateAuthForNamespaceProxy(r *http.Request) (namespace stri return "", "" // No credentials provided } - // Check middleware cache first + ns, err := g.lookupAPIKeyNamespace(r.Context(), key, g.client) + if err != nil { + return "", "invalid API key" + } + return ns, "" +} + +// lookupAPIKeyNamespace resolves an API key to its namespace using cache and DB. +// dbClient controls which database is queried (global vs namespace-specific). +// Returns the namespace name or an error if the key is invalid. +func (g *Gateway) lookupAPIKeyNamespace(ctx context.Context, key string, dbClient client.NetworkClient) (string, error) { if g.mwCache != nil { if cachedNS, ok := g.mwCache.GetAPIKeyNamespace(key); ok { - return cachedNS, "" + return cachedNS, nil } } - // Cache miss — look up API key in main cluster RQLite - db := g.client.Database() - internalCtx := client.WithInternalAuth(r.Context()) + db := dbClient.Database() + internalCtx := client.WithInternalAuth(ctx) q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1" res, err := db.Query(internalCtx, q, key) if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 { - return "", "invalid API key" + return "", fmt.Errorf("invalid API key") } - // Extract namespace name - var ns string - if s, ok := res.Rows[0][0].(string); ok { - ns = strings.TrimSpace(s) - } else { - b, _ := json.Marshal(res.Rows[0][0]) - _ = json.Unmarshal(b, &ns) - ns = strings.TrimSpace(ns) - } + ns := getString(res.Rows[0][0]) if ns == "" { - return "", "invalid API key" + return "", fmt.Errorf("invalid API key") } - // Cache the result if g.mwCache != nil { g.mwCache.SetAPIKeyNamespace(key, ns) } - - return ns, "" + return ns, nil } // isWebSocketUpgrade checks if the request is a WebSocket upgrade request @@ -179,7 +178,7 @@ func (g *Gateway) proxyWebSocket(w http.ResponseWriter, r *http.Request, targetH // withMiddleware adds CORS, security headers, rate limiting, and logging middleware func (g *Gateway) withMiddleware(next http.Handler) http.Handler { - // Order: logging -> security headers -> rate limit -> CORS -> domain routing -> auth -> namespace rate limit -> handler + // Order: logging -> security headers -> rate limit -> CORS -> domain routing -> auth -> authorization -> namespace rate limit -> handler return g.loggingMiddleware( g.securityHeadersMiddleware( g.rateLimitMiddleware( @@ -309,30 +308,13 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler { return } - // Check middleware cache first for API key → namespace mapping - if g.mwCache != nil { - if cachedNS, ok := g.mwCache.GetAPIKeyNamespace(key); ok { - reqCtx := context.WithValue(r.Context(), ctxKeyAPIKey, key) - reqCtx = context.WithValue(reqCtx, CtxKeyNamespaceOverride, cachedNS) - next.ServeHTTP(w, r.WithContext(reqCtx)) - return - } - } - - // Cache miss — look up API key in DB and derive namespace - // Use authClient for namespace gateways (validates against global RQLite) - // Otherwise use regular client for global gateways - authClient := g.client + // Look up API key → namespace (uses cache + DB) + dbClient := g.client if g.authClient != nil { - authClient = g.authClient + dbClient = g.authClient } - db := authClient.Database() - // Use internal auth for DB validation (auth not established yet) - internalCtx := client.WithInternalAuth(r.Context()) - // Join to namespaces to resolve name in one query - q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1" - res, err := db.Query(internalCtx, q, key) - if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 { + ns, err := g.lookupAPIKeyNamespace(r.Context(), key, dbClient) + if err != nil { if isPublic { next.ServeHTTP(w, r) return @@ -341,29 +323,6 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler { writeError(w, http.StatusUnauthorized, "invalid API key") return } - // Extract namespace name - var ns string - if s, ok := res.Rows[0][0].(string); ok { - ns = strings.TrimSpace(s) - } else { - b, _ := json.Marshal(res.Rows[0][0]) - _ = json.Unmarshal(b, &ns) - ns = strings.TrimSpace(ns) - } - if ns == "" { - if isPublic { - next.ServeHTTP(w, r) - return - } - w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") - writeError(w, http.StatusUnauthorized, "invalid API key") - return - } - - // Cache the result for subsequent requests - if g.mwCache != nil { - g.mwCache.SetAPIKeyNamespace(key, ns) - } // Attach auth metadata to context for downstream use reqCtx := context.WithValue(r.Context(), ctxKeyAPIKey, key) diff --git a/pkg/gateway/middleware_cache.go b/pkg/gateway/middleware_cache.go index fab8bcb..7c51a76 100644 --- a/pkg/gateway/middleware_cache.go +++ b/pkg/gateway/middleware_cache.go @@ -20,7 +20,8 @@ type middlewareCache struct { nsTargets map[string]*cachedGatewayTargets nsTargetsMu sync.RWMutex - ttl time.Duration + ttl time.Duration + stopCh chan struct{} } type cachedValue struct { @@ -43,11 +44,17 @@ func newMiddlewareCache(ttl time.Duration) *middlewareCache { apiKeyNS: make(map[string]*cachedValue), nsTargets: make(map[string]*cachedGatewayTargets), ttl: ttl, + stopCh: make(chan struct{}), } go mc.cleanup() return mc } +// Stop stops the background cleanup goroutine. +func (mc *middlewareCache) Stop() { + close(mc.stopCh) +} + // GetAPIKeyNamespace returns the cached namespace for an API key, or "" if not cached/expired. func (mc *middlewareCache) GetAPIKeyNamespace(apiKey string) (string, bool) { mc.apiKeyNSMu.RLock() @@ -99,23 +106,28 @@ func (mc *middlewareCache) cleanup() { ticker := time.NewTicker(2 * time.Minute) defer ticker.Stop() - for range ticker.C { - now := time.Now() + for { + select { + case <-ticker.C: + now := time.Now() - mc.apiKeyNSMu.Lock() - for k, v := range mc.apiKeyNS { - if now.After(v.expiresAt) { - delete(mc.apiKeyNS, k) + mc.apiKeyNSMu.Lock() + for k, v := range mc.apiKeyNS { + if now.After(v.expiresAt) { + delete(mc.apiKeyNS, k) + } } - } - mc.apiKeyNSMu.Unlock() + mc.apiKeyNSMu.Unlock() - mc.nsTargetsMu.Lock() - for k, v := range mc.nsTargets { - if now.After(v.expiresAt) { - delete(mc.nsTargets, k) + mc.nsTargetsMu.Lock() + for k, v := range mc.nsTargets { + if now.After(v.expiresAt) { + delete(mc.nsTargets, k) + } } + mc.nsTargetsMu.Unlock() + case <-mc.stopCh: + return } - mc.nsTargetsMu.Unlock() } } diff --git a/pkg/gateway/rate_limiter.go b/pkg/gateway/rate_limiter.go index d080602..8d05568 100644 --- a/pkg/gateway/rate_limiter.go +++ b/pkg/gateway/rate_limiter.go @@ -8,12 +8,20 @@ import ( "time" ) +// wireGuardNet is the WireGuard mesh subnet, parsed once at init. +var wireGuardNet *net.IPNet + +func init() { + _, wireGuardNet, _ = net.ParseCIDR("10.0.0.0/8") +} + // RateLimiter implements a token-bucket rate limiter per client IP. type RateLimiter struct { mu sync.Mutex clients map[string]*bucket rate float64 // tokens per second burst int // max tokens (burst capacity) + stopCh chan struct{} } type bucket struct { @@ -71,17 +79,30 @@ func (rl *RateLimiter) Cleanup(maxAge time.Duration) { } } -// StartCleanup runs periodic cleanup in a goroutine. +// StartCleanup runs periodic cleanup in a goroutine. Call Stop() to terminate it. func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) { + rl.stopCh = make(chan struct{}) go func() { ticker := time.NewTicker(interval) defer ticker.Stop() - for range ticker.C { - rl.Cleanup(maxAge) + for { + select { + case <-ticker.C: + rl.Cleanup(maxAge) + case <-rl.stopCh: + return + } } }() } +// Stop terminates the background cleanup goroutine. +func (rl *RateLimiter) Stop() { + if rl.stopCh != nil { + close(rl.stopCh) + } +} + // NamespaceRateLimiter provides per-namespace rate limiting using a sync.Map // for better concurrent performance than a single mutex. type NamespaceRateLimiter struct { @@ -167,6 +188,5 @@ func isInternalIP(ipStr string) bool { return true } // 10.0.0.0/8 — WireGuard mesh - _, wgNet, _ := net.ParseCIDR("10.0.0.0/8") - return wgNet.Contains(ip) + return wireGuardNet.Contains(ip) } diff --git a/pkg/gateway/request_log_batcher.go b/pkg/gateway/request_log_batcher.go index 5c12382..9ac00e6 100644 --- a/pkg/gateway/request_log_batcher.go +++ b/pkg/gateway/request_log_batcher.go @@ -158,7 +158,9 @@ func (b *requestLogBatcher) flush() { args = append(args, e.method, e.path, e.statusCode, e.bytesOut, e.durationMs, e.ip, apiKeyID) } - _, _ = db.Query(client.WithInternalAuth(ctx), sb.String(), args...) + if _, err := db.Query(client.WithInternalAuth(ctx), sb.String(), args...); err != nil && b.gw.logger != nil { + b.gw.logger.ComponentWarn(logging.ComponentGeneral, "failed to flush request logs", zap.Error(err)) + } } // Batch UPDATE last_used_at for all API keys seen in this batch @@ -171,7 +173,9 @@ func (b *requestLogBatcher) flush() { } q := fmt.Sprintf("UPDATE api_keys SET last_used_at = CURRENT_TIMESTAMP WHERE id IN (%s)", strings.Join(ids, ",")) - _, _ = db.Query(client.WithInternalAuth(ctx), q, args...) + if _, err := db.Query(client.WithInternalAuth(ctx), q, args...); err != nil && b.gw.logger != nil { + b.gw.logger.ComponentWarn(logging.ComponentGeneral, "failed to update api key last_used_at", zap.Error(err)) + } } if b.gw.logger != nil { diff --git a/pkg/httputil/validation.go b/pkg/httputil/validation.go index d99baca..b81df5c 100644 --- a/pkg/httputil/validation.go +++ b/pkg/httputil/validation.go @@ -46,12 +46,14 @@ func ValidateTopicName(topic string) bool { return topicRegex.MatchString(topic) } -// ValidateWalletAddress checks if a string looks like an Ethereum wallet address. -// Valid addresses are 40 hex characters, optionally prefixed with "0x". -var walletRegex = regexp.MustCompile(`^(0x)?[0-9a-fA-F]{40}$`) +// ValidateWalletAddress checks if a string looks like a valid wallet address. +// Supports Ethereum (40 hex chars, optional "0x" prefix) and Solana (32-44 base58 chars). +var ethWalletRegex = regexp.MustCompile(`^(0x)?[0-9a-fA-F]{40}$`) +var solanaWalletRegex = regexp.MustCompile(`^[1-9A-HJ-NP-Za-km-z]{32,44}$`) func ValidateWalletAddress(wallet string) bool { - return walletRegex.MatchString(strings.TrimSpace(wallet)) + wallet = strings.TrimSpace(wallet) + return ethWalletRegex.MatchString(wallet) || solanaWalletRegex.MatchString(wallet) } // NormalizeWalletAddress normalizes a wallet address by removing "0x" prefix and converting to lowercase. diff --git a/pkg/httputil/validation_test.go b/pkg/httputil/validation_test.go index 7c40be0..338f037 100644 --- a/pkg/httputil/validation_test.go +++ b/pkg/httputil/validation_test.go @@ -174,6 +174,21 @@ func TestValidateWalletAddress(t *testing.T) { wallet: "", valid: false, }, + { + name: "valid Solana address", + wallet: "7EcDhSYGxXyscszYEp35KHN8vvw3svAuLKTzXwCFLtV", + valid: true, + }, + { + name: "valid Solana address 44 chars", + wallet: "DRpbCBMxVnDK7maPMoGQfFiDro5Z4Ztgcyih2yZbpaHY", + valid: true, + }, + { + name: "invalid Solana - too short", + wallet: "7EcDhSYGx", + valid: false, + }, } for _, tt := range tests { diff --git a/pkg/node/libp2p.go b/pkg/node/libp2p.go index c19151c..00119b4 100644 --- a/pkg/node/libp2p.go +++ b/pkg/node/libp2p.go @@ -84,7 +84,7 @@ func (n *Node) startLibP2P() error { } // Create pubsub adapter - n.pubsub = pubsub.NewClientAdapter(ps, n.config.Discovery.NodeNamespace) + n.pubsub = pubsub.NewClientAdapter(ps, n.config.Discovery.NodeNamespace, n.logger.Logger) n.logger.Info("Initialized pubsub adapter on namespace", zap.String("namespace", n.config.Discovery.NodeNamespace)) // Connect to peers diff --git a/pkg/node/node.go b/pkg/node/node.go index c4b8438..2686284 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -5,8 +5,6 @@ import ( "fmt" "net/http" "os" - "path/filepath" - "strings" "time" "github.com/DeBrosOfficial/network/pkg/config" @@ -65,15 +63,10 @@ func NewNode(cfg *config.Config) (*Node, error) { func (n *Node) Start(ctx context.Context) error { n.logger.Info("Starting network node", zap.String("data_dir", n.config.Node.DataDir)) - // Expand ~ in data directory path - dataDir := n.config.Node.DataDir - dataDir = os.ExpandEnv(dataDir) - if strings.HasPrefix(dataDir, "~") { - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to determine home directory: %w", err) - } - dataDir = filepath.Join(home, dataDir[1:]) + // Expand ~ and env vars in data directory path + dataDir, err := config.ExpandPath(n.config.Node.DataDir) + if err != nil { + return fmt.Errorf("failed to expand data directory path: %w", err) } // Create data directory diff --git a/pkg/node/utils.go b/pkg/node/utils.go index d9d366c..b4577f8 100644 --- a/pkg/node/utils.go +++ b/pkg/node/utils.go @@ -9,9 +9,9 @@ import ( "net" "os" "path/filepath" - "strings" "time" + "github.com/DeBrosOfficial/network/pkg/config" "github.com/DeBrosOfficial/network/pkg/encryption" "github.com/multiformats/go-multiaddr" ) @@ -74,11 +74,11 @@ func addJitter(interval time.Duration) time.Duration { } func loadNodePeerIDFromIdentity(dataDir string) string { - identityFile := filepath.Join(os.ExpandEnv(dataDir), "identity.key") - if strings.HasPrefix(identityFile, "~") { - home, _ := os.UserHomeDir() - identityFile = filepath.Join(home, identityFile[1:]) + expanded, err := config.ExpandPath(dataDir) + if err != nil { + return "" } + identityFile := filepath.Join(expanded, "identity.key") if info, err := encryption.LoadIdentity(identityFile); err == nil { return info.PeerID.String() @@ -98,7 +98,9 @@ func extractPEMFromTLSCert(tlsCert *tls.Certificate, certPath, keyPath string) e defer certFile.Close() for _, certBytes := range tlsCert.Certificate { - pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}); err != nil { + return fmt.Errorf("failed to encode certificate PEM: %w", err) + } } if tlsCert.PrivateKey == nil { @@ -111,17 +113,20 @@ func extractPEMFromTLSCert(tlsCert *tls.Certificate, certPath, keyPath string) e } defer keyFile.Close() - var keyBytes []byte - switch key := tlsCert.PrivateKey.(type) { - case *x509.Certificate: - keyBytes, _ = x509.MarshalPKCS8PrivateKey(key) - default: - keyBytes, _ = x509.MarshalPKCS8PrivateKey(tlsCert.PrivateKey) + keyBytes, err := x509.MarshalPKCS8PrivateKey(tlsCert.PrivateKey) + if err != nil { + return fmt.Errorf("failed to marshal private key: %w", err) } - pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}) - os.Chmod(certPath, 0644) - os.Chmod(keyPath, 0600) + if err := pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}); err != nil { + return fmt.Errorf("failed to encode private key PEM: %w", err) + } + if err := os.Chmod(certPath, 0644); err != nil { + return fmt.Errorf("failed to set certificate permissions: %w", err) + } + if err := os.Chmod(keyPath, 0600); err != nil { + return fmt.Errorf("failed to set private key permissions: %w", err) + } return nil } diff --git a/pkg/olric/instance_spawner.go b/pkg/olric/instance_spawner.go index fa02eda..6510bf3 100644 --- a/pkg/olric/instance_spawner.go +++ b/pkg/olric/instance_spawner.go @@ -64,13 +64,17 @@ type OlricInstance struct { ConfigPath string DataDir string PID int - Status InstanceNodeStatus StartedAt time.Time - LastHealthCheck time.Time cmd *exec.Cmd logFile *os.File // kept open for process lifetime waitDone chan struct{} // closed when cmd.Wait() completes logger *zap.Logger + + // mu protects mutable state (Status, LastHealthCheck) accessed concurrently + // by the monitor goroutine and external callers. + mu sync.RWMutex + Status InstanceNodeStatus + LastHealthCheck time.Time } // InstanceConfig holds configuration for spawning an Olric instance @@ -130,7 +134,10 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig is.mu.Lock() if existing, ok := is.instances[key]; ok { - if existing.Status == InstanceStatusRunning || existing.Status == InstanceStatusStarting { + existing.mu.RLock() + status := existing.Status + existing.mu.RUnlock() + if status == InstanceStatusRunning || status == InstanceStatusStarting { is.mu.Unlock() return existing, nil } @@ -243,8 +250,10 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig } } + instance.mu.Lock() instance.Status = InstanceStatusRunning instance.LastHealthCheck = time.Now() + instance.mu.Unlock() instance.logger.Info("Olric instance started successfully", zap.Int("pid", instance.PID), @@ -331,7 +340,9 @@ func (is *InstanceSpawner) StopInstance(ctx context.Context, ns, nodeID string) } } + instance.mu.Lock() instance.Status = InstanceStatusStopped + instance.mu.Unlock() return nil } @@ -390,9 +401,9 @@ func (is *InstanceSpawner) HealthCheck(ctx context.Context, ns, nodeID string) ( healthy, err := instance.IsHealthy(ctx) if healthy { - is.mu.Lock() + instance.mu.Lock() instance.LastHealthCheck = time.Now() - is.mu.Unlock() + instance.mu.Unlock() } return healthy, err } @@ -450,13 +461,16 @@ func (is *InstanceSpawner) monitorInstance(instance *OlricInstance) { select { case <-instance.waitDone: // Process exited — update status and stop monitoring - is.mu.Lock() + is.mu.RLock() key := instanceKey(instance.Namespace, instance.NodeID) - if _, exists := is.instances[key]; exists { + _, exists := is.instances[key] + is.mu.RUnlock() + if exists { + instance.mu.Lock() instance.Status = InstanceStatusStopped + instance.mu.Unlock() instance.logger.Warn("Olric instance process exited unexpectedly") } - is.mu.Unlock() return case <-ticker.C: } @@ -474,7 +488,7 @@ func (is *InstanceSpawner) monitorInstance(instance *OlricInstance) { healthy, _ := instance.IsHealthy(ctx) cancel() - is.mu.Lock() + instance.mu.Lock() if healthy { instance.Status = InstanceStatusRunning instance.LastHealthCheck = time.Now() @@ -482,7 +496,7 @@ func (is *InstanceSpawner) monitorInstance(instance *OlricInstance) { instance.Status = InstanceStatusFailed instance.logger.Warn("Olric instance health check failed") } - is.mu.Unlock() + instance.mu.Unlock() } } diff --git a/pkg/pubsub/adapter.go b/pkg/pubsub/adapter.go index 51e0893..de8f4c5 100644 --- a/pkg/pubsub/adapter.go +++ b/pkg/pubsub/adapter.go @@ -4,6 +4,7 @@ import ( "context" pubsub "github.com/libp2p/go-libp2p-pubsub" + "go.uber.org/zap" ) // ClientAdapter adapts the pubsub Manager to work with the existing client interface @@ -12,9 +13,9 @@ type ClientAdapter struct { } // NewClientAdapter creates a new adapter for the pubsub manager -func NewClientAdapter(ps *pubsub.PubSub, namespace string) *ClientAdapter { +func NewClientAdapter(ps *pubsub.PubSub, namespace string, logger *zap.Logger) *ClientAdapter { return &ClientAdapter{ - manager: NewManager(ps, namespace), + manager: NewManager(ps, namespace, logger), } } diff --git a/pkg/pubsub/discovery_integration.go b/pkg/pubsub/discovery_integration.go index 4016a63..caec4c1 100644 --- a/pkg/pubsub/discovery_integration.go +++ b/pkg/pubsub/discovery_integration.go @@ -2,10 +2,10 @@ package pubsub import ( "context" - "log" "time" pubsub "github.com/libp2p/go-libp2p-pubsub" + "go.uber.org/zap" ) // announceTopicInterest helps with peer discovery by announcing interest in a topic. @@ -34,18 +34,22 @@ func (m *Manager) announceTopicInterest(topicName string) { // forceTopicPeerDiscovery uses a simple strategy to announce presence on the topic. // It publishes lightweight discovery pings continuously to maintain mesh health. func (m *Manager) forceTopicPeerDiscovery(topicName string, topic *pubsub.Topic) { - log.Printf("[PUBSUB] Starting continuous peer discovery for topic: %s", topicName) - + m.logger.Debug("Starting continuous peer discovery", zap.String("topic", topicName)) + // Initial aggressive discovery phase (10 attempts) for attempt := 0; attempt < 10; attempt++ { peers := topic.ListPeers() if len(peers) > 0 { - log.Printf("[PUBSUB] Topic %s: Found %d peers in initial discovery", topicName, len(peers)) + m.logger.Debug("Found peers in initial discovery", + zap.String("topic", topicName), + zap.Int("peers", len(peers))) break } - log.Printf("[PUBSUB] Topic %s: Initial attempt %d, sending discovery ping", topicName, attempt+1) - + m.logger.Debug("Sending discovery ping", + zap.String("topic", topicName), + zap.Int("attempt", attempt+1)) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) discoveryMsg := []byte("PEER_DISCOVERY_PING") _ = topic.Publish(ctx, discoveryMsg) @@ -57,25 +61,25 @@ func (m *Manager) forceTopicPeerDiscovery(topicName string, topic *pubsub.Topic) } time.Sleep(delay) } - + // Continuous maintenance phase - keep pinging every 15 seconds ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() - + for i := 0; i < 20; i++ { // Run for ~5 minutes total <-ticker.C peers := topic.ListPeers() - + if len(peers) == 0 { - log.Printf("[PUBSUB] Topic %s: No peers, sending maintenance ping", topicName) + m.logger.Debug("No peers, sending maintenance ping", zap.String("topic", topicName)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) discoveryMsg := []byte("PEER_DISCOVERY_PING") _ = topic.Publish(ctx, discoveryMsg) cancel() } } - - log.Printf("[PUBSUB] Topic %s: Peer discovery maintenance completed", topicName) + + m.logger.Debug("Peer discovery maintenance completed", zap.String("topic", topicName)) } // monitorTopicPeers periodically checks topic peer connectivity and stops once peers are found. diff --git a/pkg/pubsub/manager.go b/pkg/pubsub/manager.go index 4c481e8..6f5a92e 100644 --- a/pkg/pubsub/manager.go +++ b/pkg/pubsub/manager.go @@ -6,6 +6,7 @@ import ( "sync" pubsub "github.com/libp2p/go-libp2p-pubsub" + "go.uber.org/zap" ) // Manager handles pub/sub operations @@ -14,6 +15,7 @@ type Manager struct { topics map[string]*pubsub.Topic subscriptions map[string]*topicSubscription namespace string + logger *zap.Logger mu sync.RWMutex } @@ -27,12 +29,13 @@ type topicSubscription struct { } // NewManager creates a new pubsub manager -func NewManager(ps *pubsub.PubSub, namespace string) *Manager { - return &Manager { +func NewManager(ps *pubsub.PubSub, namespace string, logger *zap.Logger) *Manager { + return &Manager{ pubsub: ps, topics: make(map[string]*pubsub.Topic), subscriptions: make(map[string]*topicSubscription), namespace: namespace, + logger: logger.Named("pubsub"), } } diff --git a/pkg/pubsub/manager_test.go b/pkg/pubsub/manager_test.go index 612297d..f7014f1 100644 --- a/pkg/pubsub/manager_test.go +++ b/pkg/pubsub/manager_test.go @@ -8,6 +8,7 @@ import ( "github.com/libp2p/go-libp2p" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" ) func createTestManager(t *testing.T, ns string) (*Manager, func()) { @@ -24,7 +25,7 @@ func createTestManager(t *testing.T, ns string) (*Manager, func()) { t.Fatalf("failed to create gossipsub: %v", err) } - mgr := NewManager(ps, ns) + mgr := NewManager(ps, ns, zap.NewNop()) cleanup := func() { mgr.Close() @@ -165,13 +166,13 @@ func TestManager_PubSub(t *testing.T) { h1, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) ps1, _ := pubsub.NewGossipSub(ctx, h1) - mgr1 := NewManager(ps1, "test") + mgr1 := NewManager(ps1, "test", zap.NewNop()) defer h1.Close() defer mgr1.Close() h2, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) ps2, _ := pubsub.NewGossipSub(ctx, h2) - mgr2 := NewManager(ps2, "test") + mgr2 := NewManager(ps2, "test", zap.NewNop()) defer h2.Close() defer mgr2.Close() diff --git a/pkg/rqlite/util.go b/pkg/rqlite/util.go index 693be82..e24d54f 100644 --- a/pkg/rqlite/util.go +++ b/pkg/rqlite/util.go @@ -3,15 +3,15 @@ package rqlite import ( "os" "path/filepath" - "strings" "time" + + "github.com/DeBrosOfficial/network/pkg/config" ) func (r *RQLiteManager) rqliteDataDirPath() (string, error) { - dataDir := os.ExpandEnv(r.dataDir) - if strings.HasPrefix(dataDir, "~") { - home, _ := os.UserHomeDir() - dataDir = filepath.Join(home, dataDir[1:]) + dataDir, err := config.ExpandPath(r.dataDir) + if err != nil { + return "", err } return filepath.Join(dataDir, "rqlite"), nil } diff --git a/pkg/systemd/manager.go b/pkg/systemd/manager.go index 78a0f5f..65605ac 100644 --- a/pkg/systemd/manager.go +++ b/pkg/systemd/manager.go @@ -59,7 +59,7 @@ func (m *Manager) StartService(namespace string, serviceType ServiceType) error zap.Error(err), zap.String("output", string(output)), zap.String("cmd", cmd.String())) - return fmt.Errorf("failed to start %s: %w\nOutput: %s", svcName, err, string(output)) + return fmt.Errorf("failed to start %s: %w; output: %s", svcName, err, string(output)) } m.logger.Info("Service started successfully", @@ -82,7 +82,7 @@ func (m *Manager) StopService(namespace string, serviceType ServiceType) error { m.logger.Debug("Service already stopped or not loaded", zap.String("service", svcName)) return nil } - return fmt.Errorf("failed to stop %s: %w\nOutput: %s", svcName, err, string(output)) + return fmt.Errorf("failed to stop %s: %w; output: %s", svcName, err, string(output)) } m.logger.Info("Service stopped successfully", zap.String("service", svcName)) @@ -98,7 +98,7 @@ func (m *Manager) RestartService(namespace string, serviceType ServiceType) erro cmd := exec.Command("sudo", "-n", "systemctl", "restart", svcName) if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to restart %s: %w\nOutput: %s", svcName, err, string(output)) + return fmt.Errorf("failed to restart %s: %w; output: %s", svcName, err, string(output)) } m.logger.Info("Service restarted successfully", zap.String("service", svcName)) @@ -114,7 +114,7 @@ func (m *Manager) EnableService(namespace string, serviceType ServiceType) error cmd := exec.Command("sudo", "-n", "systemctl", "enable", svcName) if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to enable %s: %w\nOutput: %s", svcName, err, string(output)) + return fmt.Errorf("failed to enable %s: %w; output: %s", svcName, err, string(output)) } m.logger.Info("Service enabled successfully", zap.String("service", svcName)) @@ -135,7 +135,7 @@ func (m *Manager) DisableService(namespace string, serviceType ServiceType) erro m.logger.Debug("Service not loaded", zap.String("service", svcName)) return nil } - return fmt.Errorf("failed to disable %s: %w\nOutput: %s", svcName, err, string(output)) + return fmt.Errorf("failed to disable %s: %w; output: %s", svcName, err, string(output)) } m.logger.Info("Service disabled successfully", zap.String("service", svcName)) @@ -172,7 +172,7 @@ func (m *Manager) IsServiceActive(namespace string, serviceType ServiceType) (bo zap.String("service", svcName), zap.Error(err), zap.String("output", outputStr)) - return false, fmt.Errorf("failed to check service status: %w\nOutput: %s", err, outputStr) + return false, fmt.Errorf("failed to check service status: %w; output: %s", err, outputStr) } isActive := outputStr == "active" @@ -187,7 +187,7 @@ func (m *Manager) ReloadDaemon() error { m.logger.Info("Reloading systemd daemon") cmd := exec.Command("sudo", "-n", "systemctl", "daemon-reload") if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to reload systemd daemon: %w\nOutput: %s", err, string(output)) + return fmt.Errorf("failed to reload systemd daemon: %w; output: %s", err, string(output)) } return nil } @@ -231,7 +231,7 @@ func (m *Manager) ListNamespaceServices() ([]string, error) { cmd := exec.Command("sudo", "-n", "systemctl", "list-units", "--all", "--no-legend", "debros-namespace-*@*.service") output, err := cmd.CombinedOutput() if err != nil { - return nil, fmt.Errorf("failed to list namespace services: %w\nOutput: %s", err, string(output)) + return nil, fmt.Errorf("failed to list namespace services: %w; output: %s", err, string(output)) } var services []string