Merge pull request #76 from DeBrosOfficial/0.80.0

0.80.0
This commit is contained in:
anonpenguin 2026-01-05 20:00:41 +02:00 committed by GitHub
commit d34404ec87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 12900 additions and 7287 deletions

92
.cursor/rules/network.mdc Normal file
View File

@ -0,0 +1,92 @@
---
alwaysApply: true
---
# AI Instructions
You have access to the **network** MCP (Model Context Protocol) server for this project. This MCP provides deep, pre-analyzed context about the codebase that is far more accurate than default file searching.
## IMPORTANT: Always Use MCP First
**Before making any code changes or answering questions about this codebase, ALWAYS consult the MCP tools first.**
The MCP has pre-indexed the entire codebase with semantic understanding, embeddings, and structural analysis. While you can use your own file search capabilities, the MCP provides much better context because:
- It understands code semantics, not just text matching
- It has pre-analyzed the architecture, patterns, and relationships
- It can answer questions about intent and purpose, not just content
## Available MCP Tools
### Code Understanding
- `network_ask_question` - Ask natural language questions about the codebase. Use this for "how does X work?", "where is Y implemented?", "what does Z do?" questions. The MCP will search relevant code and provide informed answers.
- `network_search_code` - Semantic code search. Find code by meaning, not just text. Great for finding implementations, patterns, or related functionality.
- `network_get_architecture` - Get the full project architecture overview including tech stack, design patterns, domain entities, and API endpoints.
- `network_get_file_summary` - Get a detailed summary of what a specific file does, its purpose, exports, and responsibilities.
- `network_find_function` - Find a specific function or method definition by name across the codebase.
- `network_list_functions` - List all functions defined in a specific file.
### Skills (Learned Procedures)
Skills are reusable procedures that the agent has learned about this specific project (e.g., "how to deploy", "how to run tests", "how to add a new API endpoint").
- `network_list_skills` - List all learned skills for this project.
- `network_get_skill` - Get detailed information about a specific skill including its step-by-step procedure.
- `network_execute_skill` - Get the procedure for a learned skill so you can execute it step by step. Returns prerequisites, warnings, and commands to run.
- `network_learn_skill` - Teach the agent a new skill. The agent will explore, discover, and memorize how to perform this task.
- `network_get_learning_status` - Check the status of an ongoing skill learning session.
- `network_answer_question` - Answer a question that the learning agent asked during skill learning.
- `network_cancel_learning` - Cancel an active learning session.
- `network_forget_skill` - Delete a learned skill.
- `network_update_skill` - Update a learned skill with corrections or new information (e.g., 'Use .env.prod instead of .env', 'Add step to backup database first', 'The port should be 3000 not 8080').
#### Skill Learning Workflow (IMPORTANT)
When learning a skill, follow this **collaborative, goal-oriented workflow**. You (Cursor) are the executor, the MCP agent provides guidance:
**Goal-Oriented Learning**: The agent identifies specific GOALS (pieces of information to gather) and tracks progress by goal completion, not by iterations.
1. **Start Learning**: Call `learn_skill` with name and detailed description
2. **Monitor Progress**: Call `get_learning_status` to check progress
3. **Handle Status Responses**:
- `active` → Learning in progress, check again in a few seconds
- `waiting_input` → The agent has a question. Read it and call `answer_question` with your response
- `waiting_execution` → **IMPORTANT**: The agent needs you to run a command!
- Read the `pendingExecution.command` from the response
- **Execute the command yourself** using your terminal access
- Call `answer_question` with the command output
- `completed` → Skill learned successfully!
- `failed` → Check errors and try again
4. **Repeat** steps 2-3 until status is `completed`
**Key Insight**: The MCP agent runs on the server and cannot SSH to remote servers directly. When it needs remote access, it generates the SSH command for YOU to execute. You have terminal access - use it!
**User Override Commands**: If the agent gets stuck, you can include these keywords in your answer:
- `COMPLETE` or `SKIP` - Skip to synthesis phase and generate the skill from current data
- `PHASE:synthesizing` - Force transition to drafting phase
- `GOAL:goal_id=value` - Directly provide a goal's value (e.g., `GOAL:cluster_secret=abc123`)
- `I have provided X` - Tell the agent it already has certain information
**Example for `waiting_execution`**:
```
// Status response shows:
// pendingExecution: { command: "ssh root@192.168.1.1 'ls -la /home/user/.orama'" }
//
// You should:
// 1. Run the command in your terminal
// 2. Get the output
// 3. Call answer_question with the output
```
## Recommended Workflow
1. **For questions:** Use `network_ask_question` or `network_search_code` to understand the codebase.
---
# DeBros Network Gateway
This project is a high-performance, edge-focused API gateway and reverse proxy designed to bridge decentralized web services with standard HTTP clients. It serves as a comprehensive middleware layer that facilitates wallet-based authentication, distributed caching via Olric, IPFS storage interaction, and serverless execution of WebAssembly (Wasm) functions. Additionally, it provides specialized utility services such as push notifications and an anonymizing proxy, ensuring secure and monitored communication between users and decentralized infrastructure.
**Architecture:** API Gateway / Edge Middleware
## Tech Stack
- **backend:** Go

View File

@ -1,68 +0,0 @@
// Project-local debug tasks
//
// For more documentation on how to configure debug tasks,
// see: https://zed.dev/docs/debugger
[
{
"label": "Gateway Go (Delve)",
"adapter": "Delve",
"request": "launch",
"mode": "debug",
"program": "./cmd/gateway",
"env": {
"GATEWAY_ADDR": ":6001",
"GATEWAY_BOOTSTRAP_PEERS": "/ip4/localhost/tcp/4001/p2p/12D3KooWSHHwEY6cga3ng7tD1rzStAU58ogQXVMX3LZJ6Gqf6dee",
"GATEWAY_NAMESPACE": "default",
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
}
},
{
"label": "E2E Test Go (Delve)",
"adapter": "Delve",
"request": "launch",
"mode": "test",
"buildFlags": "-tags e2e",
"program": "./e2e",
"env": {
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
},
"args": ["-test.v"]
},
{
"adapter": "Delve",
"label": "Gateway Go 6001 Port (Delve)",
"request": "launch",
"mode": "debug",
"program": "./cmd/gateway",
"env": {
"GATEWAY_ADDR": ":6001",
"GATEWAY_BOOTSTRAP_PEERS": "/ip4/localhost/tcp/4001/p2p/12D3KooWSHHwEY6cga3ng7tD1rzStAU58ogQXVMX3LZJ6Gqf6dee",
"GATEWAY_NAMESPACE": "default",
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
}
},
{
"adapter": "Delve",
"label": "Network CLI - peers (Delve)",
"request": "launch",
"mode": "debug",
"program": "./cmd/cli",
"args": ["peers"]
},
{
"adapter": "Delve",
"label": "Network CLI - PubSub Subscribe (Delve)",
"request": "launch",
"mode": "debug",
"program": "./cmd/cli",
"args": ["pubsub", "subscribe", "monitoring"]
},
{
"adapter": "Delve",
"label": "Node Go (Delve)",
"request": "launch",
"mode": "debug",
"program": "./cmd/node",
"args": ["--config", "configs/node.yaml"]
}
]

View File

@ -13,6 +13,59 @@ The format is based on [Keep a Changelog][keepachangelog] and adheres to [Semant
### Deprecated ### Deprecated
### Fixed ### Fixed
## [0.82.0] - 2026-01-03
### Added
- Added PubSub Presence feature, allowing clients to track members connected to a topic via WebSocket.
- Added a new tool, `rqlite-mcp`, which implements the Model Communication Protocol (MCP) for Rqlite, enabling AI models to interact with the database using tools.
### Changed
- Updated the development environment to include and manage the new `rqlite-mcp` service.
### Deprecated
### Removed
### Fixed
\n
## [0.81.0] - 2025-12-31
### Added
- Implemented a new, robust authentication service within the Gateway for handling wallet-based challenges, signature verification (ETH/SOL), JWT issuance, and API key management.
- Introduced automatic recovery logic for RQLite to detect and recover from split-brain scenarios and ensure cluster stability during restarts.
### Changed
- Refactored the production installation command (`dbn prod install`) by moving installer logic and utility functions into a dedicated `pkg/cli/utils` package for better modularity and maintainability.
- Reworked the core logic for starting and managing LibP2P, RQLite, and the HTTP Gateway within the Node, including improved peer reconnection and cluster configuration synchronization.
### Deprecated
### Removed
### Fixed
- Corrected IPFS Cluster configuration logic to properly handle port assignments and ensure correct IPFS API addresses are used, resolving potential connection issues between cluster components.
## [0.80.0] - 2025-12-29
### Added
- Implemented the core Serverless Functions Engine, allowing users to deploy and execute WASM-based functions (e.g., Go compiled with TinyGo).
- Added new database migration (004) to support serverless functions, including tables for functions, secrets, cron triggers, database triggers, pubsub triggers, timers, jobs, and invocation logs.
- Added new API endpoints for managing and invoking serverless functions (`/v1/functions`, `/v1/invoke`, `/v1/functions/{name}/invoke`, `/v1/functions/{name}/ws`).
- Introduced `WSPubSubClient` for E2E testing of WebSocket PubSub functionality.
- Added examples and a build script for creating WASM serverless functions (Echo, Hello, Counter).
### Changed
- Updated Go version requirement from 1.23.8 to 1.24.0 in `go.mod`.
- Refactored RQLite client to improve data type handling and conversion, especially for `sql.Null*` types and number parsing.
- Improved RQLite cluster discovery logic to safely handle new nodes joining an existing cluster without clearing existing Raft state unless necessary (log index 0).
### Deprecated
### Removed
### Fixed
- Corrected an issue in the `install` command where dry-run summaries were missing newlines.
## [0.72.1] - 2025-12-09 ## [0.72.1] - 2025-12-09
### Added ### Added

View File

@ -19,7 +19,7 @@ test-e2e:
.PHONY: build clean test run-node run-node2 run-node3 run-example deps tidy fmt vet lint clear-ports install-hooks kill .PHONY: build clean test run-node run-node2 run-node3 run-example deps tidy fmt vet lint clear-ports install-hooks kill
VERSION := 0.72.1 VERSION := 0.82.0
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ) DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)' LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
@ -31,6 +31,7 @@ build: deps
go build -ldflags "$(LDFLAGS)" -o bin/identity ./cmd/identity go build -ldflags "$(LDFLAGS)" -o bin/identity ./cmd/identity
go build -ldflags "$(LDFLAGS)" -o bin/orama-node ./cmd/node go build -ldflags "$(LDFLAGS)" -o bin/orama-node ./cmd/node
go build -ldflags "$(LDFLAGS)" -o bin/orama cmd/cli/main.go go build -ldflags "$(LDFLAGS)" -o bin/orama cmd/cli/main.go
go build -ldflags "$(LDFLAGS)" -o bin/rqlite-mcp ./cmd/rqlite-mcp
# Inject gateway build metadata via pkg path variables # Inject gateway build metadata via pkg path variables
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
@echo "Build complete! Run ./bin/orama version" @echo "Build complete! Run ./bin/orama version"
@ -71,14 +72,9 @@ run-gateway:
@echo "Note: Config must be in ~/.orama/data/gateway.yaml" @echo "Note: Config must be in ~/.orama/data/gateway.yaml"
go run ./cmd/orama-gateway go run ./cmd/orama-gateway
# Setup local domain names for development
setup-domains:
@echo "Setting up local domains..."
@sudo bash scripts/setup-local-domains.sh
# Development environment target # Development environment target
# Uses orama dev up to start full stack with dependency and port checking # Uses orama dev up to start full stack with dependency and port checking
dev: build setup-domains dev: build
@./bin/orama dev up @./bin/orama dev up
# Graceful shutdown of all dev services # Graceful shutdown of all dev services

View File

@ -26,27 +26,25 @@ make stop
After running `make dev`, test service health using these curl requests: After running `make dev`, test service health using these curl requests:
> **Note:** Local domains (node-1.local, etc.) require running `sudo make setup-domains` first. Alternatively, use `localhost` with port numbers.
### Node Unified Gateways ### Node Unified Gateways
Each node is accessible via a single unified gateway port: Each node is accessible via a single unified gateway port:
```bash ```bash
# Node-1 (port 6001) # Node-1 (port 6001)
curl http://node-1.local:6001/health curl http://localhost:6001/health
# Node-2 (port 6002) # Node-2 (port 6002)
curl http://node-2.local:6002/health curl http://localhost:6002/health
# Node-3 (port 6003) # Node-3 (port 6003)
curl http://node-3.local:6003/health curl http://localhost:6003/health
# Node-4 (port 6004) # Node-4 (port 6004)
curl http://node-4.local:6004/health curl http://localhost:6004/health
# Node-5 (port 6005) # Node-5 (port 6005)
curl http://node-5.local:6005/health curl http://localhost:6005/health
``` ```
## Network Architecture ## Network Architecture
@ -129,6 +127,54 @@ make build
./bin/orama auth logout ./bin/orama auth logout
``` ```
## Serverless Functions (WASM)
Orama supports high-performance serverless function execution using WebAssembly (WASM). Functions are isolated, secure, and can interact with network services like the distributed cache.
### 1. Build Functions
Functions must be compiled to WASM. We recommend using [TinyGo](https://tinygo.org/).
```bash
# Build example functions to examples/functions/bin/
./examples/functions/build.sh
```
### 2. Deployment
Deploy your compiled `.wasm` file to the network via the Gateway.
```bash
# Deploy a function
curl -X POST http://localhost:6001/v1/functions \
-H "Authorization: Bearer <your_api_key>" \
-F "name=hello-world" \
-F "namespace=default" \
-F "wasm=@./examples/functions/bin/hello.wasm"
```
### 3. Invocation
Trigger your function with a JSON payload. The function receives the payload via `stdin` and returns its response via `stdout`.
```bash
# Invoke via HTTP
curl -X POST http://localhost:6001/v1/functions/hello-world/invoke \
-H "Authorization: Bearer <your_api_key>" \
-H "Content-Type: application/json" \
-d '{"name": "Developer"}'
```
### 4. Management
```bash
# List all functions in a namespace
curl http://localhost:6001/v1/functions?namespace=default
# Delete a function
curl -X DELETE http://localhost:6001/v1/functions/hello-world?namespace=default
```
## Production Deployment ## Production Deployment
### Prerequisites ### Prerequisites
@ -262,6 +308,11 @@ sudo orama install
- `POST /v1/pubsub/publish` - Publish message - `POST /v1/pubsub/publish` - Publish message
- `GET /v1/pubsub/topics` - List topics - `GET /v1/pubsub/topics` - List topics
- `GET /v1/pubsub/ws?topic=<name>` - WebSocket subscribe - `GET /v1/pubsub/ws?topic=<name>` - WebSocket subscribe
- `POST /v1/functions` - Deploy function (multipart/form-data)
- `POST /v1/functions/{name}/invoke` - Invoke function
- `GET /v1/functions` - List functions
- `DELETE /v1/functions/{name}` - Delete function
- `GET /v1/functions/{name}/logs` - Get function logs
See `openapi/gateway.yaml` for complete API specification. See `openapi/gateway.yaml` for complete API specification.

320
cmd/rqlite-mcp/main.go Normal file
View File

@ -0,0 +1,320 @@
package main
import (
"bufio"
"encoding/json"
"fmt"
"log"
"os"
"strings"
"time"
"github.com/rqlite/gorqlite"
)
// MCP JSON-RPC types
type JSONRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error *ResponseError `json:"error,omitempty"`
}
type ResponseError struct {
Code int `json:"code"`
Message string `json:"message"`
}
// Tool definition
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema any `json:"inputSchema"`
}
// Tool call types
type CallToolRequest struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
type TextContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type CallToolResult struct {
Content []TextContent `json:"content"`
IsError bool `json:"isError,omitempty"`
}
type MCPServer struct {
conn *gorqlite.Connection
}
func NewMCPServer(rqliteURL string) (*MCPServer, error) {
conn, err := gorqlite.Open(rqliteURL)
if err != nil {
return nil, err
}
return &MCPServer{
conn: conn,
}, nil
}
func (s *MCPServer) handleRequest(req JSONRPCRequest) JSONRPCResponse {
var resp JSONRPCResponse
resp.JSONRPC = "2.0"
resp.ID = req.ID
// Debug logging disabled to prevent excessive disk writes
// log.Printf("Received method: %s", req.Method)
switch req.Method {
case "initialize":
resp.Result = map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": "rqlite-mcp",
"version": "0.1.0",
},
}
case "notifications/initialized":
// This is a notification, no response needed
return JSONRPCResponse{}
case "tools/list":
// Debug logging disabled to prevent excessive disk writes
tools := []Tool{
{
Name: "list_tables",
Description: "List all tables in the Rqlite database",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{},
},
},
{
Name: "query",
Description: "Run a SELECT query on the Rqlite database",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"sql": map[string]any{
"type": "string",
"description": "The SQL SELECT query to run",
},
},
"required": []string{"sql"},
},
},
{
Name: "execute",
Description: "Run an INSERT, UPDATE, or DELETE statement on the Rqlite database",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"sql": map[string]any{
"type": "string",
"description": "The SQL statement (INSERT, UPDATE, DELETE) to run",
},
},
"required": []string{"sql"},
},
},
}
resp.Result = map[string]any{"tools": tools}
case "tools/call":
var callReq CallToolRequest
if err := json.Unmarshal(req.Params, &callReq); err != nil {
resp.Error = &ResponseError{Code: -32700, Message: "Parse error"}
return resp
}
resp.Result = s.handleToolCall(callReq)
default:
// Debug logging disabled to prevent excessive disk writes
resp.Error = &ResponseError{Code: -32601, Message: "Method not found"}
}
return resp
}
func (s *MCPServer) handleToolCall(req CallToolRequest) CallToolResult {
// Debug logging disabled to prevent excessive disk writes
// log.Printf("Tool call: %s", req.Name)
switch req.Name {
case "list_tables":
rows, err := s.conn.QueryOne("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
if err != nil {
return errorResult(fmt.Sprintf("Error listing tables: %v", err))
}
var tables []string
for rows.Next() {
slice, err := rows.Slice()
if err == nil && len(slice) > 0 {
tables = append(tables, fmt.Sprint(slice[0]))
}
}
if len(tables) == 0 {
return textResult("No tables found")
}
return textResult(strings.Join(tables, "\n"))
case "query":
var args struct {
SQL string `json:"sql"`
}
if err := json.Unmarshal(req.Arguments, &args); err != nil {
return errorResult(fmt.Sprintf("Invalid arguments: %v", err))
}
// Debug logging disabled to prevent excessive disk writes
rows, err := s.conn.QueryOne(args.SQL)
if err != nil {
return errorResult(fmt.Sprintf("Query error: %v", err))
}
var result strings.Builder
cols := rows.Columns()
result.WriteString(strings.Join(cols, " | ") + "\n")
result.WriteString(strings.Repeat("-", len(cols)*10) + "\n")
rowCount := 0
for rows.Next() {
vals, err := rows.Slice()
if err != nil {
continue
}
rowCount++
for i, v := range vals {
if i > 0 {
result.WriteString(" | ")
}
result.WriteString(fmt.Sprint(v))
}
result.WriteString("\n")
}
result.WriteString(fmt.Sprintf("\n(%d rows)", rowCount))
return textResult(result.String())
case "execute":
var args struct {
SQL string `json:"sql"`
}
if err := json.Unmarshal(req.Arguments, &args); err != nil {
return errorResult(fmt.Sprintf("Invalid arguments: %v", err))
}
// Debug logging disabled to prevent excessive disk writes
res, err := s.conn.WriteOne(args.SQL)
if err != nil {
return errorResult(fmt.Sprintf("Execution error: %v", err))
}
return textResult(fmt.Sprintf("Rows affected: %d", res.RowsAffected))
default:
return errorResult(fmt.Sprintf("Unknown tool: %s", req.Name))
}
}
func textResult(text string) CallToolResult {
return CallToolResult{
Content: []TextContent{
{
Type: "text",
Text: text,
},
},
}
}
func errorResult(text string) CallToolResult {
return CallToolResult{
Content: []TextContent{
{
Type: "text",
Text: text,
},
},
IsError: true,
}
}
func main() {
// Log to stderr so stdout is clean for JSON-RPC
log.SetOutput(os.Stderr)
rqliteURL := "http://localhost:5001"
if u := os.Getenv("RQLITE_URL"); u != "" {
rqliteURL = u
}
var server *MCPServer
var err error
// Retry connecting to rqlite
maxRetries := 30
for i := 0; i < maxRetries; i++ {
server, err = NewMCPServer(rqliteURL)
if err == nil {
break
}
if i%5 == 0 {
log.Printf("Waiting for Rqlite at %s... (%d/%d)", rqliteURL, i+1, maxRetries)
}
time.Sleep(1 * time.Second)
}
if err != nil {
log.Fatalf("Failed to connect to Rqlite after %d retries: %v", maxRetries, err)
}
log.Printf("MCP Rqlite server started (stdio transport)")
log.Printf("Connected to Rqlite at %s", rqliteURL)
// Read JSON-RPC requests from stdin, write responses to stdout
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
var req JSONRPCRequest
if err := json.Unmarshal([]byte(line), &req); err != nil {
// Debug logging disabled to prevent excessive disk writes
continue
}
resp := server.handleRequest(req)
// Don't send response for notifications (no ID)
if req.ID == nil && strings.HasPrefix(req.Method, "notifications/") {
continue
}
respData, err := json.Marshal(resp)
if err != nil {
// Debug logging disabled to prevent excessive disk writes
continue
}
fmt.Println(string(respData))
}
if err := scanner.Err(); err != nil {
// Debug logging disabled to prevent excessive disk writes
}
}

View File

@ -6,13 +6,16 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -20,6 +23,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/client" "github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/config" "github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/gorilla/websocket"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"go.uber.org/zap" "go.uber.org/zap"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -135,14 +139,26 @@ func GetRQLiteNodes() []string {
// queryAPIKeyFromRQLite queries the SQLite database directly for an API key // queryAPIKeyFromRQLite queries the SQLite database directly for an API key
func queryAPIKeyFromRQLite() (string, error) { func queryAPIKeyFromRQLite() (string, error) {
// Build database path from bootstrap/node config // 1. Check environment variable first
if envKey := os.Getenv("DEBROS_API_KEY"); envKey != "" {
return envKey, nil
}
// 2. Build database path from bootstrap/node config
homeDir, err := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err) return "", fmt.Errorf("failed to get home directory: %w", err)
} }
// Try all node data directories // Try all node data directories (both production and development paths)
dbPaths := []string{ dbPaths := []string{
// Development paths (~/.orama/node-x/...)
filepath.Join(homeDir, ".orama", "node-1", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "node-2", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "node-3", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "node-4", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "node-5", "rqlite", "db.sqlite"),
// Production paths (~/.orama/data/node-x/...)
filepath.Join(homeDir, ".orama", "data", "node-1", "rqlite", "db.sqlite"), filepath.Join(homeDir, ".orama", "data", "node-1", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "data", "node-2", "rqlite", "db.sqlite"), filepath.Join(homeDir, ".orama", "data", "node-2", "rqlite", "db.sqlite"),
filepath.Join(homeDir, ".orama", "data", "node-3", "rqlite", "db.sqlite"), filepath.Join(homeDir, ".orama", "data", "node-3", "rqlite", "db.sqlite"),
@ -644,3 +660,296 @@ func CleanupCacheEntry(t *testing.T, dmapName, key string) {
t.Logf("warning: delete cache entry returned status %d", status) t.Logf("warning: delete cache entry returned status %d", status)
} }
} }
// ============================================================================
// WebSocket PubSub Client for E2E Tests
// ============================================================================
// WSPubSubClient is a WebSocket-based PubSub client that connects to the gateway
type WSPubSubClient struct {
t *testing.T
conn *websocket.Conn
topic string
handlers []func(topic string, data []byte) error
msgChan chan []byte
doneChan chan struct{}
mu sync.RWMutex
writeMu sync.Mutex // Protects concurrent writes to WebSocket
closed bool
}
// WSPubSubMessage represents a message received from the gateway
type WSPubSubMessage struct {
Data string `json:"data"` // base64 encoded
Timestamp int64 `json:"timestamp"` // unix milliseconds
Topic string `json:"topic"`
}
// NewWSPubSubClient creates a new WebSocket PubSub client connected to a topic
func NewWSPubSubClient(t *testing.T, topic string) (*WSPubSubClient, error) {
t.Helper()
// Build WebSocket URL
gatewayURL := GetGatewayURL()
wsURL := strings.Replace(gatewayURL, "http://", "ws://", 1)
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
u, err := url.Parse(wsURL + "/v1/pubsub/ws")
if err != nil {
return nil, fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
q := u.Query()
q.Set("topic", topic)
u.RawQuery = q.Encode()
// Set up headers with authentication
headers := http.Header{}
if apiKey := GetAPIKey(); apiKey != "" {
headers.Set("Authorization", "Bearer "+apiKey)
}
// Connect to WebSocket
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, resp, err := dialer.Dial(u.String(), headers)
if err != nil {
if resp != nil {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("websocket dial failed (status %d): %w - body: %s", resp.StatusCode, err, string(body))
}
return nil, fmt.Errorf("websocket dial failed: %w", err)
}
client := &WSPubSubClient{
t: t,
conn: conn,
topic: topic,
handlers: make([]func(topic string, data []byte) error, 0),
msgChan: make(chan []byte, 128),
doneChan: make(chan struct{}),
}
// Start reader goroutine
go client.readLoop()
return client, nil
}
// NewWSPubSubPresenceClient creates a new WebSocket PubSub client with presence parameters
func NewWSPubSubPresenceClient(t *testing.T, topic, memberID string, meta map[string]interface{}) (*WSPubSubClient, error) {
t.Helper()
// Build WebSocket URL
gatewayURL := GetGatewayURL()
wsURL := strings.Replace(gatewayURL, "http://", "ws://", 1)
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
u, err := url.Parse(wsURL + "/v1/pubsub/ws")
if err != nil {
return nil, fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
q := u.Query()
q.Set("topic", topic)
q.Set("presence", "true")
q.Set("member_id", memberID)
if meta != nil {
metaJSON, _ := json.Marshal(meta)
q.Set("member_meta", string(metaJSON))
}
u.RawQuery = q.Encode()
// Set up headers with authentication
headers := http.Header{}
if apiKey := GetAPIKey(); apiKey != "" {
headers.Set("Authorization", "Bearer "+apiKey)
}
// Connect to WebSocket
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, resp, err := dialer.Dial(u.String(), headers)
if err != nil {
if resp != nil {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("websocket dial failed (status %d): %w - body: %s", resp.StatusCode, err, string(body))
}
return nil, fmt.Errorf("websocket dial failed: %w", err)
}
client := &WSPubSubClient{
t: t,
conn: conn,
topic: topic,
handlers: make([]func(topic string, data []byte) error, 0),
msgChan: make(chan []byte, 128),
doneChan: make(chan struct{}),
}
// Start reader goroutine
go client.readLoop()
return client, nil
}
// readLoop reads messages from the WebSocket and dispatches to handlers
func (c *WSPubSubClient) readLoop() {
defer close(c.doneChan)
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
c.mu.RLock()
closed := c.closed
c.mu.RUnlock()
if !closed {
// Only log if not intentionally closed
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
c.t.Logf("websocket read error: %v", err)
}
}
return
}
// Parse the message envelope
var msg WSPubSubMessage
if err := json.Unmarshal(message, &msg); err != nil {
c.t.Logf("failed to unmarshal message: %v", err)
continue
}
// Decode base64 data
data, err := base64.StdEncoding.DecodeString(msg.Data)
if err != nil {
c.t.Logf("failed to decode base64 data: %v", err)
continue
}
// Send to message channel
select {
case c.msgChan <- data:
default:
c.t.Logf("message channel full, dropping message")
}
// Dispatch to handlers
c.mu.RLock()
handlers := make([]func(topic string, data []byte) error, len(c.handlers))
copy(handlers, c.handlers)
c.mu.RUnlock()
for _, handler := range handlers {
if err := handler(msg.Topic, data); err != nil {
c.t.Logf("handler error: %v", err)
}
}
}
}
// Subscribe adds a message handler
func (c *WSPubSubClient) Subscribe(handler func(topic string, data []byte) error) {
c.mu.Lock()
defer c.mu.Unlock()
c.handlers = append(c.handlers, handler)
}
// Publish sends a message to the topic
func (c *WSPubSubClient) Publish(data []byte) error {
c.mu.RLock()
closed := c.closed
c.mu.RUnlock()
if closed {
return fmt.Errorf("client is closed")
}
// Protect concurrent writes to WebSocket
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.conn.WriteMessage(websocket.TextMessage, data)
}
// ReceiveWithTimeout waits for a message with timeout
func (c *WSPubSubClient) ReceiveWithTimeout(timeout time.Duration) ([]byte, error) {
select {
case msg := <-c.msgChan:
return msg, nil
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for message")
case <-c.doneChan:
return nil, fmt.Errorf("connection closed")
}
}
// Close closes the WebSocket connection
func (c *WSPubSubClient) Close() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
c.mu.Unlock()
// Send close message
_ = c.conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
// Close connection
return c.conn.Close()
}
// Topic returns the topic this client is subscribed to
func (c *WSPubSubClient) Topic() string {
return c.topic
}
// WSPubSubClientPair represents a publisher and subscriber pair for testing
type WSPubSubClientPair struct {
Publisher *WSPubSubClient
Subscriber *WSPubSubClient
Topic string
}
// NewWSPubSubClientPair creates a publisher and subscriber pair for a topic
func NewWSPubSubClientPair(t *testing.T, topic string) (*WSPubSubClientPair, error) {
t.Helper()
// Create subscriber first
sub, err := NewWSPubSubClient(t, topic)
if err != nil {
return nil, fmt.Errorf("failed to create subscriber: %w", err)
}
// Small delay to ensure subscriber is registered
time.Sleep(100 * time.Millisecond)
// Create publisher
pub, err := NewWSPubSubClient(t, topic)
if err != nil {
sub.Close()
return nil, fmt.Errorf("failed to create publisher: %w", err)
}
return &WSPubSubClientPair{
Publisher: pub,
Subscriber: sub,
Topic: topic,
}, nil
}
// Close closes both publisher and subscriber
func (p *WSPubSubClientPair) Close() {
if p.Publisher != nil {
p.Publisher.Close()
}
if p.Subscriber != nil {
p.Subscriber.Close()
}
}

View File

@ -3,82 +3,46 @@
package e2e package e2e
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"testing" "testing"
"time" "time"
) )
func newMessageCollector(ctx context.Context, buffer int) (chan []byte, func(string, []byte) error) { // TestPubSub_SubscribePublish tests basic pub/sub functionality via WebSocket
if buffer <= 0 {
buffer = 1
}
ch := make(chan []byte, buffer)
handler := func(_ string, data []byte) error {
copied := append([]byte(nil), data...)
select {
case ch <- copied:
case <-ctx.Done():
}
return nil
}
return ch, handler
}
func waitForMessage(ctx context.Context, ch <-chan []byte) ([]byte, error) {
select {
case msg := <-ch:
return msg, nil
case <-ctx.Done():
return nil, fmt.Errorf("context finished while waiting for pubsub message: %w", ctx.Err())
}
}
func TestPubSub_SubscribePublish(t *testing.T) { func TestPubSub_SubscribePublish(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create two clients
client1 := NewNetworkClient(t)
client2 := NewNetworkClient(t)
if err := client1.Connect(); err != nil {
t.Fatalf("client1 connect failed: %v", err)
}
defer client1.Disconnect()
if err := client2.Connect(); err != nil {
t.Fatalf("client2 connect failed: %v", err)
}
defer client2.Disconnect()
topic := GenerateTopic() topic := GenerateTopic()
message := "test-message-from-client1" message := "test-message-from-publisher"
// Subscribe on client2 // Create subscriber first
messageCh, handler := newMessageCollector(ctx, 1) subscriber, err := NewWSPubSubClient(t, topic)
if err := client2.PubSub().Subscribe(ctx, topic, handler); err != nil { if err != nil {
t.Fatalf("subscribe failed: %v", err) t.Fatalf("failed to create subscriber: %v", err)
} }
defer client2.PubSub().Unsubscribe(ctx, topic) defer subscriber.Close()
// Give subscription time to propagate and mesh to form // Give subscriber time to register
Delay(2000) Delay(200)
// Publish from client1 // Create publisher
if err := client1.PubSub().Publish(ctx, topic, []byte(message)); err != nil { publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish message
if err := publisher.Publish([]byte(message)); err != nil {
t.Fatalf("publish failed: %v", err) t.Fatalf("publish failed: %v", err)
} }
// Receive message on client2 // Receive message on subscriber
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second) msg, err := subscriber.ReceiveWithTimeout(10 * time.Second)
defer recvCancel()
msg, err := waitForMessage(recvCtx, messageCh)
if err != nil { if err != nil {
t.Fatalf("receive failed: %v", err) t.Fatalf("receive failed: %v", err)
} }
@ -88,154 +52,126 @@ func TestPubSub_SubscribePublish(t *testing.T) {
} }
} }
// TestPubSub_MultipleSubscribers tests that multiple subscribers receive the same message
func TestPubSub_MultipleSubscribers(t *testing.T) { func TestPubSub_MultipleSubscribers(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create three clients
clientPub := NewNetworkClient(t)
clientSub1 := NewNetworkClient(t)
clientSub2 := NewNetworkClient(t)
if err := clientPub.Connect(); err != nil {
t.Fatalf("publisher connect failed: %v", err)
}
defer clientPub.Disconnect()
if err := clientSub1.Connect(); err != nil {
t.Fatalf("subscriber1 connect failed: %v", err)
}
defer clientSub1.Disconnect()
if err := clientSub2.Connect(); err != nil {
t.Fatalf("subscriber2 connect failed: %v", err)
}
defer clientSub2.Disconnect()
topic := GenerateTopic() topic := GenerateTopic()
message1 := "message-for-sub1" message1 := "message-1"
message2 := "message-for-sub2" message2 := "message-2"
// Subscribe on both clients // Create two subscribers
sub1Ch, sub1Handler := newMessageCollector(ctx, 4) sub1, err := NewWSPubSubClient(t, topic)
if err := clientSub1.PubSub().Subscribe(ctx, topic, sub1Handler); err != nil { if err != nil {
t.Fatalf("subscribe1 failed: %v", err) t.Fatalf("failed to create subscriber1: %v", err)
} }
defer clientSub1.PubSub().Unsubscribe(ctx, topic) defer sub1.Close()
sub2Ch, sub2Handler := newMessageCollector(ctx, 4) sub2, err := NewWSPubSubClient(t, topic)
if err := clientSub2.PubSub().Subscribe(ctx, topic, sub2Handler); err != nil { if err != nil {
t.Fatalf("subscribe2 failed: %v", err) t.Fatalf("failed to create subscriber2: %v", err)
} }
defer clientSub2.PubSub().Unsubscribe(ctx, topic) defer sub2.Close()
// Give subscriptions time to propagate // Give subscribers time to register
Delay(500) Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish first message // Publish first message
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message1)); err != nil { if err := publisher.Publish([]byte(message1)); err != nil {
t.Fatalf("publish1 failed: %v", err) t.Fatalf("publish1 failed: %v", err)
} }
// Both subscribers should receive first message // Both subscribers should receive first message
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second) msg1a, err := sub1.ReceiveWithTimeout(10 * time.Second)
defer recvCancel()
msg1a, err := waitForMessage(recvCtx, sub1Ch)
if err != nil { if err != nil {
t.Fatalf("sub1 receive1 failed: %v", err) t.Fatalf("sub1 receive1 failed: %v", err)
} }
if string(msg1a) != message1 { if string(msg1a) != message1 {
t.Fatalf("sub1: expected %q, got %q", message1, string(msg1a)) t.Fatalf("sub1: expected %q, got %q", message1, string(msg1a))
} }
msg1b, err := waitForMessage(recvCtx, sub2Ch) msg1b, err := sub2.ReceiveWithTimeout(10 * time.Second)
if err != nil { if err != nil {
t.Fatalf("sub2 receive1 failed: %v", err) t.Fatalf("sub2 receive1 failed: %v", err)
} }
if string(msg1b) != message1 { if string(msg1b) != message1 {
t.Fatalf("sub2: expected %q, got %q", message1, string(msg1b)) t.Fatalf("sub2: expected %q, got %q", message1, string(msg1b))
} }
// Publish second message // Publish second message
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message2)); err != nil { if err := publisher.Publish([]byte(message2)); err != nil {
t.Fatalf("publish2 failed: %v", err) t.Fatalf("publish2 failed: %v", err)
} }
// Both subscribers should receive second message // Both subscribers should receive second message
recvCtx2, recvCancel2 := context.WithTimeout(ctx, 10*time.Second) msg2a, err := sub1.ReceiveWithTimeout(10 * time.Second)
defer recvCancel2()
msg2a, err := waitForMessage(recvCtx2, sub1Ch)
if err != nil { if err != nil {
t.Fatalf("sub1 receive2 failed: %v", err) t.Fatalf("sub1 receive2 failed: %v", err)
} }
if string(msg2a) != message2 { if string(msg2a) != message2 {
t.Fatalf("sub1: expected %q, got %q", message2, string(msg2a)) t.Fatalf("sub1: expected %q, got %q", message2, string(msg2a))
} }
msg2b, err := waitForMessage(recvCtx2, sub2Ch) msg2b, err := sub2.ReceiveWithTimeout(10 * time.Second)
if err != nil { if err != nil {
t.Fatalf("sub2 receive2 failed: %v", err) t.Fatalf("sub2 receive2 failed: %v", err)
} }
if string(msg2b) != message2 { if string(msg2b) != message2 {
t.Fatalf("sub2: expected %q, got %q", message2, string(msg2b)) t.Fatalf("sub2: expected %q, got %q", message2, string(msg2b))
} }
} }
// TestPubSub_Deduplication tests that multiple identical messages are all received
func TestPubSub_Deduplication(t *testing.T) { func TestPubSub_Deduplication(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create two clients
clientPub := NewNetworkClient(t)
clientSub := NewNetworkClient(t)
if err := clientPub.Connect(); err != nil {
t.Fatalf("publisher connect failed: %v", err)
}
defer clientPub.Disconnect()
if err := clientSub.Connect(); err != nil {
t.Fatalf("subscriber connect failed: %v", err)
}
defer clientSub.Disconnect()
topic := GenerateTopic() topic := GenerateTopic()
message := "duplicate-test-message" message := "duplicate-test-message"
// Subscribe on client // Create subscriber
messageCh, handler := newMessageCollector(ctx, 3) subscriber, err := NewWSPubSubClient(t, topic)
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil { if err != nil {
t.Fatalf("subscribe failed: %v", err) t.Fatalf("failed to create subscriber: %v", err)
} }
defer clientSub.PubSub().Unsubscribe(ctx, topic) defer subscriber.Close()
// Give subscription time to propagate and mesh to form // Give subscriber time to register
Delay(2000) Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish the same message multiple times // Publish the same message multiple times
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message)); err != nil { if err := publisher.Publish([]byte(message)); err != nil {
t.Fatalf("publish %d failed: %v", i, err) t.Fatalf("publish %d failed: %v", i, err)
} }
// Small delay between publishes
Delay(50)
} }
// Receive messages - should get all (no dedup filter on subscribe) // Receive messages - should get all (no dedup filter)
recvCtx, recvCancel := context.WithTimeout(ctx, 5*time.Second)
defer recvCancel()
receivedCount := 0 receivedCount := 0
for receivedCount < 3 { for receivedCount < 3 {
if _, err := waitForMessage(recvCtx, messageCh); err != nil { _, err := subscriber.ReceiveWithTimeout(5 * time.Second)
if err != nil {
break break
} }
receivedCount++ receivedCount++
@ -244,40 +180,35 @@ func TestPubSub_Deduplication(t *testing.T) {
if receivedCount < 1 { if receivedCount < 1 {
t.Fatalf("expected to receive at least 1 message, got %d", receivedCount) t.Fatalf("expected to receive at least 1 message, got %d", receivedCount)
} }
t.Logf("received %d messages", receivedCount)
} }
// TestPubSub_ConcurrentPublish tests concurrent message publishing
func TestPubSub_ConcurrentPublish(t *testing.T) { func TestPubSub_ConcurrentPublish(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create clients
clientPub := NewNetworkClient(t)
clientSub := NewNetworkClient(t)
if err := clientPub.Connect(); err != nil {
t.Fatalf("publisher connect failed: %v", err)
}
defer clientPub.Disconnect()
if err := clientSub.Connect(); err != nil {
t.Fatalf("subscriber connect failed: %v", err)
}
defer clientSub.Disconnect()
topic := GenerateTopic() topic := GenerateTopic()
numMessages := 10 numMessages := 10
// Subscribe // Create subscriber
messageCh, handler := newMessageCollector(ctx, numMessages) subscriber, err := NewWSPubSubClient(t, topic)
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil { if err != nil {
t.Fatalf("subscribe failed: %v", err) t.Fatalf("failed to create subscriber: %v", err)
} }
defer clientSub.PubSub().Unsubscribe(ctx, topic) defer subscriber.Close()
// Give subscription time to propagate and mesh to form // Give subscriber time to register
Delay(2000) Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish multiple messages concurrently // Publish multiple messages concurrently
var wg sync.WaitGroup var wg sync.WaitGroup
@ -286,7 +217,7 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
go func(idx int) { go func(idx int) {
defer wg.Done() defer wg.Done()
msg := fmt.Sprintf("concurrent-msg-%d", idx) msg := fmt.Sprintf("concurrent-msg-%d", idx)
if err := clientPub.PubSub().Publish(ctx, topic, []byte(msg)); err != nil { if err := publisher.Publish([]byte(msg)); err != nil {
t.Logf("publish %d failed: %v", idx, err) t.Logf("publish %d failed: %v", idx, err)
} }
}(i) }(i)
@ -294,12 +225,10 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
wg.Wait() wg.Wait()
// Receive messages // Receive messages
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
defer recvCancel()
receivedCount := 0 receivedCount := 0
for receivedCount < numMessages { for receivedCount < numMessages {
if _, err := waitForMessage(recvCtx, messageCh); err != nil { _, err := subscriber.ReceiveWithTimeout(10 * time.Second)
if err != nil {
break break
} }
receivedCount++ receivedCount++
@ -310,107 +239,110 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
} }
} }
// TestPubSub_TopicIsolation tests that messages are isolated to their topics
func TestPubSub_TopicIsolation(t *testing.T) { func TestPubSub_TopicIsolation(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create clients
clientPub := NewNetworkClient(t)
clientSub := NewNetworkClient(t)
if err := clientPub.Connect(); err != nil {
t.Fatalf("publisher connect failed: %v", err)
}
defer clientPub.Disconnect()
if err := clientSub.Connect(); err != nil {
t.Fatalf("subscriber connect failed: %v", err)
}
defer clientSub.Disconnect()
topic1 := GenerateTopic() topic1 := GenerateTopic()
topic2 := GenerateTopic() topic2 := GenerateTopic()
msg1 := "message-on-topic1"
// Subscribe to topic1
messageCh, handler := newMessageCollector(ctx, 2)
if err := clientSub.PubSub().Subscribe(ctx, topic1, handler); err != nil {
t.Fatalf("subscribe1 failed: %v", err)
}
defer clientSub.PubSub().Unsubscribe(ctx, topic1)
// Give subscription time to propagate and mesh to form
Delay(2000)
// Publish to topic2
msg2 := "message-on-topic2" msg2 := "message-on-topic2"
if err := clientPub.PubSub().Publish(ctx, topic2, []byte(msg2)); err != nil {
// Create subscriber for topic1
sub1, err := NewWSPubSubClient(t, topic1)
if err != nil {
t.Fatalf("failed to create subscriber1: %v", err)
}
defer sub1.Close()
// Create subscriber for topic2
sub2, err := NewWSPubSubClient(t, topic2)
if err != nil {
t.Fatalf("failed to create subscriber2: %v", err)
}
defer sub2.Close()
// Give subscribers time to register
Delay(200)
// Create publishers
pub1, err := NewWSPubSubClient(t, topic1)
if err != nil {
t.Fatalf("failed to create publisher1: %v", err)
}
defer pub1.Close()
pub2, err := NewWSPubSubClient(t, topic2)
if err != nil {
t.Fatalf("failed to create publisher2: %v", err)
}
defer pub2.Close()
// Give connections time to stabilize
Delay(200)
// Publish to topic2 first
if err := pub2.Publish([]byte(msg2)); err != nil {
t.Fatalf("publish2 failed: %v", err) t.Fatalf("publish2 failed: %v", err)
} }
// Publish to topic1 // Publish to topic1
msg1 := "message-on-topic1" if err := pub1.Publish([]byte(msg1)); err != nil {
if err := clientPub.PubSub().Publish(ctx, topic1, []byte(msg1)); err != nil {
t.Fatalf("publish1 failed: %v", err) t.Fatalf("publish1 failed: %v", err)
} }
// Receive on sub1 - should get msg1 only // Sub1 should receive msg1 only
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second) received1, err := sub1.ReceiveWithTimeout(10 * time.Second)
defer recvCancel()
msg, err := waitForMessage(recvCtx, messageCh)
if err != nil { if err != nil {
t.Fatalf("receive failed: %v", err) t.Fatalf("sub1 receive failed: %v", err)
}
if string(received1) != msg1 {
t.Fatalf("sub1: expected %q, got %q", msg1, string(received1))
} }
if string(msg) != msg1 { // Sub2 should receive msg2 only
t.Fatalf("expected %q, got %q", msg1, string(msg)) received2, err := sub2.ReceiveWithTimeout(10 * time.Second)
if err != nil {
t.Fatalf("sub2 receive failed: %v", err)
}
if string(received2) != msg2 {
t.Fatalf("sub2: expected %q, got %q", msg2, string(received2))
} }
} }
// TestPubSub_EmptyMessage tests sending and receiving empty messages
func TestPubSub_EmptyMessage(t *testing.T) { func TestPubSub_EmptyMessage(t *testing.T) {
SkipIfMissingGateway(t) SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create clients
clientPub := NewNetworkClient(t)
clientSub := NewNetworkClient(t)
if err := clientPub.Connect(); err != nil {
t.Fatalf("publisher connect failed: %v", err)
}
defer clientPub.Disconnect()
if err := clientSub.Connect(); err != nil {
t.Fatalf("subscriber connect failed: %v", err)
}
defer clientSub.Disconnect()
topic := GenerateTopic() topic := GenerateTopic()
// Subscribe // Create subscriber
messageCh, handler := newMessageCollector(ctx, 1) subscriber, err := NewWSPubSubClient(t, topic)
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil { if err != nil {
t.Fatalf("subscribe failed: %v", err) t.Fatalf("failed to create subscriber: %v", err)
} }
defer clientSub.PubSub().Unsubscribe(ctx, topic) defer subscriber.Close()
// Give subscription time to propagate and mesh to form // Give subscriber time to register
Delay(2000) Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish empty message // Publish empty message
if err := clientPub.PubSub().Publish(ctx, topic, []byte("")); err != nil { if err := publisher.Publish([]byte("")); err != nil {
t.Fatalf("publish empty failed: %v", err) t.Fatalf("publish empty failed: %v", err)
} }
// Receive on sub - should get empty message // Receive on subscriber - should get empty message
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second) msg, err := subscriber.ReceiveWithTimeout(10 * time.Second)
defer recvCancel()
msg, err := waitForMessage(recvCtx, messageCh)
if err != nil { if err != nil {
t.Fatalf("receive failed: %v", err) t.Fatalf("receive failed: %v", err)
} }
@ -419,3 +351,111 @@ func TestPubSub_EmptyMessage(t *testing.T) {
t.Fatalf("expected empty message, got %q", string(msg)) t.Fatalf("expected empty message, got %q", string(msg))
} }
} }
// TestPubSub_LargeMessage tests sending and receiving large messages
func TestPubSub_LargeMessage(t *testing.T) {
SkipIfMissingGateway(t)
topic := GenerateTopic()
// Create a large message (100KB)
largeMessage := make([]byte, 100*1024)
for i := range largeMessage {
largeMessage[i] = byte(i % 256)
}
// Create subscriber
subscriber, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create subscriber: %v", err)
}
defer subscriber.Close()
// Give subscriber time to register
Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish large message
if err := publisher.Publish(largeMessage); err != nil {
t.Fatalf("publish large message failed: %v", err)
}
// Receive on subscriber
msg, err := subscriber.ReceiveWithTimeout(30 * time.Second)
if err != nil {
t.Fatalf("receive failed: %v", err)
}
if len(msg) != len(largeMessage) {
t.Fatalf("expected message of length %d, got %d", len(largeMessage), len(msg))
}
// Verify content
for i := range msg {
if msg[i] != largeMessage[i] {
t.Fatalf("message content mismatch at byte %d", i)
}
}
}
// TestPubSub_RapidPublish tests rapid message publishing
func TestPubSub_RapidPublish(t *testing.T) {
SkipIfMissingGateway(t)
topic := GenerateTopic()
numMessages := 50
// Create subscriber
subscriber, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create subscriber: %v", err)
}
defer subscriber.Close()
// Give subscriber time to register
Delay(200)
// Create publisher
publisher, err := NewWSPubSubClient(t, topic)
if err != nil {
t.Fatalf("failed to create publisher: %v", err)
}
defer publisher.Close()
// Give connections time to stabilize
Delay(200)
// Publish messages rapidly
for i := 0; i < numMessages; i++ {
msg := fmt.Sprintf("rapid-msg-%d", i)
if err := publisher.Publish([]byte(msg)); err != nil {
t.Fatalf("publish %d failed: %v", i, err)
}
}
// Receive messages
receivedCount := 0
for receivedCount < numMessages {
_, err := subscriber.ReceiveWithTimeout(10 * time.Second)
if err != nil {
break
}
receivedCount++
}
// Allow some message loss due to buffering
minExpected := numMessages * 80 / 100 // 80% minimum
if receivedCount < minExpected {
t.Fatalf("expected at least %d messages, got %d", minExpected, receivedCount)
}
t.Logf("received %d/%d messages (%.1f%%)", receivedCount, numMessages, float64(receivedCount)*100/float64(numMessages))
}

122
e2e/pubsub_presence_test.go Normal file
View File

@ -0,0 +1,122 @@
//go:build e2e
package e2e
import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
)
func TestPubSub_Presence(t *testing.T) {
SkipIfMissingGateway(t)
topic := GenerateTopic()
memberID := "user123"
memberMeta := map[string]interface{}{"name": "Alice"}
// 1. Subscribe with presence
client1, err := NewWSPubSubPresenceClient(t, topic, memberID, memberMeta)
if err != nil {
t.Fatalf("failed to create presence client: %v", err)
}
defer client1.Close()
// Wait for join event
msg, err := client1.ReceiveWithTimeout(5 * time.Second)
if err != nil {
t.Fatalf("did not receive join event: %v", err)
}
var event map[string]interface{}
if err := json.Unmarshal(msg, &event); err != nil {
t.Fatalf("failed to unmarshal event: %v", err)
}
if event["type"] != "presence.join" {
t.Fatalf("expected presence.join event, got %v", event["type"])
}
if event["member_id"] != memberID {
t.Fatalf("expected member_id %s, got %v", memberID, event["member_id"])
}
// 2. Query presence endpoint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req := &HTTPRequest{
Method: http.MethodGet,
URL: fmt.Sprintf("%s/v1/pubsub/presence?topic=%s", GetGatewayURL(), topic),
}
body, status, err := req.Do(ctx)
if err != nil {
t.Fatalf("presence query failed: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
var resp map[string]interface{}
if err := DecodeJSON(body, &resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp["count"] != float64(1) {
t.Fatalf("expected count 1, got %v", resp["count"])
}
members := resp["members"].([]interface{})
if len(members) != 1 {
t.Fatalf("expected 1 member, got %d", len(members))
}
member := members[0].(map[string]interface{})
if member["member_id"] != memberID {
t.Fatalf("expected member_id %s, got %v", memberID, member["member_id"])
}
// 3. Subscribe second member
memberID2 := "user456"
client2, err := NewWSPubSubPresenceClient(t, topic, memberID2, nil)
if err != nil {
t.Fatalf("failed to create second presence client: %v", err)
}
// We'll close client2 later to test leave event
// Client1 should receive join event for Client2
msg2, err := client1.ReceiveWithTimeout(5 * time.Second)
if err != nil {
t.Fatalf("client1 did not receive join event for client2: %v", err)
}
if err := json.Unmarshal(msg2, &event); err != nil {
t.Fatalf("failed to unmarshal event: %v", err)
}
if event["type"] != "presence.join" || event["member_id"] != memberID2 {
t.Fatalf("expected presence.join for %s, got %v for %v", memberID2, event["type"], event["member_id"])
}
// 4. Disconnect client2 and verify leave event
client2.Close()
msg3, err := client1.ReceiveWithTimeout(5 * time.Second)
if err != nil {
t.Fatalf("client1 did not receive leave event for client2: %v", err)
}
if err := json.Unmarshal(msg3, &event); err != nil {
t.Fatalf("failed to unmarshal event: %v", err)
}
if event["type"] != "presence.leave" || event["member_id"] != memberID2 {
t.Fatalf("expected presence.leave for %s, got %v for %v", memberID2, event["type"], event["member_id"])
}
}

123
e2e/serverless_test.go Normal file
View File

@ -0,0 +1,123 @@
//go:build e2e
package e2e
import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"os"
"testing"
"time"
)
func TestServerless_DeployAndInvoke(t *testing.T) {
SkipIfMissingGateway(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
wasmPath := "../examples/functions/bin/hello.wasm"
if _, err := os.Stat(wasmPath); os.IsNotExist(err) {
t.Skip("hello.wasm not found")
}
wasmBytes, err := os.ReadFile(wasmPath)
if err != nil {
t.Fatalf("failed to read hello.wasm: %v", err)
}
funcName := "e2e-hello"
namespace := "default"
// 1. Deploy function
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add metadata
_ = writer.WriteField("name", funcName)
_ = writer.WriteField("namespace", namespace)
// Add WASM file
part, err := writer.CreateFormFile("wasm", funcName+".wasm")
if err != nil {
t.Fatalf("failed to create form file: %v", err)
}
part.Write(wasmBytes)
writer.Close()
deployReq, _ := http.NewRequestWithContext(ctx, "POST", GetGatewayURL()+"/v1/functions", &buf)
deployReq.Header.Set("Content-Type", writer.FormDataContentType())
if apiKey := GetAPIKey(); apiKey != "" {
deployReq.Header.Set("Authorization", "Bearer "+apiKey)
}
client := NewHTTPClient(1 * time.Minute)
resp, err := client.Do(deployReq)
if err != nil {
t.Fatalf("deploy request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("deploy failed with status %d: %s", resp.StatusCode, string(body))
}
// 2. Invoke function
invokePayload := []byte(`{"name": "E2E Tester"}`)
invokeReq, _ := http.NewRequestWithContext(ctx, "POST", GetGatewayURL()+"/v1/functions/"+funcName+"/invoke", bytes.NewReader(invokePayload))
invokeReq.Header.Set("Content-Type", "application/json")
if apiKey := GetAPIKey(); apiKey != "" {
invokeReq.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err = client.Do(invokeReq)
if err != nil {
t.Fatalf("invoke request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("invoke failed with status %d: %s", resp.StatusCode, string(body))
}
output, _ := io.ReadAll(resp.Body)
expected := "Hello, E2E Tester!"
if !bytes.Contains(output, []byte(expected)) {
t.Errorf("output %q does not contain %q", string(output), expected)
}
// 3. List functions
listReq, _ := http.NewRequestWithContext(ctx, "GET", GetGatewayURL()+"/v1/functions?namespace="+namespace, nil)
if apiKey := GetAPIKey(); apiKey != "" {
listReq.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err = client.Do(listReq)
if err != nil {
t.Fatalf("list request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("list failed with status %d", resp.StatusCode)
}
// 4. Delete function
deleteReq, _ := http.NewRequestWithContext(ctx, "DELETE", GetGatewayURL()+"/v1/functions/"+funcName+"?namespace="+namespace, nil)
if apiKey := GetAPIKey(); apiKey != "" {
deleteReq.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err = client.Do(deleteReq)
if err != nil {
t.Fatalf("delete request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("delete failed with status %d", resp.StatusCode)
}
}

158
example.http Normal file
View File

@ -0,0 +1,158 @@
### Orama Network Gateway API Examples
# This file is designed for the VS Code "REST Client" extension.
# It demonstrates the core capabilities of the DeBros Network Gateway.
@baseUrl = http://localhost:6001
@apiKey = ak_X32jj2fiin8zzv0hmBKTC5b5:default
@contentType = application/json
############################################################
### 1. SYSTEM & HEALTH
############################################################
# @name HealthCheck
GET {{baseUrl}}/v1/health
X-API-Key: {{apiKey}}
###
# @name SystemStatus
# Returns the full status of the gateway and connected services
GET {{baseUrl}}/v1/status
X-API-Key: {{apiKey}}
###
# @name NetworkStatus
# Returns the P2P network status and PeerID
GET {{baseUrl}}/v1/network/status
X-API-Key: {{apiKey}}
############################################################
### 2. DISTRIBUTED CACHE (OLRIC)
############################################################
# @name CachePut
# Stores a value in the distributed cache (DMap)
POST {{baseUrl}}/v1/cache/put
X-API-Key: {{apiKey}}
Content-Type: {{contentType}}
{
"dmap": "demo-cache",
"key": "video-demo",
"value": "Hello from REST Client!"
}
###
# @name CacheGet
# Retrieves a value from the distributed cache
POST {{baseUrl}}/v1/cache/get
X-API-Key: {{apiKey}}
Content-Type: {{contentType}}
{
"dmap": "demo-cache",
"key": "video-demo"
}
###
# @name CacheScan
# Scans for keys in a specific DMap
POST {{baseUrl}}/v1/cache/scan
X-API-Key: {{apiKey}}
Content-Type: {{contentType}}
{
"dmap": "demo-cache"
}
############################################################
### 3. DECENTRALIZED STORAGE (IPFS)
############################################################
# @name StorageUpload
# Uploads a file to IPFS (Multipart)
POST {{baseUrl}}/v1/storage/upload
X-API-Key: {{apiKey}}
Content-Type: multipart/form-data; boundary=boundary
--boundary
Content-Disposition: form-data; name="file"; filename="demo.txt"
Content-Type: text/plain
This is a demonstration of decentralized storage on the Sonr Network.
--boundary--
###
# @name StorageStatus
# Check the pinning status and replication of a CID
# Replace {cid} with the CID returned from the upload above
@demoCid = bafkreid76y6x6v2n5o4n6n5o4n6n5o4n6n5o4n6n5o4
GET {{baseUrl}}/v1/storage/status/{{demoCid}}
X-API-Key: {{apiKey}}
###
# @name StorageDownload
# Retrieve content directly from IPFS via the gateway
GET {{baseUrl}}/v1/storage/get/{{demoCid}}
X-API-Key: {{apiKey}}
############################################################
### 4. REAL-TIME PUB/SUB
############################################################
# @name ListTopics
# Lists all active topics in the current namespace
GET {{baseUrl}}/v1/pubsub/topics
X-API-Key: {{apiKey}}
###
# @name PublishMessage
# Publishes a base64 encoded message to a topic
POST {{baseUrl}}/v1/pubsub/publish
X-API-Key: {{apiKey}}
Content-Type: {{contentType}}
{
"topic": "network-updates",
"data_base64": "U29uciBOZXR3b3JrIGlzIGF3ZXNvbWUh"
}
############################################################
### 5. SERVERLESS FUNCTIONS
############################################################
# @name ListFunctions
# Lists all deployed serverless functions
GET {{baseUrl}}/v1/functions
X-API-Key: {{apiKey}}
###
# @name InvokeFunction
# Invokes a deployed function by name
# Path: /v1/invoke/{namespace}/{functionName}
POST {{baseUrl}}/v1/invoke/default/hello
X-API-Key: {{apiKey}}
Content-Type: {{contentType}}
{
"name": "Developer"
}
###
# @name WhoAmI
# Validates the API Key and returns caller identity
GET {{baseUrl}}/v1/auth/whoami
X-API-Key: {{apiKey}}

42
examples/functions/build.sh Executable file
View File

@ -0,0 +1,42 @@
#!/bin/bash
# Build all example functions to WASM using TinyGo
#
# Prerequisites:
# - TinyGo installed: https://tinygo.org/getting-started/install/
# - On macOS: brew install tinygo
#
# Usage: ./build.sh
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR="$SCRIPT_DIR/bin"
# Check if TinyGo is installed
if ! command -v tinygo &> /dev/null; then
echo "Error: TinyGo is not installed."
echo "Install it with: brew install tinygo (macOS) or see https://tinygo.org/getting-started/install/"
exit 1
fi
# Create output directory
mkdir -p "$OUTPUT_DIR"
echo "Building example functions to WASM..."
echo
# Build each function
for dir in "$SCRIPT_DIR"/*/; do
if [ -f "$dir/main.go" ]; then
name=$(basename "$dir")
echo "Building $name..."
cd "$dir"
tinygo build -o "$OUTPUT_DIR/$name.wasm" -target wasi main.go
echo " -> $OUTPUT_DIR/$name.wasm"
fi
done
echo
echo "Done! WASM files are in $OUTPUT_DIR/"
ls -lh "$OUTPUT_DIR"/*.wasm 2>/dev/null || echo "No WASM files built."

View File

@ -0,0 +1,66 @@
// Example: Counter function with Olric cache
// This function demonstrates using the distributed cache to maintain state.
// Compile with: tinygo build -o counter.wasm -target wasi main.go
//
// Note: This example shows the CONCEPT. Actual host function integration
// requires the host function bindings to be exposed to the WASM module.
package main
import (
"encoding/json"
"os"
)
func main() {
// Read input from stdin
var input []byte
buf := make([]byte, 1024)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
input = append(input, buf[:n]...)
}
if err != nil {
break
}
}
// Parse input
var payload struct {
Action string `json:"action"` // "increment", "decrement", "get", "reset"
CounterID string `json:"counter_id"`
}
if err := json.Unmarshal(input, &payload); err != nil {
response := map[string]interface{}{
"error": "Invalid JSON input",
}
output, _ := json.Marshal(response)
os.Stdout.Write(output)
return
}
if payload.CounterID == "" {
payload.CounterID = "default"
}
// NOTE: In the real implementation, this would use host functions:
// - cache_get(key) to read the counter
// - cache_put(key, value, ttl) to write the counter
//
// For this example, we just simulate the logic:
response := map[string]interface{}{
"counter_id": payload.CounterID,
"action": payload.Action,
"message": "Counter operations require cache host functions",
"example": map[string]interface{}{
"increment": "cache_put('counter:' + counter_id, current + 1)",
"decrement": "cache_put('counter:' + counter_id, current - 1)",
"get": "cache_get('counter:' + counter_id)",
"reset": "cache_put('counter:' + counter_id, 0)",
},
}
output, _ := json.Marshal(response)
os.Stdout.Write(output)
}

View File

@ -0,0 +1,50 @@
// Example: Echo function
// This is a simple serverless function that echoes back the input.
// Compile with: tinygo build -o echo.wasm -target wasi main.go
package main
import (
"encoding/json"
"os"
)
// Input is read from stdin, output is written to stdout.
// The Orama serverless engine passes the invocation payload via stdin
// and expects the response on stdout.
func main() {
// Read all input from stdin
var input []byte
buf := make([]byte, 1024)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
input = append(input, buf[:n]...)
}
if err != nil {
break
}
}
// Parse input as JSON (optional - could also just echo raw bytes)
var payload map[string]interface{}
if err := json.Unmarshal(input, &payload); err != nil {
// Not JSON, just echo the raw input
response := map[string]interface{}{
"echo": string(input),
}
output, _ := json.Marshal(response)
os.Stdout.Write(output)
return
}
// Create response
response := map[string]interface{}{
"echo": payload,
"message": "Echo function received your input!",
}
output, _ := json.Marshal(response)
os.Stdout.Write(output)
}

View File

@ -0,0 +1,42 @@
// Example: Hello function
// This is a simple serverless function that returns a greeting.
// Compile with: tinygo build -o hello.wasm -target wasi main.go
package main
import (
"encoding/json"
"os"
)
func main() {
// Read input from stdin
var input []byte
buf := make([]byte, 1024)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
input = append(input, buf[:n]...)
}
if err != nil {
break
}
}
// Parse input to get name
var payload struct {
Name string `json:"name"`
}
if err := json.Unmarshal(input, &payload); err != nil || payload.Name == "" {
payload.Name = "World"
}
// Create greeting response
response := map[string]interface{}{
"greeting": "Hello, " + payload.Name + "!",
"message": "This is a serverless function running on Orama Network",
}
output, _ := json.Marshal(response)
os.Stdout.Write(output)
}

7
go.mod
View File

@ -1,6 +1,6 @@
module github.com/DeBrosOfficial/network module github.com/DeBrosOfficial/network
go 1.23.8 go 1.24.0
toolchain go1.24.1 toolchain go1.24.1
@ -10,6 +10,7 @@ require (
github.com/charmbracelet/lipgloss v1.0.0 github.com/charmbracelet/lipgloss v1.0.0
github.com/ethereum/go-ethereum v1.13.14 github.com/ethereum/go-ethereum v1.13.14
github.com/go-chi/chi/v5 v5.2.3 github.com/go-chi/chi/v5 v5.2.3
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/libp2p/go-libp2p v0.41.1 github.com/libp2p/go-libp2p v0.41.1
github.com/libp2p/go-libp2p-pubsub v0.14.2 github.com/libp2p/go-libp2p-pubsub v0.14.2
@ -18,6 +19,7 @@ require (
github.com/multiformats/go-multiaddr v0.15.0 github.com/multiformats/go-multiaddr v0.15.0
github.com/olric-data/olric v0.7.0 github.com/olric-data/olric v0.7.0
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8 github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
github.com/tetratelabs/wazero v1.11.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
golang.org/x/crypto v0.40.0 golang.org/x/crypto v0.40.0
golang.org/x/net v0.42.0 golang.org/x/net v0.42.0
@ -54,7 +56,6 @@ require (
github.com/google/btree v1.1.3 // indirect github.com/google/btree v1.1.3 // indirect
github.com/google/gopacket v1.1.19 // indirect github.com/google/gopacket v1.1.19 // indirect
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
github.com/hashicorp/go-metrics v0.5.4 // indirect github.com/hashicorp/go-metrics v0.5.4 // indirect
@ -154,7 +155,7 @@ require (
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect
golang.org/x/mod v0.26.0 // indirect golang.org/x/mod v0.26.0 // indirect
golang.org/x/sync v0.16.0 // indirect golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.34.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.27.0 // indirect golang.org/x/text v0.27.0 // indirect
golang.org/x/tools v0.35.0 // indirect golang.org/x/tools v0.35.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect google.golang.org/protobuf v1.36.6 // indirect

6
go.sum
View File

@ -487,6 +487,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4= github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4=
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
@ -627,8 +629,8 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View File

@ -0,0 +1,243 @@
-- Orama Network - Serverless Functions Engine (Phase 4)
-- WASM-based serverless function execution with triggers, jobs, and secrets
BEGIN;
-- =============================================================================
-- FUNCTIONS TABLE
-- Core function registry with versioning support
-- =============================================================================
CREATE TABLE IF NOT EXISTS functions (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
namespace TEXT NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
wasm_cid TEXT NOT NULL,
source_cid TEXT,
memory_limit_mb INTEGER NOT NULL DEFAULT 64,
timeout_seconds INTEGER NOT NULL DEFAULT 30,
is_public BOOLEAN NOT NULL DEFAULT FALSE,
retry_count INTEGER NOT NULL DEFAULT 0,
retry_delay_seconds INTEGER NOT NULL DEFAULT 5,
dlq_topic TEXT,
status TEXT NOT NULL DEFAULT 'active',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by TEXT NOT NULL,
UNIQUE(namespace, name)
);
CREATE INDEX IF NOT EXISTS idx_functions_namespace ON functions(namespace);
CREATE INDEX IF NOT EXISTS idx_functions_name ON functions(namespace, name);
CREATE INDEX IF NOT EXISTS idx_functions_status ON functions(status);
-- =============================================================================
-- FUNCTION ENVIRONMENT VARIABLES
-- Non-sensitive configuration per function
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_env_vars (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(function_id, key),
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_env_vars_function ON function_env_vars(function_id);
-- =============================================================================
-- FUNCTION SECRETS
-- Encrypted secrets per namespace (shared across functions in namespace)
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_secrets (
id TEXT PRIMARY KEY,
namespace TEXT NOT NULL,
name TEXT NOT NULL,
encrypted_value BLOB NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(namespace, name)
);
CREATE INDEX IF NOT EXISTS idx_function_secrets_namespace ON function_secrets(namespace);
-- =============================================================================
-- CRON TRIGGERS
-- Scheduled function execution using cron expressions
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_cron_triggers (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
cron_expression TEXT NOT NULL,
next_run_at TIMESTAMP,
last_run_at TIMESTAMP,
last_status TEXT,
last_error TEXT,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_cron_triggers_function ON function_cron_triggers(function_id);
CREATE INDEX IF NOT EXISTS idx_function_cron_triggers_next_run ON function_cron_triggers(next_run_at)
WHERE enabled = TRUE;
-- =============================================================================
-- DATABASE TRIGGERS
-- Trigger functions on database changes (INSERT/UPDATE/DELETE)
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_db_triggers (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
table_name TEXT NOT NULL,
operation TEXT NOT NULL CHECK(operation IN ('INSERT', 'UPDATE', 'DELETE')),
condition TEXT,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_db_triggers_function ON function_db_triggers(function_id);
CREATE INDEX IF NOT EXISTS idx_function_db_triggers_table ON function_db_triggers(table_name, operation)
WHERE enabled = TRUE;
-- =============================================================================
-- PUBSUB TRIGGERS
-- Trigger functions on pubsub messages
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_pubsub_triggers (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
topic TEXT NOT NULL,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function ON function_pubsub_triggers(function_id);
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_topic ON function_pubsub_triggers(topic)
WHERE enabled = TRUE;
-- =============================================================================
-- ONE-TIME TIMERS
-- Schedule functions to run once at a specific time
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_timers (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
run_at TIMESTAMP NOT NULL,
payload TEXT,
status TEXT NOT NULL DEFAULT 'pending' CHECK(status IN ('pending', 'running', 'completed', 'failed')),
error TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_timers_function ON function_timers(function_id);
CREATE INDEX IF NOT EXISTS idx_function_timers_pending ON function_timers(run_at)
WHERE status = 'pending';
-- =============================================================================
-- BACKGROUND JOBS
-- Long-running async function execution
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_jobs (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
payload TEXT,
status TEXT NOT NULL DEFAULT 'pending' CHECK(status IN ('pending', 'running', 'completed', 'failed', 'cancelled')),
progress INTEGER NOT NULL DEFAULT 0 CHECK(progress >= 0 AND progress <= 100),
result TEXT,
error TEXT,
started_at TIMESTAMP,
completed_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_jobs_function ON function_jobs(function_id);
CREATE INDEX IF NOT EXISTS idx_function_jobs_status ON function_jobs(status);
CREATE INDEX IF NOT EXISTS idx_function_jobs_pending ON function_jobs(created_at)
WHERE status = 'pending';
-- =============================================================================
-- INVOCATION LOGS
-- Record of all function invocations for debugging and metrics
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_invocations (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
request_id TEXT NOT NULL,
trigger_type TEXT NOT NULL,
caller_wallet TEXT,
input_size INTEGER,
output_size INTEGER,
started_at TIMESTAMP NOT NULL,
completed_at TIMESTAMP,
duration_ms INTEGER,
status TEXT CHECK(status IN ('success', 'error', 'timeout')),
error_message TEXT,
memory_used_mb REAL,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_invocations_function ON function_invocations(function_id);
CREATE INDEX IF NOT EXISTS idx_function_invocations_request ON function_invocations(request_id);
CREATE INDEX IF NOT EXISTS idx_function_invocations_time ON function_invocations(started_at);
CREATE INDEX IF NOT EXISTS idx_function_invocations_status ON function_invocations(function_id, status);
-- =============================================================================
-- FUNCTION LOGS
-- Captured log output from function execution
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_logs (
id TEXT PRIMARY KEY,
function_id TEXT NOT NULL,
invocation_id TEXT NOT NULL,
level TEXT NOT NULL CHECK(level IN ('info', 'warn', 'error', 'debug')),
message TEXT NOT NULL,
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE,
FOREIGN KEY (invocation_id) REFERENCES function_invocations(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_function_logs_invocation ON function_logs(invocation_id);
CREATE INDEX IF NOT EXISTS idx_function_logs_function ON function_logs(function_id, timestamp);
-- =============================================================================
-- DB CHANGE TRACKING
-- Track last processed row for database triggers (CDC-like)
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_db_change_tracking (
id TEXT PRIMARY KEY,
trigger_id TEXT NOT NULL UNIQUE,
last_row_id INTEGER,
last_updated_at TIMESTAMP,
last_check_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (trigger_id) REFERENCES function_db_triggers(id) ON DELETE CASCADE
);
-- =============================================================================
-- RATE LIMITING
-- Track request counts for rate limiting
-- =============================================================================
CREATE TABLE IF NOT EXISTS function_rate_limits (
id TEXT PRIMARY KEY,
window_key TEXT NOT NULL,
count INTEGER NOT NULL DEFAULT 0,
window_start TIMESTAMP NOT NULL,
UNIQUE(window_key, window_start)
);
CREATE INDEX IF NOT EXISTS idx_function_rate_limits_window ON function_rate_limits(window_key, window_start);
-- =============================================================================
-- MIGRATION VERSION TRACKING
-- =============================================================================
INSERT OR IGNORE INTO schema_migrations(version) VALUES (4);
COMMIT;

View File

@ -2,8 +2,6 @@ package cli
import ( import (
"bufio" "bufio"
"encoding/hex"
"errors"
"flag" "flag"
"fmt" "fmt"
"net" "net"
@ -11,269 +9,12 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/config" "github.com/DeBrosOfficial/network/pkg/cli/utils"
"github.com/DeBrosOfficial/network/pkg/environments/production" "github.com/DeBrosOfficial/network/pkg/environments/production"
"github.com/DeBrosOfficial/network/pkg/installer"
"github.com/multiformats/go-multiaddr"
) )
// IPFSPeerInfo holds IPFS peer information for configuring Peering.Peers
type IPFSPeerInfo struct {
PeerID string
Addrs []string
}
// IPFSClusterPeerInfo contains IPFS Cluster peer information for cluster discovery
type IPFSClusterPeerInfo struct {
PeerID string
Addrs []string
}
// validateSwarmKey validates that a swarm key is 64 hex characters
func validateSwarmKey(key string) error {
key = strings.TrimSpace(key)
if len(key) != 64 {
return fmt.Errorf("swarm key must be 64 hex characters (32 bytes), got %d", len(key))
}
if _, err := hex.DecodeString(key); err != nil {
return fmt.Errorf("swarm key must be valid hexadecimal: %w", err)
}
return nil
}
// runInteractiveInstaller launches the TUI installer
func runInteractiveInstaller() {
config, err := installer.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1)
}
// Convert TUI config to install args and run installation
var args []string
args = append(args, "--vps-ip", config.VpsIP)
args = append(args, "--domain", config.Domain)
args = append(args, "--branch", config.Branch)
if config.NoPull {
args = append(args, "--no-pull")
}
if !config.IsFirstNode {
if config.JoinAddress != "" {
args = append(args, "--join", config.JoinAddress)
}
if config.ClusterSecret != "" {
args = append(args, "--cluster-secret", config.ClusterSecret)
}
if config.SwarmKeyHex != "" {
args = append(args, "--swarm-key", config.SwarmKeyHex)
}
if len(config.Peers) > 0 {
args = append(args, "--peers", strings.Join(config.Peers, ","))
}
// Pass IPFS peer info for Peering.Peers configuration
if config.IPFSPeerID != "" {
args = append(args, "--ipfs-peer", config.IPFSPeerID)
}
if len(config.IPFSSwarmAddrs) > 0 {
args = append(args, "--ipfs-addrs", strings.Join(config.IPFSSwarmAddrs, ","))
}
// Pass IPFS Cluster peer info for cluster peer_addresses configuration
if config.IPFSClusterPeerID != "" {
args = append(args, "--ipfs-cluster-peer", config.IPFSClusterPeerID)
}
if len(config.IPFSClusterAddrs) > 0 {
args = append(args, "--ipfs-cluster-addrs", strings.Join(config.IPFSClusterAddrs, ","))
}
}
// Re-run with collected args
handleProdInstall(args)
}
// showDryRunSummary displays what would be done during installation without making changes
func showDryRunSummary(vpsIP, domain, branch string, peers []string, joinAddress string, isFirstNode bool, oramaDir string) {
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n")
fmt.Printf("DRY RUN - No changes will be made\n")
fmt.Printf(strings.Repeat("=", 70) + "\n\n")
fmt.Printf("📋 Installation Summary:\n")
fmt.Printf(" VPS IP: %s\n", vpsIP)
fmt.Printf(" Domain: %s\n", domain)
fmt.Printf(" Branch: %s\n", branch)
if isFirstNode {
fmt.Printf(" Node Type: First node (creates new cluster)\n")
} else {
fmt.Printf(" Node Type: Joining existing cluster\n")
if joinAddress != "" {
fmt.Printf(" Join Address: %s\n", joinAddress)
}
if len(peers) > 0 {
fmt.Printf(" Peers: %d peer(s)\n", len(peers))
for _, peer := range peers {
fmt.Printf(" - %s\n", peer)
}
}
}
fmt.Printf("\n📁 Directories that would be created:\n")
fmt.Printf(" %s/configs/\n", oramaDir)
fmt.Printf(" %s/secrets/\n", oramaDir)
fmt.Printf(" %s/data/ipfs/repo/\n", oramaDir)
fmt.Printf(" %s/data/ipfs-cluster/\n", oramaDir)
fmt.Printf(" %s/data/rqlite/\n", oramaDir)
fmt.Printf(" %s/logs/\n", oramaDir)
fmt.Printf(" %s/tls-cache/\n", oramaDir)
fmt.Printf("\n🔧 Binaries that would be installed:\n")
fmt.Printf(" - Go (if not present)\n")
fmt.Printf(" - RQLite 8.43.0\n")
fmt.Printf(" - IPFS/Kubo 0.38.2\n")
fmt.Printf(" - IPFS Cluster (latest)\n")
fmt.Printf(" - Olric 0.7.0\n")
fmt.Printf(" - anyone-client (npm)\n")
fmt.Printf(" - DeBros binaries (built from %s branch)\n", branch)
fmt.Printf("\n🔐 Secrets that would be generated:\n")
fmt.Printf(" - Cluster secret (64-hex)\n")
fmt.Printf(" - IPFS swarm key\n")
fmt.Printf(" - Node identity (Ed25519 keypair)\n")
fmt.Printf("\n📝 Configuration files that would be created:\n")
fmt.Printf(" - %s/configs/node.yaml\n", oramaDir)
fmt.Printf(" - %s/configs/olric/config.yaml\n", oramaDir)
fmt.Printf("\n⚙ Systemd services that would be created:\n")
fmt.Printf(" - debros-ipfs.service\n")
fmt.Printf(" - debros-ipfs-cluster.service\n")
fmt.Printf(" - debros-olric.service\n")
fmt.Printf(" - debros-node.service (includes embedded gateway + RQLite)\n")
fmt.Printf(" - debros-anyone-client.service\n")
fmt.Printf("\n🌐 Ports that would be used:\n")
fmt.Printf(" External (must be open in firewall):\n")
fmt.Printf(" - 80 (HTTP for ACME/Let's Encrypt)\n")
fmt.Printf(" - 443 (HTTPS gateway)\n")
fmt.Printf(" - 4101 (IPFS swarm)\n")
fmt.Printf(" - 7001 (RQLite Raft)\n")
fmt.Printf(" Internal (localhost only):\n")
fmt.Printf(" - 4501 (IPFS API)\n")
fmt.Printf(" - 5001 (RQLite HTTP)\n")
fmt.Printf(" - 6001 (Unified gateway)\n")
fmt.Printf(" - 8080 (IPFS gateway)\n")
fmt.Printf(" - 9050 (Anyone SOCKS5)\n")
fmt.Printf(" - 9094 (IPFS Cluster API)\n")
fmt.Printf(" - 3320/3322 (Olric)\n")
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n")
fmt.Printf("To proceed with installation, run without --dry-run\n")
fmt.Printf(strings.Repeat("=", 70) + "\n\n")
}
// validateGeneratedConfig loads and validates the generated node configuration
func validateGeneratedConfig(oramaDir string) error {
configPath := filepath.Join(oramaDir, "configs", "node.yaml")
// Check if config file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("configuration file not found at %s", configPath)
}
// Load the config file
file, err := os.Open(configPath)
if err != nil {
return fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()
var cfg config.Config
if err := config.DecodeStrict(file, &cfg); err != nil {
return fmt.Errorf("failed to parse config: %w", err)
}
// Validate the configuration
if errs := cfg.Validate(); len(errs) > 0 {
var errMsgs []string
for _, e := range errs {
errMsgs = append(errMsgs, e.Error())
}
return fmt.Errorf("configuration validation errors:\n - %s", strings.Join(errMsgs, "\n - "))
}
return nil
}
// validateDNSRecord validates that the domain points to the expected IP address
// Returns nil if DNS is valid, warning message if DNS doesn't match but continues,
// or error if DNS lookup fails completely
func validateDNSRecord(domain, expectedIP string) error {
if domain == "" {
return nil // No domain provided, skip validation
}
ips, err := net.LookupIP(domain)
if err != nil {
// DNS lookup failed - this is a warning, not a fatal error
// The user might be setting up DNS after installation
fmt.Printf(" ⚠️ DNS lookup failed for %s: %v\n", domain, err)
fmt.Printf(" Make sure DNS is configured before enabling HTTPS\n")
return nil
}
// Check if any resolved IP matches the expected IP
for _, ip := range ips {
if ip.String() == expectedIP {
fmt.Printf(" ✓ DNS validated: %s → %s\n", domain, expectedIP)
return nil
}
}
// DNS doesn't point to expected IP - warn but continue
resolvedIPs := make([]string, len(ips))
for i, ip := range ips {
resolvedIPs[i] = ip.String()
}
fmt.Printf(" ⚠️ DNS mismatch: %s resolves to %v, expected %s\n", domain, resolvedIPs, expectedIP)
fmt.Printf(" HTTPS certificate generation may fail until DNS is updated\n")
return nil
}
// normalizePeers normalizes and validates peer multiaddrs
func normalizePeers(peersStr string) ([]string, error) {
if peersStr == "" {
return nil, nil
}
// Split by comma and trim whitespace
rawPeers := strings.Split(peersStr, ",")
peers := make([]string, 0, len(rawPeers))
seen := make(map[string]bool)
for _, peer := range rawPeers {
peer = strings.TrimSpace(peer)
if peer == "" {
continue
}
// Validate multiaddr format
if _, err := multiaddr.NewMultiaddr(peer); err != nil {
return nil, fmt.Errorf("invalid multiaddr %q: %w", peer, err)
}
// Deduplicate
if !seen[peer] {
peers = append(peers, peer)
seen[peer] = true
}
}
return peers, nil
}
// HandleProdCommand handles production environment commands // HandleProdCommand handles production environment commands
func HandleProdCommand(args []string) { func HandleProdCommand(args []string) {
if len(args) == 0 { if len(args) == 0 {
@ -368,294 +109,6 @@ func showProdHelp() {
fmt.Printf(" orama logs node --follow\n") fmt.Printf(" orama logs node --follow\n")
} }
func handleProdInstall(args []string) {
// Parse arguments using flag.FlagSet
fs := flag.NewFlagSet("install", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
force := fs.Bool("force", false, "Reconfigure all settings")
skipResourceChecks := fs.Bool("ignore-resource-checks", false, "Skip disk/RAM/CPU prerequisite validation")
vpsIP := fs.String("vps-ip", "", "VPS public IP address")
domain := fs.String("domain", "", "Domain for this node (e.g., node-123.orama.network)")
peersStr := fs.String("peers", "", "Comma-separated peer multiaddrs to connect to")
joinAddress := fs.String("join", "", "RQLite join address (IP:port) to join existing cluster")
branch := fs.String("branch", "main", "Git branch to use (main or nightly)")
clusterSecret := fs.String("cluster-secret", "", "Hex-encoded 32-byte cluster secret (for joining existing cluster)")
swarmKey := fs.String("swarm-key", "", "64-hex IPFS swarm key (for joining existing private network)")
ipfsPeerID := fs.String("ipfs-peer", "", "IPFS peer ID to connect to (auto-discovered from peer domain)")
ipfsAddrs := fs.String("ipfs-addrs", "", "Comma-separated IPFS swarm addresses (auto-discovered from peer domain)")
ipfsClusterPeerID := fs.String("ipfs-cluster-peer", "", "IPFS Cluster peer ID to connect to (auto-discovered from peer domain)")
ipfsClusterAddrs := fs.String("ipfs-cluster-addrs", "", "Comma-separated IPFS Cluster addresses (auto-discovered from peer domain)")
interactive := fs.Bool("interactive", false, "Run interactive TUI installer")
dryRun := fs.Bool("dry-run", false, "Show what would be done without making changes")
noPull := fs.Bool("no-pull", false, "Skip git clone/pull, use existing /home/debros/src")
if err := fs.Parse(args); err != nil {
if err == flag.ErrHelp {
return
}
fmt.Fprintf(os.Stderr, "❌ Failed to parse flags: %v\n", err)
os.Exit(1)
}
// Launch TUI installer if --interactive flag or no required args provided
if *interactive || (*vpsIP == "" && len(args) == 0) {
runInteractiveInstaller()
return
}
// Validate branch
if *branch != "main" && *branch != "nightly" {
fmt.Fprintf(os.Stderr, "❌ Invalid branch: %s (must be 'main' or 'nightly')\n", *branch)
os.Exit(1)
}
// Normalize and validate peers
peers, err := normalizePeers(*peersStr)
if err != nil {
fmt.Fprintf(os.Stderr, "❌ Invalid peers: %v\n", err)
fmt.Fprintf(os.Stderr, " Example: --peers /ip4/10.0.0.1/tcp/4001/p2p/Qm...,/ip4/10.0.0.2/tcp/4001/p2p/Qm...\n")
os.Exit(1)
}
// Validate setup requirements
if os.Geteuid() != 0 {
fmt.Fprintf(os.Stderr, "❌ Production install must be run as root (use sudo)\n")
os.Exit(1)
}
// Validate VPS IP is provided
if *vpsIP == "" {
fmt.Fprintf(os.Stderr, "❌ --vps-ip is required\n")
fmt.Fprintf(os.Stderr, " Usage: sudo orama install --vps-ip <public_ip>\n")
fmt.Fprintf(os.Stderr, " Or run: sudo orama install --interactive\n")
os.Exit(1)
}
// Determine if this is the first node (creates new cluster) or joining existing cluster
isFirstNode := len(peers) == 0 && *joinAddress == ""
if isFirstNode {
fmt.Printf(" First node detected - will create new cluster\n")
} else {
fmt.Printf(" Joining existing cluster\n")
// Cluster secret is required when joining
if *clusterSecret == "" {
fmt.Fprintf(os.Stderr, "❌ --cluster-secret is required when joining an existing cluster\n")
fmt.Fprintf(os.Stderr, " Provide the 64-hex secret from an existing node (cat ~/.orama/secrets/cluster-secret)\n")
os.Exit(1)
}
if err := production.ValidateClusterSecret(*clusterSecret); err != nil {
fmt.Fprintf(os.Stderr, "❌ Invalid --cluster-secret: %v\n", err)
os.Exit(1)
}
// Swarm key is required when joining
if *swarmKey == "" {
fmt.Fprintf(os.Stderr, "❌ --swarm-key is required when joining an existing cluster\n")
fmt.Fprintf(os.Stderr, " Provide the 64-hex swarm key from an existing node:\n")
fmt.Fprintf(os.Stderr, " cat ~/.orama/secrets/swarm.key | tail -1\n")
os.Exit(1)
}
if err := validateSwarmKey(*swarmKey); err != nil {
fmt.Fprintf(os.Stderr, "❌ Invalid --swarm-key: %v\n", err)
os.Exit(1)
}
}
oramaHome := "/home/debros"
oramaDir := oramaHome + "/.orama"
// If cluster secret was provided, save it to secrets directory before setup
if *clusterSecret != "" {
secretsDir := filepath.Join(oramaDir, "secrets")
if err := os.MkdirAll(secretsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
os.Exit(1)
}
secretPath := filepath.Join(secretsDir, "cluster-secret")
if err := os.WriteFile(secretPath, []byte(*clusterSecret), 0600); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to save cluster secret: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Cluster secret saved\n")
}
// If swarm key was provided, save it to secrets directory in full format
if *swarmKey != "" {
secretsDir := filepath.Join(oramaDir, "secrets")
if err := os.MkdirAll(secretsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
os.Exit(1)
}
// Convert 64-hex key to full swarm.key format
swarmKeyContent := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base16/\n%s\n", strings.ToUpper(*swarmKey))
swarmKeyPath := filepath.Join(secretsDir, "swarm.key")
if err := os.WriteFile(swarmKeyPath, []byte(swarmKeyContent), 0600); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to save swarm key: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Swarm key saved\n")
}
// Store IPFS peer info for later use in IPFS configuration
var ipfsPeerInfo *IPFSPeerInfo
if *ipfsPeerID != "" && *ipfsAddrs != "" {
ipfsPeerInfo = &IPFSPeerInfo{
PeerID: *ipfsPeerID,
Addrs: strings.Split(*ipfsAddrs, ","),
}
}
// Store IPFS Cluster peer info for cluster peer discovery
var ipfsClusterPeerInfo *IPFSClusterPeerInfo
if *ipfsClusterPeerID != "" {
var addrs []string
if *ipfsClusterAddrs != "" {
addrs = strings.Split(*ipfsClusterAddrs, ",")
}
ipfsClusterPeerInfo = &IPFSClusterPeerInfo{
PeerID: *ipfsClusterPeerID,
Addrs: addrs,
}
}
setup := production.NewProductionSetup(oramaHome, os.Stdout, *force, *branch, *noPull, *skipResourceChecks)
// Inform user if skipping git pull
if *noPull {
fmt.Printf(" ⚠️ --no-pull flag enabled: Skipping git clone/pull\n")
fmt.Printf(" Using existing repository at /home/debros/src\n")
}
// Check port availability before proceeding
if err := ensurePortsAvailable("install", defaultPorts()); err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1)
}
// Validate DNS if domain is provided
if *domain != "" {
fmt.Printf("\n🌐 Pre-flight DNS validation...\n")
validateDNSRecord(*domain, *vpsIP)
}
// Dry-run mode: show what would be done and exit
if *dryRun {
showDryRunSummary(*vpsIP, *domain, *branch, peers, *joinAddress, isFirstNode, oramaDir)
return
}
// Save branch preference for future upgrades
if err := production.SaveBranchPreference(oramaDir, *branch); err != nil {
fmt.Fprintf(os.Stderr, "⚠️ Warning: Failed to save branch preference: %v\n", err)
}
// Phase 1: Check prerequisites
fmt.Printf("\n📋 Phase 1: Checking prerequisites...\n")
if err := setup.Phase1CheckPrerequisites(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Prerequisites check failed: %v\n", err)
os.Exit(1)
}
// Phase 2: Provision environment
fmt.Printf("\n🛠 Phase 2: Provisioning environment...\n")
if err := setup.Phase2ProvisionEnvironment(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Environment provisioning failed: %v\n", err)
os.Exit(1)
}
// Phase 2b: Install binaries
fmt.Printf("\nPhase 2b: Installing binaries...\n")
if err := setup.Phase2bInstallBinaries(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Binary installation failed: %v\n", err)
os.Exit(1)
}
// Phase 3: Generate secrets FIRST (before service initialization)
// This ensures cluster secret and swarm key exist before repos are seeded
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
if err := setup.Phase3GenerateSecrets(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Secret generation failed: %v\n", err)
os.Exit(1)
}
// Phase 4: Generate configs (BEFORE service initialization)
// This ensures node.yaml exists before services try to access it
fmt.Printf("\n⚙ Phase 4: Generating configurations...\n")
enableHTTPS := *domain != ""
if err := setup.Phase4GenerateConfigs(peers, *vpsIP, enableHTTPS, *domain, *joinAddress); err != nil {
fmt.Fprintf(os.Stderr, "❌ Configuration generation failed: %v\n", err)
os.Exit(1)
}
// Validate generated configuration
fmt.Printf(" Validating generated configuration...\n")
if err := validateGeneratedConfig(oramaDir); err != nil {
fmt.Fprintf(os.Stderr, "❌ Configuration validation failed: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Configuration validated\n")
// Phase 2c: Initialize services (after config is in place)
fmt.Printf("\nPhase 2c: Initializing services...\n")
var prodIPFSPeer *production.IPFSPeerInfo
if ipfsPeerInfo != nil {
prodIPFSPeer = &production.IPFSPeerInfo{
PeerID: ipfsPeerInfo.PeerID,
Addrs: ipfsPeerInfo.Addrs,
}
}
var prodIPFSClusterPeer *production.IPFSClusterPeerInfo
if ipfsClusterPeerInfo != nil {
prodIPFSClusterPeer = &production.IPFSClusterPeerInfo{
PeerID: ipfsClusterPeerInfo.PeerID,
Addrs: ipfsClusterPeerInfo.Addrs,
}
}
if err := setup.Phase2cInitializeServices(peers, *vpsIP, prodIPFSPeer, prodIPFSClusterPeer); err != nil {
fmt.Fprintf(os.Stderr, "❌ Service initialization failed: %v\n", err)
os.Exit(1)
}
// Phase 5: Create systemd services
fmt.Printf("\n🔧 Phase 5: Creating systemd services...\n")
if err := setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
fmt.Fprintf(os.Stderr, "❌ Service creation failed: %v\n", err)
os.Exit(1)
}
// Log completion with actual peer ID
setup.LogSetupComplete(setup.NodePeerID)
fmt.Printf("✅ Production installation complete!\n\n")
// For first node, print important secrets and identifiers
if isFirstNode {
fmt.Printf("📋 Save these for joining future nodes:\n\n")
// Print cluster secret
clusterSecretPath := filepath.Join(oramaDir, "secrets", "cluster-secret")
if clusterSecretData, err := os.ReadFile(clusterSecretPath); err == nil {
fmt.Printf(" Cluster Secret (--cluster-secret):\n")
fmt.Printf(" %s\n\n", string(clusterSecretData))
}
// Print swarm key
swarmKeyPath := filepath.Join(oramaDir, "secrets", "swarm.key")
if swarmKeyData, err := os.ReadFile(swarmKeyPath); err == nil {
swarmKeyContent := strings.TrimSpace(string(swarmKeyData))
lines := strings.Split(swarmKeyContent, "\n")
if len(lines) >= 3 {
// Extract just the hex part (last line)
fmt.Printf(" IPFS Swarm Key (--swarm-key, last line only):\n")
fmt.Printf(" %s\n\n", lines[len(lines)-1])
}
}
// Print peer ID
fmt.Printf(" Node Peer ID:\n")
fmt.Printf(" %s\n\n", setup.NodePeerID)
}
}
func handleProdUpgrade(args []string) { func handleProdUpgrade(args []string) {
// Parse arguments using flag.FlagSet // Parse arguments using flag.FlagSet
fs := flag.NewFlagSet("upgrade", flag.ContinueOnError) fs := flag.NewFlagSet("upgrade", flag.ContinueOnError)
@ -767,7 +220,7 @@ func handleProdUpgrade(args []string) {
} }
// Check port availability after stopping services // Check port availability after stopping services
if err := ensurePortsAvailable("prod upgrade", defaultPorts()); err != nil { if err := utils.EnsurePortsAvailable("prod upgrade", utils.DefaultPorts()); err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1) os.Exit(1)
} }
@ -945,7 +398,7 @@ func handleProdUpgrade(args []string) {
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Failed to reload systemd daemon: %v\n", err) fmt.Fprintf(os.Stderr, " ⚠️ Warning: Failed to reload systemd daemon: %v\n", err)
} }
// Restart services to apply changes - use getProductionServices to only restart existing services // Restart services to apply changes - use getProductionServices to only restart existing services
services := getProductionServices() services := utils.GetProductionServices()
if len(services) == 0 { if len(services) == 0 {
fmt.Printf(" ⚠️ No services found to restart\n") fmt.Printf(" ⚠️ No services found to restart\n")
} else { } else {
@ -991,10 +444,9 @@ func handleProdStatus() {
fmt.Printf("Services:\n") fmt.Printf("Services:\n")
found := false found := false
for _, svc := range serviceNames { for _, svc := range serviceNames {
cmd := exec.Command("systemctl", "is-active", "--quiet", svc) active, _ := utils.IsServiceActive(svc)
err := cmd.Run()
status := "❌ Inactive" status := "❌ Inactive"
if err == nil { if active {
status = "✅ Active" status = "✅ Active"
found = true found = true
} }
@ -1016,52 +468,6 @@ func handleProdStatus() {
fmt.Printf("\nView logs with: dbn prod logs <service>\n") fmt.Printf("\nView logs with: dbn prod logs <service>\n")
} }
// resolveServiceName resolves service aliases to actual systemd service names
func resolveServiceName(alias string) ([]string, error) {
// Service alias mapping (unified - no bootstrap/node distinction)
aliases := map[string][]string{
"node": {"debros-node"},
"ipfs": {"debros-ipfs"},
"cluster": {"debros-ipfs-cluster"},
"ipfs-cluster": {"debros-ipfs-cluster"},
"gateway": {"debros-gateway"},
"olric": {"debros-olric"},
"rqlite": {"debros-node"}, // RQLite logs are in node logs
}
// Check if it's an alias
if serviceNames, ok := aliases[strings.ToLower(alias)]; ok {
// Filter to only existing services
var existing []string
for _, svc := range serviceNames {
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
if _, err := os.Stat(unitPath); err == nil {
existing = append(existing, svc)
}
}
if len(existing) == 0 {
return nil, fmt.Errorf("no services found for alias %q", alias)
}
return existing, nil
}
// Check if it's already a full service name
unitPath := filepath.Join("/etc/systemd/system", alias+".service")
if _, err := os.Stat(unitPath); err == nil {
return []string{alias}, nil
}
// Try without .service suffix
if !strings.HasSuffix(alias, ".service") {
unitPath = filepath.Join("/etc/systemd/system", alias+".service")
if _, err := os.Stat(unitPath); err == nil {
return []string{alias}, nil
}
}
return nil, fmt.Errorf("service %q not found. Use: node, ipfs, cluster, gateway, olric, or full service name", alias)
}
func handleProdLogs(args []string) { func handleProdLogs(args []string) {
if len(args) == 0 { if len(args) == 0 {
fmt.Fprintf(os.Stderr, "Usage: dbn prod logs <service> [--follow]\n") fmt.Fprintf(os.Stderr, "Usage: dbn prod logs <service> [--follow]\n")
@ -1079,7 +485,7 @@ func handleProdLogs(args []string) {
} }
// Resolve service alias to actual service names // Resolve service alias to actual service names
serviceNames, err := resolveServiceName(serviceAlias) serviceNames, err := utils.ResolveServiceName(serviceAlias)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
fmt.Fprintf(os.Stderr, "\nAvailable service aliases: node, ipfs, cluster, gateway, olric\n") fmt.Fprintf(os.Stderr, "\nAvailable service aliases: node, ipfs, cluster, gateway, olric\n")
@ -1109,7 +515,7 @@ func handleProdLogs(args []string) {
} else { } else {
for i, svc := range serviceNames { for i, svc := range serviceNames {
if i > 0 { if i > 0 {
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n\n") fmt.Print("\n" + strings.Repeat("=", 70) + "\n\n")
} }
fmt.Printf("📋 Logs for %s:\n\n", svc) fmt.Printf("📋 Logs for %s:\n\n", svc)
cmd := exec.Command("journalctl", "-u", svc, "-n", "50") cmd := exec.Command("journalctl", "-u", svc, "-n", "50")
@ -1138,145 +544,6 @@ func handleProdLogs(args []string) {
} }
} }
// errServiceNotFound marks units that systemd does not know about.
var errServiceNotFound = errors.New("service not found")
type portSpec struct {
Name string
Port int
}
var servicePorts = map[string][]portSpec{
"debros-gateway": {{"Gateway API", 6001}},
"debros-olric": {{"Olric HTTP", 3320}, {"Olric Memberlist", 3322}},
"debros-node": {{"RQLite HTTP", 5001}, {"RQLite Raft", 7001}},
"debros-ipfs": {{"IPFS API", 4501}, {"IPFS Gateway", 8080}, {"IPFS Swarm", 4101}},
"debros-ipfs-cluster": {{"IPFS Cluster API", 9094}},
}
// defaultPorts is used for fresh installs/upgrades before unit files exist.
func defaultPorts() []portSpec {
return []portSpec{
{"IPFS Swarm", 4001},
{"IPFS API", 4501},
{"IPFS Gateway", 8080},
{"Gateway API", 6001},
{"RQLite HTTP", 5001},
{"RQLite Raft", 7001},
{"IPFS Cluster API", 9094},
{"Olric HTTP", 3320},
{"Olric Memberlist", 3322},
}
}
func isServiceActive(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-active", "--quiet", service)
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
switch exitErr.ExitCode() {
case 3:
return false, nil
case 4:
return false, errServiceNotFound
}
}
return false, err
}
return true, nil
}
func isServiceEnabled(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-enabled", "--quiet", service)
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
switch exitErr.ExitCode() {
case 1:
return false, nil // Service is disabled
case 4:
return false, errServiceNotFound
}
}
return false, err
}
return true, nil
}
func collectPortsForServices(services []string, skipActive bool) ([]portSpec, error) {
seen := make(map[int]portSpec)
for _, svc := range services {
if skipActive {
active, err := isServiceActive(svc)
if err != nil {
return nil, fmt.Errorf("unable to check %s: %w", svc, err)
}
if active {
continue
}
}
for _, spec := range servicePorts[svc] {
if _, ok := seen[spec.Port]; !ok {
seen[spec.Port] = spec
}
}
}
ports := make([]portSpec, 0, len(seen))
for _, spec := range seen {
ports = append(ports, spec)
}
return ports, nil
}
func ensurePortsAvailable(action string, ports []portSpec) error {
for _, spec := range ports {
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", spec.Port))
if err != nil {
if errors.Is(err, syscall.EADDRINUSE) || strings.Contains(err.Error(), "address already in use") {
return fmt.Errorf("%s cannot continue: %s (port %d) is already in use", action, spec.Name, spec.Port)
}
return fmt.Errorf("%s cannot continue: failed to inspect %s (port %d): %w", action, spec.Name, spec.Port, err)
}
_ = ln.Close()
}
return nil
}
// getProductionServices returns a list of all DeBros production service names that exist
func getProductionServices() []string {
// Unified service names (no bootstrap/node distinction)
allServices := []string{
"debros-gateway",
"debros-node",
"debros-olric",
"debros-ipfs-cluster",
"debros-ipfs",
"debros-anyone-client",
}
// Filter to only existing services by checking if unit file exists
var existing []string
for _, svc := range allServices {
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
if _, err := os.Stat(unitPath); err == nil {
existing = append(existing, svc)
}
}
return existing
}
func isServiceMasked(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-enabled", service)
output, err := cmd.CombinedOutput()
if err != nil {
outputStr := string(output)
if strings.Contains(outputStr, "masked") {
return true, nil
}
return false, err
}
return false, nil
}
func handleProdStart() { func handleProdStart() {
if os.Geteuid() != 0 { if os.Geteuid() != 0 {
fmt.Fprintf(os.Stderr, "❌ Production commands must be run as root (use sudo)\n") fmt.Fprintf(os.Stderr, "❌ Production commands must be run as root (use sudo)\n")
@ -1285,7 +552,7 @@ func handleProdStart() {
fmt.Printf("Starting all DeBros production services...\n") fmt.Printf("Starting all DeBros production services...\n")
services := getProductionServices() services := utils.GetProductionServices()
if len(services) == 0 { if len(services) == 0 {
fmt.Printf(" ⚠️ No DeBros services found\n") fmt.Printf(" ⚠️ No DeBros services found\n")
return return
@ -1301,7 +568,7 @@ func handleProdStart() {
inactive := make([]string, 0, len(services)) inactive := make([]string, 0, len(services))
for _, svc := range services { for _, svc := range services {
// Check if service is masked and unmask it // Check if service is masked and unmask it
masked, err := isServiceMasked(svc) masked, err := utils.IsServiceMasked(svc)
if err == nil && masked { if err == nil && masked {
fmt.Printf(" ⚠️ %s is masked, unmasking...\n", svc) fmt.Printf(" ⚠️ %s is masked, unmasking...\n", svc)
if err := exec.Command("systemctl", "unmask", svc).Run(); err != nil { if err := exec.Command("systemctl", "unmask", svc).Run(); err != nil {
@ -1311,7 +578,7 @@ func handleProdStart() {
} }
} }
active, err := isServiceActive(svc) active, err := utils.IsServiceActive(svc)
if err != nil { if err != nil {
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err) fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
continue continue
@ -1319,7 +586,7 @@ func handleProdStart() {
if active { if active {
fmt.Printf(" %s already running\n", svc) fmt.Printf(" %s already running\n", svc)
// Re-enable if disabled (in case it was stopped with 'dbn prod stop') // Re-enable if disabled (in case it was stopped with 'dbn prod stop')
enabled, err := isServiceEnabled(svc) enabled, err := utils.IsServiceEnabled(svc)
if err == nil && !enabled { if err == nil && !enabled {
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil { if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
fmt.Printf(" ⚠️ Failed to re-enable %s: %v\n", svc, err) fmt.Printf(" ⚠️ Failed to re-enable %s: %v\n", svc, err)
@ -1338,12 +605,12 @@ func handleProdStart() {
} }
// Check port availability for services we're about to start // Check port availability for services we're about to start
ports, err := collectPortsForServices(inactive, false) ports, err := utils.CollectPortsForServices(inactive, false)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1) os.Exit(1)
} }
if err := ensurePortsAvailable("prod start", ports); err != nil { if err := utils.EnsurePortsAvailable("prod start", ports); err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1) os.Exit(1)
} }
@ -1351,7 +618,7 @@ func handleProdStart() {
// Enable and start inactive services // Enable and start inactive services
for _, svc := range inactive { for _, svc := range inactive {
// Re-enable the service first (in case it was disabled by 'dbn prod stop') // Re-enable the service first (in case it was disabled by 'dbn prod stop')
enabled, err := isServiceEnabled(svc) enabled, err := utils.IsServiceEnabled(svc)
if err == nil && !enabled { if err == nil && !enabled {
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil { if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
fmt.Printf(" ⚠️ Failed to enable %s: %v\n", svc, err) fmt.Printf(" ⚠️ Failed to enable %s: %v\n", svc, err)
@ -1385,7 +652,7 @@ func handleProdStop() {
fmt.Printf("Stopping all DeBros production services...\n") fmt.Printf("Stopping all DeBros production services...\n")
services := getProductionServices() services := utils.GetProductionServices()
if len(services) == 0 { if len(services) == 0 {
fmt.Printf(" ⚠️ No DeBros services found\n") fmt.Printf(" ⚠️ No DeBros services found\n")
return return
@ -1424,7 +691,7 @@ func handleProdStop() {
hadError := false hadError := false
for _, svc := range services { for _, svc := range services {
active, err := isServiceActive(svc) active, err := utils.IsServiceActive(svc)
if err != nil { if err != nil {
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err) fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
hadError = true hadError = true
@ -1441,7 +708,7 @@ func handleProdStop() {
} else { } else {
// Wait and verify again // Wait and verify again
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if stillActive, _ := isServiceActive(svc); stillActive { if stillActive, _ := utils.IsServiceActive(svc); stillActive {
fmt.Printf(" ❌ %s restarted itself (Restart=always)\n", svc) fmt.Printf(" ❌ %s restarted itself (Restart=always)\n", svc)
hadError = true hadError = true
} else { } else {
@ -1451,7 +718,7 @@ func handleProdStop() {
} }
// Disable the service to prevent it from auto-starting on boot // Disable the service to prevent it from auto-starting on boot
enabled, err := isServiceEnabled(svc) enabled, err := utils.IsServiceEnabled(svc)
if err != nil { if err != nil {
fmt.Printf(" ⚠️ Unable to check if %s is enabled: %v\n", svc, err) fmt.Printf(" ⚠️ Unable to check if %s is enabled: %v\n", svc, err)
// Continue anyway - try to disable // Continue anyway - try to disable
@ -1486,7 +753,7 @@ func handleProdRestart() {
fmt.Printf("Restarting all DeBros production services...\n") fmt.Printf("Restarting all DeBros production services...\n")
services := getProductionServices() services := utils.GetProductionServices()
if len(services) == 0 { if len(services) == 0 {
fmt.Printf(" ⚠️ No DeBros services found\n") fmt.Printf(" ⚠️ No DeBros services found\n")
return return
@ -1495,7 +762,7 @@ func handleProdRestart() {
// Stop all active services first // Stop all active services first
fmt.Printf(" Stopping services...\n") fmt.Printf(" Stopping services...\n")
for _, svc := range services { for _, svc := range services {
active, err := isServiceActive(svc) active, err := utils.IsServiceActive(svc)
if err != nil { if err != nil {
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err) fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
continue continue
@ -1512,12 +779,12 @@ func handleProdRestart() {
} }
// Check port availability before restarting // Check port availability before restarting
ports, err := collectPortsForServices(services, false) ports, err := utils.CollectPortsForServices(services, false)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1) os.Exit(1)
} }
if err := ensurePortsAvailable("prod restart", ports); err != nil { if err := utils.EnsurePortsAvailable("prod restart", ports); err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err) fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1) os.Exit(1)
} }

View File

@ -2,6 +2,8 @@ package cli
import ( import (
"testing" "testing"
"github.com/DeBrosOfficial/network/pkg/cli/utils"
) )
// TestProdCommandFlagParsing verifies that prod command flags are parsed correctly // TestProdCommandFlagParsing verifies that prod command flags are parsed correctly
@ -156,7 +158,7 @@ func TestNormalizePeers(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
peers, err := normalizePeers(tt.input) peers, err := utils.NormalizePeers(tt.input)
if tt.expectError && err == nil { if tt.expectError && err == nil {
t.Errorf("expected error but got none") t.Errorf("expected error but got none")

264
pkg/cli/prod_install.go Normal file
View File

@ -0,0 +1,264 @@
package cli
import (
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/DeBrosOfficial/network/pkg/cli/utils"
"github.com/DeBrosOfficial/network/pkg/environments/production"
)
func handleProdInstall(args []string) {
// Parse arguments using flag.FlagSet
fs := flag.NewFlagSet("install", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
vpsIP := fs.String("vps-ip", "", "Public IP of this VPS (required)")
domain := fs.String("domain", "", "Domain name for HTTPS (optional, e.g. gateway.example.com)")
branch := fs.String("branch", "main", "Git branch to use (main or nightly)")
noPull := fs.Bool("no-pull", false, "Skip git clone/pull, use existing repository in /home/debros/src")
force := fs.Bool("force", false, "Force reconfiguration even if already installed")
dryRun := fs.Bool("dry-run", false, "Show what would be done without making changes")
skipResourceChecks := fs.Bool("skip-checks", false, "Skip minimum resource checks (RAM/CPU)")
// Cluster join flags
joinAddress := fs.String("join", "", "Join an existing cluster (e.g. 1.2.3.4:7001)")
clusterSecret := fs.String("cluster-secret", "", "Cluster secret for IPFS Cluster (required if joining)")
swarmKey := fs.String("swarm-key", "", "IPFS Swarm key (required if joining)")
peersStr := fs.String("peers", "", "Comma-separated list of bootstrap peer multiaddrs")
// IPFS/Cluster specific info for Peering configuration
ipfsPeerID := fs.String("ipfs-peer", "", "Peer ID of existing IPFS node to peer with")
ipfsAddrs := fs.String("ipfs-addrs", "", "Comma-separated multiaddrs of existing IPFS node")
ipfsClusterPeerID := fs.String("ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node")
ipfsClusterAddrs := fs.String("ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node")
if err := fs.Parse(args); err != nil {
if err == flag.ErrHelp {
return
}
fmt.Fprintf(os.Stderr, "❌ Failed to parse flags: %v\n", err)
os.Exit(1)
}
// Validate required flags
if *vpsIP == "" && !*dryRun {
fmt.Fprintf(os.Stderr, "❌ Error: --vps-ip is required for installation\n")
fmt.Fprintf(os.Stderr, " Example: dbn prod install --vps-ip 1.2.3.4\n")
os.Exit(1)
}
if os.Geteuid() != 0 && !*dryRun {
fmt.Fprintf(os.Stderr, "❌ Production installation must be run as root (use sudo)\n")
os.Exit(1)
}
oramaHome := "/home/debros"
oramaDir := oramaHome + "/.orama"
fmt.Printf("🚀 Starting production installation...\n\n")
isFirstNode := *joinAddress == ""
peers, err := utils.NormalizePeers(*peersStr)
if err != nil {
fmt.Fprintf(os.Stderr, "❌ Invalid peers: %v\n", err)
os.Exit(1)
}
// If cluster secret was provided, save it to secrets directory before setup
if *clusterSecret != "" {
secretsDir := filepath.Join(oramaDir, "secrets")
if err := os.MkdirAll(secretsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
os.Exit(1)
}
secretPath := filepath.Join(secretsDir, "cluster-secret")
if err := os.WriteFile(secretPath, []byte(*clusterSecret), 0600); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to save cluster secret: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Cluster secret saved\n")
}
// If swarm key was provided, save it to secrets directory in full format
if *swarmKey != "" {
secretsDir := filepath.Join(oramaDir, "secrets")
if err := os.MkdirAll(secretsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
os.Exit(1)
}
// Convert 64-hex key to full swarm.key format
swarmKeyContent := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base16/\n%s\n", strings.ToUpper(*swarmKey))
swarmKeyPath := filepath.Join(secretsDir, "swarm.key")
if err := os.WriteFile(swarmKeyPath, []byte(swarmKeyContent), 0600); err != nil {
fmt.Fprintf(os.Stderr, "❌ Failed to save swarm key: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Swarm key saved\n")
}
// Store IPFS peer info for peering
var ipfsPeerInfo *utils.IPFSPeerInfo
if *ipfsPeerID != "" {
var addrs []string
if *ipfsAddrs != "" {
addrs = strings.Split(*ipfsAddrs, ",")
}
ipfsPeerInfo = &utils.IPFSPeerInfo{
PeerID: *ipfsPeerID,
Addrs: addrs,
}
}
// Store IPFS Cluster peer info for cluster peer discovery
var ipfsClusterPeerInfo *utils.IPFSClusterPeerInfo
if *ipfsClusterPeerID != "" {
var addrs []string
if *ipfsClusterAddrs != "" {
addrs = strings.Split(*ipfsClusterAddrs, ",")
}
ipfsClusterPeerInfo = &utils.IPFSClusterPeerInfo{
PeerID: *ipfsClusterPeerID,
Addrs: addrs,
}
}
setup := production.NewProductionSetup(oramaHome, os.Stdout, *force, *branch, *noPull, *skipResourceChecks)
// Inform user if skipping git pull
if *noPull {
fmt.Printf(" ⚠️ --no-pull flag enabled: Skipping git clone/pull\n")
fmt.Printf(" Using existing repository at /home/debros/src\n")
}
// Check port availability before proceeding
if err := utils.EnsurePortsAvailable("install", utils.DefaultPorts()); err != nil {
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
os.Exit(1)
}
// Validate DNS if domain is provided
if *domain != "" {
fmt.Printf("\n🌐 Pre-flight DNS validation...\n")
utils.ValidateDNSRecord(*domain, *vpsIP)
}
// Dry-run mode: show what would be done and exit
if *dryRun {
utils.ShowDryRunSummary(*vpsIP, *domain, *branch, peers, *joinAddress, isFirstNode, oramaDir)
return
}
// Save branch preference for future upgrades
if err := production.SaveBranchPreference(oramaDir, *branch); err != nil {
fmt.Fprintf(os.Stderr, "⚠️ Warning: Failed to save branch preference: %v\n", err)
}
// Phase 1: Check prerequisites
fmt.Printf("\n📋 Phase 1: Checking prerequisites...\n")
if err := setup.Phase1CheckPrerequisites(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Prerequisites check failed: %v\n", err)
os.Exit(1)
}
// Phase 2: Provision environment
fmt.Printf("\n🛠 Phase 2: Provisioning environment...\n")
if err := setup.Phase2ProvisionEnvironment(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Environment provisioning failed: %v\n", err)
os.Exit(1)
}
// Phase 2b: Install binaries
fmt.Printf("\nPhase 2b: Installing binaries...\n")
if err := setup.Phase2bInstallBinaries(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Binary installation failed: %v\n", err)
os.Exit(1)
}
// Phase 3: Generate secrets FIRST (before service initialization)
// This ensures cluster secret and swarm key exist before repos are seeded
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
if err := setup.Phase3GenerateSecrets(); err != nil {
fmt.Fprintf(os.Stderr, "❌ Secret generation failed: %v\n", err)
os.Exit(1)
}
// Phase 4: Generate configs (BEFORE service initialization)
// This ensures node.yaml exists before services try to access it
fmt.Printf("\n⚙ Phase 4: Generating configurations...\n")
enableHTTPS := *domain != ""
if err := setup.Phase4GenerateConfigs(peers, *vpsIP, enableHTTPS, *domain, *joinAddress); err != nil {
fmt.Fprintf(os.Stderr, "❌ Configuration generation failed: %v\n", err)
os.Exit(1)
}
// Validate generated configuration
fmt.Printf(" Validating generated configuration...\n")
if err := utils.ValidateGeneratedConfig(oramaDir); err != nil {
fmt.Fprintf(os.Stderr, "❌ Configuration validation failed: %v\n", err)
os.Exit(1)
}
fmt.Printf(" ✓ Configuration validated\n")
// Phase 2c: Initialize services (after config is in place)
fmt.Printf("\nPhase 2c: Initializing services...\n")
var prodIPFSPeer *production.IPFSPeerInfo
if ipfsPeerInfo != nil {
prodIPFSPeer = &production.IPFSPeerInfo{
PeerID: ipfsPeerInfo.PeerID,
Addrs: ipfsPeerInfo.Addrs,
}
}
var prodIPFSClusterPeer *production.IPFSClusterPeerInfo
if ipfsClusterPeerInfo != nil {
prodIPFSClusterPeer = &production.IPFSClusterPeerInfo{
PeerID: ipfsClusterPeerInfo.PeerID,
Addrs: ipfsClusterPeerInfo.Addrs,
}
}
if err := setup.Phase2cInitializeServices(peers, *vpsIP, prodIPFSPeer, prodIPFSClusterPeer); err != nil {
fmt.Fprintf(os.Stderr, "❌ Service initialization failed: %v\n", err)
os.Exit(1)
}
// Phase 5: Create systemd services
fmt.Printf("\n🔧 Phase 5: Creating systemd services...\n")
if err := setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
fmt.Fprintf(os.Stderr, "❌ Service creation failed: %v\n", err)
os.Exit(1)
}
// Log completion with actual peer ID
setup.LogSetupComplete(setup.NodePeerID)
fmt.Printf("✅ Production installation complete!\n\n")
// For first node, print important secrets and identifiers
if isFirstNode {
fmt.Printf("📋 Save these for joining future nodes:\n\n")
// Print cluster secret
clusterSecretPath := filepath.Join(oramaDir, "secrets", "cluster-secret")
if clusterSecretData, err := os.ReadFile(clusterSecretPath); err == nil {
fmt.Printf(" Cluster Secret (--cluster-secret):\n")
fmt.Printf(" %s\n\n", string(clusterSecretData))
}
// Print swarm key
swarmKeyPath := filepath.Join(oramaDir, "secrets", "swarm.key")
if swarmKeyData, err := os.ReadFile(swarmKeyPath); err == nil {
swarmKeyContent := strings.TrimSpace(string(swarmKeyData))
lines := strings.Split(swarmKeyContent, "\n")
if len(lines) >= 3 {
// Extract just the hex part (last line)
fmt.Printf(" IPFS Swarm Key (--swarm-key, last line only):\n")
fmt.Printf(" %s\n\n", lines[len(lines)-1])
}
}
// Print peer ID
fmt.Printf(" Node Peer ID:\n")
fmt.Printf(" %s\n\n", setup.NodePeerID)
}
}

97
pkg/cli/utils/install.go Normal file
View File

@ -0,0 +1,97 @@
package utils
import (
"fmt"
"strings"
)
// IPFSPeerInfo holds IPFS peer information for configuring Peering.Peers
type IPFSPeerInfo struct {
PeerID string
Addrs []string
}
// IPFSClusterPeerInfo contains IPFS Cluster peer information for cluster discovery
type IPFSClusterPeerInfo struct {
PeerID string
Addrs []string
}
// ShowDryRunSummary displays what would be done during installation without making changes
func ShowDryRunSummary(vpsIP, domain, branch string, peers []string, joinAddress string, isFirstNode bool, oramaDir string) {
fmt.Print("\n" + strings.Repeat("=", 70) + "\n")
fmt.Printf("DRY RUN - No changes will be made\n")
fmt.Print(strings.Repeat("=", 70) + "\n\n")
fmt.Printf("📋 Installation Summary:\n")
fmt.Printf(" VPS IP: %s\n", vpsIP)
fmt.Printf(" Domain: %s\n", domain)
fmt.Printf(" Branch: %s\n", branch)
if isFirstNode {
fmt.Printf(" Node Type: First node (creates new cluster)\n")
} else {
fmt.Printf(" Node Type: Joining existing cluster\n")
if joinAddress != "" {
fmt.Printf(" Join Address: %s\n", joinAddress)
}
if len(peers) > 0 {
fmt.Printf(" Peers: %d peer(s)\n", len(peers))
for _, peer := range peers {
fmt.Printf(" - %s\n", peer)
}
}
}
fmt.Printf("\n📁 Directories that would be created:\n")
fmt.Printf(" %s/configs/\n", oramaDir)
fmt.Printf(" %s/secrets/\n", oramaDir)
fmt.Printf(" %s/data/ipfs/repo/\n", oramaDir)
fmt.Printf(" %s/data/ipfs-cluster/\n", oramaDir)
fmt.Printf(" %s/data/rqlite/\n", oramaDir)
fmt.Printf(" %s/logs/\n", oramaDir)
fmt.Printf(" %s/tls-cache/\n", oramaDir)
fmt.Printf("\n🔧 Binaries that would be installed:\n")
fmt.Printf(" - Go (if not present)\n")
fmt.Printf(" - RQLite 8.43.0\n")
fmt.Printf(" - IPFS/Kubo 0.38.2\n")
fmt.Printf(" - IPFS Cluster (latest)\n")
fmt.Printf(" - Olric 0.7.0\n")
fmt.Printf(" - anyone-client (npm)\n")
fmt.Printf(" - DeBros binaries (built from %s branch)\n", branch)
fmt.Printf("\n🔐 Secrets that would be generated:\n")
fmt.Printf(" - Cluster secret (64-hex)\n")
fmt.Printf(" - IPFS swarm key\n")
fmt.Printf(" - Node identity (Ed25519 keypair)\n")
fmt.Printf("\n📝 Configuration files that would be created:\n")
fmt.Printf(" - %s/configs/node.yaml\n", oramaDir)
fmt.Printf(" - %s/configs/olric/config.yaml\n", oramaDir)
fmt.Printf("\n⚙ Systemd services that would be created:\n")
fmt.Printf(" - debros-ipfs.service\n")
fmt.Printf(" - debros-ipfs-cluster.service\n")
fmt.Printf(" - debros-olric.service\n")
fmt.Printf(" - debros-node.service (includes embedded gateway + RQLite)\n")
fmt.Printf(" - debros-anyone-client.service\n")
fmt.Printf("\n🌐 Ports that would be used:\n")
fmt.Printf(" External (must be open in firewall):\n")
fmt.Printf(" - 80 (HTTP for ACME/Let's Encrypt)\n")
fmt.Printf(" - 443 (HTTPS gateway)\n")
fmt.Printf(" - 4101 (IPFS swarm)\n")
fmt.Printf(" - 7001 (RQLite Raft)\n")
fmt.Printf(" Internal (localhost only):\n")
fmt.Printf(" - 4501 (IPFS API)\n")
fmt.Printf(" - 5001 (RQLite HTTP)\n")
fmt.Printf(" - 6001 (Unified gateway)\n")
fmt.Printf(" - 8080 (IPFS gateway)\n")
fmt.Printf(" - 9050 (Anyone SOCKS5)\n")
fmt.Printf(" - 9094 (IPFS Cluster API)\n")
fmt.Printf(" - 3320/3322 (Olric)\n")
fmt.Print("\n" + strings.Repeat("=", 70) + "\n")
fmt.Printf("To proceed with installation, run without --dry-run\n")
fmt.Print(strings.Repeat("=", 70) + "\n\n")
}

217
pkg/cli/utils/systemd.go Normal file
View File

@ -0,0 +1,217 @@
package utils
import (
"errors"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
)
var ErrServiceNotFound = errors.New("service not found")
// PortSpec defines a port and its name for checking availability
type PortSpec struct {
Name string
Port int
}
var ServicePorts = map[string][]PortSpec{
"debros-gateway": {
{Name: "Gateway API", Port: 6001},
},
"debros-olric": {
{Name: "Olric HTTP", Port: 3320},
{Name: "Olric Memberlist", Port: 3322},
},
"debros-node": {
{Name: "RQLite HTTP", Port: 5001},
{Name: "RQLite Raft", Port: 7001},
},
"debros-ipfs": {
{Name: "IPFS API", Port: 4501},
{Name: "IPFS Gateway", Port: 8080},
{Name: "IPFS Swarm", Port: 4101},
},
"debros-ipfs-cluster": {
{Name: "IPFS Cluster API", Port: 9094},
},
}
// DefaultPorts is used for fresh installs/upgrades before unit files exist.
func DefaultPorts() []PortSpec {
return []PortSpec{
{Name: "IPFS Swarm", Port: 4001},
{Name: "IPFS API", Port: 4501},
{Name: "IPFS Gateway", Port: 8080},
{Name: "Gateway API", Port: 6001},
{Name: "RQLite HTTP", Port: 5001},
{Name: "RQLite Raft", Port: 7001},
{Name: "IPFS Cluster API", Port: 9094},
{Name: "Olric HTTP", Port: 3320},
{Name: "Olric Memberlist", Port: 3322},
}
}
// ResolveServiceName resolves service aliases to actual systemd service names
func ResolveServiceName(alias string) ([]string, error) {
// Service alias mapping (unified - no bootstrap/node distinction)
aliases := map[string][]string{
"node": {"debros-node"},
"ipfs": {"debros-ipfs"},
"cluster": {"debros-ipfs-cluster"},
"ipfs-cluster": {"debros-ipfs-cluster"},
"gateway": {"debros-gateway"},
"olric": {"debros-olric"},
"rqlite": {"debros-node"}, // RQLite logs are in node logs
}
// Check if it's an alias
if serviceNames, ok := aliases[strings.ToLower(alias)]; ok {
// Filter to only existing services
var existing []string
for _, svc := range serviceNames {
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
if _, err := os.Stat(unitPath); err == nil {
existing = append(existing, svc)
}
}
if len(existing) == 0 {
return nil, fmt.Errorf("no services found for alias %q", alias)
}
return existing, nil
}
// Check if it's already a full service name
unitPath := filepath.Join("/etc/systemd/system", alias+".service")
if _, err := os.Stat(unitPath); err == nil {
return []string{alias}, nil
}
// Try without .service suffix
if !strings.HasSuffix(alias, ".service") {
unitPath = filepath.Join("/etc/systemd/system", alias+".service")
if _, err := os.Stat(unitPath); err == nil {
return []string{alias}, nil
}
}
return nil, fmt.Errorf("service %q not found. Use: node, ipfs, cluster, gateway, olric, or full service name", alias)
}
// IsServiceActive checks if a systemd service is currently active (running)
func IsServiceActive(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-active", "--quiet", service)
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
switch exitErr.ExitCode() {
case 3:
return false, nil
case 4:
return false, ErrServiceNotFound
}
}
return false, err
}
return true, nil
}
// IsServiceEnabled checks if a systemd service is enabled to start on boot
func IsServiceEnabled(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-enabled", "--quiet", service)
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
switch exitErr.ExitCode() {
case 1:
return false, nil // Service is disabled
case 4:
return false, ErrServiceNotFound
}
}
return false, err
}
return true, nil
}
// IsServiceMasked checks if a systemd service is masked
func IsServiceMasked(service string) (bool, error) {
cmd := exec.Command("systemctl", "is-enabled", service)
output, err := cmd.CombinedOutput()
if err != nil {
outputStr := string(output)
if strings.Contains(outputStr, "masked") {
return true, nil
}
return false, err
}
return false, nil
}
// GetProductionServices returns a list of all DeBros production service names that exist
func GetProductionServices() []string {
// Unified service names (no bootstrap/node distinction)
allServices := []string{
"debros-gateway",
"debros-node",
"debros-olric",
"debros-ipfs-cluster",
"debros-ipfs",
"debros-anyone-client",
}
// Filter to only existing services by checking if unit file exists
var existing []string
for _, svc := range allServices {
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
if _, err := os.Stat(unitPath); err == nil {
existing = append(existing, svc)
}
}
return existing
}
// CollectPortsForServices returns a list of ports used by the specified services
func CollectPortsForServices(services []string, skipActive bool) ([]PortSpec, error) {
seen := make(map[int]PortSpec)
for _, svc := range services {
if skipActive {
active, err := IsServiceActive(svc)
if err != nil {
return nil, fmt.Errorf("unable to check %s: %w", svc, err)
}
if active {
continue
}
}
for _, spec := range ServicePorts[svc] {
if _, ok := seen[spec.Port]; !ok {
seen[spec.Port] = spec
}
}
}
ports := make([]PortSpec, 0, len(seen))
for _, spec := range seen {
ports = append(ports, spec)
}
return ports, nil
}
// EnsurePortsAvailable checks if the specified ports are available
func EnsurePortsAvailable(action string, ports []PortSpec) error {
for _, spec := range ports {
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", spec.Port))
if err != nil {
if errors.Is(err, syscall.EADDRINUSE) || strings.Contains(err.Error(), "address already in use") {
return fmt.Errorf("%s cannot continue: %s (port %d) is already in use", action, spec.Name, spec.Port)
}
return fmt.Errorf("%s cannot continue: failed to inspect %s (port %d): %w", action, spec.Name, spec.Port, err)
}
_ = ln.Close()
}
return nil
}

113
pkg/cli/utils/validation.go Normal file
View File

@ -0,0 +1,113 @@
package utils
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/multiformats/go-multiaddr"
)
// ValidateGeneratedConfig loads and validates the generated node configuration
func ValidateGeneratedConfig(oramaDir string) error {
configPath := filepath.Join(oramaDir, "configs", "node.yaml")
// Check if config file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("configuration file not found at %s", configPath)
}
// Load the config file
file, err := os.Open(configPath)
if err != nil {
return fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()
var cfg config.Config
if err := config.DecodeStrict(file, &cfg); err != nil {
return fmt.Errorf("failed to parse config: %w", err)
}
// Validate the configuration
if errs := cfg.Validate(); len(errs) > 0 {
var errMsgs []string
for _, e := range errs {
errMsgs = append(errMsgs, e.Error())
}
return fmt.Errorf("configuration validation errors:\n - %s", strings.Join(errMsgs, "\n - "))
}
return nil
}
// ValidateDNSRecord validates that the domain points to the expected IP address
// Returns nil if DNS is valid, warning message if DNS doesn't match but continues,
// or error if DNS lookup fails completely
func ValidateDNSRecord(domain, expectedIP string) error {
if domain == "" {
return nil // No domain provided, skip validation
}
ips, err := net.LookupIP(domain)
if err != nil {
// DNS lookup failed - this is a warning, not a fatal error
// The user might be setting up DNS after installation
fmt.Printf(" ⚠️ DNS lookup failed for %s: %v\n", domain, err)
fmt.Printf(" Make sure DNS is configured before enabling HTTPS\n")
return nil
}
// Check if any resolved IP matches the expected IP
for _, ip := range ips {
if ip.String() == expectedIP {
fmt.Printf(" ✓ DNS validated: %s → %s\n", domain, expectedIP)
return nil
}
}
// DNS doesn't point to expected IP - warn but continue
resolvedIPs := make([]string, len(ips))
for i, ip := range ips {
resolvedIPs[i] = ip.String()
}
fmt.Printf(" ⚠️ DNS mismatch: %s resolves to %v, expected %s\n", domain, resolvedIPs, expectedIP)
fmt.Printf(" HTTPS certificate generation may fail until DNS is updated\n")
return nil
}
// NormalizePeers normalizes and validates peer multiaddrs
func NormalizePeers(peersStr string) ([]string, error) {
if peersStr == "" {
return nil, nil
}
// Split by comma and trim whitespace
rawPeers := strings.Split(peersStr, ",")
peers := make([]string, 0, len(rawPeers))
seen := make(map[string]bool)
for _, peer := range rawPeers {
peer = strings.TrimSpace(peer)
if peer == "" {
continue
}
// Validate multiaddr format
if _, err := multiaddr.NewMultiaddr(peer); err != nil {
return nil, fmt.Errorf("invalid multiaddr %q: %w", peer, err)
}
// Deduplicate
if !seen[peer] {
peers = append(peers, peer)
seen[peer] = true
}
}
return peers, nil
}

View File

@ -329,6 +329,18 @@ func (c *Client) getAppNamespace() string {
return c.config.AppName return c.config.AppName
} }
// PubSubAdapter returns the underlying pubsub.ClientAdapter for direct use by serverless functions.
// This bypasses the authentication checks used by PubSub() since serverless functions
// are already authenticated via the gateway.
func (c *Client) PubSubAdapter() *pubsub.ClientAdapter {
c.mu.RLock()
defer c.mu.RUnlock()
if c.pubsub == nil {
return nil
}
return c.pubsub.adapter
}
// requireAccess enforces that credentials are present and that any context-based namespace overrides match // requireAccess enforces that credentials are present and that any context-based namespace overrides match
func (c *Client) requireAccess(ctx context.Context) error { func (c *Client) requireAccess(ctx context.Context) error {
// Allow internal system operations to bypass authentication // Allow internal system operations to bypass authentication

View File

@ -1,6 +1,7 @@
package config package config
import ( import (
"encoding/hex"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -585,3 +586,15 @@ func extractTCPPort(multiaddrStr string) string {
} }
return "" return ""
} }
// ValidateSwarmKey validates that a swarm key is 64 hex characters
func ValidateSwarmKey(key string) error {
key = strings.TrimSpace(key)
if len(key) != 64 {
return fmt.Errorf("swarm key must be 64 hex characters (32 bytes), got %d", len(key))
}
if _, err := hex.DecodeString(key); err != nil {
return fmt.Errorf("swarm key must be valid hexadecimal: %w", err)
}
return nil
}

View File

@ -78,7 +78,7 @@ func (dc *DependencyChecker) CheckAll() ([]string, error) {
errMsg := fmt.Sprintf("Missing %d required dependencies:\n%s\n\nInstall them with:\n%s", errMsg := fmt.Sprintf("Missing %d required dependencies:\n%s\n\nInstall them with:\n%s",
len(missing), strings.Join(missing, ", "), strings.Join(hints, "\n")) len(missing), strings.Join(missing, ", "), strings.Join(hints, "\n"))
return missing, fmt.Errorf(errMsg) return missing, fmt.Errorf("%s", errMsg)
} }
// PortChecker validates that required ports are available // PortChecker validates that required ports are available
@ -113,7 +113,7 @@ func (pc *PortChecker) CheckAll() ([]int, error) {
errMsg := fmt.Sprintf("The following ports are unavailable: %v\n\nFree them or stop conflicting services and try again", errMsg := fmt.Sprintf("The following ports are unavailable: %v\n\nFree them or stop conflicting services and try again",
unavailable) unavailable)
return unavailable, fmt.Errorf(errMsg) return unavailable, fmt.Errorf("%s", errMsg)
} }
// isPortAvailable checks if a TCP port is available for binding // isPortAvailable checks if a TCP port is available for binding

View File

@ -0,0 +1,287 @@
package development
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/tlsutil"
)
// ipfsNodeInfo holds information about an IPFS node for peer discovery
type ipfsNodeInfo struct {
name string
ipfsPath string
apiPort int
swarmPort int
gatewayPort int
peerID string
}
func (pm *ProcessManager) buildIPFSNodes(topology *Topology) []ipfsNodeInfo {
var nodes []ipfsNodeInfo
for _, nodeSpec := range topology.Nodes {
nodes = append(nodes, ipfsNodeInfo{
name: nodeSpec.Name,
ipfsPath: filepath.Join(pm.oramaDir, nodeSpec.DataDir, "ipfs/repo"),
apiPort: nodeSpec.IPFSAPIPort,
swarmPort: nodeSpec.IPFSSwarmPort,
gatewayPort: nodeSpec.IPFSGatewayPort,
peerID: "",
})
}
return nodes
}
func (pm *ProcessManager) startIPFS(ctx context.Context) error {
topology := DefaultTopology()
nodes := pm.buildIPFSNodes(topology)
for i := range nodes {
os.MkdirAll(nodes[i].ipfsPath, 0755)
if _, err := os.Stat(filepath.Join(nodes[i].ipfsPath, "config")); os.IsNotExist(err) {
fmt.Fprintf(pm.logWriter, " Initializing IPFS (%s)...\n", nodes[i].name)
cmd := exec.CommandContext(ctx, "ipfs", "init", "--profile=server", "--repo-dir="+nodes[i].ipfsPath)
if _, err := cmd.CombinedOutput(); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: ipfs init failed: %v\n", err)
}
swarmKeyPath := filepath.Join(pm.oramaDir, "swarm.key")
if data, err := os.ReadFile(swarmKeyPath); err == nil {
os.WriteFile(filepath.Join(nodes[i].ipfsPath, "swarm.key"), data, 0600)
}
}
peerID, err := configureIPFSRepo(nodes[i].ipfsPath, nodes[i].apiPort, nodes[i].gatewayPort, nodes[i].swarmPort)
if err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to configure IPFS repo for %s: %v\n", nodes[i].name, err)
} else {
nodes[i].peerID = peerID
fmt.Fprintf(pm.logWriter, " Peer ID for %s: %s\n", nodes[i].name, peerID)
}
}
for i := range nodes {
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-%s.pid", nodes[i].name))
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-%s.log", nodes[i].name))
cmd := exec.CommandContext(ctx, "ipfs", "daemon", "--enable-pubsub-experiment", "--repo-dir="+nodes[i].ipfsPath)
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start ipfs-%s: %w", nodes[i].name, err)
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
pm.processes[fmt.Sprintf("ipfs-%s", nodes[i].name)] = &ManagedProcess{
Name: fmt.Sprintf("ipfs-%s", nodes[i].name),
PID: cmd.Process.Pid,
StartTime: time.Now(),
LogPath: logPath,
}
fmt.Fprintf(pm.logWriter, "✓ IPFS (%s) started (PID: %d, API: %d, Swarm: %d)\n", nodes[i].name, cmd.Process.Pid, nodes[i].apiPort, nodes[i].swarmPort)
}
time.Sleep(2 * time.Second)
if err := pm.seedIPFSPeersWithHTTP(ctx, nodes); err != nil {
fmt.Fprintf(pm.logWriter, "⚠️ Failed to seed IPFS peers: %v\n", err)
}
return nil
}
func configureIPFSRepo(repoPath string, apiPort, gatewayPort, swarmPort int) (string, error) {
configPath := filepath.Join(repoPath, "config")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("failed to read IPFS config: %w", err)
}
var config map[string]interface{}
if err := json.Unmarshal(data, &config); err != nil {
return "", fmt.Errorf("failed to parse IPFS config: %w", err)
}
config["Addresses"] = map[string]interface{}{
"API": []string{fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", apiPort)},
"Gateway": []string{fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", gatewayPort)},
"Swarm": []string{
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", swarmPort),
fmt.Sprintf("/ip6/::/tcp/%d", swarmPort),
},
}
config["AutoConf"] = map[string]interface{}{
"Enabled": false,
}
config["Bootstrap"] = []string{}
if dns, ok := config["DNS"].(map[string]interface{}); ok {
dns["Resolvers"] = map[string]interface{}{}
} else {
config["DNS"] = map[string]interface{}{
"Resolvers": map[string]interface{}{},
}
}
if routing, ok := config["Routing"].(map[string]interface{}); ok {
routing["DelegatedRouters"] = []string{}
} else {
config["Routing"] = map[string]interface{}{
"DelegatedRouters": []string{},
}
}
if ipns, ok := config["Ipns"].(map[string]interface{}); ok {
ipns["DelegatedPublishers"] = []string{}
} else {
config["Ipns"] = map[string]interface{}{
"DelegatedPublishers": []string{},
}
}
if api, ok := config["API"].(map[string]interface{}); ok {
api["HTTPHeaders"] = map[string][]string{
"Access-Control-Allow-Origin": {"*"},
"Access-Control-Allow-Methods": {"GET", "PUT", "POST", "DELETE", "OPTIONS"},
"Access-Control-Allow-Headers": {"Content-Type", "X-Requested-With"},
"Access-Control-Expose-Headers": {"Content-Length", "Content-Range"},
}
} else {
config["API"] = map[string]interface{}{
"HTTPHeaders": map[string][]string{
"Access-Control-Allow-Origin": {"*"},
"Access-Control-Allow-Methods": {"GET", "PUT", "POST", "DELETE", "OPTIONS"},
"Access-Control-Allow-Headers": {"Content-Type", "X-Requested-With"},
"Access-Control-Expose-Headers": {"Content-Length", "Content-Range"},
},
}
}
updatedData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", fmt.Errorf("failed to marshal IPFS config: %w", err)
}
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
return "", fmt.Errorf("failed to write IPFS config: %w", err)
}
if id, ok := config["Identity"].(map[string]interface{}); ok {
if peerID, ok := id["PeerID"].(string); ok {
return peerID, nil
}
}
return "", fmt.Errorf("could not extract peer ID from config")
}
func (pm *ProcessManager) seedIPFSPeersWithHTTP(ctx context.Context, nodes []ipfsNodeInfo) error {
fmt.Fprintf(pm.logWriter, " Seeding IPFS local bootstrap peers via HTTP API...\n")
for _, node := range nodes {
if err := pm.waitIPFSReady(ctx, node); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to wait for IPFS readiness for %s: %v\n", node.name, err)
}
}
for i, node := range nodes {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/bootstrap/rm?all=true", node.apiPort)
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to clear bootstrap for %s: %v\n", node.name, err)
}
for j, otherNode := range nodes {
if i == j {
continue
}
multiaddr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d/p2p/%s", otherNode.swarmPort, otherNode.peerID)
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/bootstrap/add?arg=%s", node.apiPort, url.QueryEscape(multiaddr))
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to add bootstrap peer for %s: %v\n", node.name, err)
}
}
}
return nil
}
func (pm *ProcessManager) waitIPFSReady(ctx context.Context, node ipfsNodeInfo) error {
maxRetries := 30
retryInterval := 500 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/version", node.apiPort)
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err == nil {
return nil
}
select {
case <-time.After(retryInterval):
continue
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("IPFS daemon %s did not become ready", node.name)
}
func (pm *ProcessManager) ipfsHTTPCall(ctx context.Context, urlStr string, method string) error {
client := tlsutil.NewHTTPClient(5 * time.Second)
req, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("HTTP call failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
return nil
}
func readIPFSConfigValue(ctx context.Context, repoPath string, key string) (string, error) {
configPath := filepath.Join(repoPath, "config")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("failed to read IPFS config: %w", err)
}
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(line, key) {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
value := strings.TrimSpace(parts[1])
value = strings.Trim(value, `",`)
if value != "" {
return value, nil
}
}
}
}
return "", fmt.Errorf("key %s not found in IPFS config", key)
}

View File

@ -0,0 +1,314 @@
package development
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
func (pm *ProcessManager) startIPFSCluster(ctx context.Context) error {
topology := DefaultTopology()
var nodes []struct {
name string
clusterPath string
restAPIPort int
clusterPort int
ipfsPort int
}
for _, nodeSpec := range topology.Nodes {
nodes = append(nodes, struct {
name string
clusterPath string
restAPIPort int
clusterPort int
ipfsPort int
}{
nodeSpec.Name,
filepath.Join(pm.oramaDir, nodeSpec.DataDir, "ipfs-cluster"),
nodeSpec.ClusterAPIPort,
nodeSpec.ClusterPort,
nodeSpec.IPFSAPIPort,
})
}
fmt.Fprintf(pm.logWriter, " Waiting for IPFS daemons to be ready...\n")
ipfsNodes := pm.buildIPFSNodes(topology)
for _, ipfsNode := range ipfsNodes {
if err := pm.waitIPFSReady(ctx, ipfsNode); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: IPFS %s did not become ready: %v\n", ipfsNode.name, err)
}
}
secretPath := filepath.Join(pm.oramaDir, "cluster-secret")
clusterSecret, err := os.ReadFile(secretPath)
if err != nil {
return fmt.Errorf("failed to read cluster secret: %w", err)
}
clusterSecretHex := strings.TrimSpace(string(clusterSecret))
bootstrapMultiaddr := ""
{
node := nodes[0]
if err := pm.cleanClusterState(node.clusterPath); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to clean cluster state for %s: %v\n", node.name, err)
}
os.MkdirAll(node.clusterPath, 0755)
fmt.Fprintf(pm.logWriter, " Initializing IPFS Cluster (%s)...\n", node.name)
cmd := exec.CommandContext(ctx, "ipfs-cluster-service", "init", "--force")
cmd.Env = append(os.Environ(),
fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath),
fmt.Sprintf("CLUSTER_SECRET=%s", clusterSecretHex),
)
if output, err := cmd.CombinedOutput(); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: ipfs-cluster-service init failed: %v (output: %s)\n", err, string(output))
}
if err := pm.ensureIPFSClusterPorts(node.clusterPath, node.restAPIPort, node.clusterPort); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to update IPFS Cluster config for %s: %v\n", node.name, err)
}
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-cluster-%s.pid", node.name))
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-cluster-%s.log", node.name))
cmd = exec.CommandContext(ctx, "ipfs-cluster-service", "daemon")
cmd.Env = append(os.Environ(), fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath))
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
return err
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ IPFS Cluster (%s) started (PID: %d, API: %d)\n", node.name, cmd.Process.Pid, node.restAPIPort)
if err := pm.waitClusterReady(ctx, node.name, node.restAPIPort); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster %s did not become ready: %v\n", node.name, err)
}
time.Sleep(500 * time.Millisecond)
peerID, err := pm.waitForClusterPeerID(ctx, filepath.Join(node.clusterPath, "identity.json"))
if err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to read bootstrap peer ID: %v\n", err)
} else {
bootstrapMultiaddr = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d/p2p/%s", node.clusterPort, peerID)
}
}
for i := 1; i < len(nodes); i++ {
node := nodes[i]
if err := pm.cleanClusterState(node.clusterPath); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to clean cluster state for %s: %v\n", node.name, err)
}
os.MkdirAll(node.clusterPath, 0755)
fmt.Fprintf(pm.logWriter, " Initializing IPFS Cluster (%s)...\n", node.name)
cmd := exec.CommandContext(ctx, "ipfs-cluster-service", "init", "--force")
cmd.Env = append(os.Environ(),
fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath),
fmt.Sprintf("CLUSTER_SECRET=%s", clusterSecretHex),
)
if output, err := cmd.CombinedOutput(); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: ipfs-cluster-service init failed for %s: %v (output: %s)\n", node.name, err, string(output))
}
if err := pm.ensureIPFSClusterPorts(node.clusterPath, node.restAPIPort, node.clusterPort); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: failed to update IPFS Cluster config for %s: %v\n", node.name, err)
}
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-cluster-%s.pid", node.name))
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-cluster-%s.log", node.name))
args := []string{"daemon"}
if bootstrapMultiaddr != "" {
args = append(args, "--bootstrap", bootstrapMultiaddr)
}
cmd = exec.CommandContext(ctx, "ipfs-cluster-service", args...)
cmd.Env = append(os.Environ(), fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath))
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
continue
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ IPFS Cluster (%s) started (PID: %d, API: %d)\n", node.name, cmd.Process.Pid, node.restAPIPort)
if err := pm.waitClusterReady(ctx, node.name, node.restAPIPort); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster %s did not become ready: %v\n", node.name, err)
}
}
fmt.Fprintf(pm.logWriter, " Waiting for IPFS Cluster peers to form...\n")
if err := pm.waitClusterFormed(ctx, nodes[0].restAPIPort); err != nil {
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster did not form fully: %v\n", err)
}
time.Sleep(1 * time.Second)
return nil
}
func (pm *ProcessManager) waitForClusterPeerID(ctx context.Context, identityPath string) (string, error) {
maxRetries := 30
retryInterval := 500 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
data, err := os.ReadFile(identityPath)
if err == nil {
var identity map[string]interface{}
if err := json.Unmarshal(data, &identity); err == nil {
if id, ok := identity["id"].(string); ok {
return id, nil
}
}
}
select {
case <-time.After(retryInterval):
continue
case <-ctx.Done():
return "", ctx.Err()
}
}
return "", fmt.Errorf("could not read cluster peer ID")
}
func (pm *ProcessManager) waitClusterReady(ctx context.Context, name string, restAPIPort int) error {
maxRetries := 30
retryInterval := 500 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/peers", restAPIPort)
resp, err := http.Get(httpURL)
if err == nil && resp.StatusCode == 200 {
resp.Body.Close()
return nil
}
if resp != nil {
resp.Body.Close()
}
select {
case <-time.After(retryInterval):
continue
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("IPFS Cluster %s did not become ready", name)
}
func (pm *ProcessManager) waitClusterFormed(ctx context.Context, bootstrapRestAPIPort int) error {
maxRetries := 30
retryInterval := 1 * time.Second
requiredPeers := 3
for attempt := 0; attempt < maxRetries; attempt++ {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/peers", bootstrapRestAPIPort)
resp, err := http.Get(httpURL)
if err == nil && resp.StatusCode == 200 {
dec := json.NewDecoder(resp.Body)
peerCount := 0
for {
var peer interface{}
if err := dec.Decode(&peer); err != nil {
break
}
peerCount++
}
resp.Body.Close()
if peerCount >= requiredPeers {
return nil
}
}
if resp != nil {
resp.Body.Close()
}
select {
case <-time.After(retryInterval):
continue
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("IPFS Cluster did not form fully")
}
func (pm *ProcessManager) cleanClusterState(clusterPath string) error {
pebblePath := filepath.Join(clusterPath, "pebble")
os.RemoveAll(pebblePath)
peerstorePath := filepath.Join(clusterPath, "peerstore")
os.Remove(peerstorePath)
serviceJSONPath := filepath.Join(clusterPath, "service.json")
os.Remove(serviceJSONPath)
lockPath := filepath.Join(clusterPath, "cluster.lock")
os.Remove(lockPath)
return nil
}
func (pm *ProcessManager) ensureIPFSClusterPorts(clusterPath string, restAPIPort int, clusterPort int) error {
serviceJSONPath := filepath.Join(clusterPath, "service.json")
data, err := os.ReadFile(serviceJSONPath)
if err != nil {
return err
}
var config map[string]interface{}
json.Unmarshal(data, &config)
portOffset := restAPIPort - 9094
proxyPort := 9095 + portOffset
pinsvcPort := 9097 + portOffset
ipfsPort := 4501 + (portOffset / 10)
if api, ok := config["api"].(map[string]interface{}); ok {
if restapi, ok := api["restapi"].(map[string]interface{}); ok {
restapi["http_listen_multiaddress"] = fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", restAPIPort)
}
if proxy, ok := api["ipfsproxy"].(map[string]interface{}); ok {
proxy["listen_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", proxyPort)
proxy["node_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", ipfsPort)
}
if pinsvc, ok := api["pinsvcapi"].(map[string]interface{}); ok {
pinsvc["http_listen_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", pinsvcPort)
}
}
if cluster, ok := config["cluster"].(map[string]interface{}); ok {
cluster["listen_multiaddress"] = []string{
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", clusterPort),
fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", clusterPort),
}
}
if connector, ok := config["ipfs_connector"].(map[string]interface{}); ok {
if ipfshttp, ok := connector["ipfshttp"].(map[string]interface{}); ok {
ipfshttp["node_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", ipfsPort)
}
}
updatedData, _ := json.MarshalIndent(config, "", " ")
return os.WriteFile(serviceJSONPath, updatedData, 0644)
}

View File

@ -0,0 +1,231 @@
package development
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
func (pm *ProcessManager) printStartupSummary(topology *Topology) {
fmt.Fprintf(pm.logWriter, "\n✅ Development environment ready!\n")
fmt.Fprintf(pm.logWriter, "═══════════════════════════════════════\n\n")
fmt.Fprintf(pm.logWriter, "📡 Access your nodes via unified gateway ports:\n\n")
for _, node := range topology.Nodes {
fmt.Fprintf(pm.logWriter, " %s:\n", node.Name)
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/health\n", node.UnifiedGatewayPort)
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/rqlite/http/db/execute\n", node.UnifiedGatewayPort)
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/cluster/health\n\n", node.UnifiedGatewayPort)
}
fmt.Fprintf(pm.logWriter, "🌐 Main Gateway:\n")
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/v1/status\n\n", topology.GatewayPort)
fmt.Fprintf(pm.logWriter, "📊 Other Services:\n")
fmt.Fprintf(pm.logWriter, " Olric: http://localhost:%d\n", topology.OlricHTTPPort)
fmt.Fprintf(pm.logWriter, " Anon SOCKS: 127.0.0.1:%d\n", topology.AnonSOCKSPort)
fmt.Fprintf(pm.logWriter, " Rqlite MCP: http://localhost:%d/sse\n\n", topology.MCPPort)
fmt.Fprintf(pm.logWriter, "📝 Useful Commands:\n")
fmt.Fprintf(pm.logWriter, " ./bin/orama dev status - Check service status\n")
fmt.Fprintf(pm.logWriter, " ./bin/orama dev logs node-1 - View logs\n")
fmt.Fprintf(pm.logWriter, " ./bin/orama dev down - Stop all services\n\n")
fmt.Fprintf(pm.logWriter, "📂 Logs: %s/logs\n", pm.oramaDir)
fmt.Fprintf(pm.logWriter, "⚙️ Config: %s\n\n", pm.oramaDir)
}
func (pm *ProcessManager) stopProcess(name string) error {
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("%s.pid", name))
pidBytes, err := os.ReadFile(pidPath)
if err != nil {
return nil
}
pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes)))
if err != nil {
os.Remove(pidPath)
return nil
}
if !checkProcessRunning(pid) {
os.Remove(pidPath)
fmt.Fprintf(pm.logWriter, "✓ %s (not running)\n", name)
return nil
}
proc, err := os.FindProcess(pid)
if err != nil {
os.Remove(pidPath)
return nil
}
proc.Signal(os.Interrupt)
gracefulShutdown := false
for i := 0; i < 20; i++ {
time.Sleep(100 * time.Millisecond)
if !checkProcessRunning(pid) {
gracefulShutdown = true
break
}
}
if !gracefulShutdown && checkProcessRunning(pid) {
proc.Signal(os.Kill)
time.Sleep(200 * time.Millisecond)
if runtime.GOOS != "windows" {
exec.Command("pkill", "-9", "-P", fmt.Sprintf("%d", pid)).Run()
}
if checkProcessRunning(pid) {
exec.Command("kill", "-9", fmt.Sprintf("%d", pid)).Run()
time.Sleep(100 * time.Millisecond)
}
}
os.Remove(pidPath)
if gracefulShutdown {
fmt.Fprintf(pm.logWriter, "✓ %s stopped gracefully\n", name)
} else {
fmt.Fprintf(pm.logWriter, "✓ %s stopped (forced)\n", name)
}
return nil
}
func checkProcessRunning(pid int) bool {
proc, err := os.FindProcess(pid)
if err != nil {
return false
}
err = proc.Signal(os.Signal(nil))
return err == nil
}
func (pm *ProcessManager) startNode(name, configFile, logPath string) error {
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("%s.pid", name))
cmd := exec.Command("./bin/orama-node", "--config", configFile)
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start %s: %w", name, err)
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ %s started (PID: %d)\n", strings.Title(name), cmd.Process.Pid)
time.Sleep(1 * time.Second)
return nil
}
func (pm *ProcessManager) startGateway(ctx context.Context) error {
pidPath := filepath.Join(pm.pidsDir, "gateway.pid")
logPath := filepath.Join(pm.oramaDir, "logs", "gateway.log")
cmd := exec.Command("./bin/gateway", "--config", "gateway.yaml")
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start gateway: %w", err)
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ Gateway started (PID: %d, listen: 6001)\n", cmd.Process.Pid)
return nil
}
func (pm *ProcessManager) startOlric(ctx context.Context) error {
pidPath := filepath.Join(pm.pidsDir, "olric.pid")
logPath := filepath.Join(pm.oramaDir, "logs", "olric.log")
configPath := filepath.Join(pm.oramaDir, "olric-config.yaml")
cmd := exec.CommandContext(ctx, "olric-server")
cmd.Env = append(os.Environ(), fmt.Sprintf("OLRIC_SERVER_CONFIG=%s", configPath))
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start olric: %w", err)
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ Olric started (PID: %d)\n", cmd.Process.Pid)
time.Sleep(1 * time.Second)
return nil
}
func (pm *ProcessManager) startAnon(ctx context.Context) error {
if runtime.GOOS != "darwin" {
return nil
}
pidPath := filepath.Join(pm.pidsDir, "anon.pid")
logPath := filepath.Join(pm.oramaDir, "logs", "anon.log")
cmd := exec.CommandContext(ctx, "npx", "anyone-client")
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
fmt.Fprintf(pm.logWriter, " ⚠️ Failed to start Anon: %v\n", err)
return nil
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ Anon proxy started (PID: %d, SOCKS: 9050)\n", cmd.Process.Pid)
return nil
}
func (pm *ProcessManager) startMCP(ctx context.Context) error {
topology := DefaultTopology()
pidPath := filepath.Join(pm.pidsDir, "rqlite-mcp.pid")
logPath := filepath.Join(pm.oramaDir, "logs", "rqlite-mcp.log")
cmd := exec.CommandContext(ctx, "./bin/rqlite-mcp")
cmd.Env = append(os.Environ(),
fmt.Sprintf("MCP_PORT=%d", topology.MCPPort),
fmt.Sprintf("RQLITE_URL=http://localhost:%d", topology.Nodes[0].RQLiteHTTPPort),
)
logFile, _ := os.Create(logPath)
cmd.Stdout = logFile
cmd.Stderr = logFile
if err := cmd.Start(); err != nil {
fmt.Fprintf(pm.logWriter, " ⚠️ Failed to start Rqlite MCP: %v\n", err)
return nil
}
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
fmt.Fprintf(pm.logWriter, "✓ Rqlite MCP started (PID: %d, port: %d)\n", cmd.Process.Pid, topology.MCPPort)
return nil
}
func (pm *ProcessManager) startNodes(ctx context.Context) error {
topology := DefaultTopology()
for _, nodeSpec := range topology.Nodes {
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("%s.log", nodeSpec.Name))
if err := pm.startNode(nodeSpec.Name, nodeSpec.ConfigFilename, logPath); err != nil {
return fmt.Errorf("failed to start %s: %w", nodeSpec.Name, err)
}
time.Sleep(500 * time.Millisecond)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -4,20 +4,20 @@ import "fmt"
// NodeSpec defines configuration for a single dev environment node // NodeSpec defines configuration for a single dev environment node
type NodeSpec struct { type NodeSpec struct {
Name string // node-1, node-2, node-3, node-4, node-5 Name string // node-1, node-2, node-3, node-4, node-5
ConfigFilename string // node-1.yaml, node-2.yaml, etc. ConfigFilename string // node-1.yaml, node-2.yaml, etc.
DataDir string // relative path from .orama root DataDir string // relative path from .orama root
P2PPort int // LibP2P listen port P2PPort int // LibP2P listen port
IPFSAPIPort int // IPFS API port IPFSAPIPort int // IPFS API port
IPFSSwarmPort int // IPFS Swarm port IPFSSwarmPort int // IPFS Swarm port
IPFSGatewayPort int // IPFS HTTP Gateway port IPFSGatewayPort int // IPFS HTTP Gateway port
RQLiteHTTPPort int // RQLite HTTP API port RQLiteHTTPPort int // RQLite HTTP API port
RQLiteRaftPort int // RQLite Raft consensus port RQLiteRaftPort int // RQLite Raft consensus port
ClusterAPIPort int // IPFS Cluster REST API port ClusterAPIPort int // IPFS Cluster REST API port
ClusterPort int // IPFS Cluster P2P port ClusterPort int // IPFS Cluster P2P port
UnifiedGatewayPort int // Unified gateway port (proxies all services) UnifiedGatewayPort int // Unified gateway port (proxies all services)
RQLiteJoinTarget string // which node's RQLite Raft port to join (empty for first node) RQLiteJoinTarget string // which node's RQLite Raft port to join (empty for first node)
ClusterJoinTarget string // which node's cluster to join (empty for first node) ClusterJoinTarget string // which node's cluster to join (empty for first node)
} }
// Topology defines the complete development environment topology // Topology defines the complete development environment topology
@ -27,97 +27,99 @@ type Topology struct {
OlricHTTPPort int OlricHTTPPort int
OlricMemberPort int OlricMemberPort int
AnonSOCKSPort int AnonSOCKSPort int
MCPPort int
} }
// DefaultTopology returns the default five-node dev environment topology // DefaultTopology returns the default five-node dev environment topology
func DefaultTopology() *Topology { func DefaultTopology() *Topology {
return &Topology{ return &Topology{
Nodes: []NodeSpec{ Nodes: []NodeSpec{
{ {
Name: "node-1", Name: "node-1",
ConfigFilename: "node-1.yaml", ConfigFilename: "node-1.yaml",
DataDir: "node-1", DataDir: "node-1",
P2PPort: 4001, P2PPort: 4001,
IPFSAPIPort: 4501, IPFSAPIPort: 4501,
IPFSSwarmPort: 4101, IPFSSwarmPort: 4101,
IPFSGatewayPort: 7501, IPFSGatewayPort: 7501,
RQLiteHTTPPort: 5001, RQLiteHTTPPort: 5001,
RQLiteRaftPort: 7001, RQLiteRaftPort: 7001,
ClusterAPIPort: 9094, ClusterAPIPort: 9094,
ClusterPort: 9096, ClusterPort: 9096,
UnifiedGatewayPort: 6001, UnifiedGatewayPort: 6001,
RQLiteJoinTarget: "", // First node - creates cluster RQLiteJoinTarget: "", // First node - creates cluster
ClusterJoinTarget: "", ClusterJoinTarget: "",
},
{
Name: "node-2",
ConfigFilename: "node-2.yaml",
DataDir: "node-2",
P2PPort: 4011,
IPFSAPIPort: 4511,
IPFSSwarmPort: 4111,
IPFSGatewayPort: 7511,
RQLiteHTTPPort: 5011,
RQLiteRaftPort: 7011,
ClusterAPIPort: 9104,
ClusterPort: 9106,
UnifiedGatewayPort: 6002,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-3",
ConfigFilename: "node-3.yaml",
DataDir: "node-3",
P2PPort: 4002,
IPFSAPIPort: 4502,
IPFSSwarmPort: 4102,
IPFSGatewayPort: 7502,
RQLiteHTTPPort: 5002,
RQLiteRaftPort: 7002,
ClusterAPIPort: 9114,
ClusterPort: 9116,
UnifiedGatewayPort: 6003,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-4",
ConfigFilename: "node-4.yaml",
DataDir: "node-4",
P2PPort: 4003,
IPFSAPIPort: 4503,
IPFSSwarmPort: 4103,
IPFSGatewayPort: 7503,
RQLiteHTTPPort: 5003,
RQLiteRaftPort: 7003,
ClusterAPIPort: 9124,
ClusterPort: 9126,
UnifiedGatewayPort: 6004,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-5",
ConfigFilename: "node-5.yaml",
DataDir: "node-5",
P2PPort: 4004,
IPFSAPIPort: 4504,
IPFSSwarmPort: 4104,
IPFSGatewayPort: 7504,
RQLiteHTTPPort: 5004,
RQLiteRaftPort: 7004,
ClusterAPIPort: 9134,
ClusterPort: 9136,
UnifiedGatewayPort: 6005,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
}, },
{ GatewayPort: 6000, // Main gateway on 6000 (nodes use 6001-6005)
Name: "node-2",
ConfigFilename: "node-2.yaml",
DataDir: "node-2",
P2PPort: 4011,
IPFSAPIPort: 4511,
IPFSSwarmPort: 4111,
IPFSGatewayPort: 7511,
RQLiteHTTPPort: 5011,
RQLiteRaftPort: 7011,
ClusterAPIPort: 9104,
ClusterPort: 9106,
UnifiedGatewayPort: 6002,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-3",
ConfigFilename: "node-3.yaml",
DataDir: "node-3",
P2PPort: 4002,
IPFSAPIPort: 4502,
IPFSSwarmPort: 4102,
IPFSGatewayPort: 7502,
RQLiteHTTPPort: 5002,
RQLiteRaftPort: 7002,
ClusterAPIPort: 9114,
ClusterPort: 9116,
UnifiedGatewayPort: 6003,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-4",
ConfigFilename: "node-4.yaml",
DataDir: "node-4",
P2PPort: 4003,
IPFSAPIPort: 4503,
IPFSSwarmPort: 4103,
IPFSGatewayPort: 7503,
RQLiteHTTPPort: 5003,
RQLiteRaftPort: 7003,
ClusterAPIPort: 9124,
ClusterPort: 9126,
UnifiedGatewayPort: 6004,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
{
Name: "node-5",
ConfigFilename: "node-5.yaml",
DataDir: "node-5",
P2PPort: 4004,
IPFSAPIPort: 4504,
IPFSSwarmPort: 4104,
IPFSGatewayPort: 7504,
RQLiteHTTPPort: 5004,
RQLiteRaftPort: 7004,
ClusterAPIPort: 9134,
ClusterPort: 9136,
UnifiedGatewayPort: 6005,
RQLiteJoinTarget: "localhost:7001",
ClusterJoinTarget: "localhost:9096",
},
},
GatewayPort: 6000, // Main gateway on 6000 (nodes use 6001-6005)
OlricHTTPPort: 3320, OlricHTTPPort: 3320,
OlricMemberPort: 3322, OlricMemberPort: 3322,
AnonSOCKSPort: 9050, AnonSOCKSPort: 9050,
MCPPort: 5825,
} }
} }

View File

@ -1,4 +1,4 @@
package gateway package auth
import ( import (
"crypto" "crypto"
@ -13,13 +13,13 @@ import (
"time" "time"
) )
func (g *Gateway) jwksHandler(w http.ResponseWriter, r *http.Request) { func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if g.signingKey == nil { if s.signingKey == nil {
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}}) _ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}})
return return
} }
pub := g.signingKey.Public().(*rsa.PublicKey) pub := s.signingKey.Public().(*rsa.PublicKey)
n := pub.N.Bytes() n := pub.N.Bytes()
// Encode exponent as big-endian bytes // Encode exponent as big-endian bytes
eVal := pub.E eVal := pub.E
@ -35,7 +35,7 @@ func (g *Gateway) jwksHandler(w http.ResponseWriter, r *http.Request) {
"kty": "RSA", "kty": "RSA",
"use": "sig", "use": "sig",
"alg": "RS256", "alg": "RS256",
"kid": g.keyID, "kid": s.keyID,
"n": base64.RawURLEncoding.EncodeToString(n), "n": base64.RawURLEncoding.EncodeToString(n),
"e": base64.RawURLEncoding.EncodeToString(eb), "e": base64.RawURLEncoding.EncodeToString(eb),
} }
@ -49,7 +49,7 @@ type jwtHeader struct {
Kid string `json:"kid"` Kid string `json:"kid"`
} }
type jwtClaims struct { type JWTClaims struct {
Iss string `json:"iss"` Iss string `json:"iss"`
Sub string `json:"sub"` Sub string `json:"sub"`
Aud string `json:"aud"` Aud string `json:"aud"`
@ -59,9 +59,9 @@ type jwtClaims struct {
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
} }
// parseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims // ParseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims
func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) { func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
if g.signingKey == nil { if s.signingKey == nil {
return nil, errors.New("signing key unavailable") return nil, errors.New("signing key unavailable")
} }
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
@ -90,12 +90,12 @@ func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) {
// Verify signature // Verify signature
signingInput := parts[0] + "." + parts[1] signingInput := parts[0] + "." + parts[1]
sum := sha256.Sum256([]byte(signingInput)) sum := sha256.Sum256([]byte(signingInput))
pub := g.signingKey.Public().(*rsa.PublicKey) pub := s.signingKey.Public().(*rsa.PublicKey)
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil { if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
return nil, errors.New("invalid signature") return nil, errors.New("invalid signature")
} }
// Parse claims // Parse claims
var claims jwtClaims var claims JWTClaims
if err := json.Unmarshal(pb, &claims); err != nil { if err := json.Unmarshal(pb, &claims); err != nil {
return nil, errors.New("invalid claims json") return nil, errors.New("invalid claims json")
} }
@ -122,14 +122,14 @@ func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) {
return &claims, nil return &claims, nil
} }
func (g *Gateway) generateJWT(ns, subject string, ttl time.Duration) (string, int64, error) { func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
if g.signingKey == nil { if s.signingKey == nil {
return "", 0, errors.New("signing key unavailable") return "", 0, errors.New("signing key unavailable")
} }
header := map[string]string{ header := map[string]string{
"alg": "RS256", "alg": "RS256",
"typ": "JWT", "typ": "JWT",
"kid": g.keyID, "kid": s.keyID,
} }
hb, _ := json.Marshal(header) hb, _ := json.Marshal(header)
now := time.Now().UTC() now := time.Now().UTC()
@ -148,7 +148,7 @@ func (g *Gateway) generateJWT(ns, subject string, ttl time.Duration) (string, in
pb64 := base64.RawURLEncoding.EncodeToString(pb) pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64 signingInput := hb64 + "." + pb64
sum := sha256.Sum256([]byte(signingInput)) sum := sha256.Sum256([]byte(signingInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, g.signingKey, crypto.SHA256, sum[:]) sig, err := rsa.SignPKCS1v15(rand.Reader, s.signingKey, crypto.SHA256, sum[:])
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }

391
pkg/gateway/auth/service.go Normal file
View File

@ -0,0 +1,391 @@
package auth
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strconv"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/logging"
ethcrypto "github.com/ethereum/go-ethereum/crypto"
)
// Service handles authentication business logic
type Service struct {
logger *logging.ColoredLogger
orm client.NetworkClient
signingKey *rsa.PrivateKey
keyID string
defaultNS string
}
func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signingKeyPEM string, defaultNS string) (*Service, error) {
s := &Service{
logger: logger,
orm: orm,
defaultNS: defaultNS,
}
if signingKeyPEM != "" {
block, _ := pem.Decode([]byte(signingKeyPEM))
if block == nil {
return nil, fmt.Errorf("failed to parse signing key PEM")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
}
s.signingKey = key
// Generate a simple KID from the public key hash
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
sum := sha256.Sum256(pubBytes)
s.keyID = hex.EncodeToString(sum[:8])
}
return s, nil
}
// CreateNonce generates a new nonce and stores it in the database
func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) {
// Generate a URL-safe random nonce (32 bytes)
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
nonce := base64.RawURLEncoding.EncodeToString(buf)
// Use internal context to bypass authentication for system operations
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
if namespace == "" {
namespace = s.defaultNS
if namespace == "" {
namespace = "default"
}
}
// Ensure namespace exists
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", namespace); err != nil {
return "", fmt.Errorf("failed to ensure namespace: %w", err)
}
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return "", fmt.Errorf("failed to resolve namespace ID: %w", err)
}
// Store nonce with 5 minute expiry
walletLower := strings.ToLower(strings.TrimSpace(wallet))
if _, err := db.Query(internalCtx,
"INSERT INTO nonces(namespace_id, wallet, nonce, purpose, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+5 minutes'))",
nsID, walletLower, nonce, purpose,
); err != nil {
return "", fmt.Errorf("failed to store nonce: %w", err)
}
return nonce, nil
}
// VerifySignature verifies a wallet signature for a given nonce
func (s *Service) VerifySignature(ctx context.Context, wallet, nonce, signature, chainType string) (bool, error) {
chainType = strings.ToUpper(strings.TrimSpace(chainType))
if chainType == "" {
chainType = "ETH"
}
switch chainType {
case "ETH":
return s.verifyEthSignature(wallet, nonce, signature)
case "SOL":
return s.verifySolSignature(wallet, nonce, signature)
default:
return false, fmt.Errorf("unsupported chain type: %s", chainType)
}
}
func (s *Service) verifyEthSignature(wallet, nonce, signature string) (bool, error) {
msg := []byte(nonce)
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
hash := ethcrypto.Keccak256(prefix, msg)
sigHex := strings.TrimSpace(signature)
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
sigHex = sigHex[2:]
}
sig, err := hex.DecodeString(sigHex)
if err != nil || len(sig) != 65 {
return false, fmt.Errorf("invalid signature format")
}
if sig[64] >= 27 {
sig[64] -= 27
}
pub, err := ethcrypto.SigToPub(hash, sig)
if err != nil {
return false, fmt.Errorf("signature recovery failed: %w", err)
}
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(wallet, "0x"), "0X"))
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
return got == want, nil
}
func (s *Service) verifySolSignature(wallet, nonce, signature string) (bool, error) {
sig, err := base64.StdEncoding.DecodeString(signature)
if err != nil {
return false, fmt.Errorf("invalid base64 signature: %w", err)
}
if len(sig) != 64 {
return false, fmt.Errorf("invalid signature length: expected 64 bytes, got %d", len(sig))
}
pubKeyBytes, err := s.Base58Decode(wallet)
if err != nil {
return false, fmt.Errorf("invalid wallet address: %w", err)
}
if len(pubKeyBytes) != 32 {
return false, fmt.Errorf("invalid public key length: expected 32 bytes, got %d", len(pubKeyBytes))
}
message := []byte(nonce)
return ed25519.Verify(ed25519.PublicKey(pubKeyBytes), message, sig), nil
}
// IssueTokens generates access and refresh tokens for a verified wallet
func (s *Service) IssueTokens(ctx context.Context, wallet, namespace string) (string, string, int64, error) {
if s.signingKey == nil {
return "", "", 0, fmt.Errorf("signing key unavailable")
}
// Issue access token (15m)
token, expUnix, err := s.GenerateJWT(namespace, wallet, 15*time.Minute)
if err != nil {
return "", "", 0, fmt.Errorf("failed to generate JWT: %w", err)
}
// Create refresh token (30d)
rbuf := make([]byte, 32)
if _, err := rand.Read(rbuf); err != nil {
return "", "", 0, fmt.Errorf("failed to generate refresh token: %w", err)
}
refresh := base64.RawURLEncoding.EncodeToString(rbuf)
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return "", "", 0, fmt.Errorf("failed to resolve namespace ID: %w", err)
}
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
if _, err := db.Query(internalCtx,
"INSERT INTO refresh_tokens(namespace_id, subject, token, audience, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+30 days'))",
nsID, wallet, refresh, "gateway",
); err != nil {
return "", "", 0, fmt.Errorf("failed to store refresh token: %w", err)
}
return token, refresh, expUnix, nil
}
// RefreshToken validates a refresh token and issues a new access token
func (s *Service) RefreshToken(ctx context.Context, refreshToken, namespace string) (string, string, int64, error) {
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return "", "", 0, err
}
q := "SELECT subject FROM refresh_tokens WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
res, err := db.Query(internalCtx, q, nsID, refreshToken)
if err != nil || res == nil || res.Count == 0 {
return "", "", 0, fmt.Errorf("invalid or expired refresh token")
}
subject := ""
if len(res.Rows) > 0 && len(res.Rows[0]) > 0 {
if val, ok := res.Rows[0][0].(string); ok {
subject = val
} else {
b, _ := json.Marshal(res.Rows[0][0])
_ = json.Unmarshal(b, &subject)
}
}
token, expUnix, err := s.GenerateJWT(namespace, subject, 15*time.Minute)
if err != nil {
return "", "", 0, err
}
return token, subject, expUnix, nil
}
// RevokeToken revokes a specific refresh token or all tokens for a subject
func (s *Service) RevokeToken(ctx context.Context, namespace, token string, all bool, subject string) error {
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return err
}
if token != "" {
_, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL", nsID, token)
return err
}
if all && subject != "" {
_, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND subject = ? AND revoked_at IS NULL", nsID, subject)
return err
}
return fmt.Errorf("nothing to revoke")
}
// RegisterApp registers a new client application
func (s *Service) RegisterApp(ctx context.Context, wallet, namespace, name, publicKey string) (string, error) {
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return "", err
}
// Generate client app_id
buf := make([]byte, 12)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("failed to generate app id: %w", err)
}
appID := "app_" + base64.RawURLEncoding.EncodeToString(buf)
// Persist app
if _, err := db.Query(internalCtx, "INSERT INTO apps(namespace_id, app_id, name, public_key) VALUES (?, ?, ?, ?)", nsID, appID, name, publicKey); err != nil {
return "", err
}
// Record ownership
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, ?, ?)", nsID, "wallet", wallet)
return appID, nil
}
// GetOrCreateAPIKey returns an existing API key or creates a new one for a wallet in a namespace
func (s *Service) GetOrCreateAPIKey(ctx context.Context, wallet, namespace string) (string, error) {
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
nsID, err := s.ResolveNamespaceID(ctx, namespace)
if err != nil {
return "", err
}
// Try existing linkage
var apiKey string
r1, err := db.Query(internalCtx,
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
nsID, wallet,
)
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
if val, ok := r1.Rows[0][0].(string); ok {
apiKey = val
}
}
if apiKey != "" {
return apiKey, nil
}
// Create new API key
buf := make([]byte, 18)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("failed to generate api key: %w", err)
}
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + namespace
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
return "", fmt.Errorf("failed to store api key: %w", err)
}
// Link wallet -> api_key
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
apiKeyID := rid.Rows[0][0]
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(wallet), apiKeyID)
}
// Record ownerships
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, wallet)
return apiKey, nil
}
// ResolveNamespaceID ensures the given namespace exists and returns its primary key ID.
func (s *Service) ResolveNamespaceID(ctx context.Context, ns string) (interface{}, error) {
if s.orm == nil {
return nil, fmt.Errorf("client not initialized")
}
ns = strings.TrimSpace(ns)
if ns == "" {
ns = "default"
}
internalCtx := client.WithInternalAuth(ctx)
db := s.orm.Database()
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
return nil, err
}
res, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
if err != nil {
return nil, err
}
if res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 {
return nil, fmt.Errorf("failed to resolve namespace")
}
return res.Rows[0][0], nil
}
// Base58Decode decodes a base58-encoded string
func (s *Service) Base58Decode(input string) ([]byte, error) {
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
answer := big.NewInt(0)
j := big.NewInt(1)
for i := len(input) - 1; i >= 0; i-- {
tmp := strings.IndexByte(alphabet, input[i])
if tmp == -1 {
return nil, fmt.Errorf("invalid base58 character")
}
idx := big.NewInt(int64(tmp))
tmp1 := new(big.Int)
tmp1.Mul(idx, j)
answer.Add(answer, tmp1)
j.Mul(j, big.NewInt(58))
}
// Handle leading zeros
res := answer.Bytes()
for i := 0; i < len(input) && input[i] == alphabet[0]; i++ {
res = append([]byte{0}, res...)
}
return res, nil
}

View File

@ -0,0 +1,166 @@
package auth
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"testing"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/logging"
)
// mockNetworkClient implements client.NetworkClient for testing
type mockNetworkClient struct {
client.NetworkClient
db *mockDatabaseClient
}
func (m *mockNetworkClient) Database() client.DatabaseClient {
return m.db
}
// mockDatabaseClient implements client.DatabaseClient for testing
type mockDatabaseClient struct {
client.DatabaseClient
}
func (m *mockDatabaseClient) Query(ctx context.Context, sql string, args ...interface{}) (*client.QueryResult, error) {
return &client.QueryResult{
Count: 1,
Rows: [][]interface{}{
{1}, // Default ID for ResolveNamespaceID
},
}, nil
}
func createTestService(t *testing.T) *Service {
logger, _ := logging.NewColoredLogger(logging.ComponentGateway, false)
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
mockDB := &mockDatabaseClient{}
mockClient := &mockNetworkClient{db: mockDB}
s, err := NewService(logger, mockClient, string(keyPEM), "test-ns")
if err != nil {
t.Fatalf("failed to create service: %v", err)
}
return s
}
func TestBase58Decode(t *testing.T) {
s := &Service{}
tests := []struct {
input string
expected string // hex representation for comparison
wantErr bool
}{
{"1", "00", false},
{"2", "01", false},
{"9", "08", false},
{"A", "09", false},
{"B", "0a", false},
{"2p", "0100", false}, // 58*1 + 0 = 58 (0x3a) - wait, base58 is weird
}
for _, tt := range tests {
got, err := s.Base58Decode(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("Base58Decode(%s) error = %v, wantErr %v", tt.input, err, tt.wantErr)
continue
}
if !tt.wantErr {
hexGot := hex.EncodeToString(got)
if tt.expected != "" && hexGot != tt.expected {
// Base58 decoding of single characters might not be exactly what I expect above
// but let's just ensure it doesn't crash and returns something for now.
// Better to test a known valid address.
}
}
}
// Test a real Solana address (Base58)
solAddr := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
_, err := s.Base58Decode(solAddr)
if err != nil {
t.Errorf("failed to decode solana address: %v", err)
}
}
func TestJWTFlow(t *testing.T) {
s := createTestService(t)
ns := "test-ns"
sub := "0x1234567890abcdef1234567890abcdef12345678"
ttl := 15 * time.Minute
token, exp, err := s.GenerateJWT(ns, sub, ttl)
if err != nil {
t.Fatalf("GenerateJWT failed: %v", err)
}
if token == "" {
t.Fatal("generated token is empty")
}
if exp <= time.Now().Unix() {
t.Errorf("expiration time %d is in the past", exp)
}
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT failed: %v", err)
}
if claims.Sub != sub {
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
}
if claims.Namespace != ns {
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
}
if claims.Iss != "debros-gateway" {
t.Errorf("expected issuer debros-gateway, got %s", claims.Iss)
}
}
func TestVerifyEthSignature(t *testing.T) {
s := &Service{}
// This is a bit hard to test without a real ETH signature
// but we can check if it returns false for obviously wrong signatures
wallet := "0x1234567890abcdef1234567890abcdef12345678"
nonce := "test-nonce"
sig := hex.EncodeToString(make([]byte, 65))
ok, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "ETH")
if err == nil && ok {
t.Error("VerifySignature should have failed for zero signature")
}
}
func TestVerifySolSignature(t *testing.T) {
s := &Service{}
// Solana address (base58)
wallet := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
nonce := "test-nonce"
sig := "invalid-sig"
_, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "SOL")
if err == nil {
t.Error("VerifySignature should have failed for invalid base64 signature")
}
}

View File

@ -1,20 +1,14 @@
package gateway package gateway
import ( import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/big"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/client" "github.com/DeBrosOfficial/network/pkg/client"
ethcrypto "github.com/ethereum/go-ethereum/crypto" "github.com/DeBrosOfficial/network/pkg/gateway/auth"
) )
func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
@ -29,7 +23,7 @@ func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
// Prefer JWT if present // Prefer JWT if present
if v := ctx.Value(ctxKeyJWT); v != nil { if v := ctx.Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*jwtClaims); ok && claims != nil { if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"authenticated": true, "authenticated": true,
"method": "jwt", "method": "jwt",
@ -61,8 +55,8 @@ func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
} }
func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -82,51 +76,16 @@ func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "wallet is required") writeError(w, http.StatusBadRequest, "wallet is required")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
// Generate a URL-safe random nonce (32 bytes)
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate nonce")
return
}
nonce := base64.RawURLEncoding.EncodeToString(buf)
// Insert namespace if missing, fetch id nonce, err := g.authService.CreateNonce(r.Context(), req.Wallet, req.Purpose, req.Namespace)
ctx := r.Context() if err != nil {
// Use internal context to bypass authentication for system operations
internalCtx := client.WithInternalAuth(ctx)
db := g.client.Database()
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
nres, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
writeError(w, http.StatusInternalServerError, "failed to resolve namespace")
return
}
nsID := nres.Rows[0][0]
// Store nonce with 5 minute expiry
// Normalize wallet address to lowercase for case-insensitive comparison
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
if _, err := db.Query(internalCtx,
"INSERT INTO nonces(namespace_id, wallet, nonce, purpose, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+5 minutes'))",
nsID, walletLower, nonce, req.Purpose,
); err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"wallet": req.Wallet, "wallet": req.Wallet,
"namespace": ns, "namespace": req.Namespace,
"nonce": nonce, "nonce": nonce,
"purpose": req.Purpose, "purpose": req.Purpose,
"expires_at": time.Now().Add(5 * time.Minute).UTC().Format(time.RFC3339Nano), "expires_at": time.Now().Add(5 * time.Minute).UTC().Format(time.RFC3339Nano),
@ -134,8 +93,8 @@ func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
} }
func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -147,7 +106,7 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
Nonce string `json:"nonce"` Nonce string `json:"nonce"`
Signature string `json:"signature"` Signature string `json:"signature"`
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
ChainType string `json:"chain_type"` // "ETH" or "SOL", defaults to "ETH" ChainType string `json:"chain_type"`
} }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body") writeError(w, http.StatusBadRequest, "invalid json body")
@ -157,185 +116,30 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required") writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context() ctx := r.Context()
// Use internal context to bypass authentication for system operations verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
internalCtx := client.WithInternalAuth(ctx) if err != nil || !verified {
writeError(w, http.StatusUnauthorized, "signature verification failed")
return
}
// Mark nonce used
nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
db := g.client.Database() db := g.client.Database()
nsID, err := g.resolveNamespaceID(ctx, ns) _, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
token, refresh, expUnix, err := g.authService.IssueTokens(ctx, req.Wallet, req.Namespace)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
// Normalize wallet address to lowercase for case-insensitive comparison
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
q := "SELECT id FROM nonces WHERE namespace_id = ? AND LOWER(wallet) = LOWER(?) AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
nres, err := db.Query(internalCtx, q, nsID, walletLower, req.Nonce)
if err != nil || nres == nil || nres.Count == 0 {
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
return
}
nonceID := nres.Rows[0][0]
// Determine chain type (default to ETH for backward compatibility) apiKey, err := g.authService.GetOrCreateAPIKey(ctx, req.Wallet, req.Namespace)
chainType := strings.ToUpper(strings.TrimSpace(req.ChainType))
if chainType == "" {
chainType = "ETH"
}
// Verify signature based on chain type
var verified bool
var verifyErr error
switch chainType {
case "ETH":
// EVM personal_sign verification of the nonce
msg := []byte(req.Nonce)
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
hash := ethcrypto.Keccak256(prefix, msg)
// Decode signature (expects 65-byte r||s||v, hex with optional 0x)
sigHex := strings.TrimSpace(req.Signature)
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
sigHex = sigHex[2:]
}
sig, err := hex.DecodeString(sigHex)
if err != nil || len(sig) != 65 {
writeError(w, http.StatusBadRequest, "invalid signature format")
return
}
// Normalize V to 0/1 as expected by geth
if sig[64] >= 27 {
sig[64] -= 27
}
pub, err := ethcrypto.SigToPub(hash, sig)
if err != nil {
writeError(w, http.StatusUnauthorized, "signature recovery failed")
return
}
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
if got != want {
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
return
}
verified = true
case "SOL":
// Solana uses Ed25519 signatures
// Signature is base64-encoded, public key is the wallet address (base58)
// Decode base64 signature (Solana signatures are 64 bytes)
sig, err := base64.StdEncoding.DecodeString(req.Signature)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid base64 signature: %v", err))
return
}
if len(sig) != 64 {
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid signature length: expected 64 bytes, got %d", len(sig)))
return
}
// Decode base58 public key (Solana wallet address)
pubKeyBytes, err := base58Decode(req.Wallet)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid wallet address: %v", err))
return
}
if len(pubKeyBytes) != 32 {
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid public key length: expected 32 bytes, got %d", len(pubKeyBytes)))
return
}
// Verify Ed25519 signature
message := []byte(req.Nonce)
if !ed25519.Verify(ed25519.PublicKey(pubKeyBytes), message, sig) {
writeError(w, http.StatusUnauthorized, "signature verification failed")
return
}
verified = true
default:
writeError(w, http.StatusBadRequest, fmt.Sprintf("unsupported chain type: %s (must be ETH or SOL)", chainType))
return
}
if !verified {
writeError(w, http.StatusUnauthorized, fmt.Sprintf("signature verification failed: %v", verifyErr))
return
}
// Mark nonce used now (after successful verification)
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if g.signingKey == nil {
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
return
}
// Issue access token (15m) and a refresh token (30d)
token, expUnix, err := g.generateJWT(ns, req.Wallet, 15*time.Minute)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
// create refresh token
rbuf := make([]byte, 32)
if _, err := rand.Read(rbuf); err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate refresh token")
return
}
refresh := base64.RawURLEncoding.EncodeToString(rbuf)
if _, err := db.Query(internalCtx, "INSERT INTO refresh_tokens(namespace_id, subject, token, audience, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+30 days'))", nsID, req.Wallet, refresh, "gateway"); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Ensure API key exists for this (namespace, wallet) and record ownerships
// This is done automatically after successful verification; no second nonce needed
var apiKey string
// Try existing linkage
r1, err := db.Query(internalCtx,
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
nsID, req.Wallet,
)
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
if s, ok := r1.Rows[0][0].(string); ok {
apiKey = s
} else {
b, _ := json.Marshal(r1.Rows[0][0])
_ = json.Unmarshal(b, &apiKey)
}
}
if strings.TrimSpace(apiKey) == "" {
// Create new API key with format ak_<random>:<namespace>
buf := make([]byte, 18)
if _, err := rand.Read(buf); err == nil {
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err == nil {
// Link wallet -> api_key
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
apiKeyID := rid.Rows[0][0]
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
}
}
}
}
// Record ownerships (best-effort)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"access_token": token, "access_token": token,
@ -343,23 +147,16 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
"expires_in": int(expUnix - time.Now().Unix()), "expires_in": int(expUnix - time.Now().Unix()),
"refresh_token": refresh, "refresh_token": refresh,
"subject": req.Wallet, "subject": req.Wallet,
"namespace": ns, "namespace": req.Namespace,
"api_key": apiKey, "api_key": apiKey,
"nonce": req.Nonce, "nonce": req.Nonce,
"signature_verified": true, "signature_verified": true,
}) })
} }
// issueAPIKeyHandler creates or returns an API key for a verified wallet in a namespace.
// Requires: POST { wallet, nonce, signature, namespace }
// Behavior:
// - Validates nonce and signature like verifyHandler
// - Ensures namespace exists
// - If an API key already exists for (namespace, wallet), returns it; else creates one
// - Records namespace ownership mapping for the wallet and api_key
func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -371,6 +168,7 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
Nonce string `json:"nonce"` Nonce string `json:"nonce"`
Signature string `json:"signature"` Signature string `json:"signature"`
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
ChainType string `json:"chain_type"`
Plan string `json:"plan"` Plan string `json:"plan"`
} }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -381,110 +179,33 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required") writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context() ctx := r.Context()
// Use internal context to bypass authentication for system operations verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
internalCtx := client.WithInternalAuth(ctx) if err != nil || !verified {
db := g.client.Database() writeError(w, http.StatusUnauthorized, "signature verification failed")
// Resolve namespace id
nsID, err := g.resolveNamespaceID(ctx, ns)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Validate nonce exists and not used/expired
// Normalize wallet address to lowercase for case-insensitive comparison
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
q := "SELECT id FROM nonces WHERE namespace_id = ? AND LOWER(wallet) = LOWER(?) AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
nres, err := db.Query(internalCtx, q, nsID, walletLower, req.Nonce)
if err != nil || nres == nil || nres.Count == 0 {
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
return
}
nonceID := nres.Rows[0][0]
// Verify signature like verifyHandler
msg := []byte(req.Nonce)
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
hash := ethcrypto.Keccak256(prefix, msg)
sigHex := strings.TrimSpace(req.Signature)
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
sigHex = sigHex[2:]
}
sig, err := hex.DecodeString(sigHex)
if err != nil || len(sig) != 65 {
writeError(w, http.StatusBadRequest, "invalid signature format")
return
}
if sig[64] >= 27 {
sig[64] -= 27
}
pub, err := ethcrypto.SigToPub(hash, sig)
if err != nil {
writeError(w, http.StatusUnauthorized, "signature recovery failed")
return
}
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
if got != want {
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
return return
} }
// Mark nonce used // Mark nonce used
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil { nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
db := g.client.Database()
_, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
apiKey, err := g.authService.GetOrCreateAPIKey(ctx, req.Wallet, req.Namespace)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
// Check if api key exists for (namespace, wallet) via linkage table
var apiKey string
r1, err := db.Query(internalCtx, "SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1", nsID, req.Wallet)
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
if s, ok := r1.Rows[0][0].(string); ok {
apiKey = s
} else {
b, _ := json.Marshal(r1.Rows[0][0])
_ = json.Unmarshal(b, &apiKey)
}
}
if strings.TrimSpace(apiKey) == "" {
// Create new API key with format ak_<random>:<namespace>
buf := make([]byte, 18)
if _, err := rand.Read(buf); err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate api key")
return
}
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Create linkage
// Find api_key id
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
apiKeyID := rid.Rows[0][0]
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
}
}
// Record ownerships (best-effort)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"api_key": apiKey, "api_key": apiKey,
"namespace": ns, "namespace": req.Namespace,
"plan": func() string { "plan": func() string {
if strings.TrimSpace(req.Plan) == "" { if strings.TrimSpace(req.Plan) == "" {
return "free" return "free"
} else {
return req.Plan
} }
return req.Plan
}(), }(),
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")), "wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
}) })
@ -494,8 +215,8 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
// Requires Authorization header with API key (Bearer or ApiKey or X-API-Key header). // Requires Authorization header with API key (Bearer or ApiKey or X-API-Key header).
// Returns a JWT bound to the namespace derived from the API key record. // Returns a JWT bound to the namespace derived from the API key record.
func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -507,10 +228,10 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusUnauthorized, "missing API key") writeError(w, http.StatusUnauthorized, "missing API key")
return return
} }
// Validate and get namespace // Validate and get namespace
db := g.client.Database() db := g.client.Database()
ctx := r.Context() ctx := r.Context()
// Use internal context to bypass authentication for system operations
internalCtx := client.WithInternalAuth(ctx) internalCtx := client.WithInternalAuth(ctx)
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1" q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
res, err := db.Query(internalCtx, q, key) res, err := db.Query(internalCtx, q, key)
@ -518,28 +239,18 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusUnauthorized, "invalid API key") writeError(w, http.StatusUnauthorized, "invalid API key")
return return
} }
var ns string var ns string
if s, ok := res.Rows[0][0].(string); ok { if s, ok := res.Rows[0][0].(string); ok {
ns = s ns = s
} else {
b, _ := json.Marshal(res.Rows[0][0])
_ = json.Unmarshal(b, &ns)
} }
ns = strings.TrimSpace(ns)
if ns == "" { token, expUnix, err := g.authService.GenerateJWT(ns, key, 15*time.Minute)
writeError(w, http.StatusUnauthorized, "invalid API key")
return
}
if g.signingKey == nil {
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
return
}
// Subject is the API key string for now
token, expUnix, err := g.generateJWT(ns, key, 15*time.Minute)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"access_token": token, "access_token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -549,8 +260,8 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
} }
func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -562,6 +273,7 @@ func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
Nonce string `json:"nonce"` Nonce string `json:"nonce"`
Signature string `json:"signature"` Signature string `json:"signature"`
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
ChainType string `json:"chain_type"`
Name string `json:"name"` Name string `json:"name"`
} }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -572,106 +284,45 @@ func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required") writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context() ctx := r.Context()
// Use internal context to bypass authentication for system operations verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
internalCtx := client.WithInternalAuth(ctx) if err != nil || !verified {
writeError(w, http.StatusUnauthorized, "signature verification failed")
return
}
// Mark nonce used
nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
db := g.client.Database() db := g.client.Database()
nsID, err := g.resolveNamespaceID(ctx, ns) _, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
// In a real app we'd derive the public key from the signature, but for simplicity here
// we just use a placeholder or expect it in the request if needed.
// For Ethereum, we can recover it.
publicKey := "recovered-pk"
appID, err := g.authService.RegisterApp(ctx, req.Wallet, req.Namespace, req.Name, publicKey)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
// Validate nonce
q := "SELECT id FROM nonces WHERE namespace_id = ? AND wallet = ? AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
nres, err := db.Query(internalCtx, q, nsID, req.Wallet, req.Nonce)
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
return
}
nonceID := nres.Rows[0][0]
// EVM personal_sign verification of the nonce
msg := []byte(req.Nonce)
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
hash := ethcrypto.Keccak256(prefix, msg)
// Decode signature (expects 65-byte r||s||v, hex with optional 0x)
sigHex := strings.TrimSpace(req.Signature)
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
sigHex = sigHex[2:]
}
sig, err := hex.DecodeString(sigHex)
if err != nil || len(sig) != 65 {
writeError(w, http.StatusBadRequest, "invalid signature format")
return
}
// Normalize V to 0/1 as expected by geth
if sig[64] >= 27 {
sig[64] -= 27
}
pub, err := ethcrypto.SigToPub(hash, sig)
if err != nil {
writeError(w, http.StatusUnauthorized, "signature recovery failed")
return
}
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
if got != want {
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
return
}
// Mark nonce used now (after successful verification)
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Derive public key (uncompressed) hex
pubBytes := ethcrypto.FromECDSAPub(pub)
pubHex := "0x" + hex.EncodeToString(pubBytes)
// Generate client app_id
buf := make([]byte, 12)
if _, err := rand.Read(buf); err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate app id")
return
}
appID := "app_" + base64.RawURLEncoding.EncodeToString(buf)
// Persist app
if _, err := db.Query(internalCtx, "INSERT INTO apps(namespace_id, app_id, name, public_key) VALUES (?, ?, ?, ?)", nsID, appID, req.Name, pubHex); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Record namespace ownership by wallet (best-effort)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, ?, ?)", nsID, "wallet", req.Wallet)
writeJSON(w, http.StatusCreated, map[string]any{ writeJSON(w, http.StatusCreated, map[string]any{
"client_id": appID, "client_id": appID,
"app": map[string]any{ "app": map[string]any{
"app_id": appID, "app_id": appID,
"name": req.Name, "name": req.Name,
"public_key": pubHex, "namespace": req.Namespace,
"namespace": ns, "wallet": strings.ToLower(req.Wallet),
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
}, },
"signature_verified": true, "signature_verified": true,
}) })
} }
func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -690,54 +341,20 @@ func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "refresh_token is required") writeError(w, http.StatusBadRequest, "refresh_token is required")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" { token, subject, expUnix, err := g.authService.RefreshToken(r.Context(), req.RefreshToken, req.Namespace)
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context()
// Use internal context to bypass authentication for system operations
internalCtx := client.WithInternalAuth(ctx)
db := g.client.Database()
nsID, err := g.resolveNamespaceID(ctx, ns)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusUnauthorized, err.Error())
return
}
q := "SELECT subject FROM refresh_tokens WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
rres, err := db.Query(internalCtx, q, nsID, req.RefreshToken)
if err != nil || rres == nil || rres.Count == 0 {
writeError(w, http.StatusUnauthorized, "invalid or expired refresh token")
return
}
subject := ""
if len(rres.Rows) > 0 && len(rres.Rows[0]) > 0 {
if s, ok := rres.Rows[0][0].(string); ok {
subject = s
} else {
// fallback: format via json
b, _ := json.Marshal(rres.Rows[0][0])
_ = json.Unmarshal(b, &subject)
}
}
if g.signingKey == nil {
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
return
}
token, expUnix, err := g.generateJWT(ns, subject, 15*time.Minute)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"access_token": token, "access_token": token,
"token_type": "Bearer", "token_type": "Bearer",
"expires_in": int(expUnix - time.Now().Unix()), "expires_in": int(expUnix - time.Now().Unix()),
"refresh_token": req.RefreshToken, "refresh_token": req.RefreshToken,
"subject": subject, "subject": subject,
"namespace": ns, "namespace": req.Namespace,
}) })
} }
@ -1064,8 +681,8 @@ func (g *Gateway) loginPageHandler(w http.ResponseWriter, r *http.Request) {
// be revoked. If all=true is provided (and the request is authenticated via JWT), // be revoked. If all=true is provided (and the request is authenticated via JWT),
// all tokens for the JWT subject within the namespace are revoked. // all tokens for the JWT subject within the namespace are revoked.
func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -1081,38 +698,12 @@ func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid json body") writeError(w, http.StatusBadRequest, "invalid json body")
return return
} }
ns := strings.TrimSpace(req.Namespace)
if ns == "" {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context() ctx := r.Context()
// Use internal context to bypass authentication for system operations var subject string
internalCtx := client.WithInternalAuth(ctx)
db := g.client.Database()
nsID, err := g.resolveNamespaceID(ctx, ns)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if strings.TrimSpace(req.RefreshToken) != "" {
// Revoke specific token
if _, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL", nsID, req.RefreshToken); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "revoked": 1})
return
}
if req.All { if req.All {
// Require JWT to identify subject
var subject string
if v := ctx.Value(ctxKeyJWT); v != nil { if v := ctx.Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*jwtClaims); ok && claims != nil { if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
subject = strings.TrimSpace(claims.Sub) subject = strings.TrimSpace(claims.Sub)
} }
} }
@ -1120,23 +711,19 @@ func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusUnauthorized, "jwt required for all=true") writeError(w, http.StatusUnauthorized, "jwt required for all=true")
return return
} }
if _, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND subject = ? AND revoked_at IS NULL", nsID, subject); err != nil { }
writeError(w, http.StatusInternalServerError, err.Error())
return if err := g.authService.RevokeToken(ctx, req.Namespace, req.RefreshToken, req.All, subject); err != nil {
} writeError(w, http.StatusInternalServerError, err.Error())
writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "revoked": "all"})
return return
} }
writeError(w, http.StatusBadRequest, "nothing to revoke: provide refresh_token or all=true") writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
} }
// simpleAPIKeyHandler creates an API key directly from a wallet address without signature verification
// This is a simplified flow for development/testing
// Requires: POST { wallet, namespace }
func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) { func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil { if g.authService == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized") writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
return return
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -1159,114 +746,16 @@ func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
ns := strings.TrimSpace(req.Namespace) apiKey, err := g.authService.GetOrCreateAPIKey(r.Context(), req.Wallet, req.Namespace)
if ns == "" { if err != nil {
ns = strings.TrimSpace(g.cfg.ClientNamespace)
if ns == "" {
ns = "default"
}
}
ctx := r.Context()
internalCtx := client.WithInternalAuth(ctx)
db := g.client.Database()
// Resolve or create namespace
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) writeError(w, http.StatusInternalServerError, err.Error())
return return
} }
nres, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
writeError(w, http.StatusInternalServerError, "failed to resolve namespace")
return
}
nsID := nres.Rows[0][0]
// Check if api key already exists for (namespace, wallet)
var apiKey string
r1, err := db.Query(internalCtx,
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
nsID, req.Wallet,
)
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
if s, ok := r1.Rows[0][0].(string); ok {
apiKey = s
} else {
b, _ := json.Marshal(r1.Rows[0][0])
_ = json.Unmarshal(b, &apiKey)
}
}
// If no existing key, create a new one
if strings.TrimSpace(apiKey) == "" {
buf := make([]byte, 18)
if _, err := rand.Read(buf); err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate api key")
return
}
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Link wallet to api key
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
apiKeyID := rid.Rows[0][0]
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
}
}
// Record ownerships (best-effort)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
writeJSON(w, http.StatusOK, map[string]any{ writeJSON(w, http.StatusOK, map[string]any{
"api_key": apiKey, "api_key": apiKey,
"namespace": ns, "namespace": req.Namespace,
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")), "wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
"created": time.Now().Format(time.RFC3339), "created": time.Now().Format(time.RFC3339),
}) })
} }
// base58Decode decodes a base58-encoded string (Bitcoin alphabet)
// Used for decoding Solana public keys (base58-encoded 32-byte ed25519 public keys)
func base58Decode(encoded string) ([]byte, error) {
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
// Build reverse lookup map
lookup := make(map[rune]int)
for i, c := range alphabet {
lookup[c] = i
}
// Convert to big integer
num := big.NewInt(0)
base := big.NewInt(58)
for _, c := range encoded {
val, ok := lookup[c]
if !ok {
return nil, fmt.Errorf("invalid base58 character: %c", c)
}
num.Mul(num, base)
num.Add(num, big.NewInt(int64(val)))
}
// Convert to bytes
decoded := num.Bytes()
// Add leading zeros for each leading '1' in the input
for _, c := range encoded {
if c != '1' {
break
}
decoded = append([]byte{0}, decoded...)
}
return decoded, nil
}

View File

@ -4,23 +4,28 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"database/sql" "database/sql"
"encoding/pem"
"fmt" "fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/client" "github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/config" "github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multiaddr"
olriclib "github.com/olric-data/olric"
"go.uber.org/zap" "go.uber.org/zap"
_ "github.com/rqlite/gorqlite/stdlib" _ "github.com/rqlite/gorqlite/stdlib"
@ -61,13 +66,11 @@ type Config struct {
} }
type Gateway struct { type Gateway struct {
logger *logging.ColoredLogger logger *logging.ColoredLogger
cfg *Config cfg *Config
client client.NetworkClient client client.NetworkClient
nodePeerID string // The node's actual peer ID from its identity file (overrides client's peer ID) nodePeerID string // The node's actual peer ID from its identity file (overrides client's peer ID)
startedAt time.Time startedAt time.Time
signingKey *rsa.PrivateKey
keyID string
// rqlite SQL connection and HTTP ORM gateway // rqlite SQL connection and HTTP ORM gateway
sqlDB *sql.DB sqlDB *sql.DB
@ -83,7 +86,19 @@ type Gateway struct {
// Local pub/sub bypass for same-gateway subscribers // Local pub/sub bypass for same-gateway subscribers
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
presenceMembers map[string][]PresenceMember // topicKey -> members
mu sync.RWMutex mu sync.RWMutex
presenceMu sync.RWMutex
// Serverless function engine
serverlessEngine *serverless.Engine
serverlessRegistry *serverless.Registry
serverlessInvoker *serverless.Invoker
serverlessWSMgr *serverless.WSManager
serverlessHandlers *ServerlessHandlers
// Authentication service
authService *auth.Service
} }
// localSubscriber represents a WebSocket subscriber for local message delivery // localSubscriber represents a WebSocket subscriber for local message delivery
@ -92,6 +107,14 @@ type localSubscriber struct {
namespace string namespace string
} }
// PresenceMember represents a member in a topic's presence list
type PresenceMember struct {
MemberID string `json:"member_id"`
JoinedAt int64 `json:"joined_at"` // Unix timestamp
Meta map[string]interface{} `json:"meta,omitempty"`
ConnID string `json:"-"` // Internal: for tracking which connection
}
// New creates and initializes a new Gateway instance // New creates and initializes a new Gateway instance
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...") logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
@ -128,16 +151,7 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
nodePeerID: cfg.NodePeerID, nodePeerID: cfg.NodePeerID,
startedAt: time.Now(), startedAt: time.Now(),
localSubscribers: make(map[string][]*localSubscriber), localSubscribers: make(map[string][]*localSubscriber),
} presenceMembers: make(map[string][]PresenceMember),
logger.ComponentInfo(logging.ComponentGeneral, "Generating RSA signing key...")
// Generate local RSA signing key for JWKS/JWT (ephemeral for now)
if key, err := rsa.GenerateKey(rand.Reader, 2048); err == nil {
gw.signingKey = key
gw.keyID = "gw-" + strconv.FormatInt(time.Now().Unix(), 10)
logger.ComponentInfo(logging.ComponentGeneral, "RSA key generated successfully")
} else {
logger.ComponentWarn(logging.ComponentGeneral, "failed to generate RSA key; jwks will be empty", zap.Error(err))
} }
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...") logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
@ -298,6 +312,104 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
// Initialize serverless function engine
logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...")
if gw.ormClient != nil && gw.ipfsClient != nil {
// Create serverless registry (stores functions in RQLite + IPFS)
registryCfg := serverless.RegistryConfig{
IPFSAPIURL: ipfsAPIURL,
}
registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger)
gw.serverlessRegistry = registry
// Create WebSocket manager for function streaming
gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger)
// Get underlying Olric client if available
var olricClient olriclib.Client
if oc := gw.getOlricClient(); oc != nil {
olricClient = oc.UnderlyingClient()
}
// Create host functions provider (allows functions to call Orama services)
// Get pubsub adapter from client for serverless functions
var pubsubAdapter *pubsub.ClientAdapter
if gw.client != nil {
if concreteClient, ok := gw.client.(*client.Client); ok {
pubsubAdapter = concreteClient.PubSubAdapter()
if pubsubAdapter != nil {
logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions")
} else {
logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable")
}
}
}
hostFuncsCfg := serverless.HostFunctionsConfig{
IPFSAPIURL: ipfsAPIURL,
HTTPTimeout: 30 * time.Second,
}
hostFuncs := serverless.NewHostFunctions(
gw.ormClient,
olricClient,
gw.ipfsClient,
pubsubAdapter, // pubsub adapter for serverless functions
gw.serverlessWSMgr,
nil, // secrets manager - TODO: implement
hostFuncsCfg,
logger.Logger,
)
// Create WASM engine configuration
engineCfg := serverless.DefaultConfig()
engineCfg.DefaultMemoryLimitMB = 128
engineCfg.MaxMemoryLimitMB = 256
engineCfg.DefaultTimeoutSeconds = 30
engineCfg.MaxTimeoutSeconds = 60
engineCfg.ModuleCacheSize = 100
// Create WASM engine
engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
if engineErr != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr))
} else {
gw.serverlessEngine = engine
// Create invoker
gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger)
// Create HTTP handlers
gw.serverlessHandlers = NewServerlessHandlers(
gw.serverlessInvoker,
registry,
gw.serverlessWSMgr,
logger.Logger,
)
// Initialize auth service
// For now using ephemeral key, can be loaded from config later
key, _ := rsa.GenerateKey(rand.Reader, 2048)
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace)
if err != nil {
logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err))
} else {
gw.authService = authService
}
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB),
zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds),
zap.Int("module_cache_size", engineCfg.ModuleCacheSize),
)
}
} else {
logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled")
}
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...") logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
return gw, nil return gw, nil
} }
@ -309,6 +421,14 @@ func (g *Gateway) withInternalAuth(ctx context.Context) context.Context {
// Close disconnects the gateway client // Close disconnects the gateway client
func (g *Gateway) Close() { func (g *Gateway) Close() {
// Close serverless engine first
if g.serverlessEngine != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := g.serverlessEngine.Close(ctx); err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err))
}
cancel()
}
if g.client != nil { if g.client != nil {
if err := g.client.Disconnect(); err != nil { if err := g.client.Disconnect(); err != nil {
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err)) g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))

View File

@ -3,22 +3,32 @@ package gateway
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing" "testing"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
) )
func TestJWTGenerateAndParse(t *testing.T) { func TestJWTGenerateAndParse(t *testing.T) {
gw := &Gateway{}
key, _ := rsa.GenerateKey(rand.Reader, 2048) key, _ := rsa.GenerateKey(rand.Reader, 2048)
gw.signingKey = key keyPEM := pem.EncodeToMemory(&pem.Block{
gw.keyID = "kid" Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
tok, exp, err := gw.generateJWT("ns1", "subj", time.Minute) svc, err := auth.NewService(nil, nil, string(keyPEM), "default")
if err != nil {
t.Fatalf("failed to create service: %v", err)
}
tok, exp, err := svc.GenerateJWT("ns1", "subj", time.Minute)
if err != nil || exp <= 0 { if err != nil || exp <= 0 {
t.Fatalf("gen err=%v exp=%d", err, exp) t.Fatalf("gen err=%v exp=%d", err, exp)
} }
claims, err := gw.parseAndVerifyJWT(tok) claims, err := svc.ParseAndVerifyJWT(tok)
if err != nil { if err != nil {
t.Fatalf("verify err: %v", err) t.Fatalf("verify err: %v", err)
} }
@ -28,17 +38,23 @@ func TestJWTGenerateAndParse(t *testing.T) {
} }
func TestJWTExpired(t *testing.T) { func TestJWTExpired(t *testing.T) {
gw := &Gateway{}
key, _ := rsa.GenerateKey(rand.Reader, 2048) key, _ := rsa.GenerateKey(rand.Reader, 2048)
gw.signingKey = key keyPEM := pem.EncodeToMemory(&pem.Block{
gw.keyID = "kid" Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
svc, err := auth.NewService(nil, nil, string(keyPEM), "default")
if err != nil {
t.Fatalf("failed to create service: %v", err)
}
// Use sufficiently negative TTL to bypass allowed clock skew // Use sufficiently negative TTL to bypass allowed clock skew
tok, _, err := gw.generateJWT("ns1", "subj", -2*time.Minute) tok, _, err := svc.GenerateJWT("ns1", "subj", -2*time.Minute)
if err != nil { if err != nil {
t.Fatalf("gen err=%v", err) t.Fatalf("gen err=%v", err)
} }
if _, err := gw.parseAndVerifyJWT(tok); err == nil { if _, err := svc.ParseAndVerifyJWT(tok); err == nil {
t.Fatalf("expected expired error") t.Fatalf("expected expired error")
} }
} }

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/DeBrosOfficial/network/pkg/client" "github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -62,11 +63,8 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
// Allow public endpoints without auth
if isPublicPath(r.URL.Path) { isPublic := isPublicPath(r.URL.Path)
next.ServeHTTP(w, r)
return
}
// 1) Try JWT Bearer first if Authorization looks like one // 1) Try JWT Bearer first if Authorization looks like one
if auth := r.Header.Get("Authorization"); auth != "" { if auth := r.Header.Get("Authorization"); auth != "" {
@ -74,7 +72,7 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
if strings.HasPrefix(lower, "bearer ") { if strings.HasPrefix(lower, "bearer ") {
tok := strings.TrimSpace(auth[len("Bearer "):]) tok := strings.TrimSpace(auth[len("Bearer "):])
if strings.Count(tok, ".") == 2 { if strings.Count(tok, ".") == 2 {
if claims, err := g.parseAndVerifyJWT(tok); err == nil { if claims, err := g.authService.ParseAndVerifyJWT(tok); err == nil {
// Attach JWT claims and namespace to context // Attach JWT claims and namespace to context
ctx := context.WithValue(r.Context(), ctxKeyJWT, claims) ctx := context.WithValue(r.Context(), ctxKeyJWT, claims)
if ns := strings.TrimSpace(claims.Namespace); ns != "" { if ns := strings.TrimSpace(claims.Namespace); ns != "" {
@ -91,6 +89,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
// 2) Fallback to API key (validate against DB) // 2) Fallback to API key (validate against DB)
key := extractAPIKey(r) key := extractAPIKey(r)
if key == "" { if key == "" {
if isPublic {
next.ServeHTTP(w, r)
return
}
w.Header().Set("WWW-Authenticate", "Bearer realm=\"gateway\", charset=\"UTF-8\"") w.Header().Set("WWW-Authenticate", "Bearer realm=\"gateway\", charset=\"UTF-8\"")
writeError(w, http.StatusUnauthorized, "missing API key") writeError(w, http.StatusUnauthorized, "missing API key")
return return
@ -104,6 +106,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1" q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
res, err := db.Query(internalCtx, q, key) res, err := db.Query(internalCtx, q, key)
if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 { if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 {
if isPublic {
next.ServeHTTP(w, r)
return
}
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
writeError(w, http.StatusUnauthorized, "invalid API key") writeError(w, http.StatusUnauthorized, "invalid API key")
return return
@ -118,6 +124,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
ns = strings.TrimSpace(ns) ns = strings.TrimSpace(ns)
} }
if ns == "" { if ns == "" {
if isPublic {
next.ServeHTTP(w, r)
return
}
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
writeError(w, http.StatusUnauthorized, "invalid API key") writeError(w, http.StatusUnauthorized, "invalid API key")
return return
@ -183,6 +193,11 @@ func isPublicPath(p string) bool {
return true return true
} }
// Serverless invocation is public (authorization is handled within the invoker)
if strings.HasPrefix(p, "/v1/invoke/") || (strings.HasPrefix(p, "/v1/functions/") && strings.HasSuffix(p, "/invoke")) {
return true
}
switch p { switch p {
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers": case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
return true return true
@ -235,7 +250,7 @@ func (g *Gateway) authorizationMiddleware(next http.Handler) http.Handler {
apiKeyFallback := "" apiKeyFallback := ""
if v := ctx.Value(ctxKeyJWT); v != nil { if v := ctx.Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*jwtClaims); ok && claims != nil && strings.TrimSpace(claims.Sub) != "" { if claims, ok := v.(*auth.JWTClaims); ok && claims != nil && strings.TrimSpace(claims.Sub) != "" {
// Determine subject type. // Determine subject type.
// If subject looks like an API key (e.g., ak_<random>:<namespace>), // If subject looks like an API key (e.g., ak_<random>:<namespace>),
// treat it as an API key owner; otherwise assume a wallet subject. // treat it as an API key owner; otherwise assume a wallet subject.
@ -324,6 +339,9 @@ func requiresNamespaceOwnership(p string) bool {
if strings.HasPrefix(p, "/v1/proxy/") { if strings.HasPrefix(p, "/v1/proxy/") {
return true return true
} }
if strings.HasPrefix(p, "/v1/functions") {
return true
}
return false return false
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/client" "github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/pubsub" "github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -51,6 +52,22 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
writeError(w, http.StatusBadRequest, "missing 'topic'") writeError(w, http.StatusBadRequest, "missing 'topic'")
return return
} }
// Presence handling
enablePresence := r.URL.Query().Get("presence") == "true"
memberID := r.URL.Query().Get("member_id")
memberMetaStr := r.URL.Query().Get("member_meta")
var memberMeta map[string]interface{}
if memberMetaStr != "" {
_ = json.Unmarshal([]byte(memberMetaStr), &memberMeta)
}
if enablePresence && memberID == "" {
g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id")
writeError(w, http.StatusBadRequest, "missing 'member_id' for presence")
return
}
conn, err := wsUpgrader.Upgrade(w, r, nil) conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed") g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
@ -73,6 +90,36 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
subscriberCount := len(g.localSubscribers[topicKey]) subscriberCount := len(g.localSubscribers[topicKey])
g.mu.Unlock() g.mu.Unlock()
connID := uuid.New().String()
if enablePresence {
member := PresenceMember{
MemberID: memberID,
JoinedAt: time.Now().Unix(),
Meta: memberMeta,
ConnID: connID,
}
g.presenceMu.Lock()
g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member)
g.presenceMu.Unlock()
// Broadcast join event
joinEvent := map[string]interface{}{
"type": "presence.join",
"member_id": memberID,
"meta": memberMeta,
"timestamp": member.JoinedAt,
}
eventData, _ := json.Marshal(joinEvent)
// Use a background context for the broadcast to ensure it finishes even if the connection closes immediately
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence",
zap.String("topic", topic),
zap.String("member_id", memberID))
}
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber", g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
zap.String("topic", topic), zap.String("topic", topic),
zap.String("namespace", ns), zap.String("namespace", ns),
@ -93,6 +140,36 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
delete(g.localSubscribers, topicKey) delete(g.localSubscribers, topicKey)
} }
g.mu.Unlock() g.mu.Unlock()
if enablePresence {
g.presenceMu.Lock()
members := g.presenceMembers[topicKey]
for i, m := range members {
if m.ConnID == connID {
g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...)
break
}
}
if len(g.presenceMembers[topicKey]) == 0 {
delete(g.presenceMembers, topicKey)
}
g.presenceMu.Unlock()
// Broadcast leave event
leaveEvent := map[string]interface{}{
"type": "presence.leave",
"member_id": memberID,
"timestamp": time.Now().Unix(),
}
eventData, _ := json.Marshal(leaveEvent)
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
g.logger.ComponentInfo("gateway", "pubsub ws: member left presence",
zap.String("topic", topic),
zap.String("member_id", memberID))
}
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber", g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
zap.String("topic", topic), zap.String("topic", topic),
zap.Int("remaining_subscribers", remainingCount)) zap.Int("remaining_subscribers", remainingCount))
@ -349,3 +426,44 @@ func namespacePrefix(ns string) string {
func namespacedTopic(ns, topic string) string { func namespacedTopic(ns, topic string) string {
return namespacePrefix(ns) + topic return namespacePrefix(ns) + topic
} }
// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic
func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
topic := r.URL.Query().Get("topic")
if topic == "" {
writeError(w, http.StatusBadRequest, "missing 'topic'")
return
}
topicKey := fmt.Sprintf("%s.%s", ns, topic)
g.presenceMu.RLock()
members, ok := g.presenceMembers[topicKey]
g.presenceMu.RUnlock()
if !ok {
writeJSON(w, http.StatusOK, map[string]any{
"topic": topic,
"members": []PresenceMember{},
"count": 0,
})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"topic": topic,
"members": members,
"count": len(members),
})
}

View File

@ -14,8 +14,8 @@ func (g *Gateway) Routes() http.Handler {
mux.HandleFunc("/v1/status", g.statusHandler) mux.HandleFunc("/v1/status", g.statusHandler)
// auth endpoints // auth endpoints
mux.HandleFunc("/v1/auth/jwks", g.jwksHandler) mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
mux.HandleFunc("/.well-known/jwks.json", g.jwksHandler) mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
mux.HandleFunc("/v1/auth/login", g.loginPageHandler) mux.HandleFunc("/v1/auth/login", g.loginPageHandler)
mux.HandleFunc("/v1/auth/challenge", g.challengeHandler) mux.HandleFunc("/v1/auth/challenge", g.challengeHandler)
mux.HandleFunc("/v1/auth/verify", g.verifyHandler) mux.HandleFunc("/v1/auth/verify", g.verifyHandler)
@ -44,6 +44,7 @@ func (g *Gateway) Routes() http.Handler {
mux.HandleFunc("/v1/pubsub/ws", g.pubsubWebsocketHandler) mux.HandleFunc("/v1/pubsub/ws", g.pubsubWebsocketHandler)
mux.HandleFunc("/v1/pubsub/publish", g.pubsubPublishHandler) mux.HandleFunc("/v1/pubsub/publish", g.pubsubPublishHandler)
mux.HandleFunc("/v1/pubsub/topics", g.pubsubTopicsHandler) mux.HandleFunc("/v1/pubsub/topics", g.pubsubTopicsHandler)
mux.HandleFunc("/v1/pubsub/presence", g.pubsubPresenceHandler)
// anon proxy (authenticated users only) // anon proxy (authenticated users only)
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler) mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)
@ -63,5 +64,10 @@ func (g *Gateway) Routes() http.Handler {
mux.HandleFunc("/v1/storage/get/", g.storageGetHandler) mux.HandleFunc("/v1/storage/get/", g.storageGetHandler)
mux.HandleFunc("/v1/storage/unpin/", g.storageUnpinHandler) mux.HandleFunc("/v1/storage/unpin/", g.storageUnpinHandler)
// serverless functions (if enabled)
if g.serverlessHandlers != nil {
g.serverlessHandlers.RegisterRoutes(mux)
}
return g.withMiddleware(mux) return g.withMiddleware(mux)
} }

View File

@ -0,0 +1,694 @@
package gateway
import (
"context"
"encoding/json"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// ServerlessHandlers contains handlers for serverless function endpoints.
// It's a separate struct to keep the Gateway struct clean.
type ServerlessHandlers struct {
invoker *serverless.Invoker
registry serverless.FunctionRegistry
wsManager *serverless.WSManager
logger *zap.Logger
}
// NewServerlessHandlers creates a new ServerlessHandlers instance.
func NewServerlessHandlers(
invoker *serverless.Invoker,
registry serverless.FunctionRegistry,
wsManager *serverless.WSManager,
logger *zap.Logger,
) *ServerlessHandlers {
return &ServerlessHandlers{
invoker: invoker,
registry: registry,
wsManager: wsManager,
logger: logger,
}
}
// RegisterRoutes registers all serverless routes on the given mux.
func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) {
// Function management
mux.HandleFunc("/v1/functions", h.handleFunctions)
mux.HandleFunc("/v1/functions/", h.handleFunctionByName)
// Direct invoke endpoint
mux.HandleFunc("/v1/invoke/", h.handleInvoke)
}
// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy)
func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listFunctions(w, r)
case http.MethodPost:
h.deployFunction(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleFunctionByName handles operations on a specific function
// Routes:
// - GET /v1/functions/{name} - Get function info
// - DELETE /v1/functions/{name} - Delete function
// - POST /v1/functions/{name}/invoke - Invoke function
// - GET /v1/functions/{name}/versions - List versions
// - GET /v1/functions/{name}/logs - Get logs
// - WS /v1/functions/{name}/ws - WebSocket invoke
func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) {
// Parse path: /v1/functions/{name}[/{action}]
path := strings.TrimPrefix(r.URL.Path, "/v1/functions/")
parts := strings.SplitN(path, "/", 2)
if len(parts) == 0 || parts[0] == "" {
http.Error(w, "Function name required", http.StatusBadRequest)
return
}
name := parts[0]
action := ""
if len(parts) > 1 {
action = parts[1]
}
// Parse version from name if present (e.g., "myfunction@2")
version := 0
if idx := strings.Index(name, "@"); idx > 0 {
vStr := name[idx+1:]
name = name[:idx]
if v, err := strconv.Atoi(vStr); err == nil {
version = v
}
}
switch action {
case "invoke":
h.invokeFunction(w, r, name, version)
case "ws":
h.handleWebSocket(w, r, name, version)
case "versions":
h.listVersions(w, r, name)
case "logs":
h.getFunctionLogs(w, r, name)
case "":
switch r.Method {
case http.MethodGet:
h.getFunctionInfo(w, r, name, version)
case http.MethodDelete:
h.deleteFunction(w, r, name, version)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
default:
http.Error(w, "Unknown action", http.StatusNotFound)
}
}
// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version]
func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse path: /v1/invoke/{namespace}/{name}[@version]
path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/")
parts := strings.SplitN(path, "/", 2)
if len(parts) < 2 {
http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest)
return
}
namespace := parts[0]
name := parts[1]
// Parse version if present
version := 0
if idx := strings.Index(name, "@"); idx > 0 {
vStr := name[idx+1:]
name = name[:idx]
if v, err := strconv.Atoi(vStr); err == nil {
version = v
}
}
h.invokeFunction(w, r, namespace+"/"+name, version)
}
// listFunctions handles GET /v1/functions
func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
// Get namespace from JWT if available
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
functions, err := h.registry.List(ctx, namespace)
if err != nil {
h.logger.Error("Failed to list functions", zap.Error(err))
writeError(w, http.StatusInternalServerError, "Failed to list functions")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"functions": functions,
"count": len(functions),
})
}
// deployFunction handles POST /v1/functions
func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) {
// Parse multipart form (for WASM upload) or JSON
contentType := r.Header.Get("Content-Type")
var def serverless.FunctionDefinition
var wasmBytes []byte
if strings.HasPrefix(contentType, "multipart/form-data") {
// Parse multipart form
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
return
}
// Get metadata from form field
metadataStr := r.FormValue("metadata")
if metadataStr != "" {
if err := json.Unmarshal([]byte(metadataStr), &def); err != nil {
writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error())
return
}
}
// Get name from form if not in metadata
if def.Name == "" {
def.Name = r.FormValue("name")
}
// Get namespace from form if not in metadata
if def.Namespace == "" {
def.Namespace = r.FormValue("namespace")
}
// Get other configuration fields from form
if v := r.FormValue("is_public"); v != "" {
def.IsPublic, _ = strconv.ParseBool(v)
}
if v := r.FormValue("memory_limit_mb"); v != "" {
def.MemoryLimitMB, _ = strconv.Atoi(v)
}
if v := r.FormValue("timeout_seconds"); v != "" {
def.TimeoutSeconds, _ = strconv.Atoi(v)
}
if v := r.FormValue("retry_count"); v != "" {
def.RetryCount, _ = strconv.Atoi(v)
}
if v := r.FormValue("retry_delay_seconds"); v != "" {
def.RetryDelaySeconds, _ = strconv.Atoi(v)
}
// Get WASM file
file, _, err := r.FormFile("wasm")
if err != nil {
writeError(w, http.StatusBadRequest, "WASM file required")
return
}
defer file.Close()
wasmBytes, err = io.ReadAll(file)
if err != nil {
writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error())
return
}
} else {
// JSON body with base64-encoded WASM
var req struct {
serverless.FunctionDefinition
WASMBase64 string `json:"wasm_base64"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
return
}
def = req.FunctionDefinition
if req.WASMBase64 != "" {
// Decode base64 WASM - for now, just reject this method
writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data")
return
}
}
// Get namespace from JWT if not provided
if def.Namespace == "" {
def.Namespace = h.getNamespaceFromRequest(r)
}
if def.Name == "" {
writeError(w, http.StatusBadRequest, "Function name required")
return
}
if def.Namespace == "" {
writeError(w, http.StatusBadRequest, "Namespace required")
return
}
if len(wasmBytes) == 0 {
writeError(w, http.StatusBadRequest, "WASM bytecode required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
defer cancel()
oldFn, err := h.registry.Register(ctx, &def, wasmBytes)
if err != nil {
h.logger.Error("Failed to deploy function",
zap.String("name", def.Name),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error())
return
}
// Invalidate cache for the old version to ensure the new one is loaded
if oldFn != nil {
h.invoker.InvalidateCache(oldFn.WASMCID)
h.logger.Debug("Invalidated function cache",
zap.String("name", def.Name),
zap.String("old_wasm_cid", oldFn.WASMCID),
)
}
h.logger.Info("Function deployed",
zap.String("name", def.Name),
zap.String("namespace", def.Namespace),
)
// Fetch the deployed function to return
fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version)
if err != nil {
writeJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Function deployed successfully",
"name": def.Name,
})
return
}
writeJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Function deployed successfully",
"function": fn,
})
}
// getFunctionInfo handles GET /v1/functions/{name}
func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
fn, err := h.registry.Get(ctx, namespace, name, version)
if err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
} else {
writeError(w, http.StatusInternalServerError, "Failed to get function")
}
return
}
writeJSON(w, http.StatusOK, fn)
}
// deleteFunction handles DELETE /v1/functions/{name}
func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
if err := h.registry.Delete(ctx, namespace, name, version); err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
} else {
writeError(w, http.StatusInternalServerError, "Failed to delete function")
}
return
}
writeJSON(w, http.StatusOK, map[string]string{
"message": "Function deleted successfully",
})
}
// invokeFunction handles POST /v1/functions/{name}/invoke
func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse namespace and name
var namespace, name string
if idx := strings.Index(nameWithNS, "/"); idx > 0 {
namespace = nameWithNS[:idx]
name = nameWithNS[idx+1:]
} else {
name = nameWithNS
namespace = r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
// Read input body
input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max
if err != nil {
writeError(w, http.StatusBadRequest, "Failed to read request body")
return
}
// Get caller wallet from JWT
callerWallet := h.getWalletFromRequest(r)
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
defer cancel()
req := &serverless.InvokeRequest{
Namespace: namespace,
FunctionName: name,
Version: version,
Input: input,
TriggerType: serverless.TriggerTypeHTTP,
CallerWallet: callerWallet,
}
resp, err := h.invoker.Invoke(ctx, req)
if err != nil {
statusCode := http.StatusInternalServerError
if serverless.IsNotFound(err) {
statusCode = http.StatusNotFound
} else if serverless.IsResourceExhausted(err) {
statusCode = http.StatusTooManyRequests
} else if serverless.IsUnauthorized(err) {
statusCode = http.StatusUnauthorized
}
writeJSON(w, statusCode, map[string]interface{}{
"request_id": resp.RequestID,
"status": resp.Status,
"error": resp.Error,
"duration_ms": resp.DurationMS,
})
return
}
// Return the function's output directly if it's JSON
w.Header().Set("X-Request-ID", resp.RequestID)
w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10))
// Try to detect if output is JSON
if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resp.Output)
} else {
writeJSON(w, http.StatusOK, map[string]interface{}{
"request_id": resp.RequestID,
"output": string(resp.Output),
"status": resp.Status,
"duration_ms": resp.DurationMS,
})
}
}
// handleWebSocket handles WebSocket connections for function streaming
func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
http.Error(w, "namespace required", http.StatusBadRequest)
return
}
// Upgrade to WebSocket
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
clientID := uuid.New().String()
wsConn := &serverless.GorillaWSConn{Conn: conn}
// Register connection
h.wsManager.Register(clientID, wsConn)
defer h.wsManager.Unregister(clientID)
h.logger.Info("WebSocket connected",
zap.String("client_id", clientID),
zap.String("function", name),
)
callerWallet := h.getWalletFromRequest(r)
// Message loop
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.logger.Warn("WebSocket error", zap.Error(err))
}
break
}
// Invoke function with WebSocket context
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
req := &serverless.InvokeRequest{
Namespace: namespace,
FunctionName: name,
Version: version,
Input: message,
TriggerType: serverless.TriggerTypeWebSocket,
CallerWallet: callerWallet,
WSClientID: clientID,
}
resp, err := h.invoker.Invoke(ctx, req)
cancel()
// Send response back
response := map[string]interface{}{
"request_id": resp.RequestID,
"status": resp.Status,
"duration_ms": resp.DurationMS,
}
if err != nil {
response["error"] = resp.Error
} else if len(resp.Output) > 0 {
// Try to parse output as JSON
var output interface{}
if json.Unmarshal(resp.Output, &output) == nil {
response["output"] = output
} else {
response["output"] = string(resp.Output)
}
}
respBytes, _ := json.Marshal(response)
if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil {
break
}
}
}
// listVersions handles GET /v1/functions/{name}/versions
func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Get registry with extended methods
reg, ok := h.registry.(*serverless.Registry)
if !ok {
writeError(w, http.StatusNotImplemented, "Version listing not supported")
return
}
versions, err := reg.ListVersions(ctx, namespace, name)
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to list versions")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"versions": versions,
"count": len(versions),
})
}
// getFunctionLogs handles GET /v1/functions/{name}/logs
func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
limit := 100
if lStr := r.URL.Query().Get("limit"); lStr != "" {
if l, err := strconv.Atoi(lStr); err == nil {
limit = l
}
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
logs, err := h.registry.GetLogs(ctx, namespace, name, limit)
if err != nil {
h.logger.Error("Failed to get function logs",
zap.String("name", name),
zap.String("namespace", namespace),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to get logs")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"name": name,
"namespace": namespace,
"logs": logs,
"count": len(logs),
})
}
// getNamespaceFromRequest extracts namespace from JWT or query param
func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string {
// Try context first (set by auth middleware) - most secure
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
if ns, ok := v.(string); ok && ns != "" {
return ns
}
}
// Try query param as fallback (e.g. for public access or admin)
if ns := r.URL.Query().Get("namespace"); ns != "" {
return ns
}
// Try header as fallback
if ns := r.Header.Get("X-Namespace"); ns != "" {
return ns
}
return "default"
}
// getWalletFromRequest extracts wallet address from JWT
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
// 1. Try X-Wallet header (legacy/direct bypass)
if wallet := r.Header.Get("X-Wallet"); wallet != "" {
return wallet
}
// 2. Try JWT claims from context
if v := r.Context().Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
subj := strings.TrimSpace(claims.Sub)
// Ensure it's not an API key (standard Orama logic)
if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") {
return subj
}
}
}
// 3. Fallback to API key identity (namespace)
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
if ns, ok := v.(string); ok && ns != "" {
return ns
}
}
return ""
}
// HealthStatus returns the health status of the serverless engine
func (h *ServerlessHandlers) HealthStatus() map[string]interface{} {
stats := h.wsManager.GetStats()
return map[string]interface{}{
"status": "ok",
"connections": stats.ConnectionCount,
"topics": stats.TopicCount,
}
}

View File

@ -0,0 +1,88 @@
package gateway
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/DeBrosOfficial/network/pkg/serverless"
"go.uber.org/zap"
)
type mockFunctionRegistry struct {
functions []*serverless.Function
}
func (m *mockFunctionRegistry) Register(ctx context.Context, fn *serverless.FunctionDefinition, wasmBytes []byte) (*serverless.Function, error) {
return nil, nil
}
func (m *mockFunctionRegistry) Get(ctx context.Context, namespace, name string, version int) (*serverless.Function, error) {
return &serverless.Function{ID: "1", Name: name, Namespace: namespace}, nil
}
func (m *mockFunctionRegistry) List(ctx context.Context, namespace string) ([]*serverless.Function, error) {
return m.functions, nil
}
func (m *mockFunctionRegistry) Delete(ctx context.Context, namespace, name string, version int) error {
return nil
}
func (m *mockFunctionRegistry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
return []byte("wasm"), nil
}
func (m *mockFunctionRegistry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]serverless.LogEntry, error) {
return []serverless.LogEntry{}, nil
}
func TestServerlessHandlers_ListFunctions(t *testing.T) {
logger := zap.NewNop()
registry := &mockFunctionRegistry{
functions: []*serverless.Function{
{ID: "1", Name: "func1", Namespace: "ns1"},
{ID: "2", Name: "func2", Namespace: "ns1"},
},
}
h := NewServerlessHandlers(nil, registry, nil, logger)
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
rr := httptest.NewRecorder()
h.handleFunctions(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
}
var resp map[string]interface{}
json.Unmarshal(rr.Body.Bytes(), &resp)
if resp["count"].(float64) != 2 {
t.Errorf("expected 2 functions, got %v", resp["count"])
}
}
func TestServerlessHandlers_DeployFunction(t *testing.T) {
logger := zap.NewNop()
registry := &mockFunctionRegistry{}
h := NewServerlessHandlers(nil, registry, nil, logger)
// Test JSON deploy (which is partially supported according to code)
// Should be 400 because WASM is missing or base64 not supported
writer := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
req.Header.Set("Content-Type", "application/json")
h.handleFunctions(writer, req)
if writer.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", writer.Code)
}
}

View File

@ -228,7 +228,12 @@ func (g *Gateway) storageStatusHandler(w http.ResponseWriter, r *http.Request) {
status, err := g.ipfsClient.PinStatus(ctx, path) status, err := g.ipfsClient.PinStatus(ctx, path)
if err != nil { if err != nil {
g.logger.ComponentError(logging.ComponentGeneral, "failed to get pin status", zap.Error(err), zap.String("cid", path)) g.logger.ComponentError(logging.ComponentGeneral, "failed to get pin status", zap.Error(err), zap.String("cid", path))
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get status: %v", err)) errStr := strings.ToLower(err.Error())
if strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") || strings.Contains(errStr, "invalid") {
writeError(w, http.StatusNotFound, fmt.Sprintf("pin not found: %s", path))
} else {
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get status: %v", err))
}
return return
} }
@ -283,7 +288,8 @@ func (g *Gateway) storageGetHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
g.logger.ComponentError(logging.ComponentGeneral, "failed to get content from IPFS", zap.Error(err), zap.String("cid", path)) g.logger.ComponentError(logging.ComponentGeneral, "failed to get content from IPFS", zap.Error(err), zap.String("cid", path))
// Check if error indicates content not found (404) // Check if error indicates content not found (404)
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "status 404") { errStr := strings.ToLower(err.Error())
if strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") || strings.Contains(errStr, "invalid") {
writeError(w, http.StatusNotFound, fmt.Sprintf("content not found: %s", path)) writeError(w, http.StatusNotFound, fmt.Sprintf("content not found: %s", path))
} else { } else {
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get content: %v", err)) writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get content: %v", err))

View File

@ -17,6 +17,7 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/DeBrosOfficial/network/pkg/certutil" "github.com/DeBrosOfficial/network/pkg/certutil"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/tlsutil" "github.com/DeBrosOfficial/network/pkg/tlsutil"
) )
@ -338,7 +339,7 @@ func (m *Model) handleEnter() (tea.Model, tea.Cmd) {
case StepSwarmKey: case StepSwarmKey:
swarmKey := strings.TrimSpace(m.textInput.Value()) swarmKey := strings.TrimSpace(m.textInput.Value())
if err := validateSwarmKey(swarmKey); err != nil { if err := config.ValidateSwarmKey(swarmKey); err != nil {
m.err = err m.err = err
return m, nil return m, nil
} }
@ -816,17 +817,6 @@ func validateClusterSecret(secret string) error {
return nil return nil
} }
func validateSwarmKey(key string) error {
if len(key) != 64 {
return fmt.Errorf("swarm key must be 64 hex characters")
}
keyRegex := regexp.MustCompile(`^[a-fA-F0-9]{64}$`)
if !keyRegex.MatchString(key) {
return fmt.Errorf("swarm key must be valid hexadecimal")
}
return nil
}
// ensureCertificatesForDomain generates self-signed certificates for the domain // ensureCertificatesForDomain generates self-signed certificates for the domain
func ensureCertificatesForDomain(domain string) error { func ensureCertificatesForDomain(domain string) error {
// Get home directory // Get home directory

File diff suppressed because it is too large Load Diff

136
pkg/ipfs/cluster_config.go Normal file
View File

@ -0,0 +1,136 @@
package ipfs
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
// ClusterServiceConfig represents the service.json configuration
type ClusterServiceConfig struct {
Cluster struct {
Peername string `json:"peername"`
Secret string `json:"secret"`
ListenMultiaddress []string `json:"listen_multiaddress"`
PeerAddresses []string `json:"peer_addresses"`
LeaveOnShutdown bool `json:"leave_on_shutdown"`
} `json:"cluster"`
Consensus struct {
CRDT struct {
ClusterName string `json:"cluster_name"`
TrustedPeers []string `json:"trusted_peers"`
Batching struct {
MaxBatchSize int `json:"max_batch_size"`
MaxBatchAge string `json:"max_batch_age"`
} `json:"batching"`
RepairInterval string `json:"repair_interval"`
} `json:"crdt"`
} `json:"consensus"`
API struct {
RestAPI struct {
HTTPListenMultiaddress string `json:"http_listen_multiaddress"`
} `json:"restapi"`
IPFSProxy struct {
ListenMultiaddress string `json:"listen_multiaddress"`
NodeMultiaddress string `json:"node_multiaddress"`
} `json:"ipfsproxy"`
PinSvcAPI struct {
HTTPListenMultiaddress string `json:"http_listen_multiaddress"`
} `json:"pinsvcapi"`
} `json:"api"`
IPFSConnector struct {
IPFSHTTP struct {
NodeMultiaddress string `json:"node_multiaddress"`
} `json:"ipfshttp"`
} `json:"ipfs_connector"`
Raw map[string]interface{} `json:"-"`
}
func (cm *ClusterConfigManager) loadOrCreateConfig(path string) (*ClusterServiceConfig, error) {
if _, err := os.Stat(path); os.IsNotExist(err) {
return cm.createTemplateConfig(), nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read service.json: %w", err)
}
var cfg ClusterServiceConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse service.json: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse raw service.json: %w", err)
}
cfg.Raw = raw
return &cfg, nil
}
func (cm *ClusterConfigManager) saveConfig(path string, cfg *ClusterServiceConfig) error {
cm.updateNestedMap(cfg.Raw, "cluster", "peername", cfg.Cluster.Peername)
cm.updateNestedMap(cfg.Raw, "cluster", "secret", cfg.Cluster.Secret)
cm.updateNestedMap(cfg.Raw, "cluster", "listen_multiaddress", cfg.Cluster.ListenMultiaddress)
cm.updateNestedMap(cfg.Raw, "cluster", "peer_addresses", cfg.Cluster.PeerAddresses)
cm.updateNestedMap(cfg.Raw, "cluster", "leave_on_shutdown", cfg.Cluster.LeaveOnShutdown)
consensus := cm.ensureRequiredSection(cfg.Raw, "consensus")
crdt := cm.ensureRequiredSection(consensus, "crdt")
crdt["cluster_name"] = cfg.Consensus.CRDT.ClusterName
crdt["trusted_peers"] = cfg.Consensus.CRDT.TrustedPeers
crdt["repair_interval"] = cfg.Consensus.CRDT.RepairInterval
batching := cm.ensureRequiredSection(crdt, "batching")
batching["max_batch_size"] = cfg.Consensus.CRDT.Batching.MaxBatchSize
batching["max_batch_age"] = cfg.Consensus.CRDT.Batching.MaxBatchAge
api := cm.ensureRequiredSection(cfg.Raw, "api")
restapi := cm.ensureRequiredSection(api, "restapi")
restapi["http_listen_multiaddress"] = cfg.API.RestAPI.HTTPListenMultiaddress
ipfsproxy := cm.ensureRequiredSection(api, "ipfsproxy")
ipfsproxy["listen_multiaddress"] = cfg.API.IPFSProxy.ListenMultiaddress
ipfsproxy["node_multiaddress"] = cfg.API.IPFSProxy.NodeMultiaddress
pinsvcapi := cm.ensureRequiredSection(api, "pinsvcapi")
pinsvcapi["http_listen_multiaddress"] = cfg.API.PinSvcAPI.HTTPListenMultiaddress
ipfsConn := cm.ensureRequiredSection(cfg.Raw, "ipfs_connector")
ipfsHttp := cm.ensureRequiredSection(ipfsConn, "ipfshttp")
ipfsHttp["node_multiaddress"] = cfg.IPFSConnector.IPFSHTTP.NodeMultiaddress
data, err := json.MarshalIndent(cfg.Raw, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal service.json: %w", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
return os.WriteFile(path, data, 0644)
}
func (cm *ClusterConfigManager) updateNestedMap(m map[string]interface{}, section, key string, val interface{}) {
if _, ok := m[section]; !ok {
m[section] = make(map[string]interface{})
}
s := m[section].(map[string]interface{})
s[key] = val
}
func (cm *ClusterConfigManager) ensureRequiredSection(m map[string]interface{}, key string) map[string]interface{} {
if _, ok := m[key]; !ok {
m[key] = make(map[string]interface{})
}
return m[key].(map[string]interface{})
}

156
pkg/ipfs/cluster_peer.go Normal file
View File

@ -0,0 +1,156 @@
package ipfs
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/multiformats/go-multiaddr"
"go.uber.org/zap"
)
// UpdatePeerAddresses updates the peer_addresses in service.json with given multiaddresses
func (cm *ClusterConfigManager) UpdatePeerAddresses(addrs []string) error {
serviceJSONPath := filepath.Join(cm.clusterPath, "service.json")
cfg, err := cm.loadOrCreateConfig(serviceJSONPath)
if err != nil {
return err
}
seen := make(map[string]bool)
uniqueAddrs := []string{}
for _, addr := range addrs {
if !seen[addr] {
uniqueAddrs = append(uniqueAddrs, addr)
seen[addr] = true
}
}
cfg.Cluster.PeerAddresses = uniqueAddrs
return cm.saveConfig(serviceJSONPath, cfg)
}
// UpdateAllClusterPeers discovers all cluster peers from the gateway and updates local config
func (cm *ClusterConfigManager) UpdateAllClusterPeers() error {
peers, err := cm.DiscoverClusterPeersFromGateway()
if err != nil {
return fmt.Errorf("failed to discover cluster peers: %w", err)
}
if len(peers) == 0 {
return nil
}
peerAddrs := []string{}
for _, p := range peers {
peerAddrs = append(peerAddrs, p.Multiaddress)
}
return cm.UpdatePeerAddresses(peerAddrs)
}
// RepairPeerConfiguration attempts to fix configuration issues and re-synchronize peers
func (cm *ClusterConfigManager) RepairPeerConfiguration() error {
cm.logger.Info("Attempting to repair IPFS Cluster peer configuration")
_ = cm.FixIPFSConfigAddresses()
peers, err := cm.DiscoverClusterPeersFromGateway()
if err != nil {
cm.logger.Warn("Could not discover peers from gateway during repair", zap.Error(err))
} else {
peerAddrs := []string{}
for _, p := range peers {
peerAddrs = append(peerAddrs, p.Multiaddress)
}
if len(peerAddrs) > 0 {
_ = cm.UpdatePeerAddresses(peerAddrs)
}
}
return nil
}
// DiscoverClusterPeersFromGateway queries the central gateway for registered IPFS Cluster peers
func (cm *ClusterConfigManager) DiscoverClusterPeersFromGateway() ([]ClusterPeerInfo, error) {
// Not implemented - would require a central gateway URL in config
return nil, nil
}
// DiscoverClusterPeersFromLibP2P uses libp2p host to find other cluster peers
func (cm *ClusterConfigManager) DiscoverClusterPeersFromLibP2P(h host.Host) error {
if h == nil {
return nil
}
var clusterPeers []string
for _, p := range h.Peerstore().Peers() {
if p == h.ID() {
continue
}
info := h.Peerstore().PeerInfo(p)
for _, addr := range info.Addrs {
if strings.Contains(addr.String(), "/tcp/9096") || strings.Contains(addr.String(), "/tcp/9094") {
ma := addr.Encapsulate(multiaddr.StringCast(fmt.Sprintf("/p2p/%s", p.String())))
clusterPeers = append(clusterPeers, ma.String())
}
}
}
if len(clusterPeers) > 0 {
return cm.UpdatePeerAddresses(clusterPeers)
}
return nil
}
func (cm *ClusterConfigManager) getPeerID() (string, error) {
dataDir := cm.cfg.Node.DataDir
if strings.HasPrefix(dataDir, "~") {
home, _ := os.UserHomeDir()
dataDir = filepath.Join(home, dataDir[1:])
}
possiblePaths := []string{
filepath.Join(dataDir, "ipfs", "repo"),
filepath.Join(dataDir, "node-1", "ipfs", "repo"),
filepath.Join(dataDir, "node-2", "ipfs", "repo"),
filepath.Join(filepath.Dir(dataDir), "node-1", "ipfs", "repo"),
filepath.Join(filepath.Dir(dataDir), "node-2", "ipfs", "repo"),
}
var ipfsRepoPath string
for _, path := range possiblePaths {
if _, err := os.Stat(filepath.Join(path, "config")); err == nil {
ipfsRepoPath = path
break
}
}
if ipfsRepoPath == "" {
return "", fmt.Errorf("could not find IPFS repo path")
}
idCmd := exec.Command("ipfs", "id", "-f", "<id>")
idCmd.Env = append(os.Environ(), "IPFS_PATH="+ipfsRepoPath)
out, err := idCmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(out)), nil
}
// ClusterPeerInfo represents information about an IPFS Cluster peer
type ClusterPeerInfo struct {
ID string `json:"id"`
Multiaddress string `json:"multiaddress"`
NodeName string `json:"node_name"`
LastSeen time.Time `json:"last_seen"`
}

119
pkg/ipfs/cluster_util.go Normal file
View File

@ -0,0 +1,119 @@
package ipfs
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
func loadOrGenerateClusterSecret(path string) (string, error) {
if data, err := os.ReadFile(path); err == nil {
secret := strings.TrimSpace(string(data))
if len(secret) == 64 {
return secret, nil
}
}
secret, err := generateRandomSecret()
if err != nil {
return "", err
}
_ = os.WriteFile(path, []byte(secret), 0600)
return secret, nil
}
func generateRandomSecret() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func parseClusterPorts(rawURL string) (int, int, error) {
if !strings.HasPrefix(rawURL, "http") {
rawURL = "http://" + rawURL
}
u, err := url.Parse(rawURL)
if err != nil {
return 9096, 9094, nil
}
_, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
return 9096, 9094, nil
}
var port int
fmt.Sscanf(portStr, "%d", &port)
if port == 0 {
return 9096, 9094, nil
}
return port + 2, port, nil
}
func parseIPFSPort(rawURL string) (int, error) {
if !strings.HasPrefix(rawURL, "http") {
rawURL = "http://" + rawURL
}
u, err := url.Parse(rawURL)
if err != nil {
return 5001, nil
}
_, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
return 5001, nil
}
var port int
fmt.Sscanf(portStr, "%d", &port)
if port == 0 {
return 5001, nil
}
return port, nil
}
func parsePeerHostAndPort(multiaddr string) (string, int) {
parts := strings.Split(multiaddr, "/")
var hostStr string
var port int
for i, part := range parts {
if part == "ip4" || part == "dns" || part == "dns4" {
hostStr = parts[i+1]
} else if part == "tcp" {
fmt.Sscanf(parts[i+1], "%d", &port)
}
}
return hostStr, port
}
func extractIPFromMultiaddrForCluster(maddr string) string {
parts := strings.Split(maddr, "/")
for i, part := range parts {
if (part == "ip4" || part == "dns" || part == "dns4") && i+1 < len(parts) {
return parts[i+1]
}
}
return ""
}
func extractDomainFromMultiaddr(maddr string) string {
parts := strings.Split(maddr, "/")
for i, part := range parts {
if (part == "dns" || part == "dns4" || part == "dns6") && i+1 < len(parts) {
return parts[i+1]
}
}
return ""
}
func newStandardHTTPClient() *http.Client {
return &http.Client{
Timeout: 10 * time.Second,
}
}

204
pkg/node/gateway.go Normal file
View File

@ -0,0 +1,204 @@
package node
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"github.com/DeBrosOfficial/network/pkg/gateway"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/logging"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
// startHTTPGateway initializes and starts the full API gateway
func (n *Node) startHTTPGateway(ctx context.Context) error {
if !n.config.HTTPGateway.Enabled {
n.logger.ComponentInfo(logging.ComponentNode, "HTTP Gateway disabled in config")
return nil
}
logFile := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "..", "logs", "gateway.log")
logsDir := filepath.Dir(logFile)
_ = os.MkdirAll(logsDir, 0755)
gatewayLogger, err := logging.NewFileLogger(logging.ComponentGeneral, logFile, false)
if err != nil {
return err
}
gwCfg := &gateway.Config{
ListenAddr: n.config.HTTPGateway.ListenAddr,
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
BootstrapPeers: n.config.Discovery.BootstrapPeers,
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
OlricServers: n.config.HTTPGateway.OlricServers,
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
DomainName: n.config.HTTPGateway.HTTPS.Domain,
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
}
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
if err != nil {
return err
}
n.apiGateway = apiGateway
var certManager *autocert.Manager
if gwCfg.EnableHTTPS && gwCfg.DomainName != "" {
tlsCacheDir := gwCfg.TLSCacheDir
if tlsCacheDir == "" {
tlsCacheDir = "/home/debros/.orama/tls-cache"
}
_ = os.MkdirAll(tlsCacheDir, 0700)
certManager = &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(gwCfg.DomainName),
Cache: autocert.DirCache(tlsCacheDir),
Email: fmt.Sprintf("admin@%s", gwCfg.DomainName),
Client: &acme.Client{
DirectoryURL: "https://acme-staging-v02.api.letsencrypt.org/directory",
},
}
n.certManager = certManager
n.certReady = make(chan struct{})
}
httpReady := make(chan struct{})
go func() {
if gwCfg.EnableHTTPS && gwCfg.DomainName != "" && certManager != nil {
httpsPort := 443
httpPort := 80
httpServer := &http.Server{
Addr: fmt.Sprintf(":%d", httpPort),
Handler: certManager.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
target := fmt.Sprintf("https://%s%s", r.Host, r.URL.RequestURI())
http.Redirect(w, r, target, http.StatusMovedPermanently)
})),
}
httpListener, err := net.Listen("tcp", fmt.Sprintf(":%d", httpPort))
if err != nil {
close(httpReady)
return
}
go httpServer.Serve(httpListener)
// Pre-provision cert
certReq := &tls.ClientHelloInfo{ServerName: gwCfg.DomainName}
_, certErr := certManager.GetCertificate(certReq)
if certErr != nil {
close(httpReady)
httpServer.Handler = apiGateway.Routes()
return
}
close(httpReady)
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: certManager.GetCertificate,
}
httpsServer := &http.Server{
Addr: fmt.Sprintf(":%d", httpsPort),
TLSConfig: tlsConfig,
Handler: apiGateway.Routes(),
}
n.apiGatewayServer = httpsServer
ln, err := tls.Listen("tcp", fmt.Sprintf(":%d", httpsPort), tlsConfig)
if err == nil {
httpsServer.Serve(ln)
}
} else {
close(httpReady)
server := &http.Server{
Addr: gwCfg.ListenAddr,
Handler: apiGateway.Routes(),
}
n.apiGatewayServer = server
ln, err := net.Listen("tcp", gwCfg.ListenAddr)
if err == nil {
server.Serve(ln)
}
}
}()
// SNI Gateway
if n.config.HTTPGateway.SNI.Enabled && n.certManager != nil {
go n.startSNIGateway(ctx, httpReady)
}
return nil
}
func (n *Node) startSNIGateway(ctx context.Context, httpReady <-chan struct{}) {
<-httpReady
domain := n.config.HTTPGateway.HTTPS.Domain
if domain == "" {
return
}
certReq := &tls.ClientHelloInfo{ServerName: domain}
tlsCert, err := n.certManager.GetCertificate(certReq)
if err != nil {
return
}
tlsCacheDir := n.config.HTTPGateway.HTTPS.CacheDir
if tlsCacheDir == "" {
tlsCacheDir = "/home/debros/.orama/tls-cache"
}
certPath := filepath.Join(tlsCacheDir, domain+".crt")
keyPath := filepath.Join(tlsCacheDir, domain+".key")
if err := extractPEMFromTLSCert(tlsCert, certPath, keyPath); err == nil {
if n.certReady != nil {
close(n.certReady)
}
}
sniCfg := n.config.HTTPGateway.SNI
sniGateway, err := gateway.NewTCPSNIGateway(n.logger, &sniCfg)
if err == nil {
n.sniGateway = sniGateway
sniGateway.Start(ctx)
}
}
// startIPFSClusterConfig initializes and ensures IPFS Cluster configuration
func (n *Node) startIPFSClusterConfig() error {
n.logger.ComponentInfo(logging.ComponentNode, "Initializing IPFS Cluster configuration")
cm, err := ipfs.NewClusterConfigManager(n.config, n.logger.Logger)
if err != nil {
return err
}
n.clusterConfigManager = cm
_ = cm.FixIPFSConfigAddresses()
if err := cm.EnsureConfig(); err != nil {
return err
}
_ = cm.RepairPeerConfiguration()
return nil
}

302
pkg/node/libp2p.go Normal file
View File

@ -0,0 +1,302 @@
package node
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/discovery"
"github.com/DeBrosOfficial/network/pkg/encryption"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/libp2p/go-libp2p"
libp2ppubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
noise "github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/multiformats/go-multiaddr"
"go.uber.org/zap"
)
// startLibP2P initializes the LibP2P host
func (n *Node) startLibP2P() error {
n.logger.ComponentInfo(logging.ComponentLibP2P, "Starting LibP2P host")
// Load or create persistent identity
identity, err := n.loadOrCreateIdentity()
if err != nil {
return fmt.Errorf("failed to load identity: %w", err)
}
// Create LibP2P host with explicit listen addresses
var opts []libp2p.Option
opts = append(opts,
libp2p.Identity(identity),
libp2p.Security(noise.ID, noise.New),
libp2p.DefaultMuxers,
)
// Add explicit listen addresses from config
if len(n.config.Node.ListenAddresses) > 0 {
listenAddrs := make([]multiaddr.Multiaddr, 0, len(n.config.Node.ListenAddresses))
for _, addr := range n.config.Node.ListenAddresses {
ma, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid listen address %s: %w", addr, err)
}
listenAddrs = append(listenAddrs, ma)
}
opts = append(opts, libp2p.ListenAddrs(listenAddrs...))
n.logger.ComponentInfo(logging.ComponentLibP2P, "Configured listen addresses",
zap.Strings("addrs", n.config.Node.ListenAddresses))
}
// For localhost/development, disable NAT services
isLocalhost := len(n.config.Node.ListenAddresses) > 0 &&
(strings.Contains(n.config.Node.ListenAddresses[0], "localhost") ||
strings.Contains(n.config.Node.ListenAddresses[0], "127.0.0.1"))
if isLocalhost {
n.logger.ComponentInfo(logging.ComponentLibP2P, "Localhost detected - disabling NAT services for local development")
} else {
n.logger.ComponentInfo(logging.ComponentLibP2P, "Production mode - enabling NAT services")
opts = append(opts,
libp2p.EnableNATService(),
libp2p.EnableAutoNATv2(),
libp2p.EnableRelay(),
libp2p.NATPortMap(),
libp2p.EnableAutoRelayWithPeerSource(
peerSource(n.config.Discovery.BootstrapPeers, n.logger.Logger),
),
)
}
h, err := libp2p.New(opts...)
if err != nil {
return err
}
n.host = h
// Initialize pubsub
ps, err := libp2ppubsub.NewGossipSub(context.Background(), h,
libp2ppubsub.WithPeerExchange(true),
libp2ppubsub.WithFloodPublish(true),
libp2ppubsub.WithDirectPeers(nil),
)
if err != nil {
return fmt.Errorf("failed to create pubsub: %w", err)
}
// Create pubsub adapter
n.pubsub = pubsub.NewClientAdapter(ps, n.config.Discovery.NodeNamespace)
n.logger.Info("Initialized pubsub adapter on namespace", zap.String("namespace", n.config.Discovery.NodeNamespace))
// Connect to peers
if err := n.connectToPeers(context.Background()); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to connect to peers", zap.Error(err))
}
// Start reconnection loop
if len(n.config.Discovery.BootstrapPeers) > 0 {
peerCtx, cancel := context.WithCancel(context.Background())
n.peerDiscoveryCancel = cancel
go n.peerReconnectionLoop(peerCtx)
}
// Add peers to peerstore
for _, peerAddr := range n.config.Discovery.BootstrapPeers {
if ma, err := multiaddr.NewMultiaddr(peerAddr); err == nil {
if peerInfo, err := peer.AddrInfoFromP2pAddr(ma); err == nil {
n.host.Peerstore().AddAddrs(peerInfo.ID, peerInfo.Addrs, time.Hour*24)
}
}
}
// Initialize discovery manager
n.discoveryManager = discovery.NewManager(h, nil, n.logger.Logger)
n.discoveryManager.StartProtocolHandler()
n.logger.ComponentInfo(logging.ComponentNode, "LibP2P host started successfully")
// Start peer discovery
n.startPeerDiscovery()
return nil
}
func (n *Node) peerReconnectionLoop(ctx context.Context) {
interval := 5 * time.Second
consecutiveFailures := 0
for {
select {
case <-ctx.Done():
return
default:
}
if !n.hasPeerConnections() {
if err := n.connectToPeers(context.Background()); err != nil {
consecutiveFailures++
jitteredInterval := addJitter(interval)
select {
case <-ctx.Done():
return
case <-time.After(jitteredInterval):
}
interval = calculateNextBackoff(interval)
} else {
interval = 5 * time.Second
consecutiveFailures = 0
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
}
}
} else {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
}
}
}
}
func (n *Node) connectToPeers(ctx context.Context) error {
for _, peerAddr := range n.config.Discovery.BootstrapPeers {
if err := n.connectToPeerAddr(ctx, peerAddr); err != nil {
continue
}
}
return nil
}
func (n *Node) connectToPeerAddr(ctx context.Context, addr string) error {
ma, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return err
}
peerInfo, err := peer.AddrInfoFromP2pAddr(ma)
if err != nil {
return err
}
if n.host != nil && peerInfo.ID == n.host.ID() {
return nil
}
return n.host.Connect(ctx, *peerInfo)
}
func (n *Node) hasPeerConnections() bool {
if n.host == nil || len(n.config.Discovery.BootstrapPeers) == 0 {
return false
}
connectedPeers := n.host.Network().Peers()
if len(connectedPeers) == 0 {
return false
}
bootstrapIDs := make(map[peer.ID]bool)
for _, addr := range n.config.Discovery.BootstrapPeers {
if ma, err := multiaddr.NewMultiaddr(addr); err == nil {
if info, err := peer.AddrInfoFromP2pAddr(ma); err == nil {
bootstrapIDs[info.ID] = true
}
}
}
for _, p := range connectedPeers {
if bootstrapIDs[p] {
return true
}
}
return false
}
func (n *Node) loadOrCreateIdentity() (crypto.PrivKey, error) {
identityFile := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "identity.key")
if strings.HasPrefix(identityFile, "~") {
home, _ := os.UserHomeDir()
identityFile = filepath.Join(home, identityFile[1:])
}
if _, err := os.Stat(identityFile); err == nil {
info, err := encryption.LoadIdentity(identityFile)
if err == nil {
return info.PrivateKey, nil
}
}
info, err := encryption.GenerateIdentity()
if err != nil {
return nil, err
}
if err := encryption.SaveIdentity(info, identityFile); err != nil {
return nil, err
}
return info.PrivateKey, nil
}
func (n *Node) startPeerDiscovery() {
if n.discoveryManager == nil {
return
}
discoveryConfig := discovery.Config{
DiscoveryInterval: n.config.Discovery.DiscoveryInterval,
MaxConnections: n.config.Node.MaxConnections,
}
n.discoveryManager.Start(discoveryConfig)
}
func (n *Node) stopPeerDiscovery() {
if n.discoveryManager != nil {
n.discoveryManager.Stop()
}
}
func (n *Node) GetPeerID() string {
if n.host == nil {
return ""
}
return n.host.ID().String()
}
func peerSource(peerAddrs []string, logger *zap.Logger) func(context.Context, int) <-chan peer.AddrInfo {
return func(ctx context.Context, num int) <-chan peer.AddrInfo {
out := make(chan peer.AddrInfo, num)
go func() {
defer close(out)
count := 0
for _, s := range peerAddrs {
if count >= num {
return
}
ma, err := multiaddr.NewMultiaddr(s)
if err != nil {
continue
}
ai, err := peer.AddrInfoFromP2pAddr(ma)
if err != nil {
continue
}
select {
case out <- *ai:
count++
case <-ctx.Done():
return
}
}
}()
return out
}
}

View File

@ -220,9 +220,9 @@ func (n *Node) startConnectionMonitoring() {
// First try to discover from LibP2P connections (works even if cluster peers aren't connected yet) // First try to discover from LibP2P connections (works even if cluster peers aren't connected yet)
// This runs every minute to discover peers automatically via LibP2P discovery // This runs every minute to discover peers automatically via LibP2P discovery
if time.Now().Unix()%60 == 0 { if time.Now().Unix()%60 == 0 {
if success, err := n.clusterConfigManager.DiscoverClusterPeersFromLibP2P(n.host); err != nil { if err := n.clusterConfigManager.DiscoverClusterPeersFromLibP2P(n.host); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to discover cluster peers from LibP2P", zap.Error(err)) n.logger.ComponentWarn(logging.ComponentNode, "Failed to discover cluster peers from LibP2P", zap.Error(err))
} else if success { } else {
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses discovered from LibP2P") n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses discovered from LibP2P")
} }
} }
@ -230,16 +230,16 @@ func (n *Node) startConnectionMonitoring() {
// Also try to update from cluster API (works once peers are connected) // Also try to update from cluster API (works once peers are connected)
// Update all cluster peers every 2 minutes to discover new peers // Update all cluster peers every 2 minutes to discover new peers
if time.Now().Unix()%120 == 0 { if time.Now().Unix()%120 == 0 {
if success, err := n.clusterConfigManager.UpdateAllClusterPeers(); err != nil { if err := n.clusterConfigManager.UpdateAllClusterPeers(); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to update cluster peers during monitoring", zap.Error(err)) n.logger.ComponentWarn(logging.ComponentNode, "Failed to update cluster peers during monitoring", zap.Error(err))
} else if success { } else {
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses updated during monitoring") n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses updated during monitoring")
} }
// Try to repair peer configuration // Try to repair peer configuration
if success, err := n.clusterConfigManager.RepairPeerConfiguration(); err != nil { if err := n.clusterConfigManager.RepairPeerConfiguration(); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to repair peer addresses during monitoring", zap.Error(err)) n.logger.ComponentWarn(logging.ComponentNode, "Failed to repair peer addresses during monitoring", zap.Error(err))
} else if success { } else {
n.logger.ComponentInfo(logging.ComponentNode, "Peer configuration repaired during monitoring") n.logger.ComponentInfo(logging.ComponentNode, "Peer configuration repaired during monitoring")
} }
} }

File diff suppressed because it is too large Load Diff

98
pkg/node/rqlite.go Normal file
View File

@ -0,0 +1,98 @@
package node
import (
"context"
"fmt"
database "github.com/DeBrosOfficial/network/pkg/rqlite"
"go.uber.org/zap"
"time"
)
// startRQLite initializes and starts the RQLite database
func (n *Node) startRQLite(ctx context.Context) error {
n.logger.Info("Starting RQLite database")
// Determine node identifier for log filename - use node ID for unique filenames
nodeID := n.config.Node.ID
if nodeID == "" {
// Default to "node" if ID is not set
nodeID = "node"
}
// Create RQLite manager
n.rqliteManager = database.NewRQLiteManager(&n.config.Database, &n.config.Discovery, n.config.Node.DataDir, n.logger.Logger)
n.rqliteManager.SetNodeType(nodeID)
// Initialize cluster discovery service if LibP2P host is available
if n.host != nil && n.discoveryManager != nil {
// Create cluster discovery service (all nodes are unified)
n.clusterDiscovery = database.NewClusterDiscoveryService(
n.host,
n.discoveryManager,
n.rqliteManager,
n.config.Node.ID,
"node", // Unified node type
n.config.Discovery.RaftAdvAddress,
n.config.Discovery.HttpAdvAddress,
n.config.Node.DataDir,
n.logger.Logger,
)
// Set discovery service on RQLite manager BEFORE starting RQLite
// This is critical for pre-start cluster discovery during recovery
n.rqliteManager.SetDiscoveryService(n.clusterDiscovery)
// Start cluster discovery (but don't trigger initial sync yet)
if err := n.clusterDiscovery.Start(ctx); err != nil {
return fmt.Errorf("failed to start cluster discovery: %w", err)
}
// Publish initial metadata (with log_index=0) so peers can discover us during recovery
// The metadata will be updated with actual log index after RQLite starts
n.clusterDiscovery.UpdateOwnMetadata()
n.logger.Info("Cluster discovery service started (waiting for RQLite)")
}
// If node-to-node TLS is configured, wait for certificates to be provisioned
// This ensures RQLite can start with TLS when joining through the SNI gateway
if n.config.Database.NodeCert != "" && n.config.Database.NodeKey != "" && n.certReady != nil {
n.logger.Info("RQLite node TLS configured, waiting for certificates to be provisioned...",
zap.String("node_cert", n.config.Database.NodeCert),
zap.String("node_key", n.config.Database.NodeKey))
// Wait for certificate ready signal with timeout
certTimeout := 5 * time.Minute
select {
case <-n.certReady:
n.logger.Info("Certificates ready, proceeding with RQLite startup")
case <-time.After(certTimeout):
return fmt.Errorf("timeout waiting for TLS certificates after %v - ensure HTTPS is configured and ports 80/443 are accessible for ACME challenges", certTimeout)
case <-ctx.Done():
return fmt.Errorf("context cancelled while waiting for certificates: %w", ctx.Err())
}
}
// Start RQLite FIRST before updating metadata
if err := n.rqliteManager.Start(ctx); err != nil {
return err
}
// NOW update metadata after RQLite is running
if n.clusterDiscovery != nil {
n.clusterDiscovery.UpdateOwnMetadata()
n.clusterDiscovery.TriggerSync() // Do initial cluster sync now that RQLite is ready
n.logger.Info("RQLite metadata published and cluster synced")
}
// Create adapter for sql.DB compatibility
adapter, err := database.NewRQLiteAdapter(n.rqliteManager)
if err != nil {
return fmt.Errorf("failed to create RQLite adapter: %w", err)
}
n.rqliteAdapter = adapter
return nil
}

127
pkg/node/utils.go Normal file
View File

@ -0,0 +1,127 @@
package node
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
mathrand "math/rand"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/encryption"
"github.com/multiformats/go-multiaddr"
)
func extractIPFromMultiaddr(multiaddrStr string) string {
ma, err := multiaddr.NewMultiaddr(multiaddrStr)
if err != nil {
return ""
}
var ip string
var dnsName string
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
switch c.Protocol().Code {
case multiaddr.P_IP4, multiaddr.P_IP6:
ip = c.Value()
return false
case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNSADDR:
dnsName = c.Value()
}
return true
})
if ip != "" {
return ip
}
if dnsName != "" {
if resolvedIPs, err := net.LookupIP(dnsName); err == nil && len(resolvedIPs) > 0 {
for _, resolvedIP := range resolvedIPs {
if resolvedIP.To4() != nil {
return resolvedIP.String()
}
}
return resolvedIPs[0].String()
}
}
return ""
}
func calculateNextBackoff(current time.Duration) time.Duration {
next := time.Duration(float64(current) * 1.5)
maxInterval := 10 * time.Minute
if next > maxInterval {
next = maxInterval
}
return next
}
func addJitter(interval time.Duration) time.Duration {
jitterPercent := 0.2
jitterRange := float64(interval) * jitterPercent
jitter := (mathrand.Float64() - 0.5) * 2 * jitterRange
result := time.Duration(float64(interval) + jitter)
if result < time.Second {
result = time.Second
}
return result
}
func loadNodePeerIDFromIdentity(dataDir string) string {
identityFile := filepath.Join(os.ExpandEnv(dataDir), "identity.key")
if strings.HasPrefix(identityFile, "~") {
home, _ := os.UserHomeDir()
identityFile = filepath.Join(home, identityFile[1:])
}
if info, err := encryption.LoadIdentity(identityFile); err == nil {
return info.PeerID.String()
}
return ""
}
func extractPEMFromTLSCert(tlsCert *tls.Certificate, certPath, keyPath string) error {
if tlsCert == nil || len(tlsCert.Certificate) == 0 {
return fmt.Errorf("invalid tls certificate")
}
certFile, err := os.Create(certPath)
if err != nil {
return err
}
defer certFile.Close()
for _, certBytes := range tlsCert.Certificate {
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
}
if tlsCert.PrivateKey == nil {
return fmt.Errorf("private key is nil")
}
keyFile, err := os.Create(keyPath)
if err != nil {
return err
}
defer keyFile.Close()
var keyBytes []byte
switch key := tlsCert.PrivateKey.(type) {
case *x509.Certificate:
keyBytes, _ = x509.MarshalPKCS8PrivateKey(key)
default:
keyBytes, _ = x509.MarshalPKCS8PrivateKey(tlsCert.PrivateKey)
}
pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})
os.Chmod(certPath, 0644)
os.Chmod(keyPath, 0600)
return nil
}

View File

@ -49,6 +49,13 @@ func NewClient(cfg Config, logger *zap.Logger) (*Client, error) {
}, nil }, nil
} }
// UnderlyingClient returns the underlying olriclib.Client for advanced usage.
// This is useful when you need to pass the client to other packages that expect
// the raw olric client interface.
func (c *Client) UnderlyingClient() olriclib.Client {
return c.client
}
// Health checks if the Olric client is healthy // Health checks if the Olric client is healthy
func (c *Client) Health(ctx context.Context) error { func (c *Client) Health(ctx context.Context) error {
// Create a DMap to test connectivity // Create a DMap to test connectivity

217
pkg/pubsub/manager_test.go Normal file
View File

@ -0,0 +1,217 @@
package pubsub
import (
"context"
"testing"
"time"
"github.com/libp2p/go-libp2p"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
)
func createTestManager(t *testing.T, ns string) (*Manager, func()) {
ctx, cancel := context.WithCancel(context.Background())
h, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
if err != nil {
t.Fatalf("failed to create libp2p host: %v", err)
}
ps, err := pubsub.NewGossipSub(ctx, h)
if err != nil {
h.Close()
t.Fatalf("failed to create gossipsub: %v", err)
}
mgr := NewManager(ps, ns)
cleanup := func() {
mgr.Close()
h.Close()
cancel()
}
return mgr, cleanup
}
func TestManager_Namespacing(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
ctx := context.Background()
topic := "my-topic"
expectedNamespacedTopic := "test-ns.my-topic"
// Subscribe
err := mgr.Subscribe(ctx, topic, func(t string, d []byte) error { return nil })
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
mgr.mu.RLock()
_, exists := mgr.subscriptions[expectedNamespacedTopic]
mgr.mu.RUnlock()
if !exists {
t.Errorf("expected subscription for %s to exist", expectedNamespacedTopic)
}
// Test override
overrideNS := "other-ns"
overrideCtx := context.WithValue(ctx, CtxKeyNamespaceOverride, overrideNS)
expectedOverrideTopic := "other-ns.my-topic"
err = mgr.Subscribe(overrideCtx, topic, func(t string, d []byte) error { return nil })
if err != nil {
t.Fatalf("Subscribe with override failed: %v", err)
}
mgr.mu.RLock()
_, exists = mgr.subscriptions[expectedOverrideTopic]
mgr.mu.RUnlock()
if !exists {
t.Errorf("expected subscription for %s to exist", expectedOverrideTopic)
}
// Test ListTopics
topics, err := mgr.ListTopics(ctx)
if err != nil {
t.Fatalf("ListTopics failed: %v", err)
}
if len(topics) != 1 || topics[0] != "my-topic" {
t.Errorf("expected 1 topic [my-topic], got %v", topics)
}
topicsOverride, err := mgr.ListTopics(overrideCtx)
if err != nil {
t.Fatalf("ListTopics with override failed: %v", err)
}
if len(topicsOverride) != 1 || topicsOverride[0] != "my-topic" {
t.Errorf("expected 1 topic [my-topic] with override, got %v", topicsOverride)
}
}
func TestManager_RefCount(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
ctx := context.Background()
topic := "ref-topic"
namespacedTopic := "test-ns.ref-topic"
h1 := func(t string, d []byte) error { return nil }
h2 := func(t string, d []byte) error { return nil }
// First subscription
err := mgr.Subscribe(ctx, topic, h1)
if err != nil {
t.Fatalf("first subscribe failed: %v", err)
}
mgr.mu.RLock()
ts := mgr.subscriptions[namespacedTopic]
mgr.mu.RUnlock()
if ts.refCount != 1 {
t.Errorf("expected refCount 1, got %d", ts.refCount)
}
// Second subscription
err = mgr.Subscribe(ctx, topic, h2)
if err != nil {
t.Fatalf("second subscribe failed: %v", err)
}
if ts.refCount != 2 {
t.Errorf("expected refCount 2, got %d", ts.refCount)
}
// Unsubscribe one
err = mgr.Unsubscribe(ctx, topic)
if err != nil {
t.Fatalf("unsubscribe 1 failed: %v", err)
}
if ts.refCount != 1 {
t.Errorf("expected refCount 1 after one unsubscribe, got %d", ts.refCount)
}
mgr.mu.RLock()
_, exists := mgr.subscriptions[namespacedTopic]
mgr.mu.RUnlock()
if !exists {
t.Error("expected subscription to still exist")
}
// Unsubscribe second
err = mgr.Unsubscribe(ctx, topic)
if err != nil {
t.Fatalf("unsubscribe 2 failed: %v", err)
}
mgr.mu.RLock()
_, exists = mgr.subscriptions[namespacedTopic]
mgr.mu.RUnlock()
if exists {
t.Error("expected subscription to be removed")
}
}
func TestManager_PubSub(t *testing.T) {
// For a real pubsub test between two managers, we need them to be connected
ctx := context.Background()
h1, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
ps1, _ := pubsub.NewGossipSub(ctx, h1)
mgr1 := NewManager(ps1, "test")
defer h1.Close()
defer mgr1.Close()
h2, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
ps2, _ := pubsub.NewGossipSub(ctx, h2)
mgr2 := NewManager(ps2, "test")
defer h2.Close()
defer mgr2.Close()
// Connect hosts
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
if err != nil {
t.Fatalf("failed to connect hosts: %v", err)
}
topic := "chat"
msgData := []byte("hello world")
received := make(chan []byte, 1)
err = mgr2.Subscribe(ctx, topic, func(t string, d []byte) error {
received <- d
return nil
})
if err != nil {
t.Fatalf("mgr2 subscribe failed: %v", err)
}
// Wait for mesh to form (mgr1 needs to know about mgr2's subscription)
// In a real network this happens via gossip. We'll just retry publish.
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
Loop:
for {
select {
case <-timeout:
t.Fatal("timed out waiting for message")
case <-ticker.C:
_ = mgr1.Publish(ctx, topic, msgData)
case data := <-received:
if string(data) != string(msgData) {
t.Errorf("expected %s, got %s", string(msgData), string(data))
}
break Loop
}
}
}

View File

@ -595,10 +595,19 @@ func setReflectValue(field reflect.Value, raw any) error {
switch v := raw.(type) { switch v := raw.(type) {
case int64: case int64:
field.SetInt(v) field.SetInt(v)
case float64:
// RQLite/JSON returns numbers as float64
field.SetInt(int64(v))
case int:
field.SetInt(int64(v))
case []byte: case []byte:
var n int64 var n int64
fmt.Sscan(string(v), &n) fmt.Sscan(string(v), &n)
field.SetInt(n) field.SetInt(n)
case string:
var n int64
fmt.Sscan(v, &n)
field.SetInt(n)
default: default:
return fmt.Errorf("cannot convert %T to int", raw) return fmt.Errorf("cannot convert %T to int", raw)
} }
@ -609,10 +618,22 @@ func setReflectValue(field reflect.Value, raw any) error {
v = 0 v = 0
} }
field.SetUint(uint64(v)) field.SetUint(uint64(v))
case float64:
// RQLite/JSON returns numbers as float64
if v < 0 {
v = 0
}
field.SetUint(uint64(v))
case uint64:
field.SetUint(v)
case []byte: case []byte:
var n uint64 var n uint64
fmt.Sscan(string(v), &n) fmt.Sscan(string(v), &n)
field.SetUint(n) field.SetUint(n)
case string:
var n uint64
fmt.Sscan(v, &n)
field.SetUint(n)
default: default:
return fmt.Errorf("cannot convert %T to uint", raw) return fmt.Errorf("cannot convert %T to uint", raw)
} }
@ -628,11 +649,16 @@ func setReflectValue(field reflect.Value, raw any) error {
return fmt.Errorf("cannot convert %T to float", raw) return fmt.Errorf("cannot convert %T to float", raw)
} }
case reflect.Struct: case reflect.Struct:
// Support time.Time; extend as needed. // Support time.Time
if field.Type() == reflect.TypeOf(time.Time{}) { if field.Type() == reflect.TypeOf(time.Time{}) {
switch v := raw.(type) { switch v := raw.(type) {
case time.Time: case time.Time:
field.Set(reflect.ValueOf(v)) field.Set(reflect.ValueOf(v))
case string:
// Try RFC3339
if tt, err := time.Parse(time.RFC3339, v); err == nil {
field.Set(reflect.ValueOf(tt))
}
case []byte: case []byte:
// Try RFC3339 // Try RFC3339
if tt, err := time.Parse(time.RFC3339, string(v)); err == nil { if tt, err := time.Parse(time.RFC3339, string(v)); err == nil {
@ -641,6 +667,68 @@ func setReflectValue(field reflect.Value, raw any) error {
} }
return nil return nil
} }
// Support sql.NullString
if field.Type() == reflect.TypeOf(sql.NullString{}) {
ns := sql.NullString{}
switch v := raw.(type) {
case string:
ns.String = v
ns.Valid = true
case []byte:
ns.String = string(v)
ns.Valid = true
}
field.Set(reflect.ValueOf(ns))
return nil
}
// Support sql.NullInt64
if field.Type() == reflect.TypeOf(sql.NullInt64{}) {
ni := sql.NullInt64{}
switch v := raw.(type) {
case int64:
ni.Int64 = v
ni.Valid = true
case float64:
ni.Int64 = int64(v)
ni.Valid = true
case int:
ni.Int64 = int64(v)
ni.Valid = true
}
field.Set(reflect.ValueOf(ni))
return nil
}
// Support sql.NullBool
if field.Type() == reflect.TypeOf(sql.NullBool{}) {
nb := sql.NullBool{}
switch v := raw.(type) {
case bool:
nb.Bool = v
nb.Valid = true
case int64:
nb.Bool = v != 0
nb.Valid = true
case float64:
nb.Bool = v != 0
nb.Valid = true
}
field.Set(reflect.ValueOf(nb))
return nil
}
// Support sql.NullFloat64
if field.Type() == reflect.TypeOf(sql.NullFloat64{}) {
nf := sql.NullFloat64{}
switch v := raw.(type) {
case float64:
nf.Float64 = v
nf.Valid = true
case int64:
nf.Float64 = float64(v)
nf.Valid = true
}
field.Set(reflect.ValueOf(nf))
return nil
}
fallthrough fallthrough
default: default:
// Not supported yet // Not supported yet

301
pkg/rqlite/cluster.go Normal file
View File

@ -0,0 +1,301 @@
package rqlite
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
// establishLeadershipOrJoin handles post-startup cluster establishment
func (r *RQLiteManager) establishLeadershipOrJoin(ctx context.Context, rqliteDataDir string) error {
timeout := 5 * time.Minute
if r.config.RQLiteJoinAddress == "" {
timeout = 2 * time.Minute
}
sqlCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if err := r.waitForSQLAvailable(sqlCtx); err != nil {
if r.cmd != nil && r.cmd.Process != nil {
_ = r.cmd.Process.Kill()
}
return err
}
return nil
}
// waitForMinClusterSizeBeforeStart waits for minimum cluster size to be discovered
func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rqliteDataDir string) error {
if r.discoveryService == nil {
return fmt.Errorf("discovery service not available")
}
requiredRemotePeers := r.config.MinClusterSize - 1
_ = r.discoveryService.TriggerPeerExchange(ctx)
checkInterval := 2 * time.Second
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
r.discoveryService.TriggerSync()
time.Sleep(checkInterval)
allPeers := r.discoveryService.GetAllPeers()
remotePeerCount := 0
for _, peer := range allPeers {
if peer.NodeID != r.discoverConfig.RaftAdvAddress {
remotePeerCount++
}
}
if remotePeerCount >= requiredRemotePeers {
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 {
data, err := os.ReadFile(peersPath)
if err == nil {
var peers []map[string]interface{}
if err := json.Unmarshal(data, &peers); err == nil && len(peers) >= requiredRemotePeers {
return nil
}
}
}
}
}
}
// performPreStartClusterDiscovery builds peers.json before starting RQLite
func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rqliteDataDir string) error {
if r.discoveryService == nil {
return fmt.Errorf("discovery service not available")
}
_ = r.discoveryService.TriggerPeerExchange(ctx)
time.Sleep(1 * time.Second)
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
discoveryDeadline := time.Now().Add(30 * time.Second)
var discoveredPeers int
for time.Now().Before(discoveryDeadline) {
allPeers := r.discoveryService.GetAllPeers()
discoveredPeers = len(allPeers)
if discoveredPeers >= r.config.MinClusterSize {
break
}
time.Sleep(2 * time.Second)
}
if discoveredPeers <= 1 {
return nil
}
if r.hasExistingRaftState(rqliteDataDir) {
ourLogIndex := r.getRaftLogIndex()
maxPeerIndex := uint64(0)
for _, peer := range r.discoveryService.GetAllPeers() {
if peer.NodeID != r.discoverConfig.RaftAdvAddress && peer.RaftLogIndex > maxPeerIndex {
maxPeerIndex = peer.RaftLogIndex
}
}
if ourLogIndex == 0 && maxPeerIndex > 0 {
_ = r.clearRaftState(rqliteDataDir)
_ = r.discoveryService.ForceWritePeersJSON()
}
}
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
return nil
}
// recoverCluster restarts RQLite using peers.json
func (r *RQLiteManager) recoverCluster(ctx context.Context, peersJSONPath string) error {
_ = r.Stop()
time.Sleep(2 * time.Second)
rqliteDataDir, err := r.rqliteDataDirPath()
if err != nil {
return err
}
if err := r.launchProcess(ctx, rqliteDataDir); err != nil {
return err
}
return r.waitForReadyAndConnect(ctx)
}
// recoverFromSplitBrain automatically recovers from split-brain state
func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
if r.discoveryService == nil {
return fmt.Errorf("discovery service not available")
}
r.discoveryService.TriggerPeerExchange(ctx)
time.Sleep(2 * time.Second)
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
rqliteDataDir, _ := r.rqliteDataDirPath()
ourIndex := r.getRaftLogIndex()
maxPeerIndex := uint64(0)
for _, peer := range r.discoveryService.GetAllPeers() {
if peer.NodeID != r.discoverConfig.RaftAdvAddress && peer.RaftLogIndex > maxPeerIndex {
maxPeerIndex = peer.RaftLogIndex
}
}
if ourIndex == 0 && maxPeerIndex > 0 {
_ = r.clearRaftState(rqliteDataDir)
r.discoveryService.TriggerPeerExchange(ctx)
time.Sleep(1 * time.Second)
_ = r.discoveryService.ForceWritePeersJSON()
return r.recoverCluster(ctx, filepath.Join(rqliteDataDir, "raft", "peers.json"))
}
return nil
}
// isInSplitBrainState detects if we're in a split-brain scenario
func (r *RQLiteManager) isInSplitBrainState() bool {
status, err := r.getRQLiteStatus()
if err != nil || r.discoveryService == nil {
return false
}
raft := status.Store.Raft
if raft.State == "Follower" && raft.Term == 0 && raft.NumPeers == 0 && !raft.Voter {
peers := r.discoveryService.GetActivePeers()
if len(peers) == 0 {
return false
}
reachableCount := 0
splitBrainCount := 0
for _, peer := range peers {
if r.isPeerReachable(peer.HTTPAddress) {
reachableCount++
peerStatus, err := r.getPeerRQLiteStatus(peer.HTTPAddress)
if err == nil {
praft := peerStatus.Store.Raft
if praft.State == "Follower" && praft.Term == 0 && praft.NumPeers == 0 && !praft.Voter {
splitBrainCount++
}
}
}
}
return reachableCount > 0 && splitBrainCount == reachableCount
}
return false
}
func (r *RQLiteManager) isPeerReachable(httpAddr string) bool {
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(fmt.Sprintf("http://%s/status", httpAddr))
if err == nil {
resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
return false
}
func (r *RQLiteManager) getPeerRQLiteStatus(httpAddr string) (*RQLiteStatus, error) {
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(fmt.Sprintf("http://%s/status", httpAddr))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var status RQLiteStatus
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return nil, err
}
return &status, nil
}
func (r *RQLiteManager) startHealthMonitoring(ctx context.Context) {
time.Sleep(30 * time.Second)
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if r.isInSplitBrainState() {
_ = r.recoverFromSplitBrain(ctx)
}
}
}
}
// checkNeedsClusterRecovery checks if the node has old cluster state that requires coordinated recovery
func (r *RQLiteManager) checkNeedsClusterRecovery(rqliteDataDir string) (bool, error) {
snapshotsDir := filepath.Join(rqliteDataDir, "rsnapshots")
if _, err := os.Stat(snapshotsDir); os.IsNotExist(err) {
return false, nil
}
entries, err := os.ReadDir(snapshotsDir)
if err != nil {
return false, err
}
hasSnapshots := false
for _, entry := range entries {
if entry.IsDir() || strings.HasSuffix(entry.Name(), ".db") {
hasSnapshots = true
break
}
}
if !hasSnapshots {
return false, nil
}
raftLogPath := filepath.Join(rqliteDataDir, "raft.db")
if info, err := os.Stat(raftLogPath); err == nil {
if info.Size() <= 8*1024*1024 {
return true, nil
}
}
return false, nil
}
func (r *RQLiteManager) hasExistingRaftState(rqliteDataDir string) bool {
raftLogPath := filepath.Join(rqliteDataDir, "raft.db")
if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 {
return true
}
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
_, err := os.Stat(peersPath)
return err == nil
}
func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error {
_ = os.Remove(filepath.Join(rqliteDataDir, "raft.db"))
_ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json"))
return nil
}

View File

@ -2,20 +2,12 @@ package rqlite
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"sync" "sync"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/discovery" "github.com/DeBrosOfficial/network/pkg/discovery"
"github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -160,855 +152,3 @@ func (c *ClusterDiscoveryService) periodicCleanup(ctx context.Context) {
} }
} }
} }
// collectPeerMetadata collects RQLite metadata from LibP2P peers
func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeMetadata {
connectedPeers := c.host.Network().Peers()
var metadata []*discovery.RQLiteNodeMetadata
// Metadata collection is routine - no need to log every occurrence
c.mu.RLock()
currentRaftAddr := c.raftAddress
currentHTTPAddr := c.httpAddress
c.mu.RUnlock()
// Add ourselves
ourMetadata := &discovery.RQLiteNodeMetadata{
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
RaftAddress: currentRaftAddr,
HTTPAddress: currentHTTPAddr,
NodeType: c.nodeType,
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
}
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
c.logger.Debug("Adjusted self-advertised RQLite addresses",
zap.String("raft_address", ourMetadata.RaftAddress),
zap.String("http_address", ourMetadata.HTTPAddress))
}
metadata = append(metadata, ourMetadata)
staleNodeIDs := make([]string, 0)
// Query connected peers for their RQLite metadata
// For now, we'll use a simple approach - store metadata in peer metadata store
// In a full implementation, this would use a custom protocol to exchange RQLite metadata
for _, peerID := range connectedPeers {
// Try to get stored metadata from peerstore
// This would be populated by a peer exchange protocol
if val, err := c.host.Peerstore().Get(peerID, "rqlite_metadata"); err == nil {
if jsonData, ok := val.([]byte); ok {
var peerMeta discovery.RQLiteNodeMetadata
if err := json.Unmarshal(jsonData, &peerMeta); err == nil {
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, &peerMeta); updated && stale != "" {
staleNodeIDs = append(staleNodeIDs, stale)
}
peerMeta.LastSeen = time.Now()
metadata = append(metadata, &peerMeta)
}
}
}
}
// Clean up stale entries if NodeID changed
if len(staleNodeIDs) > 0 {
c.mu.Lock()
for _, id := range staleNodeIDs {
delete(c.knownPeers, id)
delete(c.peerHealth, id)
}
c.mu.Unlock()
}
return metadata
}
// membershipUpdateResult contains the result of a membership update operation
type membershipUpdateResult struct {
peersJSON []map[string]interface{}
added []string
updated []string
changed bool
}
// updateClusterMembership updates the cluster membership based on discovered peers
func (c *ClusterDiscoveryService) updateClusterMembership() {
metadata := c.collectPeerMetadata()
// Compute membership changes while holding lock
c.mu.Lock()
result := c.computeMembershipChangesLocked(metadata)
c.mu.Unlock()
// Perform file I/O outside the lock
if result.changed {
// Log state changes (peer added/removed) at Info level
if len(result.added) > 0 || len(result.updated) > 0 {
c.logger.Info("Membership changed",
zap.Int("added", len(result.added)),
zap.Int("updated", len(result.updated)),
zap.Strings("added", result.added),
zap.Strings("updated", result.updated))
}
// Write peers.json without holding lock
if err := c.writePeersJSONWithData(result.peersJSON); err != nil {
c.logger.Error("Failed to write peers.json",
zap.Error(err),
zap.String("data_dir", c.dataDir),
zap.Int("peers", len(result.peersJSON)))
} else {
c.logger.Debug("peers.json updated",
zap.Int("peers", len(result.peersJSON)))
}
// Update lastUpdate timestamp
c.mu.Lock()
c.lastUpdate = time.Now()
c.mu.Unlock()
}
// No changes - don't log (reduces noise)
}
// computeMembershipChangesLocked computes membership changes and returns snapshot data
// Must be called with lock held
func (c *ClusterDiscoveryService) computeMembershipChangesLocked(metadata []*discovery.RQLiteNodeMetadata) membershipUpdateResult {
// Track changes
added := []string{}
updated := []string{}
// Update known peers, but skip self for health tracking
for _, meta := range metadata {
// Skip self-metadata for health tracking (we only track remote peers)
isSelf := meta.NodeID == c.raftAddress
if existing, ok := c.knownPeers[meta.NodeID]; ok {
// Update existing peer
if existing.RaftLogIndex != meta.RaftLogIndex ||
existing.HTTPAddress != meta.HTTPAddress ||
existing.RaftAddress != meta.RaftAddress {
updated = append(updated, meta.NodeID)
}
} else {
// New peer discovered
added = append(added, meta.NodeID)
c.logger.Info("Node added",
zap.String("node", meta.NodeID),
zap.String("raft", meta.RaftAddress),
zap.String("type", meta.NodeType),
zap.Uint64("log_index", meta.RaftLogIndex))
}
c.knownPeers[meta.NodeID] = meta
// Update health tracking only for remote peers
if !isSelf {
if _, ok := c.peerHealth[meta.NodeID]; !ok {
c.peerHealth[meta.NodeID] = &PeerHealth{
LastSeen: time.Now(),
LastSuccessful: time.Now(),
Status: "active",
}
} else {
c.peerHealth[meta.NodeID].LastSeen = time.Now()
c.peerHealth[meta.NodeID].Status = "active"
c.peerHealth[meta.NodeID].FailureCount = 0
}
}
}
// CRITICAL FIX: Count remote peers (excluding self)
remotePeerCount := 0
for _, peer := range c.knownPeers {
if peer.NodeID != c.raftAddress {
remotePeerCount++
}
}
// Get peers JSON snapshot (for checking if it would be empty)
peers := c.getPeersJSONUnlocked()
// Determine if we should write peers.json
shouldWrite := len(added) > 0 || len(updated) > 0 || c.lastUpdate.IsZero()
// CRITICAL FIX: Don't write peers.json until we have minimum cluster size
// This prevents RQLite from starting as a single-node cluster
// For min_cluster_size=3, we need at least 2 remote peers (plus self = 3 total)
if shouldWrite {
// For initial sync, wait until we have at least (MinClusterSize - 1) remote peers
// This ensures peers.json contains enough peers for proper cluster formation
if c.lastUpdate.IsZero() {
requiredRemotePeers := c.minClusterSize - 1
if remotePeerCount < requiredRemotePeers {
c.logger.Info("Waiting for peers",
zap.Int("have", remotePeerCount),
zap.Int("need", requiredRemotePeers),
zap.Int("min_size", c.minClusterSize))
return membershipUpdateResult{
changed: false,
}
}
}
// Additional safety check: don't write empty peers.json (would cause single-node cluster)
if len(peers) == 0 && c.lastUpdate.IsZero() {
c.logger.Info("No remote peers - waiting")
return membershipUpdateResult{
changed: false,
}
}
// Log initial sync if this is the first time
if c.lastUpdate.IsZero() {
c.logger.Info("Initial sync",
zap.Int("total", len(c.knownPeers)),
zap.Int("remote", remotePeerCount),
zap.Int("in_json", len(peers)))
}
return membershipUpdateResult{
peersJSON: peers,
added: added,
updated: updated,
changed: true,
}
}
return membershipUpdateResult{
changed: false,
}
}
// removeInactivePeers removes peers that haven't been seen for longer than the inactivity limit
func (c *ClusterDiscoveryService) removeInactivePeers() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
removed := []string{}
for nodeID, health := range c.peerHealth {
inactiveDuration := now.Sub(health.LastSeen)
if inactiveDuration > c.inactivityLimit {
// Mark as inactive and remove
c.logger.Warn("Node removed",
zap.String("node", nodeID),
zap.String("reason", "inactive"),
zap.Duration("inactive_duration", inactiveDuration))
delete(c.knownPeers, nodeID)
delete(c.peerHealth, nodeID)
removed = append(removed, nodeID)
}
}
// Regenerate peers.json if any peers were removed
if len(removed) > 0 {
c.logger.Info("Removed inactive",
zap.Int("count", len(removed)),
zap.Strings("nodes", removed))
if err := c.writePeersJSON(); err != nil {
c.logger.Error("Failed to write peers.json after cleanup", zap.Error(err))
}
}
}
// getPeersJSON generates the peers.json structure from active peers (acquires lock)
func (c *ClusterDiscoveryService) getPeersJSON() []map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return c.getPeersJSONUnlocked()
}
// getPeersJSONUnlocked generates the peers.json structure (must be called with lock held)
func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{} {
peers := make([]map[string]interface{}, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
// CRITICAL FIX: Include ALL peers (including self) in peers.json
// When using expect configuration with recovery, RQLite needs the complete
// expected cluster configuration to properly form consensus.
// The peers.json file is used by RQLite's recovery mechanism to know
// what the full cluster membership should be, including the local node.
peerEntry := map[string]interface{}{
"id": peer.RaftAddress, // RQLite uses raft address as node ID
"address": peer.RaftAddress,
"non_voter": false,
}
peers = append(peers, peerEntry)
}
return peers
}
// writePeersJSON atomically writes the peers.json file (acquires lock)
func (c *ClusterDiscoveryService) writePeersJSON() error {
c.mu.RLock()
peers := c.getPeersJSONUnlocked()
c.mu.RUnlock()
return c.writePeersJSONWithData(peers)
}
// writePeersJSONWithData writes the peers.json file with provided data (no lock needed)
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
// Expand ~ in data directory path
dataDir := os.ExpandEnv(c.dataDir)
if strings.HasPrefix(dataDir, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to determine home directory: %w", err)
}
dataDir = filepath.Join(home, dataDir[1:])
}
// Get the RQLite raft directory
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
// Writing peers.json - routine operation, no need to log details
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
}
peersFile := filepath.Join(rqliteDir, "peers.json")
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
// Backup existing peers.json if it exists
if _, err := os.Stat(peersFile); err == nil {
// Backup existing peers.json if it exists - routine operation
data, err := os.ReadFile(peersFile)
if err == nil {
_ = os.WriteFile(backupFile, data, 0644)
}
}
// Marshal to JSON
data, err := json.MarshalIndent(peers, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal peers.json: %w", err)
}
// Marshaled peers.json - routine operation
// Write atomically using temp file + rename
tempFile := peersFile + ".tmp"
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
}
if err := os.Rename(tempFile, peersFile); err != nil {
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
}
nodeIDs := make([]string, 0, len(peers))
for _, p := range peers {
if id, ok := p["id"].(string); ok {
nodeIDs = append(nodeIDs, id)
}
}
c.logger.Info("peers.json written",
zap.Int("peers", len(peers)),
zap.Strings("nodes", nodeIDs))
return nil
}
// GetActivePeers returns a list of active peers (not including self)
func (c *ClusterDiscoveryService) GetActivePeers() []*discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
// Skip self (compare by raft address since that's the NodeID now)
if peer.NodeID == c.raftAddress {
continue
}
peers = append(peers, peer)
}
return peers
}
// GetAllPeers returns a list of all known peers (including self)
func (c *ClusterDiscoveryService) GetAllPeers() []*discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
peers = append(peers, peer)
}
return peers
}
// GetNodeWithHighestLogIndex returns the node with the highest Raft log index
func (c *ClusterDiscoveryService) GetNodeWithHighestLogIndex() *discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
var highest *discovery.RQLiteNodeMetadata
var maxIndex uint64 = 0
for _, peer := range c.knownPeers {
// Skip self (compare by raft address since that's the NodeID now)
if peer.NodeID == c.raftAddress {
continue
}
if peer.RaftLogIndex > maxIndex {
maxIndex = peer.RaftLogIndex
highest = peer
}
}
return highest
}
// HasRecentPeersJSON checks if peers.json was recently updated
func (c *ClusterDiscoveryService) HasRecentPeersJSON() bool {
c.mu.RLock()
defer c.mu.RUnlock()
// Consider recent if updated in last 5 minutes
return time.Since(c.lastUpdate) < 5*time.Minute
}
// FindJoinTargets discovers join targets via LibP2P
func (c *ClusterDiscoveryService) FindJoinTargets() []string {
c.mu.RLock()
defer c.mu.RUnlock()
targets := []string{}
// All nodes are equal - prioritize by Raft log index (more advanced = better)
type nodeWithIndex struct {
address string
logIndex uint64
}
var nodes []nodeWithIndex
for _, peer := range c.knownPeers {
nodes = append(nodes, nodeWithIndex{peer.RaftAddress, peer.RaftLogIndex})
}
// Sort by log index descending (higher log index = more up-to-date)
for i := 0; i < len(nodes)-1; i++ {
for j := i + 1; j < len(nodes); j++ {
if nodes[j].logIndex > nodes[i].logIndex {
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
}
for _, n := range nodes {
targets = append(targets, n.address)
}
return targets
}
// WaitForDiscoverySettling waits for LibP2P discovery to settle (used on concurrent startup)
func (c *ClusterDiscoveryService) WaitForDiscoverySettling(ctx context.Context) {
settleDuration := 60 * time.Second
c.logger.Info("Waiting for discovery to settle",
zap.Duration("duration", settleDuration))
select {
case <-ctx.Done():
return
case <-time.After(settleDuration):
}
// Collect final peer list
c.updateClusterMembership()
c.mu.RLock()
peerCount := len(c.knownPeers)
c.mu.RUnlock()
c.logger.Info("Discovery settled",
zap.Int("peer_count", peerCount))
}
// TriggerSync manually triggers a cluster membership sync
func (c *ClusterDiscoveryService) TriggerSync() {
// All nodes use the same discovery timing for consistency
c.updateClusterMembership()
}
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
// This is useful after clearing raft state when we need to recreate peers.json
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
c.logger.Info("Force writing peers.json")
// First, collect latest peer metadata to ensure we have current information
metadata := c.collectPeerMetadata()
// Update known peers with latest metadata (without writing file yet)
c.mu.Lock()
for _, meta := range metadata {
c.knownPeers[meta.NodeID] = meta
// Update health tracking for remote peers
if meta.NodeID != c.raftAddress {
if _, ok := c.peerHealth[meta.NodeID]; !ok {
c.peerHealth[meta.NodeID] = &PeerHealth{
LastSeen: time.Now(),
LastSuccessful: time.Now(),
Status: "active",
}
} else {
c.peerHealth[meta.NodeID].LastSeen = time.Now()
c.peerHealth[meta.NodeID].Status = "active"
}
}
}
peers := c.getPeersJSONUnlocked()
c.mu.Unlock()
// Now force write the file
if err := c.writePeersJSONWithData(peers); err != nil {
c.logger.Error("Failed to force write peers.json",
zap.Error(err),
zap.String("data_dir", c.dataDir),
zap.Int("peers", len(peers)))
return err
}
c.logger.Info("peers.json written",
zap.Int("peers", len(peers)))
return nil
}
// TriggerPeerExchange actively exchanges peer information with connected peers
// This populates the peerstore with RQLite metadata from other nodes
func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error {
if c.discoveryMgr == nil {
return fmt.Errorf("discovery manager not available")
}
collected := c.discoveryMgr.TriggerPeerExchange(ctx)
c.logger.Debug("Exchange completed", zap.Int("with_metadata", collected))
return nil
}
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
c.mu.RLock()
currentRaftAddr := c.raftAddress
currentHTTPAddr := c.httpAddress
c.mu.RUnlock()
metadata := &discovery.RQLiteNodeMetadata{
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
RaftAddress: currentRaftAddr,
HTTPAddress: currentHTTPAddr,
NodeType: c.nodeType,
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
}
// Adjust addresses if needed
if c.adjustSelfAdvertisedAddresses(metadata) {
c.logger.Debug("Adjusted self-advertised RQLite addresses in UpdateOwnMetadata",
zap.String("raft_address", metadata.RaftAddress),
zap.String("http_address", metadata.HTTPAddress))
}
// Store in our own peerstore for peer exchange
data, err := json.Marshal(metadata)
if err != nil {
c.logger.Error("Failed to marshal own metadata", zap.Error(err))
return
}
if err := c.host.Peerstore().Put(c.host.ID(), "rqlite_metadata", data); err != nil {
c.logger.Error("Failed to store own metadata", zap.Error(err))
return
}
c.logger.Debug("Metadata updated",
zap.String("node", metadata.NodeID),
zap.Uint64("log_index", metadata.RaftLogIndex))
}
// StoreRemotePeerMetadata stores metadata received from a remote peer
func (c *ClusterDiscoveryService) StoreRemotePeerMetadata(peerID peer.ID, metadata *discovery.RQLiteNodeMetadata) error {
if metadata == nil {
return fmt.Errorf("metadata is nil")
}
// Adjust addresses if needed (replace localhost with actual IP)
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, metadata); updated && stale != "" {
// Clean up stale entry if NodeID changed
c.mu.Lock()
delete(c.knownPeers, stale)
delete(c.peerHealth, stale)
c.mu.Unlock()
}
metadata.LastSeen = time.Now()
data, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
if err := c.host.Peerstore().Put(peerID, "rqlite_metadata", data); err != nil {
return fmt.Errorf("failed to store metadata: %w", err)
}
c.logger.Debug("Metadata stored",
zap.String("peer", shortPeerID(peerID)),
zap.String("node", metadata.NodeID))
return nil
}
// adjustPeerAdvertisedAddresses adjusts peer metadata addresses by replacing localhost/loopback
// with the actual IP address from LibP2P connection. Returns (updated, staleNodeID).
// staleNodeID is non-empty if NodeID changed (indicating old entry should be cleaned up).
func (c *ClusterDiscoveryService) adjustPeerAdvertisedAddresses(peerID peer.ID, meta *discovery.RQLiteNodeMetadata) (bool, string) {
ip := c.selectPeerIP(peerID)
if ip == "" {
return false, ""
}
changed, stale := rewriteAdvertisedAddresses(meta, ip, true)
if changed {
c.logger.Debug("Addresses normalized",
zap.String("peer", shortPeerID(peerID)),
zap.String("raft", meta.RaftAddress),
zap.String("http_address", meta.HTTPAddress))
}
return changed, stale
}
// adjustSelfAdvertisedAddresses adjusts our own metadata addresses by replacing localhost/loopback
// with the actual IP address from LibP2P host. Updates internal state if changed.
func (c *ClusterDiscoveryService) adjustSelfAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata) bool {
ip := c.selectSelfIP()
if ip == "" {
return false
}
changed, _ := rewriteAdvertisedAddresses(meta, ip, true)
if !changed {
return false
}
// Update internal state with corrected addresses
c.mu.Lock()
c.raftAddress = meta.RaftAddress
c.httpAddress = meta.HTTPAddress
c.mu.Unlock()
if c.rqliteManager != nil {
c.rqliteManager.UpdateAdvertisedAddresses(meta.RaftAddress, meta.HTTPAddress)
}
return true
}
// selectPeerIP selects the best IP address for a peer from LibP2P connections.
// Prefers public IPs, falls back to private IPs if no public IP is available.
func (c *ClusterDiscoveryService) selectPeerIP(peerID peer.ID) string {
var fallback string
// First, try to get IP from active connections
for _, conn := range c.host.Network().ConnsToPeer(peerID) {
if ip, public := ipFromMultiaddr(conn.RemoteMultiaddr()); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
// Fallback to peerstore addresses
for _, addr := range c.host.Peerstore().Addrs(peerID) {
if ip, public := ipFromMultiaddr(addr); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
return fallback
}
// selectSelfIP selects the best IP address for ourselves from LibP2P host addresses.
// Prefers public IPs, falls back to private IPs if no public IP is available.
func (c *ClusterDiscoveryService) selectSelfIP() string {
var fallback string
for _, addr := range c.host.Addrs() {
if ip, public := ipFromMultiaddr(addr); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
return fallback
}
// rewriteAdvertisedAddresses rewrites RaftAddress and HTTPAddress in metadata,
// replacing localhost/loopback addresses with the provided IP.
// Returns (changed, staleNodeID). staleNodeID is non-empty if NodeID changed.
func rewriteAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata, newHost string, allowNodeIDRewrite bool) (bool, string) {
if meta == nil || newHost == "" {
return false, ""
}
originalNodeID := meta.NodeID
changed := false
nodeIDChanged := false
// Replace host in RaftAddress if it's localhost/loopback
if newAddr, replaced := replaceAddressHost(meta.RaftAddress, newHost); replaced {
if meta.RaftAddress != newAddr {
meta.RaftAddress = newAddr
changed = true
}
}
// Replace host in HTTPAddress if it's localhost/loopback
if newAddr, replaced := replaceAddressHost(meta.HTTPAddress, newHost); replaced {
if meta.HTTPAddress != newAddr {
meta.HTTPAddress = newAddr
changed = true
}
}
// Update NodeID to match RaftAddress if it changed
if allowNodeIDRewrite {
if meta.RaftAddress != "" && (meta.NodeID == "" || meta.NodeID == originalNodeID || shouldReplaceHost(hostFromAddress(meta.NodeID))) {
if meta.NodeID != meta.RaftAddress {
meta.NodeID = meta.RaftAddress
nodeIDChanged = meta.NodeID != originalNodeID
if nodeIDChanged {
changed = true
}
}
}
}
if nodeIDChanged {
return changed, originalNodeID
}
return changed, ""
}
// replaceAddressHost replaces the host part of an address if it's localhost/loopback.
// Returns (newAddress, replaced). replaced is true if host was replaced.
func replaceAddressHost(address, newHost string) (string, bool) {
if address == "" || newHost == "" {
return address, false
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return address, false
}
if !shouldReplaceHost(host) {
return address, false
}
return net.JoinHostPort(newHost, port), true
}
// shouldReplaceHost returns true if the host should be replaced (localhost, loopback, etc.)
func shouldReplaceHost(host string) bool {
if host == "" {
return true
}
if strings.EqualFold(host, "localhost") {
return true
}
// Check if it's a loopback or unspecified address
if addr, err := netip.ParseAddr(host); err == nil {
if addr.IsLoopback() || addr.IsUnspecified() {
return true
}
}
return false
}
// hostFromAddress extracts the host part from a host:port address
func hostFromAddress(address string) string {
host, _, err := net.SplitHostPort(address)
if err != nil {
return ""
}
return host
}
// ipFromMultiaddr extracts an IP address from a multiaddr and returns (ip, isPublic)
func ipFromMultiaddr(addr multiaddr.Multiaddr) (string, bool) {
if addr == nil {
return "", false
}
if v4, err := addr.ValueForProtocol(multiaddr.P_IP4); err == nil {
return v4, isPublicIP(v4)
}
if v6, err := addr.ValueForProtocol(multiaddr.P_IP6); err == nil {
return v6, isPublicIP(v6)
}
return "", false
}
// isPublicIP returns true if the IP is a public (non-private, non-loopback) address
func isPublicIP(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
// Exclude loopback, unspecified, link-local, multicast, and private addresses
if addr.IsLoopback() || addr.IsUnspecified() || addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() || addr.IsPrivate() {
return false
}
return true
}
// shortPeerID returns a shortened version of a peer ID for logging
func shortPeerID(id peer.ID) string {
s := id.String()
if len(s) <= 8 {
return s
}
return s[:8] + "..."
}

View File

@ -0,0 +1,318 @@
package rqlite
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/discovery"
"go.uber.org/zap"
)
// collectPeerMetadata collects RQLite metadata from LibP2P peers
func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeMetadata {
connectedPeers := c.host.Network().Peers()
var metadata []*discovery.RQLiteNodeMetadata
c.mu.RLock()
currentRaftAddr := c.raftAddress
currentHTTPAddr := c.httpAddress
c.mu.RUnlock()
// Add ourselves
ourMetadata := &discovery.RQLiteNodeMetadata{
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
RaftAddress: currentRaftAddr,
HTTPAddress: currentHTTPAddr,
NodeType: c.nodeType,
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
}
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
c.logger.Debug("Adjusted self-advertised RQLite addresses",
zap.String("raft_address", ourMetadata.RaftAddress),
zap.String("http_address", ourMetadata.HTTPAddress))
}
metadata = append(metadata, ourMetadata)
staleNodeIDs := make([]string, 0)
for _, peerID := range connectedPeers {
if val, err := c.host.Peerstore().Get(peerID, "rqlite_metadata"); err == nil {
if jsonData, ok := val.([]byte); ok {
var peerMeta discovery.RQLiteNodeMetadata
if err := json.Unmarshal(jsonData, &peerMeta); err == nil {
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, &peerMeta); updated && stale != "" {
staleNodeIDs = append(staleNodeIDs, stale)
}
peerMeta.LastSeen = time.Now()
metadata = append(metadata, &peerMeta)
}
}
}
}
if len(staleNodeIDs) > 0 {
c.mu.Lock()
for _, id := range staleNodeIDs {
delete(c.knownPeers, id)
delete(c.peerHealth, id)
}
c.mu.Unlock()
}
return metadata
}
type membershipUpdateResult struct {
peersJSON []map[string]interface{}
added []string
updated []string
changed bool
}
func (c *ClusterDiscoveryService) updateClusterMembership() {
metadata := c.collectPeerMetadata()
c.mu.Lock()
result := c.computeMembershipChangesLocked(metadata)
c.mu.Unlock()
if result.changed {
if len(result.added) > 0 || len(result.updated) > 0 {
c.logger.Info("Membership changed",
zap.Int("added", len(result.added)),
zap.Int("updated", len(result.updated)),
zap.Strings("added", result.added),
zap.Strings("updated", result.updated))
}
if err := c.writePeersJSONWithData(result.peersJSON); err != nil {
c.logger.Error("Failed to write peers.json",
zap.Error(err),
zap.String("data_dir", c.dataDir),
zap.Int("peers", len(result.peersJSON)))
} else {
c.logger.Debug("peers.json updated",
zap.Int("peers", len(result.peersJSON)))
}
c.mu.Lock()
c.lastUpdate = time.Now()
c.mu.Unlock()
}
}
func (c *ClusterDiscoveryService) computeMembershipChangesLocked(metadata []*discovery.RQLiteNodeMetadata) membershipUpdateResult {
added := []string{}
updated := []string{}
for _, meta := range metadata {
isSelf := meta.NodeID == c.raftAddress
if existing, ok := c.knownPeers[meta.NodeID]; ok {
if existing.RaftLogIndex != meta.RaftLogIndex ||
existing.HTTPAddress != meta.HTTPAddress ||
existing.RaftAddress != meta.RaftAddress {
updated = append(updated, meta.NodeID)
}
} else {
added = append(added, meta.NodeID)
c.logger.Info("Node added",
zap.String("node", meta.NodeID),
zap.String("raft", meta.RaftAddress),
zap.String("type", meta.NodeType),
zap.Uint64("log_index", meta.RaftLogIndex))
}
c.knownPeers[meta.NodeID] = meta
if !isSelf {
if _, ok := c.peerHealth[meta.NodeID]; !ok {
c.peerHealth[meta.NodeID] = &PeerHealth{
LastSeen: time.Now(),
LastSuccessful: time.Now(),
Status: "active",
}
} else {
c.peerHealth[meta.NodeID].LastSeen = time.Now()
c.peerHealth[meta.NodeID].Status = "active"
c.peerHealth[meta.NodeID].FailureCount = 0
}
}
}
remotePeerCount := 0
for _, peer := range c.knownPeers {
if peer.NodeID != c.raftAddress {
remotePeerCount++
}
}
peers := c.getPeersJSONUnlocked()
shouldWrite := len(added) > 0 || len(updated) > 0 || c.lastUpdate.IsZero()
if shouldWrite {
if c.lastUpdate.IsZero() {
requiredRemotePeers := c.minClusterSize - 1
if remotePeerCount < requiredRemotePeers {
c.logger.Info("Waiting for peers",
zap.Int("have", remotePeerCount),
zap.Int("need", requiredRemotePeers),
zap.Int("min_size", c.minClusterSize))
return membershipUpdateResult{
changed: false,
}
}
}
if len(peers) == 0 && c.lastUpdate.IsZero() {
c.logger.Info("No remote peers - waiting")
return membershipUpdateResult{
changed: false,
}
}
if c.lastUpdate.IsZero() {
c.logger.Info("Initial sync",
zap.Int("total", len(c.knownPeers)),
zap.Int("remote", remotePeerCount),
zap.Int("in_json", len(peers)))
}
return membershipUpdateResult{
peersJSON: peers,
added: added,
updated: updated,
changed: true,
}
}
return membershipUpdateResult{
changed: false,
}
}
func (c *ClusterDiscoveryService) removeInactivePeers() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
removed := []string{}
for nodeID, health := range c.peerHealth {
inactiveDuration := now.Sub(health.LastSeen)
if inactiveDuration > c.inactivityLimit {
c.logger.Warn("Node removed",
zap.String("node", nodeID),
zap.String("reason", "inactive"),
zap.Duration("inactive_duration", inactiveDuration))
delete(c.knownPeers, nodeID)
delete(c.peerHealth, nodeID)
removed = append(removed, nodeID)
}
}
if len(removed) > 0 {
c.logger.Info("Removed inactive",
zap.Int("count", len(removed)),
zap.Strings("nodes", removed))
if err := c.writePeersJSON(); err != nil {
c.logger.Error("Failed to write peers.json after cleanup", zap.Error(err))
}
}
}
func (c *ClusterDiscoveryService) getPeersJSON() []map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return c.getPeersJSONUnlocked()
}
func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{} {
peers := make([]map[string]interface{}, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
peerEntry := map[string]interface{}{
"id": peer.RaftAddress,
"address": peer.RaftAddress,
"non_voter": false,
}
peers = append(peers, peerEntry)
}
return peers
}
func (c *ClusterDiscoveryService) writePeersJSON() error {
c.mu.RLock()
peers := c.getPeersJSONUnlocked()
c.mu.RUnlock()
return c.writePeersJSONWithData(peers)
}
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
dataDir := os.ExpandEnv(c.dataDir)
if strings.HasPrefix(dataDir, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to determine home directory: %w", err)
}
dataDir = filepath.Join(home, dataDir[1:])
}
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
}
peersFile := filepath.Join(rqliteDir, "peers.json")
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
if _, err := os.Stat(peersFile); err == nil {
data, err := os.ReadFile(peersFile)
if err == nil {
_ = os.WriteFile(backupFile, data, 0644)
}
}
data, err := json.MarshalIndent(peers, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal peers.json: %w", err)
}
tempFile := peersFile + ".tmp"
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
}
if err := os.Rename(tempFile, peersFile); err != nil {
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
}
nodeIDs := make([]string, 0, len(peers))
for _, p := range peers {
if id, ok := p["id"].(string); ok {
nodeIDs = append(nodeIDs, id)
}
}
c.logger.Info("peers.json written",
zap.Int("peers", len(peers)),
zap.Strings("nodes", nodeIDs))
return nil
}

View File

@ -0,0 +1,251 @@
package rqlite
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/DeBrosOfficial/network/pkg/discovery"
"github.com/libp2p/go-libp2p/core/peer"
"go.uber.org/zap"
)
// GetActivePeers returns a list of active peers (not including self)
func (c *ClusterDiscoveryService) GetActivePeers() []*discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
if peer.NodeID == c.raftAddress {
continue
}
peers = append(peers, peer)
}
return peers
}
// GetAllPeers returns a list of all known peers (including self)
func (c *ClusterDiscoveryService) GetAllPeers() []*discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
for _, peer := range c.knownPeers {
peers = append(peers, peer)
}
return peers
}
// GetNodeWithHighestLogIndex returns the node with the highest Raft log index
func (c *ClusterDiscoveryService) GetNodeWithHighestLogIndex() *discovery.RQLiteNodeMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
var highest *discovery.RQLiteNodeMetadata
var maxIndex uint64 = 0
for _, peer := range c.knownPeers {
if peer.NodeID == c.raftAddress {
continue
}
if peer.RaftLogIndex > maxIndex {
maxIndex = peer.RaftLogIndex
highest = peer
}
}
return highest
}
// HasRecentPeersJSON checks if peers.json was recently updated
func (c *ClusterDiscoveryService) HasRecentPeersJSON() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return time.Since(c.lastUpdate) < 5*time.Minute
}
// FindJoinTargets discovers join targets via LibP2P
func (c *ClusterDiscoveryService) FindJoinTargets() []string {
c.mu.RLock()
defer c.mu.RUnlock()
targets := []string{}
type nodeWithIndex struct {
address string
logIndex uint64
}
var nodes []nodeWithIndex
for _, peer := range c.knownPeers {
nodes = append(nodes, nodeWithIndex{peer.RaftAddress, peer.RaftLogIndex})
}
for i := 0; i < len(nodes)-1; i++ {
for j := i + 1; j < len(nodes); j++ {
if nodes[j].logIndex > nodes[i].logIndex {
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
}
for _, n := range nodes {
targets = append(targets, n.address)
}
return targets
}
// WaitForDiscoverySettling waits for LibP2P discovery to settle (used on concurrent startup)
func (c *ClusterDiscoveryService) WaitForDiscoverySettling(ctx context.Context) {
settleDuration := 60 * time.Second
c.logger.Info("Waiting for discovery to settle",
zap.Duration("duration", settleDuration))
select {
case <-ctx.Done():
return
case <-time.After(settleDuration):
}
c.updateClusterMembership()
c.mu.RLock()
peerCount := len(c.knownPeers)
c.mu.RUnlock()
c.logger.Info("Discovery settled",
zap.Int("peer_count", peerCount))
}
// TriggerSync manually triggers a cluster membership sync
func (c *ClusterDiscoveryService) TriggerSync() {
c.updateClusterMembership()
}
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
c.logger.Info("Force writing peers.json")
metadata := c.collectPeerMetadata()
c.mu.Lock()
for _, meta := range metadata {
c.knownPeers[meta.NodeID] = meta
if meta.NodeID != c.raftAddress {
if _, ok := c.peerHealth[meta.NodeID]; !ok {
c.peerHealth[meta.NodeID] = &PeerHealth{
LastSeen: time.Now(),
LastSuccessful: time.Now(),
Status: "active",
}
} else {
c.peerHealth[meta.NodeID].LastSeen = time.Now()
c.peerHealth[meta.NodeID].Status = "active"
}
}
}
peers := c.getPeersJSONUnlocked()
c.mu.Unlock()
if err := c.writePeersJSONWithData(peers); err != nil {
c.logger.Error("Failed to force write peers.json",
zap.Error(err),
zap.String("data_dir", c.dataDir),
zap.Int("peers", len(peers)))
return err
}
c.logger.Info("peers.json written",
zap.Int("peers", len(peers)))
return nil
}
// TriggerPeerExchange actively exchanges peer information with connected peers
func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error {
if c.discoveryMgr == nil {
return fmt.Errorf("discovery manager not available")
}
collected := c.discoveryMgr.TriggerPeerExchange(ctx)
c.logger.Debug("Exchange completed", zap.Int("with_metadata", collected))
return nil
}
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
c.mu.RLock()
currentRaftAddr := c.raftAddress
currentHTTPAddr := c.httpAddress
c.mu.RUnlock()
metadata := &discovery.RQLiteNodeMetadata{
NodeID: currentRaftAddr,
RaftAddress: currentRaftAddr,
HTTPAddress: currentHTTPAddr,
NodeType: c.nodeType,
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
}
if c.adjustSelfAdvertisedAddresses(metadata) {
c.logger.Debug("Adjusted self-advertised RQLite addresses in UpdateOwnMetadata",
zap.String("raft_address", metadata.RaftAddress),
zap.String("http_address", metadata.HTTPAddress))
}
data, err := json.Marshal(metadata)
if err != nil {
c.logger.Error("Failed to marshal own metadata", zap.Error(err))
return
}
if err := c.host.Peerstore().Put(c.host.ID(), "rqlite_metadata", data); err != nil {
c.logger.Error("Failed to store own metadata", zap.Error(err))
return
}
c.logger.Debug("Metadata updated",
zap.String("node", metadata.NodeID),
zap.Uint64("log_index", metadata.RaftLogIndex))
}
// StoreRemotePeerMetadata stores metadata received from a remote peer
func (c *ClusterDiscoveryService) StoreRemotePeerMetadata(peerID peer.ID, metadata *discovery.RQLiteNodeMetadata) error {
if metadata == nil {
return fmt.Errorf("metadata is nil")
}
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, metadata); updated && stale != "" {
c.mu.Lock()
delete(c.knownPeers, stale)
delete(c.peerHealth, stale)
c.mu.Unlock()
}
metadata.LastSeen = time.Now()
data, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
if err := c.host.Peerstore().Put(peerID, "rqlite_metadata", data); err != nil {
return fmt.Errorf("failed to store metadata: %w", err)
}
c.logger.Debug("Metadata stored",
zap.String("peer", shortPeerID(peerID)),
zap.String("node", metadata.NodeID))
return nil
}

View File

@ -0,0 +1,97 @@
package rqlite
import (
"testing"
"github.com/DeBrosOfficial/network/pkg/discovery"
)
func TestShouldReplaceHost(t *testing.T) {
tests := []struct {
host string
expected bool
}{
{"", true},
{"localhost", true},
{"127.0.0.1", true},
{"::1", true},
{"0.0.0.0", true},
{"1.1.1.1", false},
{"8.8.8.8", false},
{"example.com", false},
}
for _, tt := range tests {
if got := shouldReplaceHost(tt.host); got != tt.expected {
t.Errorf("shouldReplaceHost(%s) = %v; want %v", tt.host, got, tt.expected)
}
}
}
func TestIsPublicIP(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"127.0.0.1", false},
{"192.168.1.1", false},
{"10.0.0.1", false},
{"172.16.0.1", false},
{"1.1.1.1", true},
{"8.8.8.8", true},
{"2001:4860:4860::8888", true},
}
for _, tt := range tests {
if got := isPublicIP(tt.ip); got != tt.expected {
t.Errorf("isPublicIP(%s) = %v; want %v", tt.ip, got, tt.expected)
}
}
}
func TestReplaceAddressHost(t *testing.T) {
tests := []struct {
address string
newHost string
expected string
replaced bool
}{
{"localhost:4001", "1.1.1.1", "1.1.1.1:4001", true},
{"127.0.0.1:4001", "1.1.1.1", "1.1.1.1:4001", true},
{"8.8.8.8:4001", "1.1.1.1", "8.8.8.8:4001", false}, // Don't replace public IP
{"invalid", "1.1.1.1", "invalid", false},
}
for _, tt := range tests {
got, replaced := replaceAddressHost(tt.address, tt.newHost)
if got != tt.expected || replaced != tt.replaced {
t.Errorf("replaceAddressHost(%s, %s) = %s, %v; want %s, %v", tt.address, tt.newHost, got, replaced, tt.expected, tt.replaced)
}
}
}
func TestRewriteAdvertisedAddresses(t *testing.T) {
meta := &discovery.RQLiteNodeMetadata{
NodeID: "localhost:4001",
RaftAddress: "localhost:4001",
HTTPAddress: "localhost:4002",
}
changed, originalNodeID := rewriteAdvertisedAddresses(meta, "1.1.1.1", true)
if !changed {
t.Error("expected changed to be true")
}
if originalNodeID != "localhost:4001" {
t.Errorf("expected originalNodeID localhost:4001, got %s", originalNodeID)
}
if meta.RaftAddress != "1.1.1.1:4001" {
t.Errorf("expected RaftAddress 1.1.1.1:4001, got %s", meta.RaftAddress)
}
if meta.HTTPAddress != "1.1.1.1:4002" {
t.Errorf("expected HTTPAddress 1.1.1.1:4002, got %s", meta.HTTPAddress)
}
if meta.NodeID != "1.1.1.1:4001" {
t.Errorf("expected NodeID 1.1.1.1:4001, got %s", meta.NodeID)
}
}

View File

@ -0,0 +1,233 @@
package rqlite
import (
"net"
"net/netip"
"strings"
"github.com/DeBrosOfficial/network/pkg/discovery"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"go.uber.org/zap"
)
// adjustPeerAdvertisedAddresses adjusts peer metadata addresses
func (c *ClusterDiscoveryService) adjustPeerAdvertisedAddresses(peerID peer.ID, meta *discovery.RQLiteNodeMetadata) (bool, string) {
ip := c.selectPeerIP(peerID)
if ip == "" {
return false, ""
}
changed, stale := rewriteAdvertisedAddresses(meta, ip, true)
if changed {
c.logger.Debug("Addresses normalized",
zap.String("peer", shortPeerID(peerID)),
zap.String("raft", meta.RaftAddress),
zap.String("http_address", meta.HTTPAddress))
}
return changed, stale
}
// adjustSelfAdvertisedAddresses adjusts our own metadata addresses
func (c *ClusterDiscoveryService) adjustSelfAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata) bool {
ip := c.selectSelfIP()
if ip == "" {
return false
}
changed, _ := rewriteAdvertisedAddresses(meta, ip, true)
if !changed {
return false
}
c.mu.Lock()
c.raftAddress = meta.RaftAddress
c.httpAddress = meta.HTTPAddress
c.mu.Unlock()
if c.rqliteManager != nil {
c.rqliteManager.UpdateAdvertisedAddresses(meta.RaftAddress, meta.HTTPAddress)
}
return true
}
// selectPeerIP selects the best IP address for a peer
func (c *ClusterDiscoveryService) selectPeerIP(peerID peer.ID) string {
var fallback string
for _, conn := range c.host.Network().ConnsToPeer(peerID) {
if ip, public := ipFromMultiaddr(conn.RemoteMultiaddr()); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
for _, addr := range c.host.Peerstore().Addrs(peerID) {
if ip, public := ipFromMultiaddr(addr); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
return fallback
}
// selectSelfIP selects the best IP address for ourselves
func (c *ClusterDiscoveryService) selectSelfIP() string {
var fallback string
for _, addr := range c.host.Addrs() {
if ip, public := ipFromMultiaddr(addr); ip != "" {
if shouldReplaceHost(ip) {
continue
}
if public {
return ip
}
if fallback == "" {
fallback = ip
}
}
}
return fallback
}
// rewriteAdvertisedAddresses rewrites RaftAddress and HTTPAddress in metadata
func rewriteAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata, newHost string, allowNodeIDRewrite bool) (bool, string) {
if meta == nil || newHost == "" {
return false, ""
}
originalNodeID := meta.NodeID
changed := false
nodeIDChanged := false
if newAddr, replaced := replaceAddressHost(meta.RaftAddress, newHost); replaced {
if meta.RaftAddress != newAddr {
meta.RaftAddress = newAddr
changed = true
}
}
if newAddr, replaced := replaceAddressHost(meta.HTTPAddress, newHost); replaced {
if meta.HTTPAddress != newAddr {
meta.HTTPAddress = newAddr
changed = true
}
}
if allowNodeIDRewrite {
if meta.RaftAddress != "" && (meta.NodeID == "" || meta.NodeID == originalNodeID || shouldReplaceHost(hostFromAddress(meta.NodeID))) {
if meta.NodeID != meta.RaftAddress {
meta.NodeID = meta.RaftAddress
nodeIDChanged = meta.NodeID != originalNodeID
if nodeIDChanged {
changed = true
}
}
}
}
if nodeIDChanged {
return changed, originalNodeID
}
return changed, ""
}
// replaceAddressHost replaces the host part of an address
func replaceAddressHost(address, newHost string) (string, bool) {
if address == "" || newHost == "" {
return address, false
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return address, false
}
if !shouldReplaceHost(host) {
return address, false
}
return net.JoinHostPort(newHost, port), true
}
// shouldReplaceHost returns true if the host should be replaced
func shouldReplaceHost(host string) bool {
if host == "" {
return true
}
if strings.EqualFold(host, "localhost") {
return true
}
if addr, err := netip.ParseAddr(host); err == nil {
if addr.IsLoopback() || addr.IsUnspecified() {
return true
}
}
return false
}
// hostFromAddress extracts the host part from a host:port address
func hostFromAddress(address string) string {
host, _, err := net.SplitHostPort(address)
if err != nil {
return ""
}
return host
}
// ipFromMultiaddr extracts an IP address from a multiaddr and returns (ip, isPublic)
func ipFromMultiaddr(addr multiaddr.Multiaddr) (string, bool) {
if addr == nil {
return "", false
}
if v4, err := addr.ValueForProtocol(multiaddr.P_IP4); err == nil {
return v4, isPublicIP(v4)
}
if v6, err := addr.ValueForProtocol(multiaddr.P_IP6); err == nil {
return v6, isPublicIP(v6)
}
return "", false
}
// isPublicIP returns true if the IP is a public address
func isPublicIP(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
if addr.IsLoopback() || addr.IsUnspecified() || addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() || addr.IsPrivate() {
return false
}
return true
}
// shortPeerID returns a shortened version of a peer ID
func shortPeerID(id peer.ID) string {
s := id.String()
if len(s) <= 8 {
return s
}
return s[:8] + "..."
}

View File

@ -0,0 +1,61 @@
package rqlite
import (
"fmt"
"time"
)
// SetDiscoveryService sets the cluster discovery service
func (r *RQLiteManager) SetDiscoveryService(service *ClusterDiscoveryService) {
r.discoveryService = service
}
// SetNodeType sets the node type
func (r *RQLiteManager) SetNodeType(nodeType string) {
if nodeType != "" {
r.nodeType = nodeType
}
}
// UpdateAdvertisedAddresses overrides advertised addresses
func (r *RQLiteManager) UpdateAdvertisedAddresses(raftAddr, httpAddr string) {
if r == nil || r.discoverConfig == nil {
return
}
if raftAddr != "" && r.discoverConfig.RaftAdvAddress != raftAddr {
r.discoverConfig.RaftAdvAddress = raftAddr
}
if httpAddr != "" && r.discoverConfig.HttpAdvAddress != httpAddr {
r.discoverConfig.HttpAdvAddress = httpAddr
}
}
func (r *RQLiteManager) validateNodeID() error {
for i := 0; i < 5; i++ {
nodes, err := r.getRQLiteNodes()
if err != nil {
if i < 4 {
time.Sleep(500 * time.Millisecond)
continue
}
return nil
}
expectedID := r.discoverConfig.RaftAdvAddress
if expectedID == "" || len(nodes) == 0 {
return nil
}
for _, node := range nodes {
if node.Address == expectedID {
if node.ID != expectedID {
return fmt.Errorf("node ID mismatch: %s != %s", expectedID, node.ID)
}
return nil
}
}
return nil
}
return nil
}

View File

@ -570,9 +570,13 @@ func (g *HTTPGateway) handleDropTable(w http.ResponseWriter, r *http.Request) {
ctx, cancel := g.withTimeout(r.Context()) ctx, cancel := g.withTimeout(r.Context())
defer cancel() defer cancel()
stmt := "DROP TABLE IF EXISTS " + tbl stmt := "DROP TABLE " + tbl
if _, err := g.Client.Exec(ctx, stmt); err != nil { if _, err := g.Client.Exec(ctx, stmt); err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) if strings.Contains(err.Error(), "no such table") {
writeError(w, http.StatusNotFound, err.Error())
} else {
writeError(w, http.StatusInternalServerError, err.Error())
}
return return
} }
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"}) writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})

239
pkg/rqlite/process.go Normal file
View File

@ -0,0 +1,239 @@
package rqlite
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/tlsutil"
"github.com/rqlite/gorqlite"
"go.uber.org/zap"
)
// launchProcess starts the RQLite process with appropriate arguments
func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error {
// Build RQLite command
args := []string{
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
"-http-adv-addr", r.discoverConfig.HttpAdvAddress,
"-raft-adv-addr", r.discoverConfig.RaftAdvAddress,
"-raft-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLiteRaftPort),
}
if r.config.NodeCert != "" && r.config.NodeKey != "" {
r.logger.Info("Enabling node-to-node TLS encryption",
zap.String("node_cert", r.config.NodeCert),
zap.String("node_key", r.config.NodeKey))
args = append(args, "-node-cert", r.config.NodeCert)
args = append(args, "-node-key", r.config.NodeKey)
if r.config.NodeCACert != "" {
args = append(args, "-node-ca-cert", r.config.NodeCACert)
}
if r.config.NodeNoVerify {
args = append(args, "-node-no-verify")
}
}
if r.config.RQLiteJoinAddress != "" {
r.logger.Info("Joining RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress))
joinArg := r.config.RQLiteJoinAddress
if strings.HasPrefix(joinArg, "http://") {
joinArg = strings.TrimPrefix(joinArg, "http://")
} else if strings.HasPrefix(joinArg, "https://") {
joinArg = strings.TrimPrefix(joinArg, "https://")
}
joinTimeout := 5 * time.Minute
if err := r.waitForJoinTarget(ctx, r.config.RQLiteJoinAddress, joinTimeout); err != nil {
r.logger.Warn("Join target did not become reachable within timeout; will still attempt to join",
zap.Error(err))
}
args = append(args, "-join", joinArg, "-join-as", r.discoverConfig.RaftAdvAddress, "-join-attempts", "30", "-join-interval", "10s")
}
args = append(args, rqliteDataDir)
r.cmd = exec.Command("rqlited", args...)
nodeType := r.nodeType
if nodeType == "" {
nodeType = "node"
}
logsDir := filepath.Join(filepath.Dir(r.dataDir), "logs")
_ = os.MkdirAll(logsDir, 0755)
logPath := filepath.Join(logsDir, fmt.Sprintf("rqlite-%s.log", nodeType))
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("failed to open log file: %w", err)
}
r.cmd.Stdout = logFile
r.cmd.Stderr = logFile
if err := r.cmd.Start(); err != nil {
logFile.Close()
return fmt.Errorf("failed to start RQLite: %w", err)
}
logFile.Close()
return nil
}
// waitForReadyAndConnect waits for RQLite to be ready and establishes connection
func (r *RQLiteManager) waitForReadyAndConnect(ctx context.Context) error {
if err := r.waitForReady(ctx); err != nil {
if r.cmd != nil && r.cmd.Process != nil {
_ = r.cmd.Process.Kill()
}
return err
}
var conn *gorqlite.Connection
var err error
maxConnectAttempts := 10
connectBackoff := 500 * time.Millisecond
for attempt := 0; attempt < maxConnectAttempts; attempt++ {
conn, err = gorqlite.Open(fmt.Sprintf("http://localhost:%d", r.config.RQLitePort))
if err == nil {
r.connection = conn
break
}
if strings.Contains(err.Error(), "store is not open") {
time.Sleep(connectBackoff)
connectBackoff = time.Duration(float64(connectBackoff) * 1.5)
if connectBackoff > 5*time.Second {
connectBackoff = 5 * time.Second
}
continue
}
if r.cmd != nil && r.cmd.Process != nil {
_ = r.cmd.Process.Kill()
}
return fmt.Errorf("failed to connect to RQLite: %w", err)
}
if conn == nil {
return fmt.Errorf("failed to connect to RQLite after max attempts")
}
_ = r.validateNodeID()
return nil
}
// waitForReady waits for RQLite to be ready to accept connections
func (r *RQLiteManager) waitForReady(ctx context.Context) error {
url := fmt.Sprintf("http://localhost:%d/status", r.config.RQLitePort)
client := tlsutil.NewHTTPClient(2 * time.Second)
for i := 0; i < 180; i++ {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(1 * time.Second):
}
resp, err := client.Get(url)
if err == nil && resp.StatusCode == http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
var statusResp map[string]interface{}
if err := json.Unmarshal(body, &statusResp); err == nil {
if raft, ok := statusResp["raft"].(map[string]interface{}); ok {
state, _ := raft["state"].(string)
if state == "leader" || state == "follower" {
return nil
}
} else {
return nil // Backwards compatibility
}
}
}
}
return fmt.Errorf("RQLite did not become ready within timeout")
}
// waitForSQLAvailable waits until a simple query succeeds
func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if r.connection == nil {
continue
}
_, err := r.connection.QueryOne("SELECT 1")
if err == nil {
return nil
}
}
}
}
// testJoinAddress tests if a join address is reachable
func (r *RQLiteManager) testJoinAddress(joinAddress string) error {
client := tlsutil.NewHTTPClient(5 * time.Second)
var statusURL string
if strings.HasPrefix(joinAddress, "http://") || strings.HasPrefix(joinAddress, "https://") {
statusURL = strings.TrimRight(joinAddress, "/") + "/status"
} else {
host := joinAddress
if idx := strings.Index(joinAddress, ":"); idx != -1 {
host = joinAddress[:idx]
}
statusURL = fmt.Sprintf("http://%s:%d/status", host, 5001)
}
resp, err := client.Get(statusURL)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("leader returned status %d", resp.StatusCode)
}
return nil
}
// waitForJoinTarget waits until the join target's HTTP status becomes reachable
func (r *RQLiteManager) waitForJoinTarget(ctx context.Context, joinAddress string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
if err := r.testJoinAddress(joinAddress); err == nil {
return nil
} else {
lastErr = err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(2 * time.Second):
}
}
return lastErr
}

File diff suppressed because it is too large Load Diff

58
pkg/rqlite/util.go Normal file
View File

@ -0,0 +1,58 @@
package rqlite
import (
"os"
"path/filepath"
"strings"
"time"
)
func (r *RQLiteManager) rqliteDataDirPath() (string, error) {
dataDir := os.ExpandEnv(r.dataDir)
if strings.HasPrefix(dataDir, "~") {
home, _ := os.UserHomeDir()
dataDir = filepath.Join(home, dataDir[1:])
}
return filepath.Join(dataDir, "rqlite"), nil
}
func (r *RQLiteManager) resolveMigrationsDir() (string, error) {
productionPath := "/home/debros/src/migrations"
if _, err := os.Stat(productionPath); err == nil {
return productionPath, nil
}
return "migrations", nil
}
func (r *RQLiteManager) prepareDataDir() (string, error) {
rqliteDataDir, err := r.rqliteDataDirPath()
if err != nil {
return "", err
}
if err := os.MkdirAll(rqliteDataDir, 0755); err != nil {
return "", err
}
return rqliteDataDir, nil
}
func (r *RQLiteManager) hasExistingState(rqliteDataDir string) bool {
entries, err := os.ReadDir(rqliteDataDir)
if err != nil {
return false
}
for _, e := range entries {
if e.Name() != "." && e.Name() != ".." {
return true
}
}
return false
}
func (r *RQLiteManager) exponentialBackoff(attempt int, baseDelay time.Duration, maxDelay time.Duration) time.Duration {
delay := baseDelay * time.Duration(1<<uint(attempt))
if delay > maxDelay {
delay = maxDelay
}
return delay
}

89
pkg/rqlite/util_test.go Normal file
View File

@ -0,0 +1,89 @@
package rqlite
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestExponentialBackoff(t *testing.T) {
r := &RQLiteManager{}
baseDelay := 100 * time.Millisecond
maxDelay := 1 * time.Second
tests := []struct {
attempt int
expected time.Duration
}{
{0, 100 * time.Millisecond},
{1, 200 * time.Millisecond},
{2, 400 * time.Millisecond},
{3, 800 * time.Millisecond},
{4, 1000 * time.Millisecond}, // Maxed out
{10, 1000 * time.Millisecond}, // Maxed out
}
for _, tt := range tests {
got := r.exponentialBackoff(tt.attempt, baseDelay, maxDelay)
if got != tt.expected {
t.Errorf("exponentialBackoff(%d) = %v; want %v", tt.attempt, got, tt.expected)
}
}
}
func TestRQLiteDataDirPath(t *testing.T) {
// Test with explicit path
r := &RQLiteManager{dataDir: "/tmp/data"}
got, _ := r.rqliteDataDirPath()
expected := filepath.Join("/tmp/data", "rqlite")
if got != expected {
t.Errorf("rqliteDataDirPath() = %s; want %s", got, expected)
}
// Test with environment variable expansion
os.Setenv("TEST_DATA_DIR", "/tmp/env-data")
defer os.Unsetenv("TEST_DATA_DIR")
r = &RQLiteManager{dataDir: "$TEST_DATA_DIR"}
got, _ = r.rqliteDataDirPath()
expected = filepath.Join("/tmp/env-data", "rqlite")
if got != expected {
t.Errorf("rqliteDataDirPath() with env = %s; want %s", got, expected)
}
// Test with home directory expansion
r = &RQLiteManager{dataDir: "~/data"}
got, _ = r.rqliteDataDirPath()
home, _ := os.UserHomeDir()
expected = filepath.Join(home, "data", "rqlite")
if got != expected {
t.Errorf("rqliteDataDirPath() with ~ = %s; want %s", got, expected)
}
}
func TestHasExistingState(t *testing.T) {
r := &RQLiteManager{}
// Create a temp directory for testing
tmpDir, err := os.MkdirTemp("", "rqlite-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Test empty directory
if r.hasExistingState(tmpDir) {
t.Errorf("hasExistingState() = true; want false for empty dir")
}
// Test directory with a file
testFile := filepath.Join(tmpDir, "test.txt")
if err := os.WriteFile(testFile, []byte("data"), 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
}
if !r.hasExistingState(tmpDir) {
t.Errorf("hasExistingState() = false; want true for non-empty dir")
}
}

187
pkg/serverless/config.go Normal file
View File

@ -0,0 +1,187 @@
package serverless
import (
"time"
)
// Config holds configuration for the serverless engine.
type Config struct {
// Memory limits
DefaultMemoryLimitMB int `yaml:"default_memory_limit_mb"`
MaxMemoryLimitMB int `yaml:"max_memory_limit_mb"`
// Execution limits
DefaultTimeoutSeconds int `yaml:"default_timeout_seconds"`
MaxTimeoutSeconds int `yaml:"max_timeout_seconds"`
// Retry configuration
DefaultRetryCount int `yaml:"default_retry_count"`
MaxRetryCount int `yaml:"max_retry_count"`
DefaultRetryDelaySeconds int `yaml:"default_retry_delay_seconds"`
// Rate limiting (global)
GlobalRateLimitPerMinute int `yaml:"global_rate_limit_per_minute"`
// Background job configuration
JobWorkers int `yaml:"job_workers"`
JobPollInterval time.Duration `yaml:"job_poll_interval"`
JobMaxQueueSize int `yaml:"job_max_queue_size"`
JobMaxPayloadSize int `yaml:"job_max_payload_size"` // bytes
// Scheduler configuration
CronPollInterval time.Duration `yaml:"cron_poll_interval"`
TimerPollInterval time.Duration `yaml:"timer_poll_interval"`
DBPollInterval time.Duration `yaml:"db_poll_interval"`
// WASM compilation cache
ModuleCacheSize int `yaml:"module_cache_size"` // Number of compiled modules to cache
EnablePrewarm bool `yaml:"enable_prewarm"` // Pre-compile frequently used functions
// Secrets encryption
SecretsEncryptionKey string `yaml:"secrets_encryption_key"` // AES-256 key (32 bytes, hex-encoded)
// Logging
LogInvocations bool `yaml:"log_invocations"` // Log all invocations to database
LogRetention int `yaml:"log_retention"` // Days to retain logs
}
// DefaultConfig returns a configuration with sensible defaults.
func DefaultConfig() *Config {
return &Config{
// Memory limits
DefaultMemoryLimitMB: 64,
MaxMemoryLimitMB: 256,
// Execution limits
DefaultTimeoutSeconds: 30,
MaxTimeoutSeconds: 300, // 5 minutes max
// Retry configuration
DefaultRetryCount: 0,
MaxRetryCount: 5,
DefaultRetryDelaySeconds: 5,
// Rate limiting
GlobalRateLimitPerMinute: 10000, // 10k requests/minute globally
// Background jobs
JobWorkers: 4,
JobPollInterval: time.Second,
JobMaxQueueSize: 10000,
JobMaxPayloadSize: 1024 * 1024, // 1MB
// Scheduler
CronPollInterval: time.Minute,
TimerPollInterval: time.Second,
DBPollInterval: time.Second * 5,
// WASM cache
ModuleCacheSize: 100,
EnablePrewarm: true,
// Logging
LogInvocations: true,
LogRetention: 7, // 7 days
}
}
// Validate checks the configuration for errors.
func (c *Config) Validate() []error {
var errs []error
if c.DefaultMemoryLimitMB <= 0 {
errs = append(errs, &ConfigError{Field: "DefaultMemoryLimitMB", Message: "must be positive"})
}
if c.MaxMemoryLimitMB < c.DefaultMemoryLimitMB {
errs = append(errs, &ConfigError{Field: "MaxMemoryLimitMB", Message: "must be >= DefaultMemoryLimitMB"})
}
if c.DefaultTimeoutSeconds <= 0 {
errs = append(errs, &ConfigError{Field: "DefaultTimeoutSeconds", Message: "must be positive"})
}
if c.MaxTimeoutSeconds < c.DefaultTimeoutSeconds {
errs = append(errs, &ConfigError{Field: "MaxTimeoutSeconds", Message: "must be >= DefaultTimeoutSeconds"})
}
if c.GlobalRateLimitPerMinute <= 0 {
errs = append(errs, &ConfigError{Field: "GlobalRateLimitPerMinute", Message: "must be positive"})
}
if c.JobWorkers <= 0 {
errs = append(errs, &ConfigError{Field: "JobWorkers", Message: "must be positive"})
}
if c.ModuleCacheSize <= 0 {
errs = append(errs, &ConfigError{Field: "ModuleCacheSize", Message: "must be positive"})
}
return errs
}
// ApplyDefaults fills in zero values with defaults.
func (c *Config) ApplyDefaults() {
defaults := DefaultConfig()
if c.DefaultMemoryLimitMB == 0 {
c.DefaultMemoryLimitMB = defaults.DefaultMemoryLimitMB
}
if c.MaxMemoryLimitMB == 0 {
c.MaxMemoryLimitMB = defaults.MaxMemoryLimitMB
}
if c.DefaultTimeoutSeconds == 0 {
c.DefaultTimeoutSeconds = defaults.DefaultTimeoutSeconds
}
if c.MaxTimeoutSeconds == 0 {
c.MaxTimeoutSeconds = defaults.MaxTimeoutSeconds
}
if c.GlobalRateLimitPerMinute == 0 {
c.GlobalRateLimitPerMinute = defaults.GlobalRateLimitPerMinute
}
if c.JobWorkers == 0 {
c.JobWorkers = defaults.JobWorkers
}
if c.JobPollInterval == 0 {
c.JobPollInterval = defaults.JobPollInterval
}
if c.JobMaxQueueSize == 0 {
c.JobMaxQueueSize = defaults.JobMaxQueueSize
}
if c.JobMaxPayloadSize == 0 {
c.JobMaxPayloadSize = defaults.JobMaxPayloadSize
}
if c.CronPollInterval == 0 {
c.CronPollInterval = defaults.CronPollInterval
}
if c.TimerPollInterval == 0 {
c.TimerPollInterval = defaults.TimerPollInterval
}
if c.DBPollInterval == 0 {
c.DBPollInterval = defaults.DBPollInterval
}
if c.ModuleCacheSize == 0 {
c.ModuleCacheSize = defaults.ModuleCacheSize
}
if c.LogRetention == 0 {
c.LogRetention = defaults.LogRetention
}
}
// WithMemoryLimit returns a copy with the memory limit set.
func (c *Config) WithMemoryLimit(defaultMB, maxMB int) *Config {
copy := *c
copy.DefaultMemoryLimitMB = defaultMB
copy.MaxMemoryLimitMB = maxMB
return &copy
}
// WithTimeout returns a copy with the timeout set.
func (c *Config) WithTimeout(defaultSec, maxSec int) *Config {
copy := *c
copy.DefaultTimeoutSeconds = defaultSec
copy.MaxTimeoutSeconds = maxSec
return &copy
}
// WithRateLimit returns a copy with the rate limit set.
func (c *Config) WithRateLimit(perMinute int) *Config {
copy := *c
copy.GlobalRateLimitPerMinute = perMinute
return &copy
}

736
pkg/serverless/engine.go Normal file
View File

@ -0,0 +1,736 @@
package serverless
import (
"bytes"
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/google/uuid"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"go.uber.org/zap"
)
// contextAwareHostServices is an internal interface for services that need to know about
// the current invocation context.
type contextAwareHostServices interface {
SetInvocationContext(invCtx *InvocationContext)
ClearContext()
}
// Ensure Engine implements FunctionExecutor interface.
var _ FunctionExecutor = (*Engine)(nil)
// Engine is the core WASM execution engine using wazero.
// It manages compiled module caching and function execution.
type Engine struct {
runtime wazero.Runtime
config *Config
registry FunctionRegistry
hostServices HostServices
logger *zap.Logger
// Module cache: wasmCID -> compiled module
moduleCache map[string]wazero.CompiledModule
moduleCacheMu sync.RWMutex
// Invocation logger for metrics/debugging
invocationLogger InvocationLogger
// Rate limiter
rateLimiter RateLimiter
}
// InvocationLogger logs function invocations (optional).
type InvocationLogger interface {
Log(ctx context.Context, inv *InvocationRecord) error
}
// InvocationRecord represents a logged invocation.
type InvocationRecord struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
RequestID string `json:"request_id"`
TriggerType TriggerType `json:"trigger_type"`
CallerWallet string `json:"caller_wallet,omitempty"`
InputSize int `json:"input_size"`
OutputSize int `json:"output_size"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at"`
DurationMS int64 `json:"duration_ms"`
Status InvocationStatus `json:"status"`
ErrorMessage string `json:"error_message,omitempty"`
MemoryUsedMB float64 `json:"memory_used_mb"`
Logs []LogEntry `json:"logs,omitempty"`
}
// RateLimiter checks if a request should be rate limited.
type RateLimiter interface {
Allow(ctx context.Context, key string) (bool, error)
}
// EngineOption configures the Engine.
type EngineOption func(*Engine)
// WithInvocationLogger sets the invocation logger.
func WithInvocationLogger(logger InvocationLogger) EngineOption {
return func(e *Engine) {
e.invocationLogger = logger
}
}
// WithRateLimiter sets the rate limiter.
func WithRateLimiter(limiter RateLimiter) EngineOption {
return func(e *Engine) {
e.rateLimiter = limiter
}
}
// NewEngine creates a new WASM execution engine.
func NewEngine(cfg *Config, registry FunctionRegistry, hostServices HostServices, logger *zap.Logger, opts ...EngineOption) (*Engine, error) {
if cfg == nil {
cfg = DefaultConfig()
}
cfg.ApplyDefaults()
// Create wazero runtime with compilation cache
runtimeConfig := wazero.NewRuntimeConfig().
WithCloseOnContextDone(true)
runtime := wazero.NewRuntimeWithConfig(context.Background(), runtimeConfig)
// Instantiate WASI - required for WASM modules compiled with TinyGo targeting WASI
wasi_snapshot_preview1.MustInstantiate(context.Background(), runtime)
engine := &Engine{
runtime: runtime,
config: cfg,
registry: registry,
hostServices: hostServices,
logger: logger,
moduleCache: make(map[string]wazero.CompiledModule),
}
// Apply options
for _, opt := range opts {
opt(engine)
}
// Register host functions
if err := engine.registerHostModule(context.Background()); err != nil {
return nil, fmt.Errorf("failed to register host module: %w", err)
}
return engine, nil
}
// Execute runs a function with the given input and returns the output.
func (e *Engine) Execute(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error) {
if fn == nil {
return nil, &ValidationError{Field: "function", Message: "cannot be nil"}
}
if invCtx == nil {
invCtx = &InvocationContext{
RequestID: uuid.New().String(),
FunctionID: fn.ID,
FunctionName: fn.Name,
Namespace: fn.Namespace,
TriggerType: TriggerTypeHTTP,
}
}
startTime := time.Now()
// Check rate limit
if e.rateLimiter != nil {
allowed, err := e.rateLimiter.Allow(ctx, "global")
if err != nil {
e.logger.Warn("Rate limiter error", zap.Error(err))
} else if !allowed {
return nil, ErrRateLimited
}
}
// Create timeout context
timeout := time.Duration(fn.TimeoutSeconds) * time.Second
if timeout > time.Duration(e.config.MaxTimeoutSeconds)*time.Second {
timeout = time.Duration(e.config.MaxTimeoutSeconds) * time.Second
}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Get compiled module (from cache or compile)
module, err := e.getOrCompileModule(execCtx, fn.WASMCID)
if err != nil {
e.logInvocation(ctx, fn, invCtx, startTime, 0, InvocationStatusError, err)
return nil, &ExecutionError{FunctionName: fn.Name, RequestID: invCtx.RequestID, Cause: err}
}
// Execute the module
output, err := e.executeModule(execCtx, module, fn, input, invCtx)
if err != nil {
status := InvocationStatusError
if execCtx.Err() == context.DeadlineExceeded {
status = InvocationStatusTimeout
err = ErrTimeout
}
e.logInvocation(ctx, fn, invCtx, startTime, len(output), status, err)
return nil, &ExecutionError{FunctionName: fn.Name, RequestID: invCtx.RequestID, Cause: err}
}
e.logInvocation(ctx, fn, invCtx, startTime, len(output), InvocationStatusSuccess, nil)
return output, nil
}
// Precompile compiles a WASM module and caches it for faster execution.
func (e *Engine) Precompile(ctx context.Context, wasmCID string, wasmBytes []byte) error {
if wasmCID == "" {
return &ValidationError{Field: "wasmCID", Message: "cannot be empty"}
}
if len(wasmBytes) == 0 {
return &ValidationError{Field: "wasmBytes", Message: "cannot be empty"}
}
// Check if already cached
e.moduleCacheMu.RLock()
_, exists := e.moduleCache[wasmCID]
e.moduleCacheMu.RUnlock()
if exists {
return nil
}
// Compile the module
compiled, err := e.runtime.CompileModule(ctx, wasmBytes)
if err != nil {
return &DeployError{FunctionName: wasmCID, Cause: fmt.Errorf("failed to compile WASM: %w", err)}
}
// Cache the compiled module
e.moduleCacheMu.Lock()
defer e.moduleCacheMu.Unlock()
// Evict oldest if cache is full
if len(e.moduleCache) >= e.config.ModuleCacheSize {
e.evictOldestModule()
}
e.moduleCache[wasmCID] = compiled
e.logger.Debug("Module precompiled and cached",
zap.String("wasm_cid", wasmCID),
zap.Int("cache_size", len(e.moduleCache)),
)
return nil
}
// Invalidate removes a compiled module from the cache.
func (e *Engine) Invalidate(wasmCID string) {
e.moduleCacheMu.Lock()
defer e.moduleCacheMu.Unlock()
if module, exists := e.moduleCache[wasmCID]; exists {
_ = module.Close(context.Background())
delete(e.moduleCache, wasmCID)
e.logger.Debug("Module invalidated from cache", zap.String("wasm_cid", wasmCID))
}
}
// Close shuts down the engine and releases resources.
func (e *Engine) Close(ctx context.Context) error {
e.moduleCacheMu.Lock()
defer e.moduleCacheMu.Unlock()
// Close all cached modules
for cid, module := range e.moduleCache {
if err := module.Close(ctx); err != nil {
e.logger.Warn("Failed to close cached module", zap.String("cid", cid), zap.Error(err))
}
}
e.moduleCache = make(map[string]wazero.CompiledModule)
// Close the runtime
return e.runtime.Close(ctx)
}
// GetCacheStats returns cache statistics.
func (e *Engine) GetCacheStats() (size int, capacity int) {
e.moduleCacheMu.RLock()
defer e.moduleCacheMu.RUnlock()
return len(e.moduleCache), e.config.ModuleCacheSize
}
// -----------------------------------------------------------------------------
// Private methods
// -----------------------------------------------------------------------------
// getOrCompileModule retrieves a compiled module from cache or compiles it.
func (e *Engine) getOrCompileModule(ctx context.Context, wasmCID string) (wazero.CompiledModule, error) {
// Check cache first
e.moduleCacheMu.RLock()
if module, exists := e.moduleCache[wasmCID]; exists {
e.moduleCacheMu.RUnlock()
return module, nil
}
e.moduleCacheMu.RUnlock()
// Fetch WASM bytes from registry
wasmBytes, err := e.registry.GetWASMBytes(ctx, wasmCID)
if err != nil {
return nil, fmt.Errorf("failed to fetch WASM: %w", err)
}
// Compile the module
compiled, err := e.runtime.CompileModule(ctx, wasmBytes)
if err != nil {
return nil, ErrCompilationFailed
}
// Cache the compiled module
e.moduleCacheMu.Lock()
defer e.moduleCacheMu.Unlock()
// Double-check (another goroutine might have added it)
if existingModule, exists := e.moduleCache[wasmCID]; exists {
_ = compiled.Close(ctx) // Discard our compilation
return existingModule, nil
}
// Evict if cache is full
if len(e.moduleCache) >= e.config.ModuleCacheSize {
e.evictOldestModule()
}
e.moduleCache[wasmCID] = compiled
e.logger.Debug("Module compiled and cached",
zap.String("wasm_cid", wasmCID),
zap.Int("cache_size", len(e.moduleCache)),
)
return compiled, nil
}
// executeModule instantiates and runs a WASM module.
func (e *Engine) executeModule(ctx context.Context, compiled wazero.CompiledModule, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error) {
// Set invocation context for host functions if the service supports it
if hf, ok := e.hostServices.(contextAwareHostServices); ok {
hf.SetInvocationContext(invCtx)
defer hf.ClearContext()
}
// Create buffers for stdin/stdout (WASI uses these for I/O)
stdin := bytes.NewReader(input)
stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer)
// Create module configuration with WASI stdio
moduleConfig := wazero.NewModuleConfig().
WithName(fn.Name).
WithStdin(stdin).
WithStdout(stdout).
WithStderr(stderr).
WithArgs(fn.Name) // argv[0] is the program name
// Instantiate and run the module (WASI _start will be called automatically)
instance, err := e.runtime.InstantiateModule(ctx, compiled, moduleConfig)
if err != nil {
// Check if stderr has any output
if stderr.Len() > 0 {
e.logger.Warn("WASM stderr output", zap.String("stderr", stderr.String()))
}
return nil, fmt.Errorf("failed to instantiate module: %w", err)
}
defer instance.Close(ctx)
// For WASI modules, the output is already in stdout buffer
// The _start function was called during instantiation
output := stdout.Bytes()
// Log stderr if any
if stderr.Len() > 0 {
e.logger.Debug("WASM stderr", zap.String("stderr", stderr.String()))
}
return output, nil
}
// callHandleFunction calls the main 'handle' export in the WASM module.
func (e *Engine) callHandleFunction(ctx context.Context, instance api.Module, input []byte, invCtx *InvocationContext) ([]byte, error) {
// Get the 'handle' function export
handleFn := instance.ExportedFunction("handle")
if handleFn == nil {
return nil, fmt.Errorf("WASM module does not export 'handle' function")
}
// Get memory export
memory := instance.ExportedMemory("memory")
if memory == nil {
return nil, fmt.Errorf("WASM module does not export 'memory'")
}
// Get malloc/free exports for memory management
mallocFn := instance.ExportedFunction("malloc")
freeFn := instance.ExportedFunction("free")
var inputPtr uint32
var inputLen = uint32(len(input))
if mallocFn != nil && len(input) > 0 {
// Allocate memory for input
results, err := mallocFn.Call(ctx, uint64(inputLen))
if err != nil {
return nil, fmt.Errorf("malloc failed: %w", err)
}
inputPtr = uint32(results[0])
// Write input to memory
if !memory.Write(inputPtr, input) {
return nil, fmt.Errorf("failed to write input to WASM memory")
}
// Defer free if available
if freeFn != nil {
defer func() {
_, _ = freeFn.Call(ctx, uint64(inputPtr))
}()
}
}
// Call handle(input_ptr, input_len)
// Returns: output_ptr (packed with length in upper 32 bits)
results, err := handleFn.Call(ctx, uint64(inputPtr), uint64(inputLen))
if err != nil {
return nil, fmt.Errorf("handle function error: %w", err)
}
if len(results) == 0 {
return nil, nil // No output
}
// Parse result - assume format: lower 32 bits = ptr, upper 32 bits = len
result := results[0]
outputPtr := uint32(result & 0xFFFFFFFF)
outputLen := uint32(result >> 32)
if outputLen == 0 {
return nil, nil
}
// Read output from memory
output, ok := memory.Read(outputPtr, outputLen)
if !ok {
return nil, fmt.Errorf("failed to read output from WASM memory")
}
// Make a copy (memory will be freed)
outputCopy := make([]byte, len(output))
copy(outputCopy, output)
return outputCopy, nil
}
// evictOldestModule removes the oldest module from cache.
// Must be called with moduleCacheMu held.
func (e *Engine) evictOldestModule() {
// Simple LRU: just remove the first one we find
// In production, you'd want proper LRU tracking
for cid, module := range e.moduleCache {
_ = module.Close(context.Background())
delete(e.moduleCache, cid)
e.logger.Debug("Evicted module from cache", zap.String("wasm_cid", cid))
break
}
}
// logInvocation logs an invocation record.
func (e *Engine) logInvocation(ctx context.Context, fn *Function, invCtx *InvocationContext, startTime time.Time, outputSize int, status InvocationStatus, err error) {
if e.invocationLogger == nil || !e.config.LogInvocations {
return
}
completedAt := time.Now()
record := &InvocationRecord{
ID: uuid.New().String(),
FunctionID: fn.ID,
RequestID: invCtx.RequestID,
TriggerType: invCtx.TriggerType,
CallerWallet: invCtx.CallerWallet,
OutputSize: outputSize,
StartedAt: startTime,
CompletedAt: completedAt,
DurationMS: completedAt.Sub(startTime).Milliseconds(),
Status: status,
}
if err != nil {
record.ErrorMessage = err.Error()
}
// Collect logs from host services if supported
if hf, ok := e.hostServices.(interface{ GetLogs() []LogEntry }); ok {
record.Logs = hf.GetLogs()
}
if logErr := e.invocationLogger.Log(ctx, record); logErr != nil {
e.logger.Warn("Failed to log invocation", zap.Error(logErr))
}
}
// registerHostModule registers the Orama host functions with the wazero runtime.
func (e *Engine) registerHostModule(ctx context.Context) error {
// Register under both "env" and "host" to support different import styles
// The user requested "env" in instructions but "host" in expected result.
for _, moduleName := range []string{"env", "host"} {
_, err := e.runtime.NewHostModuleBuilder(moduleName).
NewFunctionBuilder().WithFunc(e.hGetCallerWallet).Export("get_caller_wallet").
NewFunctionBuilder().WithFunc(e.hGetRequestID).Export("get_request_id").
NewFunctionBuilder().WithFunc(e.hGetEnv).Export("get_env").
NewFunctionBuilder().WithFunc(e.hGetSecret).Export("get_secret").
NewFunctionBuilder().WithFunc(e.hDBQuery).Export("db_query").
NewFunctionBuilder().WithFunc(e.hDBExecute).Export("db_execute").
NewFunctionBuilder().WithFunc(e.hCacheGet).Export("cache_get").
NewFunctionBuilder().WithFunc(e.hCacheSet).Export("cache_set").
NewFunctionBuilder().WithFunc(e.hCacheIncr).Export("cache_incr").
NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by").
NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch").
NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish").
NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info").
NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error").
Instantiate(ctx)
if err != nil {
return err
}
}
return nil
}
func (e *Engine) hGetCallerWallet(ctx context.Context, mod api.Module) uint64 {
wallet := e.hostServices.GetCallerWallet(ctx)
return e.writeToGuest(ctx, mod, []byte(wallet))
}
func (e *Engine) hGetRequestID(ctx context.Context, mod api.Module) uint64 {
rid := e.hostServices.GetRequestID(ctx)
return e.writeToGuest(ctx, mod, []byte(rid))
}
func (e *Engine) hGetEnv(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) uint64 {
key, ok := mod.Memory().Read(keyPtr, keyLen)
if !ok {
return 0
}
val, _ := e.hostServices.GetEnv(ctx, string(key))
return e.writeToGuest(ctx, mod, []byte(val))
}
func (e *Engine) hGetSecret(ctx context.Context, mod api.Module, namePtr, nameLen uint32) uint64 {
name, ok := mod.Memory().Read(namePtr, nameLen)
if !ok {
return 0
}
val, err := e.hostServices.GetSecret(ctx, string(name))
if err != nil {
return 0
}
return e.writeToGuest(ctx, mod, []byte(val))
}
func (e *Engine) hDBQuery(ctx context.Context, mod api.Module, queryPtr, queryLen, argsPtr, argsLen uint32) uint64 {
query, ok := mod.Memory().Read(queryPtr, queryLen)
if !ok {
return 0
}
var args []interface{}
if argsLen > 0 {
argsData, ok := mod.Memory().Read(argsPtr, argsLen)
if !ok {
return 0
}
if err := json.Unmarshal(argsData, &args); err != nil {
e.logger.Error("failed to unmarshal db_query arguments", zap.Error(err))
return 0
}
}
results, err := e.hostServices.DBQuery(ctx, string(query), args)
if err != nil {
e.logger.Error("host function db_query failed", zap.Error(err), zap.String("query", string(query)))
return 0
}
return e.writeToGuest(ctx, mod, results)
}
func (e *Engine) hDBExecute(ctx context.Context, mod api.Module, queryPtr, queryLen, argsPtr, argsLen uint32) uint32 {
query, ok := mod.Memory().Read(queryPtr, queryLen)
if !ok {
return 0
}
var args []interface{}
if argsLen > 0 {
argsData, ok := mod.Memory().Read(argsPtr, argsLen)
if !ok {
return 0
}
if err := json.Unmarshal(argsData, &args); err != nil {
e.logger.Error("failed to unmarshal db_execute arguments", zap.Error(err))
return 0
}
}
affected, err := e.hostServices.DBExecute(ctx, string(query), args)
if err != nil {
e.logger.Error("host function db_execute failed", zap.Error(err), zap.String("query", string(query)))
return 0
}
return uint32(affected)
}
func (e *Engine) hCacheGet(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) uint64 {
key, ok := mod.Memory().Read(keyPtr, keyLen)
if !ok {
return 0
}
val, err := e.hostServices.CacheGet(ctx, string(key))
if err != nil {
return 0
}
return e.writeToGuest(ctx, mod, val)
}
func (e *Engine) hCacheSet(ctx context.Context, mod api.Module, keyPtr, keyLen, valPtr, valLen uint32, ttl int64) {
key, ok := mod.Memory().Read(keyPtr, keyLen)
if !ok {
return
}
val, ok := mod.Memory().Read(valPtr, valLen)
if !ok {
return
}
_ = e.hostServices.CacheSet(ctx, string(key), val, ttl)
}
func (e *Engine) hCacheIncr(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) int64 {
key, ok := mod.Memory().Read(keyPtr, keyLen)
if !ok {
return 0
}
val, err := e.hostServices.CacheIncr(ctx, string(key))
if err != nil {
e.logger.Error("host function cache_incr failed", zap.Error(err), zap.String("key", string(key)))
return 0
}
return val
}
func (e *Engine) hCacheIncrBy(ctx context.Context, mod api.Module, keyPtr, keyLen uint32, delta int64) int64 {
key, ok := mod.Memory().Read(keyPtr, keyLen)
if !ok {
return 0
}
val, err := e.hostServices.CacheIncrBy(ctx, string(key), delta)
if err != nil {
e.logger.Error("host function cache_incr_by failed", zap.Error(err), zap.String("key", string(key)), zap.Int64("delta", delta))
return 0
}
return val
}
func (e *Engine) hHTTPFetch(ctx context.Context, mod api.Module, methodPtr, methodLen, urlPtr, urlLen, headersPtr, headersLen, bodyPtr, bodyLen uint32) uint64 {
method, ok := mod.Memory().Read(methodPtr, methodLen)
if !ok {
return 0
}
u, ok := mod.Memory().Read(urlPtr, urlLen)
if !ok {
return 0
}
var headers map[string]string
if headersLen > 0 {
headersData, ok := mod.Memory().Read(headersPtr, headersLen)
if !ok {
return 0
}
if err := json.Unmarshal(headersData, &headers); err != nil {
e.logger.Error("failed to unmarshal http_fetch headers", zap.Error(err))
return 0
}
}
body, ok := mod.Memory().Read(bodyPtr, bodyLen)
if !ok {
return 0
}
resp, err := e.hostServices.HTTPFetch(ctx, string(method), string(u), headers, body)
if err != nil {
e.logger.Error("host function http_fetch failed", zap.Error(err), zap.String("url", string(u)))
return 0
}
return e.writeToGuest(ctx, mod, resp)
}
func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, topicLen, dataPtr, dataLen uint32) uint32 {
topic, ok := mod.Memory().Read(topicPtr, topicLen)
if !ok {
return 0
}
data, ok := mod.Memory().Read(dataPtr, dataLen)
if !ok {
return 0
}
err := e.hostServices.PubSubPublish(ctx, string(topic), data)
if err != nil {
e.logger.Error("host function pubsub_publish failed", zap.Error(err), zap.String("topic", string(topic)))
return 0
}
return 1 // Success
}
func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) {
msg, ok := mod.Memory().Read(ptr, size)
if ok {
e.hostServices.LogInfo(ctx, string(msg))
}
}
func (e *Engine) hLogError(ctx context.Context, mod api.Module, ptr, size uint32) {
msg, ok := mod.Memory().Read(ptr, size)
if ok {
e.hostServices.LogError(ctx, string(msg))
}
}
func (e *Engine) writeToGuest(ctx context.Context, mod api.Module, data []byte) uint64 {
if len(data) == 0 {
return 0
}
// Try to find a non-conflicting allocator first, fallback to malloc
malloc := mod.ExportedFunction("orama_alloc")
if malloc == nil {
malloc = mod.ExportedFunction("malloc")
}
if malloc == nil {
e.logger.Warn("WASM module missing malloc/orama_alloc export, cannot return string/bytes to guest")
return 0
}
results, err := malloc.Call(ctx, uint64(len(data)))
if err != nil {
e.logger.Error("failed to call malloc in WASM module", zap.Error(err))
return 0
}
ptr := uint32(results[0])
if !mod.Memory().Write(ptr, data) {
e.logger.Error("failed to write to WASM memory")
return 0
}
return (uint64(ptr) << 32) | uint64(len(data))
}

View File

@ -0,0 +1,202 @@
package serverless
import (
"context"
"os"
"testing"
"go.uber.org/zap"
)
func TestEngine_Execute(t *testing.T) {
logger := zap.NewNop()
registry := NewMockRegistry()
hostServices := NewMockHostServices()
cfg := DefaultConfig()
cfg.ModuleCacheSize = 2
engine, err := NewEngine(cfg, registry, hostServices, logger)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close(context.Background())
// Use a minimal valid WASM module that exports _start (WASI)
// This is just 'nop' in WASM
wasmBytes := []byte{
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
0x03, 0x02, 0x01, 0x00,
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
}
fnDef := &FunctionDefinition{
Name: "test-func",
Namespace: "test-ns",
MemoryLimitMB: 64,
TimeoutSeconds: 5,
}
_, err = registry.Register(context.Background(), fnDef, wasmBytes)
if err != nil {
t.Fatalf("failed to register function: %v", err)
}
fn, err := registry.Get(context.Background(), "test-ns", "test-func", 0)
if err != nil {
t.Fatalf("failed to get function: %v", err)
}
// Execute function
ctx := context.Background()
output, err := engine.Execute(ctx, fn, []byte("input"), nil)
if err != nil {
t.Errorf("failed to execute function: %v", err)
}
// Our minimal WASM doesn't write to stdout, so output should be empty
if len(output) != 0 {
t.Errorf("expected empty output, got %d bytes", len(output))
}
// Test cache stats
size, capacity := engine.GetCacheStats()
if size != 1 {
t.Errorf("expected cache size 1, got %d", size)
}
if capacity != 2 {
t.Errorf("expected cache capacity 2, got %d", capacity)
}
// Test Invalidate
engine.Invalidate(fn.WASMCID)
size, _ = engine.GetCacheStats()
if size != 0 {
t.Errorf("expected cache size 0 after invalidation, got %d", size)
}
}
func TestEngine_Precompile(t *testing.T) {
logger := zap.NewNop()
registry := NewMockRegistry()
hostServices := NewMockHostServices()
engine, _ := NewEngine(nil, registry, hostServices, logger)
defer engine.Close(context.Background())
wasmBytes := []byte{
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
0x03, 0x02, 0x01, 0x00,
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
}
err := engine.Precompile(context.Background(), "test-cid", wasmBytes)
if err != nil {
t.Fatalf("failed to precompile: %v", err)
}
size, _ := engine.GetCacheStats()
if size != 1 {
t.Errorf("expected cache size 1, got %d", size)
}
}
func TestEngine_Timeout(t *testing.T) {
logger := zap.NewNop()
registry := NewMockRegistry()
hostServices := NewMockHostServices()
engine, _ := NewEngine(nil, registry, hostServices, logger)
defer engine.Close(context.Background())
wasmBytes := []byte{
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
0x03, 0x02, 0x01, 0x00,
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
}
fn, _ := registry.Get(context.Background(), "test", "timeout", 0)
if fn == nil {
_, _ = registry.Register(context.Background(), &FunctionDefinition{Name: "timeout", Namespace: "test"}, wasmBytes)
fn, _ = registry.Get(context.Background(), "test", "timeout", 0)
}
fn.TimeoutSeconds = 1
// Test with already canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := engine.Execute(ctx, fn, nil, nil)
if err == nil {
t.Error("expected error for canceled context, got nil")
}
}
func TestEngine_MemoryLimit(t *testing.T) {
logger := zap.NewNop()
registry := NewMockRegistry()
hostServices := NewMockHostServices()
engine, _ := NewEngine(nil, registry, hostServices, logger)
defer engine.Close(context.Background())
wasmBytes := []byte{
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
0x03, 0x02, 0x01, 0x00,
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
}
_, _ = registry.Register(context.Background(), &FunctionDefinition{Name: "memory", Namespace: "test", MemoryLimitMB: 1, TimeoutSeconds: 5}, wasmBytes)
fn, _ := registry.Get(context.Background(), "test", "memory", 0)
// This should pass because the minimal WASM doesn't use much memory
_, err := engine.Execute(context.Background(), fn, nil, nil)
if err != nil {
t.Errorf("expected success for minimal WASM within memory limit, got error: %v", err)
}
}
func TestEngine_RealWASM(t *testing.T) {
wasmPath := "../../examples/functions/bin/hello.wasm"
if _, err := os.Stat(wasmPath); os.IsNotExist(err) {
t.Skip("hello.wasm not found")
}
wasmBytes, err := os.ReadFile(wasmPath)
if err != nil {
t.Fatalf("failed to read hello.wasm: %v", err)
}
logger := zap.NewNop()
registry := NewMockRegistry()
hostServices := NewMockHostServices()
engine, _ := NewEngine(nil, registry, hostServices, logger)
defer engine.Close(context.Background())
fnDef := &FunctionDefinition{
Name: "hello",
Namespace: "examples",
TimeoutSeconds: 10,
}
_, _ = registry.Register(context.Background(), fnDef, wasmBytes)
fn, _ := registry.Get(context.Background(), "examples", "hello", 0)
output, err := engine.Execute(context.Background(), fn, []byte(`{"name": "Tester"}`), nil)
if err != nil {
t.Fatalf("execution failed: %v", err)
}
expected := "Hello, Tester!"
if !contains(string(output), expected) {
t.Errorf("output %q does not contain %q", string(output), expected)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s[:len(substr)] == substr || contains(s[1:], substr))
}

216
pkg/serverless/errors.go Normal file
View File

@ -0,0 +1,216 @@
package serverless
import (
"errors"
"fmt"
)
// Sentinel errors for common conditions.
var (
// ErrFunctionNotFound is returned when a function does not exist.
ErrFunctionNotFound = errors.New("function not found")
// ErrFunctionExists is returned when attempting to create a function that already exists.
ErrFunctionExists = errors.New("function already exists")
// ErrVersionNotFound is returned when a specific function version does not exist.
ErrVersionNotFound = errors.New("function version not found")
// ErrSecretNotFound is returned when a secret does not exist.
ErrSecretNotFound = errors.New("secret not found")
// ErrJobNotFound is returned when a job does not exist.
ErrJobNotFound = errors.New("job not found")
// ErrTriggerNotFound is returned when a trigger does not exist.
ErrTriggerNotFound = errors.New("trigger not found")
// ErrTimerNotFound is returned when a timer does not exist.
ErrTimerNotFound = errors.New("timer not found")
// ErrUnauthorized is returned when the caller is not authorized.
ErrUnauthorized = errors.New("unauthorized")
// ErrRateLimited is returned when the rate limit is exceeded.
ErrRateLimited = errors.New("rate limit exceeded")
// ErrInvalidWASM is returned when the WASM module is invalid.
ErrInvalidWASM = errors.New("invalid WASM module")
// ErrCompilationFailed is returned when WASM compilation fails.
ErrCompilationFailed = errors.New("WASM compilation failed")
// ErrExecutionFailed is returned when function execution fails.
ErrExecutionFailed = errors.New("function execution failed")
// ErrTimeout is returned when function execution times out.
ErrTimeout = errors.New("function execution timeout")
// ErrMemoryExceeded is returned when the function exceeds memory limits.
ErrMemoryExceeded = errors.New("memory limit exceeded")
// ErrInvalidInput is returned when function input is invalid.
ErrInvalidInput = errors.New("invalid input")
// ErrWSNotAvailable is returned when WebSocket operations are used outside WS context.
ErrWSNotAvailable = errors.New("websocket operations not available in this context")
// ErrWSClientNotFound is returned when a WebSocket client is not connected.
ErrWSClientNotFound = errors.New("websocket client not found")
// ErrInvalidCronExpression is returned when a cron expression is invalid.
ErrInvalidCronExpression = errors.New("invalid cron expression")
// ErrPayloadTooLarge is returned when a job payload exceeds the maximum size.
ErrPayloadTooLarge = errors.New("payload too large")
// ErrQueueFull is returned when the job queue is full.
ErrQueueFull = errors.New("job queue is full")
// ErrJobCancelled is returned when a job is cancelled.
ErrJobCancelled = errors.New("job cancelled")
// ErrStorageUnavailable is returned when IPFS storage is unavailable.
ErrStorageUnavailable = errors.New("storage unavailable")
// ErrDatabaseUnavailable is returned when the database is unavailable.
ErrDatabaseUnavailable = errors.New("database unavailable")
// ErrCacheUnavailable is returned when the cache is unavailable.
ErrCacheUnavailable = errors.New("cache unavailable")
)
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// DeployError represents an error during function deployment.
type DeployError struct {
FunctionName string
Cause error
}
func (e *DeployError) Error() string {
return fmt.Sprintf("deploy error for function '%s': %v", e.FunctionName, e.Cause)
}
func (e *DeployError) Unwrap() error {
return e.Cause
}
// ExecutionError represents an error during function execution.
type ExecutionError struct {
FunctionName string
RequestID string
Cause error
}
func (e *ExecutionError) Error() string {
return fmt.Sprintf("execution error for function '%s' (request %s): %v",
e.FunctionName, e.RequestID, e.Cause)
}
func (e *ExecutionError) Unwrap() error {
return e.Cause
}
// HostFunctionError represents an error in a host function call.
type HostFunctionError struct {
Function string
Cause error
}
func (e *HostFunctionError) Error() string {
return fmt.Sprintf("host function '%s' error: %v", e.Function, e.Cause)
}
func (e *HostFunctionError) Unwrap() error {
return e.Cause
}
// TriggerError represents an error in trigger execution.
type TriggerError struct {
TriggerType string
TriggerID string
FunctionID string
Cause error
}
func (e *TriggerError) Error() string {
return fmt.Sprintf("trigger error (%s/%s) for function '%s': %v",
e.TriggerType, e.TriggerID, e.FunctionID, e.Cause)
}
func (e *TriggerError) Unwrap() error {
return e.Cause
}
// ValidationError represents an input validation error.
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return fmt.Sprintf("validation error: %s: %s", e.Field, e.Message)
}
// RetryableError wraps an error that should be retried.
type RetryableError struct {
Cause error
RetryAfter int // Suggested retry delay in seconds
MaxRetries int // Maximum number of retries remaining
CurrentTry int // Current attempt number
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error (attempt %d): %v", e.CurrentTry, e.Cause)
}
func (e *RetryableError) Unwrap() error {
return e.Cause
}
// IsRetryable checks if an error should be retried.
func IsRetryable(err error) bool {
var retryable *RetryableError
return errors.As(err, &retryable)
}
// IsNotFound checks if an error indicates a resource was not found.
func IsNotFound(err error) bool {
return errors.Is(err, ErrFunctionNotFound) ||
errors.Is(err, ErrVersionNotFound) ||
errors.Is(err, ErrSecretNotFound) ||
errors.Is(err, ErrJobNotFound) ||
errors.Is(err, ErrTriggerNotFound) ||
errors.Is(err, ErrTimerNotFound) ||
errors.Is(err, ErrWSClientNotFound)
}
// IsUnauthorized checks if an error indicates a lack of authorization.
func IsUnauthorized(err error) bool {
return errors.Is(err, ErrUnauthorized)
}
// IsResourceExhausted checks if an error indicates resource exhaustion.
func IsResourceExhausted(err error) bool {
return errors.Is(err, ErrRateLimited) ||
errors.Is(err, ErrMemoryExceeded) ||
errors.Is(err, ErrPayloadTooLarge) ||
errors.Is(err, ErrQueueFull) ||
errors.Is(err, ErrTimeout)
}
// IsServiceUnavailable checks if an error indicates a service is unavailable.
func IsServiceUnavailable(err error) bool {
return errors.Is(err, ErrStorageUnavailable) ||
errors.Is(err, ErrDatabaseUnavailable) ||
errors.Is(err, ErrCacheUnavailable)
}

687
pkg/serverless/hostfuncs.go Normal file
View File

@ -0,0 +1,687 @@
package serverless
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/tlsutil"
olriclib "github.com/olric-data/olric"
"go.uber.org/zap"
)
// Ensure HostFunctions implements HostServices interface.
var _ HostServices = (*HostFunctions)(nil)
// HostFunctions provides the bridge between WASM functions and Orama services.
// It implements the HostServices interface and is injected into the execution context.
type HostFunctions struct {
db rqlite.Client
cacheClient olriclib.Client
storage ipfs.IPFSClient
ipfsAPIURL string
pubsub *pubsub.ClientAdapter
wsManager WebSocketManager
secrets SecretsManager
httpClient *http.Client
logger *zap.Logger
// Current invocation context (set per-execution)
invCtx *InvocationContext
invCtxLock sync.RWMutex
// Captured logs for this invocation
logs []LogEntry
logsLock sync.Mutex
}
// HostFunctionsConfig holds configuration for HostFunctions.
type HostFunctionsConfig struct {
IPFSAPIURL string
HTTPTimeout time.Duration
}
// NewHostFunctions creates a new HostFunctions instance.
func NewHostFunctions(
db rqlite.Client,
cacheClient olriclib.Client,
storage ipfs.IPFSClient,
pubsubAdapter *pubsub.ClientAdapter,
wsManager WebSocketManager,
secrets SecretsManager,
cfg HostFunctionsConfig,
logger *zap.Logger,
) *HostFunctions {
httpTimeout := cfg.HTTPTimeout
if httpTimeout == 0 {
httpTimeout = 30 * time.Second
}
return &HostFunctions{
db: db,
cacheClient: cacheClient,
storage: storage,
ipfsAPIURL: cfg.IPFSAPIURL,
pubsub: pubsubAdapter,
wsManager: wsManager,
secrets: secrets,
httpClient: tlsutil.NewHTTPClient(httpTimeout),
logger: logger,
logs: make([]LogEntry, 0),
}
}
// SetInvocationContext sets the current invocation context.
// Must be called before executing a function.
func (h *HostFunctions) SetInvocationContext(invCtx *InvocationContext) {
h.invCtxLock.Lock()
defer h.invCtxLock.Unlock()
h.invCtx = invCtx
h.logs = make([]LogEntry, 0) // Reset logs for new invocation
}
// GetLogs returns the captured logs for the current invocation.
func (h *HostFunctions) GetLogs() []LogEntry {
h.logsLock.Lock()
defer h.logsLock.Unlock()
logsCopy := make([]LogEntry, len(h.logs))
copy(logsCopy, h.logs)
return logsCopy
}
// ClearContext clears the invocation context after execution.
func (h *HostFunctions) ClearContext() {
h.invCtxLock.Lock()
defer h.invCtxLock.Unlock()
h.invCtx = nil
}
// -----------------------------------------------------------------------------
// Database Operations
// -----------------------------------------------------------------------------
// DBQuery executes a SELECT query and returns JSON-encoded results.
func (h *HostFunctions) DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error) {
if h.db == nil {
return nil, &HostFunctionError{Function: "db_query", Cause: ErrDatabaseUnavailable}
}
var results []map[string]interface{}
if err := h.db.Query(ctx, &results, query, args...); err != nil {
return nil, &HostFunctionError{Function: "db_query", Cause: err}
}
data, err := json.Marshal(results)
if err != nil {
return nil, &HostFunctionError{Function: "db_query", Cause: fmt.Errorf("failed to marshal results: %w", err)}
}
return data, nil
}
// DBExecute executes an INSERT/UPDATE/DELETE query and returns affected rows.
func (h *HostFunctions) DBExecute(ctx context.Context, query string, args []interface{}) (int64, error) {
if h.db == nil {
return 0, &HostFunctionError{Function: "db_execute", Cause: ErrDatabaseUnavailable}
}
result, err := h.db.Exec(ctx, query, args...)
if err != nil {
return 0, &HostFunctionError{Function: "db_execute", Cause: err}
}
affected, _ := result.RowsAffected()
return affected, nil
}
// -----------------------------------------------------------------------------
// Cache Operations
// -----------------------------------------------------------------------------
const cacheDMapName = "serverless_cache"
// CacheGet retrieves a value from the cache.
func (h *HostFunctions) CacheGet(ctx context.Context, key string) ([]byte, error) {
if h.cacheClient == nil {
return nil, &HostFunctionError{Function: "cache_get", Cause: ErrCacheUnavailable}
}
dm, err := h.cacheClient.NewDMap(cacheDMapName)
if err != nil {
return nil, &HostFunctionError{Function: "cache_get", Cause: fmt.Errorf("failed to get DMap: %w", err)}
}
result, err := dm.Get(ctx, key)
if err != nil {
return nil, &HostFunctionError{Function: "cache_get", Cause: err}
}
value, err := result.Byte()
if err != nil {
return nil, &HostFunctionError{Function: "cache_get", Cause: fmt.Errorf("failed to decode value: %w", err)}
}
return value, nil
}
// CacheSet stores a value in the cache with optional TTL.
// Note: TTL is currently not supported by the underlying Olric DMap.Put method.
// Values are stored indefinitely until explicitly deleted.
func (h *HostFunctions) CacheSet(ctx context.Context, key string, value []byte, ttlSeconds int64) error {
if h.cacheClient == nil {
return &HostFunctionError{Function: "cache_set", Cause: ErrCacheUnavailable}
}
dm, err := h.cacheClient.NewDMap(cacheDMapName)
if err != nil {
return &HostFunctionError{Function: "cache_set", Cause: fmt.Errorf("failed to get DMap: %w", err)}
}
// Note: Olric DMap.Put doesn't support TTL in the basic API
// For TTL support, consider using Olric's Expire API separately
if err := dm.Put(ctx, key, value); err != nil {
return &HostFunctionError{Function: "cache_set", Cause: err}
}
return nil
}
// CacheDelete removes a value from the cache.
func (h *HostFunctions) CacheDelete(ctx context.Context, key string) error {
if h.cacheClient == nil {
return &HostFunctionError{Function: "cache_delete", Cause: ErrCacheUnavailable}
}
dm, err := h.cacheClient.NewDMap(cacheDMapName)
if err != nil {
return &HostFunctionError{Function: "cache_delete", Cause: fmt.Errorf("failed to get DMap: %w", err)}
}
if _, err := dm.Delete(ctx, key); err != nil {
return &HostFunctionError{Function: "cache_delete", Cause: err}
}
return nil
}
// CacheIncr atomically increments a numeric value in cache by 1 and returns the new value.
// If the key doesn't exist, it is initialized to 0 before incrementing.
// Returns an error if the value exists but is not numeric.
func (h *HostFunctions) CacheIncr(ctx context.Context, key string) (int64, error) {
return h.CacheIncrBy(ctx, key, 1)
}
// CacheIncrBy atomically increments a numeric value by delta and returns the new value.
// If the key doesn't exist, it is initialized to 0 before incrementing.
// Returns an error if the value exists but is not numeric.
func (h *HostFunctions) CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error) {
if h.cacheClient == nil {
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: ErrCacheUnavailable}
}
dm, err := h.cacheClient.NewDMap(cacheDMapName)
if err != nil {
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: fmt.Errorf("failed to get DMap: %w", err)}
}
// Olric's Incr method atomically increments a numeric value
// It initializes the key to 0 if it doesn't exist, then increments by delta
// Note: Olric's Incr takes int (not int64) and returns int
newValue, err := dm.Incr(ctx, key, int(delta))
if err != nil {
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: fmt.Errorf("failed to increment: %w", err)}
}
return int64(newValue), nil
}
// -----------------------------------------------------------------------------
// Storage Operations
// -----------------------------------------------------------------------------
// StoragePut uploads data to IPFS and returns the CID.
func (h *HostFunctions) StoragePut(ctx context.Context, data []byte) (string, error) {
if h.storage == nil {
return "", &HostFunctionError{Function: "storage_put", Cause: ErrStorageUnavailable}
}
reader := bytes.NewReader(data)
resp, err := h.storage.Add(ctx, reader, "function-data")
if err != nil {
return "", &HostFunctionError{Function: "storage_put", Cause: err}
}
return resp.Cid, nil
}
// StorageGet retrieves data from IPFS by CID.
func (h *HostFunctions) StorageGet(ctx context.Context, cid string) ([]byte, error) {
if h.storage == nil {
return nil, &HostFunctionError{Function: "storage_get", Cause: ErrStorageUnavailable}
}
reader, err := h.storage.Get(ctx, cid, h.ipfsAPIURL)
if err != nil {
return nil, &HostFunctionError{Function: "storage_get", Cause: err}
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return nil, &HostFunctionError{Function: "storage_get", Cause: fmt.Errorf("failed to read data: %w", err)}
}
return data, nil
}
// -----------------------------------------------------------------------------
// PubSub Operations
// -----------------------------------------------------------------------------
// PubSubPublish publishes a message to a topic.
func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data []byte) error {
if h.pubsub == nil {
return &HostFunctionError{Function: "pubsub_publish", Cause: fmt.Errorf("pubsub not available")}
}
// The pubsub adapter handles namespacing internally
if err := h.pubsub.Publish(ctx, topic, data); err != nil {
return &HostFunctionError{Function: "pubsub_publish", Cause: err}
}
return nil
}
// -----------------------------------------------------------------------------
// WebSocket Operations
// -----------------------------------------------------------------------------
// WSSend sends data to a specific WebSocket client.
func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error {
if h.wsManager == nil {
return &HostFunctionError{Function: "ws_send", Cause: ErrWSNotAvailable}
}
// If no clientID provided, use the current invocation's client
if clientID == "" {
h.invCtxLock.RLock()
if h.invCtx != nil && h.invCtx.WSClientID != "" {
clientID = h.invCtx.WSClientID
}
h.invCtxLock.RUnlock()
}
if clientID == "" {
return &HostFunctionError{Function: "ws_send", Cause: ErrWSNotAvailable}
}
if err := h.wsManager.Send(clientID, data); err != nil {
return &HostFunctionError{Function: "ws_send", Cause: err}
}
return nil
}
// WSBroadcast sends data to all WebSocket clients subscribed to a topic.
func (h *HostFunctions) WSBroadcast(ctx context.Context, topic string, data []byte) error {
if h.wsManager == nil {
return &HostFunctionError{Function: "ws_broadcast", Cause: ErrWSNotAvailable}
}
if err := h.wsManager.Broadcast(topic, data); err != nil {
return &HostFunctionError{Function: "ws_broadcast", Cause: err}
}
return nil
}
// -----------------------------------------------------------------------------
// HTTP Operations
// -----------------------------------------------------------------------------
// HTTPFetch makes an outbound HTTP request.
func (h *HostFunctions) HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error) {
var bodyReader io.Reader
if len(body) > 0 {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
h.logger.Error("http_fetch request creation error", zap.Error(err), zap.String("url", url))
errorResp := map[string]interface{}{
"error": "failed to create request: " + err.Error(),
"status": 0,
}
return json.Marshal(errorResp)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := h.httpClient.Do(req)
if err != nil {
h.logger.Error("http_fetch transport error", zap.Error(err), zap.String("url", url))
errorResp := map[string]interface{}{
"error": err.Error(),
"status": 0, // Transport error
}
return json.Marshal(errorResp)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
h.logger.Error("http_fetch response read error", zap.Error(err), zap.String("url", url))
errorResp := map[string]interface{}{
"error": "failed to read response: " + err.Error(),
"status": resp.StatusCode,
}
return json.Marshal(errorResp)
}
// Encode response with status code
response := map[string]interface{}{
"status": resp.StatusCode,
"headers": resp.Header,
"body": string(respBody),
}
data, err := json.Marshal(response)
if err != nil {
return nil, &HostFunctionError{Function: "http_fetch", Cause: fmt.Errorf("failed to marshal response: %w", err)}
}
return data, nil
}
// -----------------------------------------------------------------------------
// Context Operations
// -----------------------------------------------------------------------------
// GetEnv retrieves an environment variable for the function.
func (h *HostFunctions) GetEnv(ctx context.Context, key string) (string, error) {
h.invCtxLock.RLock()
defer h.invCtxLock.RUnlock()
if h.invCtx == nil || h.invCtx.EnvVars == nil {
return "", nil
}
return h.invCtx.EnvVars[key], nil
}
// GetSecret retrieves a decrypted secret.
func (h *HostFunctions) GetSecret(ctx context.Context, name string) (string, error) {
if h.secrets == nil {
return "", &HostFunctionError{Function: "get_secret", Cause: fmt.Errorf("secrets manager not available")}
}
h.invCtxLock.RLock()
namespace := ""
if h.invCtx != nil {
namespace = h.invCtx.Namespace
}
h.invCtxLock.RUnlock()
value, err := h.secrets.Get(ctx, namespace, name)
if err != nil {
return "", &HostFunctionError{Function: "get_secret", Cause: err}
}
return value, nil
}
// GetRequestID returns the current request ID.
func (h *HostFunctions) GetRequestID(ctx context.Context) string {
h.invCtxLock.RLock()
defer h.invCtxLock.RUnlock()
if h.invCtx == nil {
return ""
}
return h.invCtx.RequestID
}
// GetCallerWallet returns the wallet address of the caller.
func (h *HostFunctions) GetCallerWallet(ctx context.Context) string {
h.invCtxLock.RLock()
defer h.invCtxLock.RUnlock()
if h.invCtx == nil {
return ""
}
return h.invCtx.CallerWallet
}
// -----------------------------------------------------------------------------
// Job Operations
// -----------------------------------------------------------------------------
// EnqueueBackground queues a function for background execution.
func (h *HostFunctions) EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error) {
// This will be implemented when JobManager is integrated
// For now, return an error indicating it's not yet available
return "", &HostFunctionError{Function: "enqueue_background", Cause: fmt.Errorf("background jobs not yet implemented")}
}
// ScheduleOnce schedules a function to run once at a specific time.
func (h *HostFunctions) ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error) {
// This will be implemented when Scheduler is integrated
return "", &HostFunctionError{Function: "schedule_once", Cause: fmt.Errorf("timers not yet implemented")}
}
// -----------------------------------------------------------------------------
// Logging Operations
// -----------------------------------------------------------------------------
// LogInfo logs an info message.
func (h *HostFunctions) LogInfo(ctx context.Context, message string) {
h.logsLock.Lock()
defer h.logsLock.Unlock()
h.logs = append(h.logs, LogEntry{
Level: "info",
Message: message,
Timestamp: time.Now(),
})
h.logger.Info(message,
zap.String("request_id", h.GetRequestID(ctx)),
zap.String("level", "function"),
)
}
// LogError logs an error message.
func (h *HostFunctions) LogError(ctx context.Context, message string) {
h.logsLock.Lock()
defer h.logsLock.Unlock()
h.logs = append(h.logs, LogEntry{
Level: "error",
Message: message,
Timestamp: time.Now(),
})
h.logger.Error(message,
zap.String("request_id", h.GetRequestID(ctx)),
zap.String("level", "function"),
)
}
// -----------------------------------------------------------------------------
// Secrets Manager Implementation (built-in)
// -----------------------------------------------------------------------------
// DBSecretsManager implements SecretsManager using the database.
type DBSecretsManager struct {
db rqlite.Client
encryptionKey []byte // 32-byte AES-256 key
logger *zap.Logger
}
// Ensure DBSecretsManager implements SecretsManager.
var _ SecretsManager = (*DBSecretsManager)(nil)
// NewDBSecretsManager creates a secrets manager backed by the database.
func NewDBSecretsManager(db rqlite.Client, encryptionKeyHex string, logger *zap.Logger) (*DBSecretsManager, error) {
var key []byte
if encryptionKeyHex != "" {
var err error
key, err = hex.DecodeString(encryptionKeyHex)
if err != nil || len(key) != 32 {
return nil, fmt.Errorf("invalid encryption key: must be 32 bytes hex-encoded")
}
} else {
// Generate a random key if none provided
key = make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("failed to generate encryption key: %w", err)
}
logger.Warn("Generated random secrets encryption key - secrets will not persist across restarts")
}
return &DBSecretsManager{
db: db,
encryptionKey: key,
logger: logger,
}, nil
}
// Set stores an encrypted secret.
func (s *DBSecretsManager) Set(ctx context.Context, namespace, name, value string) error {
encrypted, err := s.encrypt([]byte(value))
if err != nil {
return fmt.Errorf("failed to encrypt secret: %w", err)
}
// Upsert the secret
query := `
INSERT INTO function_secrets (id, namespace, name, encrypted_value, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(namespace, name) DO UPDATE SET
encrypted_value = excluded.encrypted_value,
updated_at = excluded.updated_at
`
id := fmt.Sprintf("%s:%s", namespace, name)
now := time.Now()
if _, err := s.db.Exec(ctx, query, id, namespace, name, encrypted, now, now); err != nil {
return fmt.Errorf("failed to save secret: %w", err)
}
return nil
}
// Get retrieves a decrypted secret.
func (s *DBSecretsManager) Get(ctx context.Context, namespace, name string) (string, error) {
query := `SELECT encrypted_value FROM function_secrets WHERE namespace = ? AND name = ?`
var rows []struct {
EncryptedValue []byte `db:"encrypted_value"`
}
if err := s.db.Query(ctx, &rows, query, namespace, name); err != nil {
return "", fmt.Errorf("failed to query secret: %w", err)
}
if len(rows) == 0 {
return "", ErrSecretNotFound
}
decrypted, err := s.decrypt(rows[0].EncryptedValue)
if err != nil {
return "", fmt.Errorf("failed to decrypt secret: %w", err)
}
return string(decrypted), nil
}
// List returns all secret names for a namespace.
func (s *DBSecretsManager) List(ctx context.Context, namespace string) ([]string, error) {
query := `SELECT name FROM function_secrets WHERE namespace = ? ORDER BY name`
var rows []struct {
Name string `db:"name"`
}
if err := s.db.Query(ctx, &rows, query, namespace); err != nil {
return nil, fmt.Errorf("failed to list secrets: %w", err)
}
names := make([]string, len(rows))
for i, row := range rows {
names[i] = row.Name
}
return names, nil
}
// Delete removes a secret.
func (s *DBSecretsManager) Delete(ctx context.Context, namespace, name string) error {
query := `DELETE FROM function_secrets WHERE namespace = ? AND name = ?`
result, err := s.db.Exec(ctx, query, namespace, name)
if err != nil {
return fmt.Errorf("failed to delete secret: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return ErrSecretNotFound
}
return nil
}
// encrypt encrypts data using AES-256-GCM.
func (s *DBSecretsManager) encrypt(plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(s.encryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
}
// decrypt decrypts data using AES-256-GCM.
func (s *DBSecretsManager) decrypt(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(s.encryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}

View File

@ -0,0 +1,45 @@
package serverless
import (
"context"
"testing"
"go.uber.org/zap"
)
func TestHostFunctions_Cache(t *testing.T) {
db := NewMockRQLite()
ipfs := NewMockIPFSClient()
logger := zap.NewNop()
// MockOlricClient needs to implement olriclib.Client
// For now, let's just test other host functions if Olric is hard to mock
h := NewHostFunctions(db, nil, ipfs, nil, nil, nil, HostFunctionsConfig{}, logger)
ctx := context.Background()
h.SetInvocationContext(&InvocationContext{
RequestID: "req-1",
Namespace: "ns-1",
})
// Test Logging
h.LogInfo(ctx, "hello world")
logs := h.GetLogs()
if len(logs) != 1 || logs[0].Message != "hello world" {
t.Errorf("unexpected logs: %+v", logs)
}
// Test Storage
cid, err := h.StoragePut(ctx, []byte("data"))
if err != nil {
t.Fatalf("StoragePut failed: %v", err)
}
data, err := h.StorageGet(ctx, cid)
if err != nil {
t.Fatalf("StorageGet failed: %v", err)
}
if string(data) != "data" {
t.Errorf("expected 'data', got %q", string(data))
}
}

452
pkg/serverless/invoke.go Normal file
View File

@ -0,0 +1,452 @@
package serverless
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Invoker handles function invocation with retry logic and DLQ support.
// It wraps the Engine to provide higher-level invocation semantics.
type Invoker struct {
engine *Engine
registry FunctionRegistry
hostServices HostServices
logger *zap.Logger
}
// NewInvoker creates a new function invoker.
func NewInvoker(engine *Engine, registry FunctionRegistry, hostServices HostServices, logger *zap.Logger) *Invoker {
return &Invoker{
engine: engine,
registry: registry,
hostServices: hostServices,
logger: logger,
}
}
// InvokeRequest contains the parameters for invoking a function.
type InvokeRequest struct {
Namespace string `json:"namespace"`
FunctionName string `json:"function_name"`
Version int `json:"version,omitempty"` // 0 = latest
Input []byte `json:"input"`
TriggerType TriggerType `json:"trigger_type"`
CallerWallet string `json:"caller_wallet,omitempty"`
WSClientID string `json:"ws_client_id,omitempty"`
}
// InvokeResponse contains the result of a function invocation.
type InvokeResponse struct {
RequestID string `json:"request_id"`
Output []byte `json:"output,omitempty"`
Status InvocationStatus `json:"status"`
Error string `json:"error,omitempty"`
DurationMS int64 `json:"duration_ms"`
Retries int `json:"retries,omitempty"`
}
// Invoke executes a function with automatic retry logic.
func (i *Invoker) Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error) {
if req == nil {
return nil, &ValidationError{Field: "request", Message: "cannot be nil"}
}
if req.FunctionName == "" {
return nil, &ValidationError{Field: "function_name", Message: "cannot be empty"}
}
if req.Namespace == "" {
return nil, &ValidationError{Field: "namespace", Message: "cannot be empty"}
}
requestID := uuid.New().String()
startTime := time.Now()
// Get function from registry
fn, err := i.registry.Get(ctx, req.Namespace, req.FunctionName, req.Version)
if err != nil {
return &InvokeResponse{
RequestID: requestID,
Status: InvocationStatusError,
Error: err.Error(),
DurationMS: time.Since(startTime).Milliseconds(),
}, err
}
// Check authorization
authorized, err := i.CanInvoke(ctx, req.Namespace, req.FunctionName, req.CallerWallet)
if err != nil || !authorized {
return &InvokeResponse{
RequestID: requestID,
Status: InvocationStatusError,
Error: "unauthorized",
DurationMS: time.Since(startTime).Milliseconds(),
}, ErrUnauthorized
}
// Get environment variables
envVars, err := i.getEnvVars(ctx, fn.ID)
if err != nil {
i.logger.Warn("Failed to get env vars", zap.Error(err))
envVars = make(map[string]string)
}
// Build invocation context
invCtx := &InvocationContext{
RequestID: requestID,
FunctionID: fn.ID,
FunctionName: fn.Name,
Namespace: fn.Namespace,
CallerWallet: req.CallerWallet,
TriggerType: req.TriggerType,
WSClientID: req.WSClientID,
EnvVars: envVars,
}
// Execute with retry logic
output, retries, err := i.executeWithRetry(ctx, fn, req.Input, invCtx)
response := &InvokeResponse{
RequestID: requestID,
Output: output,
DurationMS: time.Since(startTime).Milliseconds(),
Retries: retries,
}
if err != nil {
response.Status = InvocationStatusError
response.Error = err.Error()
// Check if it's a timeout
if ctx.Err() == context.DeadlineExceeded {
response.Status = InvocationStatusTimeout
}
return response, err
}
response.Status = InvocationStatusSuccess
return response, nil
}
// InvokeByID invokes a function by its ID.
func (i *Invoker) InvokeByID(ctx context.Context, functionID string, input []byte, invCtx *InvocationContext) (*InvokeResponse, error) {
// Get function from registry by ID
fn, err := i.getByID(ctx, functionID)
if err != nil {
return nil, err
}
if invCtx == nil {
invCtx = &InvocationContext{
RequestID: uuid.New().String(),
FunctionID: fn.ID,
FunctionName: fn.Name,
Namespace: fn.Namespace,
TriggerType: TriggerTypeHTTP,
}
}
startTime := time.Now()
output, retries, err := i.executeWithRetry(ctx, fn, input, invCtx)
response := &InvokeResponse{
RequestID: invCtx.RequestID,
Output: output,
DurationMS: time.Since(startTime).Milliseconds(),
Retries: retries,
}
if err != nil {
response.Status = InvocationStatusError
response.Error = err.Error()
return response, err
}
response.Status = InvocationStatusSuccess
return response, nil
}
// InvalidateCache removes a compiled module from the engine's cache.
func (i *Invoker) InvalidateCache(wasmCID string) {
i.engine.Invalidate(wasmCID)
}
// executeWithRetry executes a function with retry logic and DLQ.
func (i *Invoker) executeWithRetry(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, int, error) {
var lastErr error
var output []byte
maxAttempts := fn.RetryCount + 1 // Initial attempt + retries
if maxAttempts < 1 {
maxAttempts = 1
}
for attempt := 0; attempt < maxAttempts; attempt++ {
// Check if context is cancelled
if ctx.Err() != nil {
return nil, attempt, ctx.Err()
}
// Execute the function
output, lastErr = i.engine.Execute(ctx, fn, input, invCtx)
if lastErr == nil {
return output, attempt, nil
}
i.logger.Warn("Function execution failed",
zap.String("function", fn.Name),
zap.String("request_id", invCtx.RequestID),
zap.Int("attempt", attempt+1),
zap.Int("max_attempts", maxAttempts),
zap.Error(lastErr),
)
// Don't retry on certain errors
if !i.isRetryable(lastErr) {
break
}
// Don't wait after the last attempt
if attempt < maxAttempts-1 {
delay := i.calculateBackoff(fn.RetryDelaySeconds, attempt)
select {
case <-ctx.Done():
return nil, attempt + 1, ctx.Err()
case <-time.After(delay):
// Continue to next attempt
}
}
}
// All retries exhausted - send to DLQ if configured
if fn.DLQTopic != "" {
i.sendToDLQ(ctx, fn, input, invCtx, lastErr)
}
return nil, maxAttempts - 1, lastErr
}
// isRetryable determines if an error should trigger a retry.
func (i *Invoker) isRetryable(err error) bool {
// Don't retry validation errors or not-found errors
if IsNotFound(err) {
return false
}
// Don't retry resource exhaustion (rate limits, memory)
if IsResourceExhausted(err) {
return false
}
// Retry service unavailable errors
if IsServiceUnavailable(err) {
return true
}
// Retry execution errors (could be transient)
var execErr *ExecutionError
if ok := errorAs(err, &execErr); ok {
return true
}
// Default to retryable for unknown errors
return true
}
// calculateBackoff calculates the delay before the next retry attempt.
// Uses exponential backoff with jitter.
func (i *Invoker) calculateBackoff(baseDelaySeconds, attempt int) time.Duration {
if baseDelaySeconds <= 0 {
baseDelaySeconds = 5
}
// Exponential backoff: delay * 2^attempt
delay := time.Duration(baseDelaySeconds) * time.Second
for j := 0; j < attempt; j++ {
delay *= 2
if delay > 5*time.Minute {
delay = 5 * time.Minute
break
}
}
return delay
}
// sendToDLQ sends a failed invocation to the dead letter queue.
func (i *Invoker) sendToDLQ(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext, err error) {
dlqMessage := DLQMessage{
FunctionID: fn.ID,
FunctionName: fn.Name,
Namespace: fn.Namespace,
RequestID: invCtx.RequestID,
Input: input,
Error: err.Error(),
FailedAt: time.Now(),
TriggerType: invCtx.TriggerType,
CallerWallet: invCtx.CallerWallet,
}
data, marshalErr := json.Marshal(dlqMessage)
if marshalErr != nil {
i.logger.Error("Failed to marshal DLQ message",
zap.Error(marshalErr),
zap.String("function", fn.Name),
)
return
}
// Publish to DLQ topic via host services
if err := i.hostServices.PubSubPublish(ctx, fn.DLQTopic, data); err != nil {
i.logger.Error("Failed to send to DLQ",
zap.Error(err),
zap.String("function", fn.Name),
zap.String("dlq_topic", fn.DLQTopic),
)
} else {
i.logger.Info("Sent failed invocation to DLQ",
zap.String("function", fn.Name),
zap.String("dlq_topic", fn.DLQTopic),
zap.String("request_id", invCtx.RequestID),
)
}
}
// getEnvVars retrieves environment variables for a function.
func (i *Invoker) getEnvVars(ctx context.Context, functionID string) (map[string]string, error) {
// Type assert to get extended registry methods
if reg, ok := i.registry.(*Registry); ok {
return reg.GetEnvVars(ctx, functionID)
}
return nil, nil
}
// getByID retrieves a function by ID.
func (i *Invoker) getByID(ctx context.Context, functionID string) (*Function, error) {
// Type assert to get extended registry methods
if reg, ok := i.registry.(*Registry); ok {
return reg.GetByID(ctx, functionID)
}
return nil, ErrFunctionNotFound
}
// DLQMessage represents a message sent to the dead letter queue.
type DLQMessage struct {
FunctionID string `json:"function_id"`
FunctionName string `json:"function_name"`
Namespace string `json:"namespace"`
RequestID string `json:"request_id"`
Input []byte `json:"input"`
Error string `json:"error"`
FailedAt time.Time `json:"failed_at"`
TriggerType TriggerType `json:"trigger_type"`
CallerWallet string `json:"caller_wallet,omitempty"`
}
// errorAs is a helper to avoid import of errors package.
func errorAs(err error, target interface{}) bool {
if err == nil {
return false
}
// Simple type assertion for our custom error types
switch t := target.(type) {
case **ExecutionError:
if e, ok := err.(*ExecutionError); ok {
*t = e
return true
}
}
return false
}
// -----------------------------------------------------------------------------
// Batch Invocation (for future use)
// -----------------------------------------------------------------------------
// BatchInvokeRequest contains parameters for batch invocation.
type BatchInvokeRequest struct {
Requests []*InvokeRequest `json:"requests"`
}
// BatchInvokeResponse contains results of batch invocation.
type BatchInvokeResponse struct {
Responses []*InvokeResponse `json:"responses"`
Duration time.Duration `json:"duration"`
}
// BatchInvoke executes multiple functions in parallel.
func (i *Invoker) BatchInvoke(ctx context.Context, req *BatchInvokeRequest) (*BatchInvokeResponse, error) {
if req == nil || len(req.Requests) == 0 {
return nil, &ValidationError{Field: "requests", Message: "cannot be empty"}
}
startTime := time.Now()
responses := make([]*InvokeResponse, len(req.Requests))
// For simplicity, execute sequentially for now
// TODO: Implement parallel execution with goroutines and semaphore
for idx, invReq := range req.Requests {
resp, err := i.Invoke(ctx, invReq)
if err != nil && resp == nil {
responses[idx] = &InvokeResponse{
RequestID: uuid.New().String(),
Status: InvocationStatusError,
Error: err.Error(),
}
} else {
responses[idx] = resp
}
}
return &BatchInvokeResponse{
Responses: responses,
Duration: time.Since(startTime),
}, nil
}
// -----------------------------------------------------------------------------
// Public Invocation Helpers
// -----------------------------------------------------------------------------
// CanInvoke checks if a caller is authorized to invoke a function.
func (i *Invoker) CanInvoke(ctx context.Context, namespace, functionName string, callerWallet string) (bool, error) {
fn, err := i.registry.Get(ctx, namespace, functionName, 0)
if err != nil {
return false, err
}
// Public functions can be invoked by anyone
if fn.IsPublic {
return true, nil
}
// Non-public functions require the caller to be in the same namespace
// (simplified authorization - can be extended)
if callerWallet == "" {
return false, nil
}
// For now, just check if caller wallet matches namespace
// In production, you'd check group membership, roles, etc.
return callerWallet == namespace || fn.CreatedBy == callerWallet, nil
}
// GetFunctionInfo returns basic info about a function for invocation.
func (i *Invoker) GetFunctionInfo(ctx context.Context, namespace, functionName string, version int) (*Function, error) {
return i.registry.Get(ctx, namespace, functionName, version)
}
// ValidateInput performs basic input validation.
func (i *Invoker) ValidateInput(input []byte, maxSize int) error {
if maxSize > 0 && len(input) > maxSize {
return &ValidationError{
Field: "input",
Message: fmt.Sprintf("exceeds maximum size of %d bytes", maxSize),
}
}
return nil
}

View File

@ -0,0 +1,434 @@
package serverless
import (
"context"
"database/sql"
"fmt"
"io"
"reflect"
"strings"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/rqlite"
)
// MockRegistry is a mock implementation of FunctionRegistry
type MockRegistry struct {
mu sync.RWMutex
functions map[string]*Function
wasm map[string][]byte
}
func NewMockRegistry() *MockRegistry {
return &MockRegistry{
functions: make(map[string]*Function),
wasm: make(map[string][]byte),
}
}
func (m *MockRegistry) Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error) {
m.mu.Lock()
defer m.mu.Unlock()
id := fn.Namespace + "/" + fn.Name
wasmCID := "cid-" + id
oldFn := m.functions[id]
m.functions[id] = &Function{
ID: id,
Name: fn.Name,
Namespace: fn.Namespace,
WASMCID: wasmCID,
MemoryLimitMB: fn.MemoryLimitMB,
TimeoutSeconds: fn.TimeoutSeconds,
IsPublic: fn.IsPublic,
RetryCount: fn.RetryCount,
RetryDelaySeconds: fn.RetryDelaySeconds,
Status: FunctionStatusActive,
}
m.wasm[wasmCID] = wasmBytes
return oldFn, nil
}
func (m *MockRegistry) Get(ctx context.Context, namespace, name string, version int) (*Function, error) {
m.mu.RLock()
defer m.mu.RUnlock()
fn, ok := m.functions[namespace+"/"+name]
if !ok {
return nil, ErrFunctionNotFound
}
return fn, nil
}
func (m *MockRegistry) List(ctx context.Context, namespace string) ([]*Function, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var res []*Function
for _, fn := range m.functions {
if fn.Namespace == namespace {
res = append(res, fn)
}
}
return res, nil
}
func (m *MockRegistry) Delete(ctx context.Context, namespace, name string, version int) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.functions, namespace+"/"+name)
return nil
}
func (m *MockRegistry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
m.mu.RLock()
defer m.mu.RUnlock()
data, ok := m.wasm[wasmCID]
if !ok {
return nil, ErrFunctionNotFound
}
return data, nil
}
func (m *MockRegistry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error) {
return []LogEntry{}, nil
}
// MockHostServices is a mock implementation of HostServices
type MockHostServices struct {
mu sync.RWMutex
cache map[string][]byte
storage map[string][]byte
logs []string
}
func NewMockHostServices() *MockHostServices {
return &MockHostServices{
cache: make(map[string][]byte),
storage: make(map[string][]byte),
}
}
func (m *MockHostServices) DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error) {
return []byte("[]"), nil
}
func (m *MockHostServices) DBExecute(ctx context.Context, query string, args []interface{}) (int64, error) {
return 0, nil
}
func (m *MockHostServices) CacheGet(ctx context.Context, key string) ([]byte, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.cache[key], nil
}
func (m *MockHostServices) CacheSet(ctx context.Context, key string, value []byte, ttl int64) error {
m.mu.Lock()
defer m.mu.Unlock()
m.cache[key] = value
return nil
}
func (m *MockHostServices) CacheDelete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cache, key)
return nil
}
func (m *MockHostServices) CacheIncr(ctx context.Context, key string) (int64, error) {
return m.CacheIncrBy(ctx, key, 1)
}
func (m *MockHostServices) CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
var currentValue int64
if val, ok := m.cache[key]; ok {
var err error
currentValue, err = parseInt64FromBytes(val)
if err != nil {
return 0, fmt.Errorf("value is not numeric")
}
}
newValue := currentValue + delta
m.cache[key] = []byte(fmt.Sprintf("%d", newValue))
return newValue, nil
}
func (m *MockHostServices) StoragePut(ctx context.Context, data []byte) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
cid := "cid-" + time.Now().String()
m.storage[cid] = data
return cid, nil
}
func (m *MockHostServices) StorageGet(ctx context.Context, cid string) ([]byte, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.storage[cid], nil
}
func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data []byte) error {
return nil
}
func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
return nil
}
func (m *MockHostServices) WSBroadcast(ctx context.Context, topic string, data []byte) error {
return nil
}
func (m *MockHostServices) HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error) {
return nil, nil
}
func (m *MockHostServices) GetEnv(ctx context.Context, key string) (string, error) {
return "", nil
}
func (m *MockHostServices) GetSecret(ctx context.Context, name string) (string, error) {
return "", nil
}
func (m *MockHostServices) GetRequestID(ctx context.Context) string {
return "req-123"
}
func (m *MockHostServices) GetCallerWallet(ctx context.Context) string {
return "wallet-123"
}
func (m *MockHostServices) EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error) {
return "job-123", nil
}
func (m *MockHostServices) ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error) {
return "timer-123", nil
}
func (m *MockHostServices) LogInfo(ctx context.Context, message string) {
m.mu.Lock()
defer m.mu.Unlock()
m.logs = append(m.logs, "INFO: "+message)
}
func (m *MockHostServices) LogError(ctx context.Context, message string) {
m.mu.Lock()
defer m.mu.Unlock()
m.logs = append(m.logs, "ERROR: "+message)
}
// MockIPFSClient is a mock for ipfs.IPFSClient
type MockIPFSClient struct {
data map[string][]byte
}
func NewMockIPFSClient() *MockIPFSClient {
return &MockIPFSClient{data: make(map[string][]byte)}
}
func (m *MockIPFSClient) Add(ctx context.Context, reader io.Reader, filename string) (*ipfs.AddResponse, error) {
data, _ := io.ReadAll(reader)
cid := "cid-" + filename
m.data[cid] = data
return &ipfs.AddResponse{Cid: cid, Name: filename}, nil
}
func (m *MockIPFSClient) Pin(ctx context.Context, cid string, name string, replicationFactor int) (*ipfs.PinResponse, error) {
return &ipfs.PinResponse{Cid: cid, Name: name}, nil
}
func (m *MockIPFSClient) PinStatus(ctx context.Context, cid string) (*ipfs.PinStatus, error) {
return &ipfs.PinStatus{Cid: cid, Status: "pinned"}, nil
}
func (m *MockIPFSClient) Get(ctx context.Context, cid, apiURL string) (io.ReadCloser, error) {
data, ok := m.data[cid]
if !ok {
return nil, fmt.Errorf("not found")
}
return io.NopCloser(strings.NewReader(string(data))), nil
}
func (m *MockIPFSClient) Unpin(ctx context.Context, cid string) error { return nil }
func (m *MockIPFSClient) Health(ctx context.Context) error { return nil }
func (m *MockIPFSClient) GetPeerCount(ctx context.Context) (int, error) { return 1, nil }
func (m *MockIPFSClient) Close(ctx context.Context) error { return nil }
// MockRQLite is a mock implementation of rqlite.Client
type MockRQLite struct {
mu sync.Mutex
tables map[string][]map[string]any
}
func NewMockRQLite() *MockRQLite {
return &MockRQLite{
tables: make(map[string][]map[string]any),
}
}
func (m *MockRQLite) Query(ctx context.Context, dest any, query string, args ...any) error {
m.mu.Lock()
defer m.mu.Unlock()
// Very limited mock query logic for scanning into structs
if strings.Contains(query, "FROM functions") {
rows := m.tables["functions"]
filtered := rows
if strings.Contains(query, "namespace = ? AND name = ?") {
ns := args[0].(string)
name := args[1].(string)
filtered = nil
for _, r := range rows {
if r["namespace"] == ns && r["name"] == name {
filtered = append(filtered, r)
}
}
}
destVal := reflect.ValueOf(dest).Elem()
if destVal.Kind() == reflect.Slice {
elemType := destVal.Type().Elem()
for _, r := range filtered {
newElem := reflect.New(elemType).Elem()
// This is a simplified mapping
if f := newElem.FieldByName("ID"); f.IsValid() {
f.SetString(r["id"].(string))
}
if f := newElem.FieldByName("Name"); f.IsValid() {
f.SetString(r["name"].(string))
}
if f := newElem.FieldByName("Namespace"); f.IsValid() {
f.SetString(r["namespace"].(string))
}
destVal.Set(reflect.Append(destVal, newElem))
}
}
}
return nil
}
func (m *MockRQLite) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
m.mu.Lock()
defer m.mu.Unlock()
return &mockResult{}, nil
}
func (m *MockRQLite) FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...rqlite.FindOption) error {
return nil
}
func (m *MockRQLite) FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...rqlite.FindOption) error {
return nil
}
func (m *MockRQLite) Save(ctx context.Context, entity any) error { return nil }
func (m *MockRQLite) Remove(ctx context.Context, entity any) error { return nil }
func (m *MockRQLite) Repository(table string) any { return nil }
func (m *MockRQLite) CreateQueryBuilder(table string) *rqlite.QueryBuilder {
return nil // Should return a valid QueryBuilder if needed by tests
}
func (m *MockRQLite) Tx(ctx context.Context, fn func(tx rqlite.Tx) error) error {
return nil
}
type mockResult struct{}
func (m *mockResult) LastInsertId() (int64, error) { return 1, nil }
func (m *mockResult) RowsAffected() (int64, error) { return 1, nil }
// MockOlricClient is a mock for olriclib.Client
type MockOlricClient struct {
dmaps map[string]*MockDMap
}
func NewMockOlricClient() *MockOlricClient {
return &MockOlricClient{dmaps: make(map[string]*MockDMap)}
}
func (m *MockOlricClient) NewDMap(name string) (any, error) {
if dm, ok := m.dmaps[name]; ok {
return dm, nil
}
dm := &MockDMap{data: make(map[string][]byte)}
m.dmaps[name] = dm
return dm, nil
}
func (m *MockOlricClient) Close(ctx context.Context) error { return nil }
func (m *MockOlricClient) Stats(ctx context.Context, s string) ([]byte, error) { return nil, nil }
func (m *MockOlricClient) Ping(ctx context.Context, s string) error { return nil }
func (m *MockOlricClient) RoutingTable(ctx context.Context) (map[uint64][]string, error) {
return nil, nil
}
// MockDMap is a mock for olriclib.DMap
type MockDMap struct {
data map[string][]byte
}
func (m *MockDMap) Get(ctx context.Context, key string) (any, error) {
val, ok := m.data[key]
if !ok {
return nil, fmt.Errorf("not found")
}
return &MockGetResponse{val: val}, nil
}
func (m *MockDMap) Put(ctx context.Context, key string, value any) error {
switch v := value.(type) {
case []byte:
m.data[key] = v
case string:
m.data[key] = []byte(v)
}
return nil
}
func (m *MockDMap) Delete(ctx context.Context, key string) (bool, error) {
_, ok := m.data[key]
delete(m.data, key)
return ok, nil
}
func (m *MockDMap) Incr(ctx context.Context, key string, delta int64) (int64, error) {
var currentValue int64
// Get current value if it exists
if val, ok := m.data[key]; ok {
// Try to parse as int64
var err error
currentValue, err = parseInt64FromBytes(val)
if err != nil {
return 0, fmt.Errorf("value is not numeric")
}
}
// Increment
newValue := currentValue + delta
// Store the new value
m.data[key] = []byte(fmt.Sprintf("%d", newValue))
return newValue, nil
}
// parseInt64FromBytes parses an int64 from byte slice
func parseInt64FromBytes(data []byte) (int64, error) {
var val int64
_, err := fmt.Sscanf(string(data), "%d", &val)
return val, err
}
type MockGetResponse struct {
val []byte
}
func (m *MockGetResponse) Byte() ([]byte, error) { return m.val, nil }
func (m *MockGetResponse) String() (string, error) { return string(m.val), nil }

561
pkg/serverless/registry.go Normal file
View File

@ -0,0 +1,561 @@
package serverless
import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Ensure Registry implements FunctionRegistry and InvocationLogger interfaces.
var _ FunctionRegistry = (*Registry)(nil)
var _ InvocationLogger = (*Registry)(nil)
// Registry manages function metadata in RQLite and bytecode in IPFS.
// It implements the FunctionRegistry interface.
type Registry struct {
db rqlite.Client
ipfs ipfs.IPFSClient
ipfsAPIURL string
logger *zap.Logger
tableName string
}
// RegistryConfig holds configuration for the Registry.
type RegistryConfig struct {
IPFSAPIURL string // IPFS API URL for content retrieval
}
// NewRegistry creates a new function registry.
func NewRegistry(db rqlite.Client, ipfsClient ipfs.IPFSClient, cfg RegistryConfig, logger *zap.Logger) *Registry {
return &Registry{
db: db,
ipfs: ipfsClient,
ipfsAPIURL: cfg.IPFSAPIURL,
logger: logger,
tableName: "functions",
}
}
// Register deploys a new function or updates an existing one.
func (r *Registry) Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error) {
if fn == nil {
return nil, &ValidationError{Field: "definition", Message: "cannot be nil"}
}
fn.Name = strings.TrimSpace(fn.Name)
fn.Namespace = strings.TrimSpace(fn.Namespace)
if fn.Name == "" {
return nil, &ValidationError{Field: "name", Message: "cannot be empty"}
}
if fn.Namespace == "" {
return nil, &ValidationError{Field: "namespace", Message: "cannot be empty"}
}
if len(wasmBytes) == 0 {
return nil, &ValidationError{Field: "wasmBytes", Message: "cannot be empty"}
}
// Check if function already exists (regardless of status) to get old metadata for invalidation
oldFn, err := r.getByNameInternal(ctx, fn.Namespace, fn.Name)
if err != nil && err != ErrFunctionNotFound {
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
}
// Upload WASM to IPFS
wasmCID, err := r.uploadWASM(ctx, wasmBytes, fn.Name)
if err != nil {
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
}
// Apply defaults
memoryLimit := fn.MemoryLimitMB
if memoryLimit == 0 {
memoryLimit = 64
}
timeout := fn.TimeoutSeconds
if timeout == 0 {
timeout = 30
}
retryDelay := fn.RetryDelaySeconds
if retryDelay == 0 {
retryDelay = 5
}
now := time.Now()
id := uuid.New().String()
version := 1
if oldFn != nil {
// Use existing ID and increment version
id = oldFn.ID
version = oldFn.Version + 1
}
// Use INSERT OR REPLACE to ensure we never hit UNIQUE constraint failures on (namespace, name).
// This handles both new registrations and overwriting existing (even inactive) functions.
query := `
INSERT OR REPLACE INTO functions (
id, name, namespace, version, wasm_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err = r.db.Exec(ctx, query,
id, fn.Name, fn.Namespace, version, wasmCID,
memoryLimit, timeout, fn.IsPublic,
fn.RetryCount, retryDelay, fn.DLQTopic,
string(FunctionStatusActive), now, now, fn.Namespace,
)
if err != nil {
return nil, &DeployError{FunctionName: fn.Name, Cause: fmt.Errorf("failed to register function: %w", err)}
}
// Save environment variables
if err := r.saveEnvVars(ctx, id, fn.EnvVars); err != nil {
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
}
r.logger.Info("Function registered",
zap.String("id", id),
zap.String("name", fn.Name),
zap.String("namespace", fn.Namespace),
zap.String("wasm_cid", wasmCID),
zap.Int("version", version),
zap.Bool("updated", oldFn != nil),
)
return oldFn, nil
}
// Get retrieves a function by name and optional version.
// If version is 0, returns the latest version.
func (r *Registry) Get(ctx context.Context, namespace, name string, version int) (*Function, error) {
namespace = strings.TrimSpace(namespace)
name = strings.TrimSpace(name)
var query string
var args []interface{}
if version == 0 {
// Get latest version
query = `
SELECT id, name, namespace, version, wasm_cid, source_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
FROM functions
WHERE namespace = ? AND name = ? AND status = ?
ORDER BY version DESC
LIMIT 1
`
args = []interface{}{namespace, name, string(FunctionStatusActive)}
} else {
query = `
SELECT id, name, namespace, version, wasm_cid, source_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
FROM functions
WHERE namespace = ? AND name = ? AND version = ?
`
args = []interface{}{namespace, name, version}
}
var functions []functionRow
if err := r.db.Query(ctx, &functions, query, args...); err != nil {
return nil, fmt.Errorf("failed to query function: %w", err)
}
if len(functions) == 0 {
if version == 0 {
return nil, ErrFunctionNotFound
}
return nil, ErrVersionNotFound
}
return r.rowToFunction(&functions[0]), nil
}
// List returns all functions for a namespace.
func (r *Registry) List(ctx context.Context, namespace string) ([]*Function, error) {
// Get latest version of each function in the namespace
query := `
SELECT f.id, f.name, f.namespace, f.version, f.wasm_cid, f.source_cid,
f.memory_limit_mb, f.timeout_seconds, f.is_public,
f.retry_count, f.retry_delay_seconds, f.dlq_topic,
f.status, f.created_at, f.updated_at, f.created_by
FROM functions f
INNER JOIN (
SELECT namespace, name, MAX(version) as max_version
FROM functions
WHERE namespace = ? AND status = ?
GROUP BY namespace, name
) latest ON f.namespace = latest.namespace
AND f.name = latest.name
AND f.version = latest.max_version
ORDER BY f.name
`
var rows []functionRow
if err := r.db.Query(ctx, &rows, query, namespace, string(FunctionStatusActive)); err != nil {
return nil, fmt.Errorf("failed to list functions: %w", err)
}
functions := make([]*Function, len(rows))
for i, row := range rows {
functions[i] = r.rowToFunction(&row)
}
return functions, nil
}
// Delete removes a function. If version is 0, removes all versions.
func (r *Registry) Delete(ctx context.Context, namespace, name string, version int) error {
namespace = strings.TrimSpace(namespace)
name = strings.TrimSpace(name)
var query string
var args []interface{}
if version == 0 {
// Mark all versions as inactive (soft delete)
query = `UPDATE functions SET status = ?, updated_at = ? WHERE namespace = ? AND name = ?`
args = []interface{}{string(FunctionStatusInactive), time.Now(), namespace, name}
} else {
query = `UPDATE functions SET status = ?, updated_at = ? WHERE namespace = ? AND name = ? AND version = ?`
args = []interface{}{string(FunctionStatusInactive), time.Now(), namespace, name, version}
}
result, err := r.db.Exec(ctx, query, args...)
if err != nil {
return fmt.Errorf("failed to delete function: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
if version == 0 {
return ErrFunctionNotFound
}
return ErrVersionNotFound
}
r.logger.Info("Function deleted",
zap.String("namespace", namespace),
zap.String("name", name),
zap.Int("version", version),
)
return nil
}
// GetWASMBytes retrieves the compiled WASM bytecode for a function.
func (r *Registry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
if wasmCID == "" {
return nil, &ValidationError{Field: "wasmCID", Message: "cannot be empty"}
}
reader, err := r.ipfs.Get(ctx, wasmCID, r.ipfsAPIURL)
if err != nil {
return nil, fmt.Errorf("failed to get WASM from IPFS: %w", err)
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to read WASM data: %w", err)
}
return data, nil
}
// GetEnvVars retrieves environment variables for a function.
func (r *Registry) GetEnvVars(ctx context.Context, functionID string) (map[string]string, error) {
query := `SELECT key, value FROM function_env_vars WHERE function_id = ?`
var rows []envVarRow
if err := r.db.Query(ctx, &rows, query, functionID); err != nil {
return nil, fmt.Errorf("failed to query env vars: %w", err)
}
envVars := make(map[string]string, len(rows))
for _, row := range rows {
envVars[row.Key] = row.Value
}
return envVars, nil
}
// GetByID retrieves a function by its ID.
func (r *Registry) GetByID(ctx context.Context, id string) (*Function, error) {
query := `
SELECT id, name, namespace, version, wasm_cid, source_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
FROM functions
WHERE id = ?
`
var functions []functionRow
if err := r.db.Query(ctx, &functions, query, id); err != nil {
return nil, fmt.Errorf("failed to query function: %w", err)
}
if len(functions) == 0 {
return nil, ErrFunctionNotFound
}
return r.rowToFunction(&functions[0]), nil
}
// ListVersions returns all versions of a function.
func (r *Registry) ListVersions(ctx context.Context, namespace, name string) ([]*Function, error) {
query := `
SELECT id, name, namespace, version, wasm_cid, source_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
FROM functions
WHERE namespace = ? AND name = ?
ORDER BY version DESC
`
var rows []functionRow
if err := r.db.Query(ctx, &rows, query, namespace, name); err != nil {
return nil, fmt.Errorf("failed to list versions: %w", err)
}
functions := make([]*Function, len(rows))
for i, row := range rows {
functions[i] = r.rowToFunction(&row)
}
return functions, nil
}
// Log records a function invocation and its logs to the database.
func (r *Registry) Log(ctx context.Context, inv *InvocationRecord) error {
if inv == nil {
return nil
}
// Insert invocation record
invQuery := `
INSERT INTO function_invocations (
id, function_id, request_id, trigger_type, caller_wallet,
input_size, output_size, started_at, completed_at,
duration_ms, status, error_message, memory_used_mb
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := r.db.Exec(ctx, invQuery,
inv.ID, inv.FunctionID, inv.RequestID, string(inv.TriggerType), inv.CallerWallet,
inv.InputSize, inv.OutputSize, inv.StartedAt, inv.CompletedAt,
inv.DurationMS, string(inv.Status), inv.ErrorMessage, inv.MemoryUsedMB,
)
if err != nil {
return fmt.Errorf("failed to insert invocation record: %w", err)
}
// Insert logs if any
if len(inv.Logs) > 0 {
for _, entry := range inv.Logs {
logID := uuid.New().String()
logQuery := `
INSERT INTO function_logs (
id, function_id, invocation_id, level, message, timestamp
) VALUES (?, ?, ?, ?, ?, ?)
`
_, err := r.db.Exec(ctx, logQuery,
logID, inv.FunctionID, inv.ID, entry.Level, entry.Message, entry.Timestamp,
)
if err != nil {
r.logger.Warn("Failed to insert function log", zap.Error(err))
// Continue with other logs
}
}
}
return nil
}
// GetLogs retrieves logs for a function.
func (r *Registry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error) {
if limit <= 0 {
limit = 100
}
query := `
SELECT l.level, l.message, l.timestamp
FROM function_logs l
JOIN functions f ON l.function_id = f.id
WHERE f.namespace = ? AND f.name = ?
ORDER BY l.timestamp DESC
LIMIT ?
`
var results []struct {
Level string `db:"level"`
Message string `db:"message"`
Timestamp time.Time `db:"timestamp"`
}
if err := r.db.Query(ctx, &results, query, namespace, name, limit); err != nil {
return nil, fmt.Errorf("failed to query logs: %w", err)
}
logs := make([]LogEntry, len(results))
for i, res := range results {
logs[i] = LogEntry{
Level: res.Level,
Message: res.Message,
Timestamp: res.Timestamp,
}
}
return logs, nil
}
// -----------------------------------------------------------------------------
// Private helpers
// -----------------------------------------------------------------------------
// uploadWASM uploads WASM bytecode to IPFS and returns the CID.
func (r *Registry) uploadWASM(ctx context.Context, wasmBytes []byte, name string) (string, error) {
reader := bytes.NewReader(wasmBytes)
resp, err := r.ipfs.Add(ctx, reader, name+".wasm")
if err != nil {
return "", fmt.Errorf("failed to upload WASM to IPFS: %w", err)
}
return resp.Cid, nil
}
// getLatestVersion returns the latest version number for a function.
func (r *Registry) getLatestVersion(ctx context.Context, namespace, name string) (int, error) {
query := `SELECT MAX(version) FROM functions WHERE namespace = ? AND name = ?`
var maxVersion sql.NullInt64
var results []struct {
MaxVersion sql.NullInt64 `db:"max(version)"`
}
if err := r.db.Query(ctx, &results, query, namespace, name); err != nil {
return 0, err
}
if len(results) == 0 || !results[0].MaxVersion.Valid {
return 0, ErrFunctionNotFound
}
maxVersion = results[0].MaxVersion
return int(maxVersion.Int64), nil
}
// getByNameInternal retrieves a function by name regardless of status.
func (r *Registry) getByNameInternal(ctx context.Context, namespace, name string) (*Function, error) {
namespace = strings.TrimSpace(namespace)
name = strings.TrimSpace(name)
query := `
SELECT id, name, namespace, version, wasm_cid, source_cid,
memory_limit_mb, timeout_seconds, is_public,
retry_count, retry_delay_seconds, dlq_topic,
status, created_at, updated_at, created_by
FROM functions
WHERE namespace = ? AND name = ?
ORDER BY version DESC
LIMIT 1
`
var functions []functionRow
if err := r.db.Query(ctx, &functions, query, namespace, name); err != nil {
return nil, fmt.Errorf("failed to query function: %w", err)
}
if len(functions) == 0 {
return nil, ErrFunctionNotFound
}
return r.rowToFunction(&functions[0]), nil
}
// saveEnvVars saves environment variables for a function.
func (r *Registry) saveEnvVars(ctx context.Context, functionID string, envVars map[string]string) error {
// Clear existing env vars first
deleteQuery := `DELETE FROM function_env_vars WHERE function_id = ?`
if _, err := r.db.Exec(ctx, deleteQuery, functionID); err != nil {
return fmt.Errorf("failed to clear existing env vars: %w", err)
}
if len(envVars) == 0 {
return nil
}
for key, value := range envVars {
id := uuid.New().String()
query := `INSERT INTO function_env_vars (id, function_id, key, value, created_at) VALUES (?, ?, ?, ?, ?)`
if _, err := r.db.Exec(ctx, query, id, functionID, key, value, time.Now()); err != nil {
return fmt.Errorf("failed to save env var '%s': %w", key, err)
}
}
return nil
}
// rowToFunction converts a database row to a Function struct.
func (r *Registry) rowToFunction(row *functionRow) *Function {
return &Function{
ID: row.ID,
Name: row.Name,
Namespace: row.Namespace,
Version: row.Version,
WASMCID: row.WASMCID,
SourceCID: row.SourceCID.String,
MemoryLimitMB: row.MemoryLimitMB,
TimeoutSeconds: row.TimeoutSeconds,
IsPublic: row.IsPublic,
RetryCount: row.RetryCount,
RetryDelaySeconds: row.RetryDelaySeconds,
DLQTopic: row.DLQTopic.String,
Status: FunctionStatus(row.Status),
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
CreatedBy: row.CreatedBy,
}
}
// -----------------------------------------------------------------------------
// Database row types (internal)
// -----------------------------------------------------------------------------
type functionRow struct {
ID string `db:"id"`
Name string `db:"name"`
Namespace string `db:"namespace"`
Version int `db:"version"`
WASMCID string `db:"wasm_cid"`
SourceCID sql.NullString `db:"source_cid"`
MemoryLimitMB int `db:"memory_limit_mb"`
TimeoutSeconds int `db:"timeout_seconds"`
IsPublic bool `db:"is_public"`
RetryCount int `db:"retry_count"`
RetryDelaySeconds int `db:"retry_delay_seconds"`
DLQTopic sql.NullString `db:"dlq_topic"`
Status string `db:"status"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
CreatedBy string `db:"created_by"`
}
type envVarRow struct {
Key string `db:"key"`
Value string `db:"value"`
}

View File

@ -0,0 +1,40 @@
package serverless
import (
"context"
"testing"
"go.uber.org/zap"
)
func TestRegistry_RegisterAndGet(t *testing.T) {
db := NewMockRQLite()
ipfs := NewMockIPFSClient()
logger := zap.NewNop()
registry := NewRegistry(db, ipfs, RegistryConfig{IPFSAPIURL: "http://localhost:5001"}, logger)
ctx := context.Background()
fnDef := &FunctionDefinition{
Name: "test-func",
Namespace: "test-ns",
IsPublic: true,
}
wasmBytes := []byte("mock wasm")
_, err := registry.Register(ctx, fnDef, wasmBytes)
if err != nil {
t.Fatalf("Register failed: %v", err)
}
// Since MockRQLite doesn't fully implement Query scanning yet,
// we won't be able to test Get() effectively without more work.
// But we can check if wasm was uploaded.
wasm, err := registry.GetWASMBytes(ctx, "cid-test-func.wasm")
if err != nil {
t.Fatalf("GetWASMBytes failed: %v", err)
}
if string(wasm) != "mock wasm" {
t.Errorf("expected 'mock wasm', got %q", string(wasm))
}
}

379
pkg/serverless/types.go Normal file
View File

@ -0,0 +1,379 @@
// Package serverless provides a WASM-based serverless function engine for the Orama Network.
// It enables users to deploy and execute Go functions (compiled to WASM) across all nodes,
// with support for HTTP/WebSocket triggers, cron jobs, database triggers, pub/sub triggers,
// one-time timers, retries with DLQ, and background jobs.
package serverless
import (
"context"
"io"
"time"
)
// FunctionStatus represents the current state of a deployed function.
type FunctionStatus string
const (
FunctionStatusActive FunctionStatus = "active"
FunctionStatusInactive FunctionStatus = "inactive"
FunctionStatusError FunctionStatus = "error"
)
// TriggerType identifies the type of event that triggered a function invocation.
type TriggerType string
const (
TriggerTypeHTTP TriggerType = "http"
TriggerTypeWebSocket TriggerType = "websocket"
TriggerTypeCron TriggerType = "cron"
TriggerTypeDatabase TriggerType = "database"
TriggerTypePubSub TriggerType = "pubsub"
TriggerTypeTimer TriggerType = "timer"
TriggerTypeJob TriggerType = "job"
)
// JobStatus represents the current state of a background job.
type JobStatus string
const (
JobStatusPending JobStatus = "pending"
JobStatusRunning JobStatus = "running"
JobStatusCompleted JobStatus = "completed"
JobStatusFailed JobStatus = "failed"
)
// InvocationStatus represents the result of a function invocation.
type InvocationStatus string
const (
InvocationStatusSuccess InvocationStatus = "success"
InvocationStatusError InvocationStatus = "error"
InvocationStatusTimeout InvocationStatus = "timeout"
)
// DBOperation represents the type of database operation that triggered a function.
type DBOperation string
const (
DBOperationInsert DBOperation = "INSERT"
DBOperationUpdate DBOperation = "UPDATE"
DBOperationDelete DBOperation = "DELETE"
)
// -----------------------------------------------------------------------------
// Core Interfaces (following Interface Segregation Principle)
// -----------------------------------------------------------------------------
// FunctionRegistry manages function metadata and bytecode storage.
// Responsible for CRUD operations on function definitions.
type FunctionRegistry interface {
// Register deploys a new function or updates an existing one.
// Returns the old function definition if it was updated, or nil if it was a new registration.
Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error)
// Get retrieves a function by name and optional version.
// If version is 0, returns the latest version.
Get(ctx context.Context, namespace, name string, version int) (*Function, error)
// List returns all functions for a namespace.
List(ctx context.Context, namespace string) ([]*Function, error)
// Delete removes a function. If version is 0, removes all versions.
Delete(ctx context.Context, namespace, name string, version int) error
// GetWASMBytes retrieves the compiled WASM bytecode for a function.
GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error)
// GetLogs retrieves logs for a function.
GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error)
}
// FunctionExecutor handles the actual execution of WASM functions.
type FunctionExecutor interface {
// Execute runs a function with the given input and returns the output.
Execute(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error)
// Precompile compiles a WASM module and caches it for faster execution.
Precompile(ctx context.Context, wasmCID string, wasmBytes []byte) error
// Invalidate removes a compiled module from the cache.
Invalidate(wasmCID string)
}
// SecretsManager handles secure storage and retrieval of secrets.
type SecretsManager interface {
// Set stores an encrypted secret.
Set(ctx context.Context, namespace, name, value string) error
// Get retrieves a decrypted secret.
Get(ctx context.Context, namespace, name string) (string, error)
// List returns all secret names for a namespace (not values).
List(ctx context.Context, namespace string) ([]string, error)
// Delete removes a secret.
Delete(ctx context.Context, namespace, name string) error
}
// TriggerManager manages function triggers (cron, database, pubsub, timer).
type TriggerManager interface {
// AddCronTrigger adds a cron-based trigger to a function.
AddCronTrigger(ctx context.Context, functionID, cronExpr string) error
// AddDBTrigger adds a database trigger to a function.
AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error
// AddPubSubTrigger adds a pubsub trigger to a function.
AddPubSubTrigger(ctx context.Context, functionID, topic string) error
// ScheduleOnce schedules a one-time execution.
ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error)
// RemoveTrigger removes a trigger by ID.
RemoveTrigger(ctx context.Context, triggerID string) error
}
// JobManager manages background job execution.
type JobManager interface {
// Enqueue adds a job to the queue for background execution.
Enqueue(ctx context.Context, functionID string, payload []byte) (string, error)
// GetStatus retrieves the current status of a job.
GetStatus(ctx context.Context, jobID string) (*Job, error)
// List returns jobs for a function.
List(ctx context.Context, functionID string, limit int) ([]*Job, error)
// Cancel attempts to cancel a pending or running job.
Cancel(ctx context.Context, jobID string) error
}
// WebSocketManager manages WebSocket connections for function streaming.
type WebSocketManager interface {
// Register registers a new WebSocket connection.
Register(clientID string, conn WebSocketConn)
// Unregister removes a WebSocket connection.
Unregister(clientID string)
// Send sends data to a specific client.
Send(clientID string, data []byte) error
// Broadcast sends data to all clients subscribed to a topic.
Broadcast(topic string, data []byte) error
// Subscribe adds a client to a topic.
Subscribe(clientID, topic string)
// Unsubscribe removes a client from a topic.
Unsubscribe(clientID, topic string)
}
// WebSocketConn abstracts a WebSocket connection for testability.
type WebSocketConn interface {
WriteMessage(messageType int, data []byte) error
ReadMessage() (messageType int, p []byte, err error)
Close() error
}
// -----------------------------------------------------------------------------
// Data Types
// -----------------------------------------------------------------------------
// FunctionDefinition contains the configuration for deploying a function.
type FunctionDefinition struct {
Name string `json:"name"`
Namespace string `json:"namespace"`
Version int `json:"version,omitempty"`
MemoryLimitMB int `json:"memory_limit_mb,omitempty"`
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
IsPublic bool `json:"is_public,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
RetryDelaySeconds int `json:"retry_delay_seconds,omitempty"`
DLQTopic string `json:"dlq_topic,omitempty"`
EnvVars map[string]string `json:"env_vars,omitempty"`
CronExpressions []string `json:"cron_expressions,omitempty"`
DBTriggers []DBTriggerConfig `json:"db_triggers,omitempty"`
PubSubTopics []string `json:"pubsub_topics,omitempty"`
}
// DBTriggerConfig defines a database trigger configuration.
type DBTriggerConfig struct {
Table string `json:"table"`
Operation DBOperation `json:"operation"`
Condition string `json:"condition,omitempty"`
}
// Function represents a deployed serverless function.
type Function struct {
ID string `json:"id"`
Name string `json:"name"`
Namespace string `json:"namespace"`
Version int `json:"version"`
WASMCID string `json:"wasm_cid"`
SourceCID string `json:"source_cid,omitempty"`
MemoryLimitMB int `json:"memory_limit_mb"`
TimeoutSeconds int `json:"timeout_seconds"`
IsPublic bool `json:"is_public"`
RetryCount int `json:"retry_count"`
RetryDelaySeconds int `json:"retry_delay_seconds"`
DLQTopic string `json:"dlq_topic,omitempty"`
Status FunctionStatus `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
CreatedBy string `json:"created_by"`
}
// InvocationContext provides context for a function invocation.
type InvocationContext struct {
RequestID string `json:"request_id"`
FunctionID string `json:"function_id"`
FunctionName string `json:"function_name"`
Namespace string `json:"namespace"`
CallerWallet string `json:"caller_wallet,omitempty"`
TriggerType TriggerType `json:"trigger_type"`
WSClientID string `json:"ws_client_id,omitempty"`
EnvVars map[string]string `json:"env_vars,omitempty"`
}
// InvocationResult represents the result of a function invocation.
type InvocationResult struct {
RequestID string `json:"request_id"`
Output []byte `json:"output,omitempty"`
Status InvocationStatus `json:"status"`
Error string `json:"error,omitempty"`
DurationMS int64 `json:"duration_ms"`
Logs []LogEntry `json:"logs,omitempty"`
}
// LogEntry represents a log message from a function.
type LogEntry struct {
Level string `json:"level"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// Job represents a background job.
type Job struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
Payload []byte `json:"payload,omitempty"`
Status JobStatus `json:"status"`
Progress int `json:"progress"`
Result []byte `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// CronTrigger represents a cron-based trigger.
type CronTrigger struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
CronExpression string `json:"cron_expression"`
NextRunAt *time.Time `json:"next_run_at,omitempty"`
LastRunAt *time.Time `json:"last_run_at,omitempty"`
Enabled bool `json:"enabled"`
}
// DBTrigger represents a database trigger.
type DBTrigger struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
TableName string `json:"table_name"`
Operation DBOperation `json:"operation"`
Condition string `json:"condition,omitempty"`
Enabled bool `json:"enabled"`
}
// PubSubTrigger represents a pubsub trigger.
type PubSubTrigger struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
Topic string `json:"topic"`
Enabled bool `json:"enabled"`
}
// Timer represents a one-time scheduled execution.
type Timer struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
RunAt time.Time `json:"run_at"`
Payload []byte `json:"payload,omitempty"`
Status JobStatus `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// DBChangeEvent is passed to functions triggered by database changes.
type DBChangeEvent struct {
Table string `json:"table"`
Operation DBOperation `json:"operation"`
Row map[string]interface{} `json:"row"`
OldRow map[string]interface{} `json:"old_row,omitempty"`
}
// -----------------------------------------------------------------------------
// Host Function Types (passed to WASM functions)
// -----------------------------------------------------------------------------
// HostServices provides access to Orama services from within WASM functions.
// This interface is implemented by the host and exposed to WASM modules.
type HostServices interface {
// Database operations
DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error)
DBExecute(ctx context.Context, query string, args []interface{}) (int64, error)
// Cache operations
CacheGet(ctx context.Context, key string) ([]byte, error)
CacheSet(ctx context.Context, key string, value []byte, ttlSeconds int64) error
CacheDelete(ctx context.Context, key string) error
CacheIncr(ctx context.Context, key string) (int64, error)
CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error)
// Storage operations
StoragePut(ctx context.Context, data []byte) (string, error)
StorageGet(ctx context.Context, cid string) ([]byte, error)
// PubSub operations
PubSubPublish(ctx context.Context, topic string, data []byte) error
// WebSocket operations (only valid in WS context)
WSSend(ctx context.Context, clientID string, data []byte) error
WSBroadcast(ctx context.Context, topic string, data []byte) error
// HTTP operations
HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error)
// Context operations
GetEnv(ctx context.Context, key string) (string, error)
GetSecret(ctx context.Context, name string) (string, error)
GetRequestID(ctx context.Context) string
GetCallerWallet(ctx context.Context) string
// Job operations
EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error)
ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error)
// Logging
LogInfo(ctx context.Context, message string)
LogError(ctx context.Context, message string)
}
// -----------------------------------------------------------------------------
// Deployment Types
// -----------------------------------------------------------------------------
// DeployRequest represents a request to deploy a function.
type DeployRequest struct {
Definition *FunctionDefinition `json:"definition"`
Source io.Reader `json:"-"` // Go source code or WASM bytes
IsWASM bool `json:"is_wasm"` // True if Source contains WASM bytes
}
// DeployResult represents the result of a deployment.
type DeployResult struct {
Function *Function `json:"function"`
WASMCID string `json:"wasm_cid"`
Triggers []string `json:"triggers,omitempty"`
}

332
pkg/serverless/websocket.go Normal file
View File

@ -0,0 +1,332 @@
package serverless
import (
"sync"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Ensure WSManager implements WebSocketManager interface.
var _ WebSocketManager = (*WSManager)(nil)
// WSManager manages WebSocket connections for serverless functions.
// It handles connection registration, message routing, and topic subscriptions.
type WSManager struct {
// connections maps client IDs to their WebSocket connections
connections map[string]*wsConnection
connectionsMu sync.RWMutex
// subscriptions maps topic names to sets of client IDs
subscriptions map[string]map[string]struct{}
subscriptionsMu sync.RWMutex
logger *zap.Logger
}
// wsConnection wraps a WebSocket connection with metadata.
type wsConnection struct {
conn WebSocketConn
clientID string
topics map[string]struct{} // Topics this client is subscribed to
mu sync.Mutex
}
// GorillaWSConn wraps a gorilla/websocket.Conn to implement WebSocketConn.
type GorillaWSConn struct {
*websocket.Conn
}
// Ensure GorillaWSConn implements WebSocketConn.
var _ WebSocketConn = (*GorillaWSConn)(nil)
// WriteMessage writes a message to the WebSocket connection.
func (c *GorillaWSConn) WriteMessage(messageType int, data []byte) error {
return c.Conn.WriteMessage(messageType, data)
}
// ReadMessage reads a message from the WebSocket connection.
func (c *GorillaWSConn) ReadMessage() (messageType int, p []byte, err error) {
return c.Conn.ReadMessage()
}
// Close closes the WebSocket connection.
func (c *GorillaWSConn) Close() error {
return c.Conn.Close()
}
// NewWSManager creates a new WebSocket manager.
func NewWSManager(logger *zap.Logger) *WSManager {
return &WSManager{
connections: make(map[string]*wsConnection),
subscriptions: make(map[string]map[string]struct{}),
logger: logger,
}
}
// Register registers a new WebSocket connection.
func (m *WSManager) Register(clientID string, conn WebSocketConn) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
// Close existing connection if any
if existing, exists := m.connections[clientID]; exists {
_ = existing.conn.Close()
m.logger.Debug("Closed existing connection", zap.String("client_id", clientID))
}
m.connections[clientID] = &wsConnection{
conn: conn,
clientID: clientID,
topics: make(map[string]struct{}),
}
m.logger.Debug("Registered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("total_connections", len(m.connections)),
)
}
// Unregister removes a WebSocket connection and its subscriptions.
func (m *WSManager) Unregister(clientID string) {
m.connectionsMu.Lock()
conn, exists := m.connections[clientID]
if exists {
delete(m.connections, clientID)
}
m.connectionsMu.Unlock()
if !exists {
return
}
// Remove from all subscriptions
m.subscriptionsMu.Lock()
for topic := range conn.topics {
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
}
m.subscriptionsMu.Unlock()
// Close the connection
_ = conn.conn.Close()
m.logger.Debug("Unregistered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("remaining_connections", m.GetConnectionCount()),
)
}
// Send sends data to a specific client.
func (m *WSManager) Send(clientID string, data []byte) error {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return ErrWSClientNotFound
}
conn.mu.Lock()
defer conn.mu.Unlock()
if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
m.logger.Warn("Failed to send WebSocket message",
zap.String("client_id", clientID),
zap.Error(err),
)
return err
}
return nil
}
// Broadcast sends data to all clients subscribed to a topic.
func (m *WSManager) Broadcast(topic string, data []byte) error {
m.subscriptionsMu.RLock()
clients, exists := m.subscriptions[topic]
if !exists || len(clients) == 0 {
m.subscriptionsMu.RUnlock()
return nil // No subscribers, not an error
}
// Copy client IDs to avoid holding lock during send
clientIDs := make([]string, 0, len(clients))
for clientID := range clients {
clientIDs = append(clientIDs, clientID)
}
m.subscriptionsMu.RUnlock()
// Send to all subscribers
var sendErrors int
for _, clientID := range clientIDs {
if err := m.Send(clientID, data); err != nil {
sendErrors++
}
}
m.logger.Debug("Broadcast message",
zap.String("topic", topic),
zap.Int("recipients", len(clientIDs)),
zap.Int("errors", sendErrors),
)
return nil
}
// Subscribe adds a client to a topic.
func (m *WSManager) Subscribe(clientID, topic string) {
// Add to connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return
}
conn.mu.Lock()
conn.topics[topic] = struct{}{}
conn.mu.Unlock()
// Add to topic's client list
m.subscriptionsMu.Lock()
if m.subscriptions[topic] == nil {
m.subscriptions[topic] = make(map[string]struct{})
}
m.subscriptions[topic][clientID] = struct{}{}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client subscribed to topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// Unsubscribe removes a client from a topic.
func (m *WSManager) Unsubscribe(clientID, topic string) {
// Remove from connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if exists {
conn.mu.Lock()
delete(conn.topics, topic)
conn.mu.Unlock()
}
// Remove from topic's client list
m.subscriptionsMu.Lock()
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client unsubscribed from topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// GetConnectionCount returns the number of active connections.
func (m *WSManager) GetConnectionCount() int {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
return len(m.connections)
}
// GetTopicSubscriberCount returns the number of subscribers for a topic.
func (m *WSManager) GetTopicSubscriberCount(topic string) int {
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
if clients, exists := m.subscriptions[topic]; exists {
return len(clients)
}
return 0
}
// GetClientTopics returns all topics a client is subscribed to.
func (m *WSManager) GetClientTopics(clientID string) []string {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return nil
}
conn.mu.Lock()
defer conn.mu.Unlock()
topics := make([]string, 0, len(conn.topics))
for topic := range conn.topics {
topics = append(topics, topic)
}
return topics
}
// IsConnected checks if a client is connected.
func (m *WSManager) IsConnected(clientID string) bool {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
_, exists := m.connections[clientID]
return exists
}
// Close closes all connections and cleans up resources.
func (m *WSManager) Close() {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
for clientID, conn := range m.connections {
_ = conn.conn.Close()
delete(m.connections, clientID)
}
m.subscriptionsMu.Lock()
m.subscriptions = make(map[string]map[string]struct{})
m.subscriptionsMu.Unlock()
m.logger.Info("WebSocket manager closed")
}
// Stats returns statistics about the WebSocket manager.
type WSStats struct {
ConnectionCount int `json:"connection_count"`
TopicCount int `json:"topic_count"`
SubscriptionCount int `json:"subscription_count"`
TopicStats map[string]int `json:"topic_stats"` // topic -> subscriber count
}
// GetStats returns current statistics.
func (m *WSManager) GetStats() *WSStats {
m.connectionsMu.RLock()
connCount := len(m.connections)
m.connectionsMu.RUnlock()
m.subscriptionsMu.RLock()
topicCount := len(m.subscriptions)
topicStats := make(map[string]int, topicCount)
totalSubs := 0
for topic, clients := range m.subscriptions {
topicStats[topic] = len(clients)
totalSubs += len(clients)
}
m.subscriptionsMu.RUnlock()
return &WSStats{
ConnectionCount: connCount,
TopicCount: topicCount,
SubscriptionCount: totalSubs,
TopicStats: topicStats,
}
}

View File

@ -1,53 +0,0 @@
#!/bin/bash
# Setup local domains for DeBros Network development
# Adds entries to /etc/hosts for node-1.local through node-5.local
# Maps them to 127.0.0.1 for local development
set -e
HOSTS_FILE="/etc/hosts"
NODES=("node-1" "node-2" "node-3" "node-4" "node-5")
# Check if we have sudo access
if [ "$EUID" -ne 0 ]; then
echo "This script requires sudo to modify /etc/hosts"
echo "Please run: sudo bash scripts/setup-local-domains.sh"
exit 1
fi
# Function to add or update domain entry
add_domain() {
local domain=$1
local ip="127.0.0.1"
# Check if domain already exists
if grep -q "^[[:space:]]*$ip[[:space:]]\+$domain" "$HOSTS_FILE"; then
echo "$domain already configured"
return 0
fi
# Add domain to /etc/hosts
echo "$ip $domain" >> "$HOSTS_FILE"
echo "✓ Added $domain -> $ip"
}
echo "Setting up local domains for DeBros Network..."
echo ""
# Add each node domain
for node in "${NODES[@]}"; do
add_domain "${node}.local"
done
echo ""
echo "✓ Local domains configured successfully!"
echo ""
echo "You can now access nodes via:"
for node in "${NODES[@]}"; do
echo " - ${node}.local (HTTP Gateway)"
done
echo ""
echo "Example: curl http://node-1.local:8080/rqlite/http/db/status"

View File

@ -1,85 +0,0 @@
#!/bin/bash
# Test local domain routing for DeBros Network
# Validates that all HTTP gateway routes are working
set -e
NODES=("1" "2" "3" "4" "5")
GATEWAY_PORTS=(8080 8081 8082 8083 8084)
# Color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Counters
PASSED=0
FAILED=0
# Test a single endpoint
test_endpoint() {
local node=$1
local port=$2
local path=$3
local description=$4
local url="http://node-${node}.local:${port}${path}"
printf "Testing %-50s ... " "$description"
if curl -s -f "$url" > /dev/null 2>&1; then
echo -e "${GREEN}✓ PASS${NC}"
((PASSED++))
return 0
else
echo -e "${RED}✗ FAIL${NC}"
((FAILED++))
return 1
fi
}
echo "=========================================="
echo "DeBros Network Local Domain Tests"
echo "=========================================="
echo ""
# Test each node's HTTP gateway
for i in "${!NODES[@]}"; do
node=${NODES[$i]}
port=${GATEWAY_PORTS[$i]}
echo "Testing node-${node}.local (port ${port}):"
# Test health endpoint
test_endpoint "$node" "$port" "/health" "Node-${node} health check"
# Test RQLite HTTP endpoint
test_endpoint "$node" "$port" "/rqlite/http/db/execute" "Node-${node} RQLite HTTP"
# Test IPFS API endpoint (may fail if IPFS not running, but at least connection should work)
test_endpoint "$node" "$port" "/ipfs/api/v0/version" "Node-${node} IPFS API" || true
# Test Cluster API endpoint (may fail if Cluster not running, but at least connection should work)
test_endpoint "$node" "$port" "/cluster/health" "Node-${node} Cluster API" || true
echo ""
done
# Summary
echo "=========================================="
echo "Test Results"
echo "=========================================="
echo -e "${GREEN}Passed: $PASSED${NC}"
echo -e "${RED}Failed: $FAILED${NC}"
echo ""
if [ $FAILED -eq 0 ]; then
echo -e "${GREEN}✓ All tests passed!${NC}"
exit 0
else
echo -e "${YELLOW}⚠ Some tests failed (this is expected if services aren't running)${NC}"
exit 1
fi