diff --git a/pkg/cli/cmd/namespacecmd/rqlite.go b/pkg/cli/cmd/namespacecmd/rqlite.go new file mode 100644 index 0000000..3cd9944 --- /dev/null +++ b/pkg/cli/cmd/namespacecmd/rqlite.go @@ -0,0 +1,219 @@ +package namespacecmd + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/DeBrosOfficial/network/pkg/auth" + "github.com/spf13/cobra" +) + +var rqliteCmd = &cobra.Command{ + Use: "rqlite", + Short: "Manage the namespace's internal RQLite database", + Long: "Export and import the namespace's internal RQLite database (stores deployments, DNS records, API keys, etc.).", +} + +var rqliteExportCmd = &cobra.Command{ + Use: "export", + Short: "Export the namespace's RQLite database to a local SQLite file", + Long: "Downloads a consistent SQLite snapshot of the namespace's internal RQLite database.", + RunE: rqliteExport, +} + +var rqliteImportCmd = &cobra.Command{ + Use: "import", + Short: "Import a SQLite dump into the namespace's RQLite (DESTRUCTIVE)", + Long: `Replaces the namespace's entire RQLite database with the contents of the provided SQLite file. + +WARNING: This is a destructive operation. All existing data in the namespace's RQLite +(deployments, DNS records, API keys, etc.) will be replaced with the imported file.`, + RunE: rqliteImport, +} + +func init() { + rqliteExportCmd.Flags().StringP("output", "o", "", "Output file path (default: rqlite-export.db)") + + rqliteImportCmd.Flags().StringP("input", "i", "", "Input SQLite file path") + _ = rqliteImportCmd.MarkFlagRequired("input") + + rqliteCmd.AddCommand(rqliteExportCmd) + rqliteCmd.AddCommand(rqliteImportCmd) + + Cmd.AddCommand(rqliteCmd) +} + +func rqliteExport(cmd *cobra.Command, args []string) error { + output, _ := cmd.Flags().GetString("output") + if output == "" { + output = "rqlite-export.db" + } + + apiURL := nsRQLiteAPIURL() + token, err := nsRQLiteAuthToken() + if err != nil { + return err + } + + url := apiURL + "/v1/rqlite/export" + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + client := &http.Client{ + Timeout: 0, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + + fmt.Printf("Exporting RQLite database to %s...\n", output) + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to gateway: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("export failed (HTTP %d): %s", resp.StatusCode, string(body)) + } + + outFile, err := os.Create(output) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + written, err := io.Copy(outFile, resp.Body) + if err != nil { + os.Remove(output) + return fmt.Errorf("failed to write export file: %w", err) + } + + fmt.Printf("Export complete: %s (%d bytes)\n", output, written) + return nil +} + +func rqliteImport(cmd *cobra.Command, args []string) error { + input, _ := cmd.Flags().GetString("input") + + info, err := os.Stat(input) + if err != nil { + return fmt.Errorf("cannot access input file: %w", err) + } + if info.IsDir() { + return fmt.Errorf("input path is a directory, not a file") + } + + store, err := auth.LoadEnhancedCredentials() + if err != nil { + return fmt.Errorf("failed to load credentials: %w", err) + } + gatewayURL := auth.GetDefaultGatewayURL() + creds := store.GetDefaultCredential(gatewayURL) + if creds == nil || !creds.IsValid() { + return fmt.Errorf("not authenticated. Run 'orama auth login' first") + } + + namespace := creds.Namespace + if namespace == "" { + namespace = "default" + } + + fmt.Printf("WARNING: This will REPLACE the entire RQLite database for namespace '%s'.\n", namespace) + fmt.Printf("All existing data (deployments, DNS records, API keys, etc.) will be lost.\n") + fmt.Printf("Importing from: %s (%d bytes)\n\n", input, info.Size()) + fmt.Printf("Type the namespace name '%s' to confirm: ", namespace) + + scanner := bufio.NewScanner(os.Stdin) + scanner.Scan() + confirmation := strings.TrimSpace(scanner.Text()) + if confirmation != namespace { + return fmt.Errorf("aborted - namespace name did not match") + } + + apiURL := nsRQLiteAPIURL() + token, err := nsRQLiteAuthToken() + if err != nil { + return err + } + + file, err := os.Open(input) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer file.Close() + + url := apiURL + "/v1/rqlite/import" + + req, err := http.NewRequest(http.MethodPost, url, file) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/octet-stream") + req.ContentLength = info.Size() + + client := &http.Client{ + Timeout: 0, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + + fmt.Printf("Importing database...\n") + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to gateway: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("import failed (HTTP %d): %s", resp.StatusCode, string(body)) + } + + fmt.Printf("Import complete. The namespace '%s' RQLite database has been replaced.\n", namespace) + return nil +} + +func nsRQLiteAPIURL() string { + if url := os.Getenv("ORAMA_API_URL"); url != "" { + return url + } + return auth.GetDefaultGatewayURL() +} + +func nsRQLiteAuthToken() (string, error) { + if token := os.Getenv("ORAMA_TOKEN"); token != "" { + return token, nil + } + + store, err := auth.LoadEnhancedCredentials() + if err != nil { + return "", fmt.Errorf("failed to load credentials: %w", err) + } + + gatewayURL := auth.GetDefaultGatewayURL() + creds := store.GetDefaultCredential(gatewayURL) + if creds == nil { + return "", fmt.Errorf("no credentials found for %s. Run 'orama auth login' to authenticate", gatewayURL) + } + + if !creds.IsValid() { + return "", fmt.Errorf("credentials expired for %s. Run 'orama auth login' to re-authenticate", gatewayURL) + } + + return creds.APIKey, nil +} diff --git a/pkg/deployments/replica_manager.go b/pkg/deployments/replica_manager.go index a34121c..4db6123 100644 --- a/pkg/deployments/replica_manager.go +++ b/pkg/deployments/replica_manager.go @@ -84,7 +84,7 @@ func (rm *ReplicaManager) SelectReplicaNodes(ctx context.Context, primaryNodeID } // CreateReplica inserts a replica record for a deployment on a specific node. -func (rm *ReplicaManager) CreateReplica(ctx context.Context, deploymentID, nodeID string, port int, isPrimary bool) error { +func (rm *ReplicaManager) CreateReplica(ctx context.Context, deploymentID, nodeID string, port int, isPrimary bool, status ReplicaStatus) error { internalCtx := client.WithInternalAuth(ctx) query := ` @@ -98,7 +98,7 @@ func (rm *ReplicaManager) CreateReplica(ctx context.Context, deploymentID, nodeI ` now := time.Now() - _, err := rm.db.Exec(internalCtx, query, deploymentID, nodeID, port, ReplicaStatusActive, isPrimary, now, now) + _, err := rm.db.Exec(internalCtx, query, deploymentID, nodeID, port, status, isPrimary, now, now) if err != nil { return &DeploymentError{ Message: fmt.Sprintf("failed to create replica for deployment %s on node %s", deploymentID, nodeID), @@ -161,7 +161,7 @@ func (rm *ReplicaManager) GetActiveReplicaNodes(ctx context.Context, deploymentI } var rows []nodeRow - query := `SELECT node_id FROM deployment_replicas WHERE deployment_id = ? AND status = ?` + query := `SELECT node_id FROM deployment_replicas WHERE deployment_id = ? AND status = ? AND port > 0` err := rm.db.Query(internalCtx, &rows, query, deploymentID, ReplicaStatusActive) if err != nil { return nil, &DeploymentError{ @@ -259,7 +259,8 @@ func (rm *ReplicaManager) GetNodeIP(ctx context.Context, nodeID string) (string, } var rows []nodeRow - query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1` + // Use public IP for DNS A records (internal/WG IPs are not reachable from the internet) + query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` err := rm.db.Query(internalCtx, &rows, query, nodeID) if err != nil { return "", err diff --git a/pkg/gateway/handlers/deployments/go_handler.go b/pkg/gateway/handlers/deployments/go_handler.go index 9b2a355..a7a29e1 100644 --- a/pkg/gateway/handlers/deployments/go_handler.go +++ b/pkg/gateway/handlers/deployments/go_handler.go @@ -119,14 +119,6 @@ func (h *GoHandler) HandleUpload(w http.ResponseWriter, r *http.Request) { return } - // Create DNS records (use background context since HTTP context will be cancelled) - go func() { - if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil { - h.logger.Error("Background DNS creation failed", - zap.String("deployment", deployment.Name), zap.Error(err)) - } - }() - // Build response urls := h.service.BuildDeploymentURLs(deployment) diff --git a/pkg/gateway/handlers/deployments/nextjs_handler.go b/pkg/gateway/handlers/deployments/nextjs_handler.go index 611bab6..8bee467 100644 --- a/pkg/gateway/handlers/deployments/nextjs_handler.go +++ b/pkg/gateway/handlers/deployments/nextjs_handler.go @@ -125,14 +125,6 @@ func (h *NextJSHandler) HandleUpload(w http.ResponseWriter, r *http.Request) { } } - // Create DNS records (use background context since HTTP context will be cancelled) - go func() { - if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil { - h.logger.Error("Background DNS creation failed", - zap.String("deployment", deployment.Name), zap.Error(err)) - } - }() - // Build response urls := h.service.BuildDeploymentURLs(deployment) diff --git a/pkg/gateway/handlers/deployments/nodejs_handler.go b/pkg/gateway/handlers/deployments/nodejs_handler.go index 284ee52..38d301b 100644 --- a/pkg/gateway/handlers/deployments/nodejs_handler.go +++ b/pkg/gateway/handlers/deployments/nodejs_handler.go @@ -111,14 +111,6 @@ func (h *NodeJSHandler) HandleUpload(w http.ResponseWriter, r *http.Request) { return } - // Create DNS records (use background context since HTTP context will be cancelled) - go func() { - if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil { - h.logger.Error("Background DNS creation failed", - zap.String("deployment", deployment.Name), zap.Error(err)) - } - }() - // Build response urls := h.service.BuildDeploymentURLs(deployment) diff --git a/pkg/gateway/handlers/deployments/replica_handler.go b/pkg/gateway/handlers/deployments/replica_handler.go index 93c04c7..8cda600 100644 --- a/pkg/gateway/handlers/deployments/replica_handler.go +++ b/pkg/gateway/handlers/deployments/replica_handler.go @@ -99,6 +99,16 @@ func (h *ReplicaHandler) HandleSetup(w http.ResponseWriter, r *http.Request) { return } + // Release port if setup fails after this point + setupOK := false + defer func() { + if !setupOK { + if deallocErr := h.service.portAllocator.DeallocatePort(ctx, req.DeploymentID); deallocErr != nil { + h.logger.Error("Failed to deallocate port after setup failure", zap.Error(deallocErr)) + } + } + }() + // Create the deployment directory deployPath := filepath.Join(h.baseDeployPath, req.Namespace, req.Name) if err := os.MkdirAll(deployPath, 0755); err != nil { @@ -152,6 +162,8 @@ func (h *ReplicaHandler) HandleSetup(w http.ResponseWriter, r *http.Request) { return } + setupOK = true + // Wait for health check if err := h.processManager.WaitForHealthy(ctx, deployment, 90*time.Second); err != nil { h.logger.Warn("Replica did not become healthy", zap.Error(err)) @@ -159,7 +171,7 @@ func (h *ReplicaHandler) HandleSetup(w http.ResponseWriter, r *http.Request) { // Update replica record to active with the port if h.service.replicaManager != nil { - h.service.replicaManager.CreateReplica(ctx, req.DeploymentID, h.service.nodePeerID, port, false) + h.service.replicaManager.CreateReplica(ctx, req.DeploymentID, h.service.nodePeerID, port, false, deployments.ReplicaStatusActive) } resp := map[string]interface{}{ @@ -384,6 +396,11 @@ func (h *ReplicaHandler) HandleTeardown(w http.ResponseWriter, r *http.Request) h.logger.Warn("Failed to remove replica files", zap.Error(err)) } + // Deallocate the port + if err := h.service.portAllocator.DeallocatePort(ctx, req.DeploymentID); err != nil { + h.logger.Warn("Failed to deallocate port during teardown", zap.Error(err)) + } + // Update replica status if h.service.replicaManager != nil { h.service.replicaManager.UpdateReplicaStatus(ctx, req.DeploymentID, h.service.nodePeerID, deployments.ReplicaStatusRemoving) diff --git a/pkg/gateway/handlers/deployments/service.go b/pkg/gateway/handlers/deployments/service.go index 06f7581..906c482 100644 --- a/pkg/gateway/handlers/deployments/service.go +++ b/pkg/gateway/handlers/deployments/service.go @@ -270,7 +270,7 @@ func (s *DeploymentService) createDeploymentReplicas(ctx context.Context, deploy primaryNodeID := deployment.HomeNodeID // Register the primary replica - if err := s.replicaManager.CreateReplica(ctx, deployment.ID, primaryNodeID, deployment.Port, true); err != nil { + if err := s.replicaManager.CreateReplica(ctx, deployment.ID, primaryNodeID, deployment.Port, true, deployments.ReplicaStatusActive); err != nil { s.logger.Error("Failed to create primary replica record", zap.String("deployment_id", deployment.ID), zap.Error(err), @@ -278,6 +278,23 @@ func (s *DeploymentService) createDeploymentReplicas(ctx context.Context, deploy return } + // Create DNS record for the home node (synchronous, before replicas) + dnsName := deployment.Subdomain + if dnsName == "" { + dnsName = deployment.Name + } + fqdn := fmt.Sprintf("%s.%s.", dnsName, s.BaseDomain()) + if nodeIP, err := s.getNodeIP(ctx, deployment.HomeNodeID); err != nil { + s.logger.Error("Failed to get home node IP for DNS", zap.String("node_id", deployment.HomeNodeID), zap.Error(err)) + } else if err := s.createDNSRecord(ctx, fqdn, "A", nodeIP, deployment.Namespace, deployment.ID); err != nil { + s.logger.Error("Failed to create DNS record for home node", zap.Error(err)) + } else { + s.logger.Info("Created DNS record for home node", + zap.String("fqdn", fqdn), + zap.String("ip", nodeIP), + ) + } + // Select a secondary node secondaryNodes, err := s.replicaManager.SelectReplicaNodes(ctx, primaryNodeID, deployments.DefaultReplicaCount-1) if err != nil { @@ -302,12 +319,17 @@ func (s *DeploymentService) createDeploymentReplicas(ctx context.Context, deploy if isStatic { // Static deployments: content is in IPFS, no process to start - if err := s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, 0, false); err != nil { + if err := s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, 0, false, deployments.ReplicaStatusActive); err != nil { s.logger.Error("Failed to create static replica", zap.String("deployment_id", deployment.ID), zap.String("node_id", nodeID), zap.Error(err), ) + } else { + // Create DNS record for static replica + if nodeIP, err := s.replicaManager.GetNodeIP(ctx, nodeID); err == nil { + s.createDNSRecord(ctx, fqdn, "A", nodeIP, deployment.Namespace, deployment.ID) + } } } else { // Dynamic deployments: fan out to the secondary node to set up the process @@ -328,7 +350,7 @@ func (s *DeploymentService) setupDynamicReplica(ctx context.Context, deployment } // Create the replica record in pending status - if err := s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, 0, false); err != nil { + if err := s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, 0, false, deployments.ReplicaStatusPending); err != nil { s.logger.Error("Failed to create pending replica record", zap.String("deployment_id", deployment.ID), zap.String("node_id", nodeID), @@ -368,13 +390,22 @@ func (s *DeploymentService) setupDynamicReplica(ctx context.Context, deployment } // Update replica with allocated port - if port, ok := resp["port"].(float64); ok && port > 0 { - s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, int(port), false) + port, ok := resp["port"].(float64) + if !ok || port <= 0 { + s.logger.Error("Replica setup returned invalid port", + zap.String("deployment_id", deployment.ID), + zap.String("node_id", nodeID), + zap.Any("port_value", resp["port"]), + ) + s.replicaManager.UpdateReplicaStatus(ctx, deployment.ID, nodeID, deployments.ReplicaStatusFailed) + return } + s.replicaManager.CreateReplica(ctx, deployment.ID, nodeID, int(port), false, deployments.ReplicaStatusActive) s.logger.Info("Dynamic replica set up on remote node", zap.String("deployment_id", deployment.ID), zap.String("node_id", nodeID), + zap.Int("port", int(port)), ) // Create DNS record for the replica node (after successful setup) @@ -653,8 +684,8 @@ func (s *DeploymentService) getNodeIP(ctx context.Context, nodeID string) (strin var rows []nodeRow - // Try full node ID first (prefer internal/WG IP for cross-node communication) - query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1` + // Use public IP for DNS A records (internal/WG IPs are not reachable from the internet) + query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` err := s.db.Query(ctx, &rows, query, nodeID) if err != nil { return "", err diff --git a/pkg/gateway/handlers/deployments/static_handler.go b/pkg/gateway/handlers/deployments/static_handler.go index bbf9069..7a1b909 100644 --- a/pkg/gateway/handlers/deployments/static_handler.go +++ b/pkg/gateway/handlers/deployments/static_handler.go @@ -154,14 +154,6 @@ func (h *StaticDeploymentHandler) HandleUpload(w http.ResponseWriter, r *http.Re return } - // Create DNS records (use background context since HTTP context will be cancelled) - go func() { - if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil { - h.logger.Error("Background DNS creation failed", - zap.String("deployment", deployment.Name), zap.Error(err)) - } - }() - // Build URLs urls := h.service.BuildDeploymentURLs(deployment) diff --git a/pkg/gateway/handlers/serverless/invoke_handler.go b/pkg/gateway/handlers/serverless/invoke_handler.go index 809ad84..1bdb067 100644 --- a/pkg/gateway/handlers/serverless/invoke_handler.go +++ b/pkg/gateway/handlers/serverless/invoke_handler.go @@ -70,6 +70,13 @@ func (h *ServerlessHandlers) InvokeFunction(w http.ResponseWriter, r *http.Reque statusCode = http.StatusUnauthorized } + if resp == nil { + writeJSON(w, statusCode, map[string]interface{}{ + "error": err.Error(), + }) + return + } + writeJSON(w, statusCode, map[string]interface{}{ "request_id": resp.RequestID, "status": resp.Status, diff --git a/pkg/gateway/routes.go b/pkg/gateway/routes.go index 19530f6..63831f1 100644 --- a/pkg/gateway/routes.go +++ b/pkg/gateway/routes.go @@ -67,6 +67,10 @@ func (g *Gateway) Routes() http.Handler { mux.HandleFunc("/v1/auth/phantom/complete", g.authHandlers.PhantomCompleteHandler) } + // RQLite native backup/restore proxy (namespace auth via /v1/rqlite/ prefix) + mux.HandleFunc("/v1/rqlite/export", g.rqliteExportHandler) + mux.HandleFunc("/v1/rqlite/import", g.rqliteImportHandler) + // rqlite ORM HTTP gateway (mounts /v1/rqlite/* endpoints) if g.ormHTTP != nil { g.ormHTTP.BasePath = "/v1/rqlite" diff --git a/pkg/gateway/rqlite_backup_handler.go b/pkg/gateway/rqlite_backup_handler.go new file mode 100644 index 0000000..fb11bba --- /dev/null +++ b/pkg/gateway/rqlite_backup_handler.go @@ -0,0 +1,133 @@ +package gateway + +import ( + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// rqliteExportHandler handles GET /v1/rqlite/export +// Proxies to the namespace's RQLite /db/backup endpoint to download a raw SQLite snapshot. +// Protected by requiresNamespaceOwnership() via the /v1/rqlite/ prefix. +func (g *Gateway) rqliteExportHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + rqliteURL := g.rqliteBaseURL() + if rqliteURL == "" { + writeError(w, http.StatusServiceUnavailable, "RQLite not configured") + return + } + + backupURL := rqliteURL + "/db/backup" + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Get(backupURL) + if err != nil { + g.logger.ComponentError(logging.ComponentGeneral, "rqlite export: failed to reach RQLite backup endpoint", + zap.String("url", backupURL), zap.Error(err)) + writeError(w, http.StatusBadGateway, fmt.Sprintf("failed to reach RQLite: %v", err)) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + writeError(w, resp.StatusCode, fmt.Sprintf("RQLite backup failed: %s", string(body))) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", "attachment; filename=rqlite-export.db") + if resp.ContentLength > 0 { + w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) + } + w.WriteHeader(http.StatusOK) + + written, err := io.Copy(w, resp.Body) + if err != nil { + g.logger.ComponentError(logging.ComponentGeneral, "rqlite export: error streaming backup", + zap.Int64("bytes_written", written), zap.Error(err)) + return + } + + g.logger.ComponentInfo(logging.ComponentGeneral, "rqlite export completed", zap.Int64("bytes", written)) +} + +// rqliteImportHandler handles POST /v1/rqlite/import +// Proxies the request body (raw SQLite binary) to the namespace's RQLite /db/load endpoint. +// This is a DESTRUCTIVE operation that replaces the entire database. +// Protected by requiresNamespaceOwnership() via the /v1/rqlite/ prefix. +func (g *Gateway) rqliteImportHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + rqliteURL := g.rqliteBaseURL() + if rqliteURL == "" { + writeError(w, http.StatusServiceUnavailable, "RQLite not configured") + return + } + + ct := r.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/octet-stream") { + writeError(w, http.StatusBadRequest, "Content-Type must be application/octet-stream") + return + } + + loadURL := rqliteURL + "/db/load" + + proxyReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, loadURL, r.Body) + if err != nil { + writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to create proxy request: %v", err)) + return + } + proxyReq.Header.Set("Content-Type", "application/octet-stream") + if r.ContentLength > 0 { + proxyReq.ContentLength = r.ContentLength + } + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Do(proxyReq) + if err != nil { + g.logger.ComponentError(logging.ComponentGeneral, "rqlite import: failed to reach RQLite load endpoint", + zap.String("url", loadURL), zap.Error(err)) + writeError(w, http.StatusBadGateway, fmt.Sprintf("failed to reach RQLite: %v", err)) + return + } + defer resp.Body.Close() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + + if resp.StatusCode != http.StatusOK { + writeError(w, resp.StatusCode, fmt.Sprintf("RQLite load failed: %s", string(body))) + return + } + + g.logger.ComponentInfo(logging.ComponentGeneral, "rqlite import completed successfully") + + writeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "message": "database imported successfully", + }) +} + +// rqliteBaseURL returns the raw RQLite HTTP URL for proxying native API calls. +func (g *Gateway) rqliteBaseURL() string { + dsn := g.cfg.RQLiteDSN + if dsn == "" { + dsn = "http://localhost:5001" + } + if idx := strings.Index(dsn, "?"); idx != -1 { + dsn = dsn[:idx] + } + return strings.TrimRight(dsn, "/") +} diff --git a/pkg/gateway/rqlite_backup_handler_test.go b/pkg/gateway/rqlite_backup_handler_test.go new file mode 100644 index 0000000..d5cf12c --- /dev/null +++ b/pkg/gateway/rqlite_backup_handler_test.go @@ -0,0 +1,214 @@ +package gateway + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/DeBrosOfficial/network/pkg/logging" +) + +func newRQLiteTestLogger() *logging.ColoredLogger { + l, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + return l +} + +func TestRqliteBaseURL(t *testing.T) { + tests := []struct { + name string + dsn string + want string + }{ + {"empty defaults to localhost:5001", "", "http://localhost:5001"}, + {"simple URL", "http://10.0.0.1:10000", "http://10.0.0.1:10000"}, + {"strips query params", "http://10.0.0.1:10000?foo=bar", "http://10.0.0.1:10000"}, + {"strips trailing slash", "http://10.0.0.1:10000/", "http://10.0.0.1:10000"}, + {"strips both", "http://10.0.0.1:10000/?foo=bar", "http://10.0.0.1:10000"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gw := &Gateway{cfg: &Config{RQLiteDSN: tt.dsn}} + got := gw.rqliteBaseURL() + if got != tt.want { + t.Errorf("rqliteBaseURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestRqliteExportHandler_MethodNotAllowed(t *testing.T) { + gw := &Gateway{cfg: &Config{RQLiteDSN: "http://localhost:5001"}} + + for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodDelete} { + req := httptest.NewRequest(method, "/v1/rqlite/export", nil) + rr := httptest.NewRecorder() + gw.rqliteExportHandler(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("method %s: got status %d, want %d", method, rr.Code, http.StatusMethodNotAllowed) + } + } +} + +func TestRqliteExportHandler_Success(t *testing.T) { + backupData := "fake-sqlite-binary-data" + + mockRQLite := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/db/backup" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method != http.MethodGet { + t.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte(backupData)) + })) + defer mockRQLite.Close() + + gw := &Gateway{ + cfg: &Config{RQLiteDSN: mockRQLite.URL}, + logger: newRQLiteTestLogger(), + } + + req := httptest.NewRequest(http.MethodGet, "/v1/rqlite/export", nil) + rr := httptest.NewRecorder() + gw.rqliteExportHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("got status %d, want %d, body: %s", rr.Code, http.StatusOK, rr.Body.String()) + } + + if ct := rr.Header().Get("Content-Type"); ct != "application/octet-stream" { + t.Errorf("Content-Type = %q, want application/octet-stream", ct) + } + + if cd := rr.Header().Get("Content-Disposition"); !strings.Contains(cd, "rqlite-export.db") { + t.Errorf("Content-Disposition = %q, want to contain 'rqlite-export.db'", cd) + } + + if body := rr.Body.String(); body != backupData { + t.Errorf("body = %q, want %q", body, backupData) + } +} + +func TestRqliteExportHandler_RQLiteError(t *testing.T) { + mockRQLite := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("rqlite internal error")) + })) + defer mockRQLite.Close() + + gw := &Gateway{ + cfg: &Config{RQLiteDSN: mockRQLite.URL}, + logger: newRQLiteTestLogger(), + } + + req := httptest.NewRequest(http.MethodGet, "/v1/rqlite/export", nil) + rr := httptest.NewRecorder() + gw.rqliteExportHandler(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("got status %d, want %d", rr.Code, http.StatusInternalServerError) + } +} + +func TestRqliteImportHandler_MethodNotAllowed(t *testing.T) { + gw := &Gateway{cfg: &Config{RQLiteDSN: "http://localhost:5001"}} + + for _, method := range []string{http.MethodGet, http.MethodPut, http.MethodDelete} { + req := httptest.NewRequest(method, "/v1/rqlite/import", nil) + rr := httptest.NewRecorder() + gw.rqliteImportHandler(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("method %s: got status %d, want %d", method, rr.Code, http.StatusMethodNotAllowed) + } + } +} + +func TestRqliteImportHandler_WrongContentType(t *testing.T) { + gw := &Gateway{cfg: &Config{RQLiteDSN: "http://localhost:5001"}} + + req := httptest.NewRequest(http.MethodPost, "/v1/rqlite/import", strings.NewReader("data")) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + gw.rqliteImportHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rr.Code, http.StatusBadRequest) + } +} + +func TestRqliteImportHandler_Success(t *testing.T) { + importData := "fake-sqlite-binary-data" + var receivedBody string + + mockRQLite := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/db/load" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + t.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if ct := r.Header.Get("Content-Type"); ct != "application/octet-stream" { + t.Errorf("Content-Type = %q, want application/octet-stream", ct) + } + body, _ := io.ReadAll(r.Body) + receivedBody = string(body) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer mockRQLite.Close() + + gw := &Gateway{ + cfg: &Config{RQLiteDSN: mockRQLite.URL}, + logger: newRQLiteTestLogger(), + } + + req := httptest.NewRequest(http.MethodPost, "/v1/rqlite/import", strings.NewReader(importData)) + req.Header.Set("Content-Type", "application/octet-stream") + rr := httptest.NewRecorder() + gw.rqliteImportHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("got status %d, want %d, body: %s", rr.Code, http.StatusOK, rr.Body.String()) + } + + if receivedBody != importData { + t.Errorf("RQLite received body %q, want %q", receivedBody, importData) + } +} + +func TestRqliteImportHandler_RQLiteError(t *testing.T) { + mockRQLite := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("load failed")) + })) + defer mockRQLite.Close() + + gw := &Gateway{ + cfg: &Config{RQLiteDSN: mockRQLite.URL}, + logger: newRQLiteTestLogger(), + } + + req := httptest.NewRequest(http.MethodPost, "/v1/rqlite/import", strings.NewReader("data")) + req.Header.Set("Content-Type", "application/octet-stream") + rr := httptest.NewRecorder() + gw.rqliteImportHandler(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("got status %d, want %d", rr.Code, http.StatusInternalServerError) + } +}