mirror of
https://github.com/DeBrosOfficial/network.git
synced 2026-01-30 22:43:04 +00:00
commit
d34404ec87
92
.cursor/rules/network.mdc
Normal file
92
.cursor/rules/network.mdc
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
---
|
||||||
|
alwaysApply: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# AI Instructions
|
||||||
|
|
||||||
|
You have access to the **network** MCP (Model Context Protocol) server for this project. This MCP provides deep, pre-analyzed context about the codebase that is far more accurate than default file searching.
|
||||||
|
|
||||||
|
## IMPORTANT: Always Use MCP First
|
||||||
|
|
||||||
|
**Before making any code changes or answering questions about this codebase, ALWAYS consult the MCP tools first.**
|
||||||
|
|
||||||
|
The MCP has pre-indexed the entire codebase with semantic understanding, embeddings, and structural analysis. While you can use your own file search capabilities, the MCP provides much better context because:
|
||||||
|
- It understands code semantics, not just text matching
|
||||||
|
- It has pre-analyzed the architecture, patterns, and relationships
|
||||||
|
- It can answer questions about intent and purpose, not just content
|
||||||
|
|
||||||
|
## Available MCP Tools
|
||||||
|
|
||||||
|
### Code Understanding
|
||||||
|
- `network_ask_question` - Ask natural language questions about the codebase. Use this for "how does X work?", "where is Y implemented?", "what does Z do?" questions. The MCP will search relevant code and provide informed answers.
|
||||||
|
- `network_search_code` - Semantic code search. Find code by meaning, not just text. Great for finding implementations, patterns, or related functionality.
|
||||||
|
- `network_get_architecture` - Get the full project architecture overview including tech stack, design patterns, domain entities, and API endpoints.
|
||||||
|
- `network_get_file_summary` - Get a detailed summary of what a specific file does, its purpose, exports, and responsibilities.
|
||||||
|
- `network_find_function` - Find a specific function or method definition by name across the codebase.
|
||||||
|
- `network_list_functions` - List all functions defined in a specific file.
|
||||||
|
|
||||||
|
### Skills (Learned Procedures)
|
||||||
|
Skills are reusable procedures that the agent has learned about this specific project (e.g., "how to deploy", "how to run tests", "how to add a new API endpoint").
|
||||||
|
|
||||||
|
- `network_list_skills` - List all learned skills for this project.
|
||||||
|
- `network_get_skill` - Get detailed information about a specific skill including its step-by-step procedure.
|
||||||
|
- `network_execute_skill` - Get the procedure for a learned skill so you can execute it step by step. Returns prerequisites, warnings, and commands to run.
|
||||||
|
- `network_learn_skill` - Teach the agent a new skill. The agent will explore, discover, and memorize how to perform this task.
|
||||||
|
- `network_get_learning_status` - Check the status of an ongoing skill learning session.
|
||||||
|
- `network_answer_question` - Answer a question that the learning agent asked during skill learning.
|
||||||
|
- `network_cancel_learning` - Cancel an active learning session.
|
||||||
|
- `network_forget_skill` - Delete a learned skill.
|
||||||
|
- `network_update_skill` - Update a learned skill with corrections or new information (e.g., 'Use .env.prod instead of .env', 'Add step to backup database first', 'The port should be 3000 not 8080').
|
||||||
|
|
||||||
|
#### Skill Learning Workflow (IMPORTANT)
|
||||||
|
|
||||||
|
When learning a skill, follow this **collaborative, goal-oriented workflow**. You (Cursor) are the executor, the MCP agent provides guidance:
|
||||||
|
|
||||||
|
**Goal-Oriented Learning**: The agent identifies specific GOALS (pieces of information to gather) and tracks progress by goal completion, not by iterations.
|
||||||
|
|
||||||
|
1. **Start Learning**: Call `learn_skill` with name and detailed description
|
||||||
|
2. **Monitor Progress**: Call `get_learning_status` to check progress
|
||||||
|
3. **Handle Status Responses**:
|
||||||
|
- `active` → Learning in progress, check again in a few seconds
|
||||||
|
- `waiting_input` → The agent has a question. Read it and call `answer_question` with your response
|
||||||
|
- `waiting_execution` → **IMPORTANT**: The agent needs you to run a command!
|
||||||
|
- Read the `pendingExecution.command` from the response
|
||||||
|
- **Execute the command yourself** using your terminal access
|
||||||
|
- Call `answer_question` with the command output
|
||||||
|
- `completed` → Skill learned successfully!
|
||||||
|
- `failed` → Check errors and try again
|
||||||
|
4. **Repeat** steps 2-3 until status is `completed`
|
||||||
|
|
||||||
|
**Key Insight**: The MCP agent runs on the server and cannot SSH to remote servers directly. When it needs remote access, it generates the SSH command for YOU to execute. You have terminal access - use it!
|
||||||
|
|
||||||
|
**User Override Commands**: If the agent gets stuck, you can include these keywords in your answer:
|
||||||
|
- `COMPLETE` or `SKIP` - Skip to synthesis phase and generate the skill from current data
|
||||||
|
- `PHASE:synthesizing` - Force transition to drafting phase
|
||||||
|
- `GOAL:goal_id=value` - Directly provide a goal's value (e.g., `GOAL:cluster_secret=abc123`)
|
||||||
|
- `I have provided X` - Tell the agent it already has certain information
|
||||||
|
|
||||||
|
**Example for `waiting_execution`**:
|
||||||
|
```
|
||||||
|
// Status response shows:
|
||||||
|
// pendingExecution: { command: "ssh root@192.168.1.1 'ls -la /home/user/.orama'" }
|
||||||
|
//
|
||||||
|
// You should:
|
||||||
|
// 1. Run the command in your terminal
|
||||||
|
// 2. Get the output
|
||||||
|
// 3. Call answer_question with the output
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recommended Workflow
|
||||||
|
|
||||||
|
1. **For questions:** Use `network_ask_question` or `network_search_code` to understand the codebase.
|
||||||
|
---
|
||||||
|
|
||||||
|
# DeBros Network Gateway
|
||||||
|
|
||||||
|
This project is a high-performance, edge-focused API gateway and reverse proxy designed to bridge decentralized web services with standard HTTP clients. It serves as a comprehensive middleware layer that facilitates wallet-based authentication, distributed caching via Olric, IPFS storage interaction, and serverless execution of WebAssembly (Wasm) functions. Additionally, it provides specialized utility services such as push notifications and an anonymizing proxy, ensuring secure and monitored communication between users and decentralized infrastructure.
|
||||||
|
|
||||||
|
**Architecture:** API Gateway / Edge Middleware
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
- **backend:** Go
|
||||||
|
|
||||||
@ -1,68 +0,0 @@
|
|||||||
// Project-local debug tasks
|
|
||||||
//
|
|
||||||
// For more documentation on how to configure debug tasks,
|
|
||||||
// see: https://zed.dev/docs/debugger
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"label": "Gateway Go (Delve)",
|
|
||||||
"adapter": "Delve",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "debug",
|
|
||||||
"program": "./cmd/gateway",
|
|
||||||
"env": {
|
|
||||||
"GATEWAY_ADDR": ":6001",
|
|
||||||
"GATEWAY_BOOTSTRAP_PEERS": "/ip4/localhost/tcp/4001/p2p/12D3KooWSHHwEY6cga3ng7tD1rzStAU58ogQXVMX3LZJ6Gqf6dee",
|
|
||||||
"GATEWAY_NAMESPACE": "default",
|
|
||||||
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"label": "E2E Test Go (Delve)",
|
|
||||||
"adapter": "Delve",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "test",
|
|
||||||
"buildFlags": "-tags e2e",
|
|
||||||
"program": "./e2e",
|
|
||||||
"env": {
|
|
||||||
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
|
|
||||||
},
|
|
||||||
"args": ["-test.v"]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"adapter": "Delve",
|
|
||||||
"label": "Gateway Go 6001 Port (Delve)",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "debug",
|
|
||||||
"program": "./cmd/gateway",
|
|
||||||
"env": {
|
|
||||||
"GATEWAY_ADDR": ":6001",
|
|
||||||
"GATEWAY_BOOTSTRAP_PEERS": "/ip4/localhost/tcp/4001/p2p/12D3KooWSHHwEY6cga3ng7tD1rzStAU58ogQXVMX3LZJ6Gqf6dee",
|
|
||||||
"GATEWAY_NAMESPACE": "default",
|
|
||||||
"GATEWAY_API_KEY": "ak_iGustrsFk9H8uXpwczCATe5U:default"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"adapter": "Delve",
|
|
||||||
"label": "Network CLI - peers (Delve)",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "debug",
|
|
||||||
"program": "./cmd/cli",
|
|
||||||
"args": ["peers"]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"adapter": "Delve",
|
|
||||||
"label": "Network CLI - PubSub Subscribe (Delve)",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "debug",
|
|
||||||
"program": "./cmd/cli",
|
|
||||||
"args": ["pubsub", "subscribe", "monitoring"]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"adapter": "Delve",
|
|
||||||
"label": "Node Go (Delve)",
|
|
||||||
"request": "launch",
|
|
||||||
"mode": "debug",
|
|
||||||
"program": "./cmd/node",
|
|
||||||
"args": ["--config", "configs/node.yaml"]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
53
CHANGELOG.md
53
CHANGELOG.md
@ -13,6 +13,59 @@ The format is based on [Keep a Changelog][keepachangelog] and adheres to [Semant
|
|||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
## [0.82.0] - 2026-01-03
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Added PubSub Presence feature, allowing clients to track members connected to a topic via WebSocket.
|
||||||
|
- Added a new tool, `rqlite-mcp`, which implements the Model Communication Protocol (MCP) for Rqlite, enabling AI models to interact with the database using tools.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Updated the development environment to include and manage the new `rqlite-mcp` service.
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
\n
|
||||||
|
## [0.81.0] - 2025-12-31
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Implemented a new, robust authentication service within the Gateway for handling wallet-based challenges, signature verification (ETH/SOL), JWT issuance, and API key management.
|
||||||
|
- Introduced automatic recovery logic for RQLite to detect and recover from split-brain scenarios and ensure cluster stability during restarts.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Refactored the production installation command (`dbn prod install`) by moving installer logic and utility functions into a dedicated `pkg/cli/utils` package for better modularity and maintainability.
|
||||||
|
- Reworked the core logic for starting and managing LibP2P, RQLite, and the HTTP Gateway within the Node, including improved peer reconnection and cluster configuration synchronization.
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Corrected IPFS Cluster configuration logic to properly handle port assignments and ensure correct IPFS API addresses are used, resolving potential connection issues between cluster components.
|
||||||
|
|
||||||
|
## [0.80.0] - 2025-12-29
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Implemented the core Serverless Functions Engine, allowing users to deploy and execute WASM-based functions (e.g., Go compiled with TinyGo).
|
||||||
|
- Added new database migration (004) to support serverless functions, including tables for functions, secrets, cron triggers, database triggers, pubsub triggers, timers, jobs, and invocation logs.
|
||||||
|
- Added new API endpoints for managing and invoking serverless functions (`/v1/functions`, `/v1/invoke`, `/v1/functions/{name}/invoke`, `/v1/functions/{name}/ws`).
|
||||||
|
- Introduced `WSPubSubClient` for E2E testing of WebSocket PubSub functionality.
|
||||||
|
- Added examples and a build script for creating WASM serverless functions (Echo, Hello, Counter).
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Updated Go version requirement from 1.23.8 to 1.24.0 in `go.mod`.
|
||||||
|
- Refactored RQLite client to improve data type handling and conversion, especially for `sql.Null*` types and number parsing.
|
||||||
|
- Improved RQLite cluster discovery logic to safely handle new nodes joining an existing cluster without clearing existing Raft state unless necessary (log index 0).
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Corrected an issue in the `install` command where dry-run summaries were missing newlines.
|
||||||
|
|
||||||
## [0.72.1] - 2025-12-09
|
## [0.72.1] - 2025-12-09
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
10
Makefile
10
Makefile
@ -19,7 +19,7 @@ test-e2e:
|
|||||||
|
|
||||||
.PHONY: build clean test run-node run-node2 run-node3 run-example deps tidy fmt vet lint clear-ports install-hooks kill
|
.PHONY: build clean test run-node run-node2 run-node3 run-example deps tidy fmt vet lint clear-ports install-hooks kill
|
||||||
|
|
||||||
VERSION := 0.72.1
|
VERSION := 0.82.0
|
||||||
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
|
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
|
||||||
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
|
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||||
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
|
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
|
||||||
@ -31,6 +31,7 @@ build: deps
|
|||||||
go build -ldflags "$(LDFLAGS)" -o bin/identity ./cmd/identity
|
go build -ldflags "$(LDFLAGS)" -o bin/identity ./cmd/identity
|
||||||
go build -ldflags "$(LDFLAGS)" -o bin/orama-node ./cmd/node
|
go build -ldflags "$(LDFLAGS)" -o bin/orama-node ./cmd/node
|
||||||
go build -ldflags "$(LDFLAGS)" -o bin/orama cmd/cli/main.go
|
go build -ldflags "$(LDFLAGS)" -o bin/orama cmd/cli/main.go
|
||||||
|
go build -ldflags "$(LDFLAGS)" -o bin/rqlite-mcp ./cmd/rqlite-mcp
|
||||||
# Inject gateway build metadata via pkg path variables
|
# Inject gateway build metadata via pkg path variables
|
||||||
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
|
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
|
||||||
@echo "Build complete! Run ./bin/orama version"
|
@echo "Build complete! Run ./bin/orama version"
|
||||||
@ -71,14 +72,9 @@ run-gateway:
|
|||||||
@echo "Note: Config must be in ~/.orama/data/gateway.yaml"
|
@echo "Note: Config must be in ~/.orama/data/gateway.yaml"
|
||||||
go run ./cmd/orama-gateway
|
go run ./cmd/orama-gateway
|
||||||
|
|
||||||
# Setup local domain names for development
|
|
||||||
setup-domains:
|
|
||||||
@echo "Setting up local domains..."
|
|
||||||
@sudo bash scripts/setup-local-domains.sh
|
|
||||||
|
|
||||||
# Development environment target
|
# Development environment target
|
||||||
# Uses orama dev up to start full stack with dependency and port checking
|
# Uses orama dev up to start full stack with dependency and port checking
|
||||||
dev: build setup-domains
|
dev: build
|
||||||
@./bin/orama dev up
|
@./bin/orama dev up
|
||||||
|
|
||||||
# Graceful shutdown of all dev services
|
# Graceful shutdown of all dev services
|
||||||
|
|||||||
65
README.md
65
README.md
@ -26,27 +26,25 @@ make stop
|
|||||||
|
|
||||||
After running `make dev`, test service health using these curl requests:
|
After running `make dev`, test service health using these curl requests:
|
||||||
|
|
||||||
> **Note:** Local domains (node-1.local, etc.) require running `sudo make setup-domains` first. Alternatively, use `localhost` with port numbers.
|
|
||||||
|
|
||||||
### Node Unified Gateways
|
### Node Unified Gateways
|
||||||
|
|
||||||
Each node is accessible via a single unified gateway port:
|
Each node is accessible via a single unified gateway port:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Node-1 (port 6001)
|
# Node-1 (port 6001)
|
||||||
curl http://node-1.local:6001/health
|
curl http://localhost:6001/health
|
||||||
|
|
||||||
# Node-2 (port 6002)
|
# Node-2 (port 6002)
|
||||||
curl http://node-2.local:6002/health
|
curl http://localhost:6002/health
|
||||||
|
|
||||||
# Node-3 (port 6003)
|
# Node-3 (port 6003)
|
||||||
curl http://node-3.local:6003/health
|
curl http://localhost:6003/health
|
||||||
|
|
||||||
# Node-4 (port 6004)
|
# Node-4 (port 6004)
|
||||||
curl http://node-4.local:6004/health
|
curl http://localhost:6004/health
|
||||||
|
|
||||||
# Node-5 (port 6005)
|
# Node-5 (port 6005)
|
||||||
curl http://node-5.local:6005/health
|
curl http://localhost:6005/health
|
||||||
```
|
```
|
||||||
|
|
||||||
## Network Architecture
|
## Network Architecture
|
||||||
@ -129,6 +127,54 @@ make build
|
|||||||
./bin/orama auth logout
|
./bin/orama auth logout
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Serverless Functions (WASM)
|
||||||
|
|
||||||
|
Orama supports high-performance serverless function execution using WebAssembly (WASM). Functions are isolated, secure, and can interact with network services like the distributed cache.
|
||||||
|
|
||||||
|
### 1. Build Functions
|
||||||
|
|
||||||
|
Functions must be compiled to WASM. We recommend using [TinyGo](https://tinygo.org/).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build example functions to examples/functions/bin/
|
||||||
|
./examples/functions/build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Deployment
|
||||||
|
|
||||||
|
Deploy your compiled `.wasm` file to the network via the Gateway.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Deploy a function
|
||||||
|
curl -X POST http://localhost:6001/v1/functions \
|
||||||
|
-H "Authorization: Bearer <your_api_key>" \
|
||||||
|
-F "name=hello-world" \
|
||||||
|
-F "namespace=default" \
|
||||||
|
-F "wasm=@./examples/functions/bin/hello.wasm"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Invocation
|
||||||
|
|
||||||
|
Trigger your function with a JSON payload. The function receives the payload via `stdin` and returns its response via `stdout`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Invoke via HTTP
|
||||||
|
curl -X POST http://localhost:6001/v1/functions/hello-world/invoke \
|
||||||
|
-H "Authorization: Bearer <your_api_key>" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"name": "Developer"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Management
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List all functions in a namespace
|
||||||
|
curl http://localhost:6001/v1/functions?namespace=default
|
||||||
|
|
||||||
|
# Delete a function
|
||||||
|
curl -X DELETE http://localhost:6001/v1/functions/hello-world?namespace=default
|
||||||
|
```
|
||||||
|
|
||||||
## Production Deployment
|
## Production Deployment
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
@ -262,6 +308,11 @@ sudo orama install
|
|||||||
- `POST /v1/pubsub/publish` - Publish message
|
- `POST /v1/pubsub/publish` - Publish message
|
||||||
- `GET /v1/pubsub/topics` - List topics
|
- `GET /v1/pubsub/topics` - List topics
|
||||||
- `GET /v1/pubsub/ws?topic=<name>` - WebSocket subscribe
|
- `GET /v1/pubsub/ws?topic=<name>` - WebSocket subscribe
|
||||||
|
- `POST /v1/functions` - Deploy function (multipart/form-data)
|
||||||
|
- `POST /v1/functions/{name}/invoke` - Invoke function
|
||||||
|
- `GET /v1/functions` - List functions
|
||||||
|
- `DELETE /v1/functions/{name}` - Delete function
|
||||||
|
- `GET /v1/functions/{name}/logs` - Get function logs
|
||||||
|
|
||||||
See `openapi/gateway.yaml` for complete API specification.
|
See `openapi/gateway.yaml` for complete API specification.
|
||||||
|
|
||||||
|
|||||||
320
cmd/rqlite-mcp/main.go
Normal file
320
cmd/rqlite-mcp/main.go
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rqlite/gorqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MCP JSON-RPC types
|
||||||
|
type JSONRPCRequest struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID any `json:"id,omitempty"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params json.RawMessage `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONRPCResponse struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID any `json:"id"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
Error *ResponseError `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool definition
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema any `json:"inputSchema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call types
|
||||||
|
type CallToolRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments json.RawMessage `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallToolResult struct {
|
||||||
|
Content []TextContent `json:"content"`
|
||||||
|
IsError bool `json:"isError,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MCPServer struct {
|
||||||
|
conn *gorqlite.Connection
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMCPServer(rqliteURL string) (*MCPServer, error) {
|
||||||
|
conn, err := gorqlite.Open(rqliteURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &MCPServer{
|
||||||
|
conn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPServer) handleRequest(req JSONRPCRequest) JSONRPCResponse {
|
||||||
|
var resp JSONRPCResponse
|
||||||
|
resp.JSONRPC = "2.0"
|
||||||
|
resp.ID = req.ID
|
||||||
|
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
// log.Printf("Received method: %s", req.Method)
|
||||||
|
|
||||||
|
switch req.Method {
|
||||||
|
case "initialize":
|
||||||
|
resp.Result = map[string]any{
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": map[string]any{
|
||||||
|
"tools": map[string]any{},
|
||||||
|
},
|
||||||
|
"serverInfo": map[string]any{
|
||||||
|
"name": "rqlite-mcp",
|
||||||
|
"version": "0.1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
case "notifications/initialized":
|
||||||
|
// This is a notification, no response needed
|
||||||
|
return JSONRPCResponse{}
|
||||||
|
|
||||||
|
case "tools/list":
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
tools := []Tool{
|
||||||
|
{
|
||||||
|
Name: "list_tables",
|
||||||
|
Description: "List all tables in the Rqlite database",
|
||||||
|
InputSchema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "query",
|
||||||
|
Description: "Run a SELECT query on the Rqlite database",
|
||||||
|
InputSchema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"sql": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The SQL SELECT query to run",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"sql"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "execute",
|
||||||
|
Description: "Run an INSERT, UPDATE, or DELETE statement on the Rqlite database",
|
||||||
|
InputSchema: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"sql": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The SQL statement (INSERT, UPDATE, DELETE) to run",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"sql"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp.Result = map[string]any{"tools": tools}
|
||||||
|
|
||||||
|
case "tools/call":
|
||||||
|
var callReq CallToolRequest
|
||||||
|
if err := json.Unmarshal(req.Params, &callReq); err != nil {
|
||||||
|
resp.Error = &ResponseError{Code: -32700, Message: "Parse error"}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
resp.Result = s.handleToolCall(callReq)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
resp.Error = &ResponseError{Code: -32601, Message: "Method not found"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPServer) handleToolCall(req CallToolRequest) CallToolResult {
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
// log.Printf("Tool call: %s", req.Name)
|
||||||
|
|
||||||
|
switch req.Name {
|
||||||
|
case "list_tables":
|
||||||
|
rows, err := s.conn.QueryOne("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
return errorResult(fmt.Sprintf("Error listing tables: %v", err))
|
||||||
|
}
|
||||||
|
var tables []string
|
||||||
|
for rows.Next() {
|
||||||
|
slice, err := rows.Slice()
|
||||||
|
if err == nil && len(slice) > 0 {
|
||||||
|
tables = append(tables, fmt.Sprint(slice[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(tables) == 0 {
|
||||||
|
return textResult("No tables found")
|
||||||
|
}
|
||||||
|
return textResult(strings.Join(tables, "\n"))
|
||||||
|
|
||||||
|
case "query":
|
||||||
|
var args struct {
|
||||||
|
SQL string `json:"sql"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Arguments, &args); err != nil {
|
||||||
|
return errorResult(fmt.Sprintf("Invalid arguments: %v", err))
|
||||||
|
}
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
rows, err := s.conn.QueryOne(args.SQL)
|
||||||
|
if err != nil {
|
||||||
|
return errorResult(fmt.Sprintf("Query error: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result strings.Builder
|
||||||
|
cols := rows.Columns()
|
||||||
|
result.WriteString(strings.Join(cols, " | ") + "\n")
|
||||||
|
result.WriteString(strings.Repeat("-", len(cols)*10) + "\n")
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rows.Next() {
|
||||||
|
vals, err := rows.Slice()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rowCount++
|
||||||
|
for i, v := range vals {
|
||||||
|
if i > 0 {
|
||||||
|
result.WriteString(" | ")
|
||||||
|
}
|
||||||
|
result.WriteString(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
result.WriteString("\n")
|
||||||
|
}
|
||||||
|
result.WriteString(fmt.Sprintf("\n(%d rows)", rowCount))
|
||||||
|
return textResult(result.String())
|
||||||
|
|
||||||
|
case "execute":
|
||||||
|
var args struct {
|
||||||
|
SQL string `json:"sql"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Arguments, &args); err != nil {
|
||||||
|
return errorResult(fmt.Sprintf("Invalid arguments: %v", err))
|
||||||
|
}
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
res, err := s.conn.WriteOne(args.SQL)
|
||||||
|
if err != nil {
|
||||||
|
return errorResult(fmt.Sprintf("Execution error: %v", err))
|
||||||
|
}
|
||||||
|
return textResult(fmt.Sprintf("Rows affected: %d", res.RowsAffected))
|
||||||
|
|
||||||
|
default:
|
||||||
|
return errorResult(fmt.Sprintf("Unknown tool: %s", req.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func textResult(text string) CallToolResult {
|
||||||
|
return CallToolResult{
|
||||||
|
Content: []TextContent{
|
||||||
|
{
|
||||||
|
Type: "text",
|
||||||
|
Text: text,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorResult(text string) CallToolResult {
|
||||||
|
return CallToolResult{
|
||||||
|
Content: []TextContent{
|
||||||
|
{
|
||||||
|
Type: "text",
|
||||||
|
Text: text,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IsError: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Log to stderr so stdout is clean for JSON-RPC
|
||||||
|
log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
rqliteURL := "http://localhost:5001"
|
||||||
|
if u := os.Getenv("RQLITE_URL"); u != "" {
|
||||||
|
rqliteURL = u
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *MCPServer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Retry connecting to rqlite
|
||||||
|
maxRetries := 30
|
||||||
|
for i := 0; i < maxRetries; i++ {
|
||||||
|
server, err = NewMCPServer(rqliteURL)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i%5 == 0 {
|
||||||
|
log.Printf("Waiting for Rqlite at %s... (%d/%d)", rqliteURL, i+1, maxRetries)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to connect to Rqlite after %d retries: %v", maxRetries, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("MCP Rqlite server started (stdio transport)")
|
||||||
|
log.Printf("Connected to Rqlite at %s", rqliteURL)
|
||||||
|
|
||||||
|
// Read JSON-RPC requests from stdin, write responses to stdout
|
||||||
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var req JSONRPCRequest
|
||||||
|
if err := json.Unmarshal([]byte(line), &req); err != nil {
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := server.handleRequest(req)
|
||||||
|
|
||||||
|
// Don't send response for notifications (no ID)
|
||||||
|
if req.ID == nil && strings.HasPrefix(req.Method, "notifications/") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
respData, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(string(respData))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
// Debug logging disabled to prevent excessive disk writes
|
||||||
|
}
|
||||||
|
}
|
||||||
313
e2e/env.go
313
e2e/env.go
@ -6,13 +6,16 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -20,6 +23,7 @@ import (
|
|||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
"github.com/DeBrosOfficial/network/pkg/config"
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
@ -135,14 +139,26 @@ func GetRQLiteNodes() []string {
|
|||||||
|
|
||||||
// queryAPIKeyFromRQLite queries the SQLite database directly for an API key
|
// queryAPIKeyFromRQLite queries the SQLite database directly for an API key
|
||||||
func queryAPIKeyFromRQLite() (string, error) {
|
func queryAPIKeyFromRQLite() (string, error) {
|
||||||
// Build database path from bootstrap/node config
|
// 1. Check environment variable first
|
||||||
|
if envKey := os.Getenv("DEBROS_API_KEY"); envKey != "" {
|
||||||
|
return envKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build database path from bootstrap/node config
|
||||||
homeDir, err := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to get home directory: %w", err)
|
return "", fmt.Errorf("failed to get home directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try all node data directories
|
// Try all node data directories (both production and development paths)
|
||||||
dbPaths := []string{
|
dbPaths := []string{
|
||||||
|
// Development paths (~/.orama/node-x/...)
|
||||||
|
filepath.Join(homeDir, ".orama", "node-1", "rqlite", "db.sqlite"),
|
||||||
|
filepath.Join(homeDir, ".orama", "node-2", "rqlite", "db.sqlite"),
|
||||||
|
filepath.Join(homeDir, ".orama", "node-3", "rqlite", "db.sqlite"),
|
||||||
|
filepath.Join(homeDir, ".orama", "node-4", "rqlite", "db.sqlite"),
|
||||||
|
filepath.Join(homeDir, ".orama", "node-5", "rqlite", "db.sqlite"),
|
||||||
|
// Production paths (~/.orama/data/node-x/...)
|
||||||
filepath.Join(homeDir, ".orama", "data", "node-1", "rqlite", "db.sqlite"),
|
filepath.Join(homeDir, ".orama", "data", "node-1", "rqlite", "db.sqlite"),
|
||||||
filepath.Join(homeDir, ".orama", "data", "node-2", "rqlite", "db.sqlite"),
|
filepath.Join(homeDir, ".orama", "data", "node-2", "rqlite", "db.sqlite"),
|
||||||
filepath.Join(homeDir, ".orama", "data", "node-3", "rqlite", "db.sqlite"),
|
filepath.Join(homeDir, ".orama", "data", "node-3", "rqlite", "db.sqlite"),
|
||||||
@ -644,3 +660,296 @@ func CleanupCacheEntry(t *testing.T, dmapName, key string) {
|
|||||||
t.Logf("warning: delete cache entry returned status %d", status)
|
t.Logf("warning: delete cache entry returned status %d", status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// WebSocket PubSub Client for E2E Tests
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// WSPubSubClient is a WebSocket-based PubSub client that connects to the gateway
|
||||||
|
type WSPubSubClient struct {
|
||||||
|
t *testing.T
|
||||||
|
conn *websocket.Conn
|
||||||
|
topic string
|
||||||
|
handlers []func(topic string, data []byte) error
|
||||||
|
msgChan chan []byte
|
||||||
|
doneChan chan struct{}
|
||||||
|
mu sync.RWMutex
|
||||||
|
writeMu sync.Mutex // Protects concurrent writes to WebSocket
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSPubSubMessage represents a message received from the gateway
|
||||||
|
type WSPubSubMessage struct {
|
||||||
|
Data string `json:"data"` // base64 encoded
|
||||||
|
Timestamp int64 `json:"timestamp"` // unix milliseconds
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWSPubSubClient creates a new WebSocket PubSub client connected to a topic
|
||||||
|
func NewWSPubSubClient(t *testing.T, topic string) (*WSPubSubClient, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Build WebSocket URL
|
||||||
|
gatewayURL := GetGatewayURL()
|
||||||
|
wsURL := strings.Replace(gatewayURL, "http://", "ws://", 1)
|
||||||
|
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
|
||||||
|
|
||||||
|
u, err := url.Parse(wsURL + "/v1/pubsub/ws")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("topic", topic)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Set up headers with authentication
|
||||||
|
headers := http.Header{}
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
headers.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, resp, err := dialer.Dial(u.String(), headers)
|
||||||
|
if err != nil {
|
||||||
|
if resp != nil {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("websocket dial failed (status %d): %w - body: %s", resp.StatusCode, err, string(body))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("websocket dial failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &WSPubSubClient{
|
||||||
|
t: t,
|
||||||
|
conn: conn,
|
||||||
|
topic: topic,
|
||||||
|
handlers: make([]func(topic string, data []byte) error, 0),
|
||||||
|
msgChan: make(chan []byte, 128),
|
||||||
|
doneChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start reader goroutine
|
||||||
|
go client.readLoop()
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWSPubSubPresenceClient creates a new WebSocket PubSub client with presence parameters
|
||||||
|
func NewWSPubSubPresenceClient(t *testing.T, topic, memberID string, meta map[string]interface{}) (*WSPubSubClient, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Build WebSocket URL
|
||||||
|
gatewayURL := GetGatewayURL()
|
||||||
|
wsURL := strings.Replace(gatewayURL, "http://", "ws://", 1)
|
||||||
|
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
|
||||||
|
|
||||||
|
u, err := url.Parse(wsURL + "/v1/pubsub/ws")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("topic", topic)
|
||||||
|
q.Set("presence", "true")
|
||||||
|
q.Set("member_id", memberID)
|
||||||
|
if meta != nil {
|
||||||
|
metaJSON, _ := json.Marshal(meta)
|
||||||
|
q.Set("member_meta", string(metaJSON))
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Set up headers with authentication
|
||||||
|
headers := http.Header{}
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
headers.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, resp, err := dialer.Dial(u.String(), headers)
|
||||||
|
if err != nil {
|
||||||
|
if resp != nil {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("websocket dial failed (status %d): %w - body: %s", resp.StatusCode, err, string(body))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("websocket dial failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &WSPubSubClient{
|
||||||
|
t: t,
|
||||||
|
conn: conn,
|
||||||
|
topic: topic,
|
||||||
|
handlers: make([]func(topic string, data []byte) error, 0),
|
||||||
|
msgChan: make(chan []byte, 128),
|
||||||
|
doneChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start reader goroutine
|
||||||
|
go client.readLoop()
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop reads messages from the WebSocket and dispatches to handlers
|
||||||
|
func (c *WSPubSubClient) readLoop() {
|
||||||
|
defer close(c.doneChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, message, err := c.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
c.mu.RLock()
|
||||||
|
closed := c.closed
|
||||||
|
c.mu.RUnlock()
|
||||||
|
if !closed {
|
||||||
|
// Only log if not intentionally closed
|
||||||
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
c.t.Logf("websocket read error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the message envelope
|
||||||
|
var msg WSPubSubMessage
|
||||||
|
if err := json.Unmarshal(message, &msg); err != nil {
|
||||||
|
c.t.Logf("failed to unmarshal message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode base64 data
|
||||||
|
data, err := base64.StdEncoding.DecodeString(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
c.t.Logf("failed to decode base64 data: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send to message channel
|
||||||
|
select {
|
||||||
|
case c.msgChan <- data:
|
||||||
|
default:
|
||||||
|
c.t.Logf("message channel full, dropping message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch to handlers
|
||||||
|
c.mu.RLock()
|
||||||
|
handlers := make([]func(topic string, data []byte) error, len(c.handlers))
|
||||||
|
copy(handlers, c.handlers)
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, handler := range handlers {
|
||||||
|
if err := handler(msg.Topic, data); err != nil {
|
||||||
|
c.t.Logf("handler error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a message handler
|
||||||
|
func (c *WSPubSubClient) Subscribe(handler func(topic string, data []byte) error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.handlers = append(c.handlers, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish sends a message to the topic
|
||||||
|
func (c *WSPubSubClient) Publish(data []byte) error {
|
||||||
|
c.mu.RLock()
|
||||||
|
closed := c.closed
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("client is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect concurrent writes to WebSocket
|
||||||
|
c.writeMu.Lock()
|
||||||
|
defer c.writeMu.Unlock()
|
||||||
|
|
||||||
|
return c.conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveWithTimeout waits for a message with timeout
|
||||||
|
func (c *WSPubSubClient) ReceiveWithTimeout(timeout time.Duration) ([]byte, error) {
|
||||||
|
select {
|
||||||
|
case msg := <-c.msgChan:
|
||||||
|
return msg, nil
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for message")
|
||||||
|
case <-c.doneChan:
|
||||||
|
return nil, fmt.Errorf("connection closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the WebSocket connection
|
||||||
|
func (c *WSPubSubClient) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// Send close message
|
||||||
|
_ = c.conn.WriteMessage(websocket.CloseMessage,
|
||||||
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
|
||||||
|
// Close connection
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topic returns the topic this client is subscribed to
|
||||||
|
func (c *WSPubSubClient) Topic() string {
|
||||||
|
return c.topic
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSPubSubClientPair represents a publisher and subscriber pair for testing
|
||||||
|
type WSPubSubClientPair struct {
|
||||||
|
Publisher *WSPubSubClient
|
||||||
|
Subscriber *WSPubSubClient
|
||||||
|
Topic string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWSPubSubClientPair creates a publisher and subscriber pair for a topic
|
||||||
|
func NewWSPubSubClientPair(t *testing.T, topic string) (*WSPubSubClientPair, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Create subscriber first
|
||||||
|
sub, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create subscriber: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small delay to ensure subscriber is registered
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
pub, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
sub.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create publisher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WSPubSubClientPair{
|
||||||
|
Publisher: pub,
|
||||||
|
Subscriber: sub,
|
||||||
|
Topic: topic,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes both publisher and subscriber
|
||||||
|
func (p *WSPubSubClientPair) Close() {
|
||||||
|
if p.Publisher != nil {
|
||||||
|
p.Publisher.Close()
|
||||||
|
}
|
||||||
|
if p.Subscriber != nil {
|
||||||
|
p.Subscriber.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,82 +3,46 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newMessageCollector(ctx context.Context, buffer int) (chan []byte, func(string, []byte) error) {
|
// TestPubSub_SubscribePublish tests basic pub/sub functionality via WebSocket
|
||||||
if buffer <= 0 {
|
|
||||||
buffer = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := make(chan []byte, buffer)
|
|
||||||
handler := func(_ string, data []byte) error {
|
|
||||||
copied := append([]byte(nil), data...)
|
|
||||||
select {
|
|
||||||
case ch <- copied:
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ch, handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForMessage(ctx context.Context, ch <-chan []byte) ([]byte, error) {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
return msg, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, fmt.Errorf("context finished while waiting for pubsub message: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubSub_SubscribePublish(t *testing.T) {
|
func TestPubSub_SubscribePublish(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create two clients
|
|
||||||
client1 := NewNetworkClient(t)
|
|
||||||
client2 := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := client1.Connect(); err != nil {
|
|
||||||
t.Fatalf("client1 connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer client1.Disconnect()
|
|
||||||
|
|
||||||
if err := client2.Connect(); err != nil {
|
|
||||||
t.Fatalf("client2 connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer client2.Disconnect()
|
|
||||||
|
|
||||||
topic := GenerateTopic()
|
topic := GenerateTopic()
|
||||||
message := "test-message-from-client1"
|
message := "test-message-from-publisher"
|
||||||
|
|
||||||
// Subscribe on client2
|
// Create subscriber first
|
||||||
messageCh, handler := newMessageCollector(ctx, 1)
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
if err := client2.PubSub().Subscribe(ctx, topic, handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe failed: %v", err)
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
}
|
}
|
||||||
defer client2.PubSub().Unsubscribe(ctx, topic)
|
defer subscriber.Close()
|
||||||
|
|
||||||
// Give subscription time to propagate and mesh to form
|
// Give subscriber time to register
|
||||||
Delay(2000)
|
Delay(200)
|
||||||
|
|
||||||
// Publish from client1
|
// Create publisher
|
||||||
if err := client1.PubSub().Publish(ctx, topic, []byte(message)); err != nil {
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Publish message
|
||||||
|
if err := publisher.Publish([]byte(message)); err != nil {
|
||||||
t.Fatalf("publish failed: %v", err)
|
t.Fatalf("publish failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive message on client2
|
// Receive message on subscriber
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
|
msg, err := subscriber.ReceiveWithTimeout(10 * time.Second)
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
msg, err := waitForMessage(recvCtx, messageCh)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receive failed: %v", err)
|
t.Fatalf("receive failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -88,154 +52,126 @@ func TestPubSub_SubscribePublish(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_MultipleSubscribers tests that multiple subscribers receive the same message
|
||||||
func TestPubSub_MultipleSubscribers(t *testing.T) {
|
func TestPubSub_MultipleSubscribers(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create three clients
|
|
||||||
clientPub := NewNetworkClient(t)
|
|
||||||
clientSub1 := NewNetworkClient(t)
|
|
||||||
clientSub2 := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := clientPub.Connect(); err != nil {
|
|
||||||
t.Fatalf("publisher connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientPub.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub1.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber1 connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub1.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub2.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber2 connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub2.Disconnect()
|
|
||||||
|
|
||||||
topic := GenerateTopic()
|
topic := GenerateTopic()
|
||||||
message1 := "message-for-sub1"
|
message1 := "message-1"
|
||||||
message2 := "message-for-sub2"
|
message2 := "message-2"
|
||||||
|
|
||||||
// Subscribe on both clients
|
// Create two subscribers
|
||||||
sub1Ch, sub1Handler := newMessageCollector(ctx, 4)
|
sub1, err := NewWSPubSubClient(t, topic)
|
||||||
if err := clientSub1.PubSub().Subscribe(ctx, topic, sub1Handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe1 failed: %v", err)
|
t.Fatalf("failed to create subscriber1: %v", err)
|
||||||
}
|
}
|
||||||
defer clientSub1.PubSub().Unsubscribe(ctx, topic)
|
defer sub1.Close()
|
||||||
|
|
||||||
sub2Ch, sub2Handler := newMessageCollector(ctx, 4)
|
sub2, err := NewWSPubSubClient(t, topic)
|
||||||
if err := clientSub2.PubSub().Subscribe(ctx, topic, sub2Handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe2 failed: %v", err)
|
t.Fatalf("failed to create subscriber2: %v", err)
|
||||||
}
|
}
|
||||||
defer clientSub2.PubSub().Unsubscribe(ctx, topic)
|
defer sub2.Close()
|
||||||
|
|
||||||
// Give subscriptions time to propagate
|
// Give subscribers time to register
|
||||||
Delay(500)
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
// Publish first message
|
// Publish first message
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message1)); err != nil {
|
if err := publisher.Publish([]byte(message1)); err != nil {
|
||||||
t.Fatalf("publish1 failed: %v", err)
|
t.Fatalf("publish1 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Both subscribers should receive first message
|
// Both subscribers should receive first message
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
|
msg1a, err := sub1.ReceiveWithTimeout(10 * time.Second)
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
msg1a, err := waitForMessage(recvCtx, sub1Ch)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("sub1 receive1 failed: %v", err)
|
t.Fatalf("sub1 receive1 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(msg1a) != message1 {
|
if string(msg1a) != message1 {
|
||||||
t.Fatalf("sub1: expected %q, got %q", message1, string(msg1a))
|
t.Fatalf("sub1: expected %q, got %q", message1, string(msg1a))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg1b, err := waitForMessage(recvCtx, sub2Ch)
|
msg1b, err := sub2.ReceiveWithTimeout(10 * time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("sub2 receive1 failed: %v", err)
|
t.Fatalf("sub2 receive1 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(msg1b) != message1 {
|
if string(msg1b) != message1 {
|
||||||
t.Fatalf("sub2: expected %q, got %q", message1, string(msg1b))
|
t.Fatalf("sub2: expected %q, got %q", message1, string(msg1b))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish second message
|
// Publish second message
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message2)); err != nil {
|
if err := publisher.Publish([]byte(message2)); err != nil {
|
||||||
t.Fatalf("publish2 failed: %v", err)
|
t.Fatalf("publish2 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Both subscribers should receive second message
|
// Both subscribers should receive second message
|
||||||
recvCtx2, recvCancel2 := context.WithTimeout(ctx, 10*time.Second)
|
msg2a, err := sub1.ReceiveWithTimeout(10 * time.Second)
|
||||||
defer recvCancel2()
|
|
||||||
|
|
||||||
msg2a, err := waitForMessage(recvCtx2, sub1Ch)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("sub1 receive2 failed: %v", err)
|
t.Fatalf("sub1 receive2 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(msg2a) != message2 {
|
if string(msg2a) != message2 {
|
||||||
t.Fatalf("sub1: expected %q, got %q", message2, string(msg2a))
|
t.Fatalf("sub1: expected %q, got %q", message2, string(msg2a))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg2b, err := waitForMessage(recvCtx2, sub2Ch)
|
msg2b, err := sub2.ReceiveWithTimeout(10 * time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("sub2 receive2 failed: %v", err)
|
t.Fatalf("sub2 receive2 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(msg2b) != message2 {
|
if string(msg2b) != message2 {
|
||||||
t.Fatalf("sub2: expected %q, got %q", message2, string(msg2b))
|
t.Fatalf("sub2: expected %q, got %q", message2, string(msg2b))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_Deduplication tests that multiple identical messages are all received
|
||||||
func TestPubSub_Deduplication(t *testing.T) {
|
func TestPubSub_Deduplication(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create two clients
|
|
||||||
clientPub := NewNetworkClient(t)
|
|
||||||
clientSub := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := clientPub.Connect(); err != nil {
|
|
||||||
t.Fatalf("publisher connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientPub.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub.Disconnect()
|
|
||||||
|
|
||||||
topic := GenerateTopic()
|
topic := GenerateTopic()
|
||||||
message := "duplicate-test-message"
|
message := "duplicate-test-message"
|
||||||
|
|
||||||
// Subscribe on client
|
// Create subscriber
|
||||||
messageCh, handler := newMessageCollector(ctx, 3)
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe failed: %v", err)
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
}
|
}
|
||||||
defer clientSub.PubSub().Unsubscribe(ctx, topic)
|
defer subscriber.Close()
|
||||||
|
|
||||||
// Give subscription time to propagate and mesh to form
|
// Give subscriber time to register
|
||||||
Delay(2000)
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
// Publish the same message multiple times
|
// Publish the same message multiple times
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic, []byte(message)); err != nil {
|
if err := publisher.Publish([]byte(message)); err != nil {
|
||||||
t.Fatalf("publish %d failed: %v", i, err)
|
t.Fatalf("publish %d failed: %v", i, err)
|
||||||
}
|
}
|
||||||
|
// Small delay between publishes
|
||||||
|
Delay(50)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive messages - should get all (no dedup filter on subscribe)
|
// Receive messages - should get all (no dedup filter)
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
receivedCount := 0
|
receivedCount := 0
|
||||||
for receivedCount < 3 {
|
for receivedCount < 3 {
|
||||||
if _, err := waitForMessage(recvCtx, messageCh); err != nil {
|
_, err := subscriber.ReceiveWithTimeout(5 * time.Second)
|
||||||
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
receivedCount++
|
receivedCount++
|
||||||
@ -244,40 +180,35 @@ func TestPubSub_Deduplication(t *testing.T) {
|
|||||||
if receivedCount < 1 {
|
if receivedCount < 1 {
|
||||||
t.Fatalf("expected to receive at least 1 message, got %d", receivedCount)
|
t.Fatalf("expected to receive at least 1 message, got %d", receivedCount)
|
||||||
}
|
}
|
||||||
|
t.Logf("received %d messages", receivedCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_ConcurrentPublish tests concurrent message publishing
|
||||||
func TestPubSub_ConcurrentPublish(t *testing.T) {
|
func TestPubSub_ConcurrentPublish(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create clients
|
|
||||||
clientPub := NewNetworkClient(t)
|
|
||||||
clientSub := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := clientPub.Connect(); err != nil {
|
|
||||||
t.Fatalf("publisher connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientPub.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub.Disconnect()
|
|
||||||
|
|
||||||
topic := GenerateTopic()
|
topic := GenerateTopic()
|
||||||
numMessages := 10
|
numMessages := 10
|
||||||
|
|
||||||
// Subscribe
|
// Create subscriber
|
||||||
messageCh, handler := newMessageCollector(ctx, numMessages)
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe failed: %v", err)
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
}
|
}
|
||||||
defer clientSub.PubSub().Unsubscribe(ctx, topic)
|
defer subscriber.Close()
|
||||||
|
|
||||||
// Give subscription time to propagate and mesh to form
|
// Give subscriber time to register
|
||||||
Delay(2000)
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
// Publish multiple messages concurrently
|
// Publish multiple messages concurrently
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@ -286,7 +217,7 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
|
|||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
msg := fmt.Sprintf("concurrent-msg-%d", idx)
|
msg := fmt.Sprintf("concurrent-msg-%d", idx)
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic, []byte(msg)); err != nil {
|
if err := publisher.Publish([]byte(msg)); err != nil {
|
||||||
t.Logf("publish %d failed: %v", idx, err)
|
t.Logf("publish %d failed: %v", idx, err)
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
@ -294,12 +225,10 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Receive messages
|
// Receive messages
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
receivedCount := 0
|
receivedCount := 0
|
||||||
for receivedCount < numMessages {
|
for receivedCount < numMessages {
|
||||||
if _, err := waitForMessage(recvCtx, messageCh); err != nil {
|
_, err := subscriber.ReceiveWithTimeout(10 * time.Second)
|
||||||
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
receivedCount++
|
receivedCount++
|
||||||
@ -310,107 +239,110 @@ func TestPubSub_ConcurrentPublish(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_TopicIsolation tests that messages are isolated to their topics
|
||||||
func TestPubSub_TopicIsolation(t *testing.T) {
|
func TestPubSub_TopicIsolation(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create clients
|
|
||||||
clientPub := NewNetworkClient(t)
|
|
||||||
clientSub := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := clientPub.Connect(); err != nil {
|
|
||||||
t.Fatalf("publisher connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientPub.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub.Disconnect()
|
|
||||||
|
|
||||||
topic1 := GenerateTopic()
|
topic1 := GenerateTopic()
|
||||||
topic2 := GenerateTopic()
|
topic2 := GenerateTopic()
|
||||||
|
msg1 := "message-on-topic1"
|
||||||
// Subscribe to topic1
|
|
||||||
messageCh, handler := newMessageCollector(ctx, 2)
|
|
||||||
if err := clientSub.PubSub().Subscribe(ctx, topic1, handler); err != nil {
|
|
||||||
t.Fatalf("subscribe1 failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub.PubSub().Unsubscribe(ctx, topic1)
|
|
||||||
|
|
||||||
// Give subscription time to propagate and mesh to form
|
|
||||||
Delay(2000)
|
|
||||||
|
|
||||||
// Publish to topic2
|
|
||||||
msg2 := "message-on-topic2"
|
msg2 := "message-on-topic2"
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic2, []byte(msg2)); err != nil {
|
|
||||||
|
// Create subscriber for topic1
|
||||||
|
sub1, err := NewWSPubSubClient(t, topic1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create subscriber1: %v", err)
|
||||||
|
}
|
||||||
|
defer sub1.Close()
|
||||||
|
|
||||||
|
// Create subscriber for topic2
|
||||||
|
sub2, err := NewWSPubSubClient(t, topic2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create subscriber2: %v", err)
|
||||||
|
}
|
||||||
|
defer sub2.Close()
|
||||||
|
|
||||||
|
// Give subscribers time to register
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publishers
|
||||||
|
pub1, err := NewWSPubSubClient(t, topic1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher1: %v", err)
|
||||||
|
}
|
||||||
|
defer pub1.Close()
|
||||||
|
|
||||||
|
pub2, err := NewWSPubSubClient(t, topic2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher2: %v", err)
|
||||||
|
}
|
||||||
|
defer pub2.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Publish to topic2 first
|
||||||
|
if err := pub2.Publish([]byte(msg2)); err != nil {
|
||||||
t.Fatalf("publish2 failed: %v", err)
|
t.Fatalf("publish2 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish to topic1
|
// Publish to topic1
|
||||||
msg1 := "message-on-topic1"
|
if err := pub1.Publish([]byte(msg1)); err != nil {
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic1, []byte(msg1)); err != nil {
|
|
||||||
t.Fatalf("publish1 failed: %v", err)
|
t.Fatalf("publish1 failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive on sub1 - should get msg1 only
|
// Sub1 should receive msg1 only
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
|
received1, err := sub1.ReceiveWithTimeout(10 * time.Second)
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
msg, err := waitForMessage(recvCtx, messageCh)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receive failed: %v", err)
|
t.Fatalf("sub1 receive failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(received1) != msg1 {
|
||||||
|
t.Fatalf("sub1: expected %q, got %q", msg1, string(received1))
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(msg) != msg1 {
|
// Sub2 should receive msg2 only
|
||||||
t.Fatalf("expected %q, got %q", msg1, string(msg))
|
received2, err := sub2.ReceiveWithTimeout(10 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sub2 receive failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(received2) != msg2 {
|
||||||
|
t.Fatalf("sub2: expected %q, got %q", msg2, string(received2))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_EmptyMessage tests sending and receiving empty messages
|
||||||
func TestPubSub_EmptyMessage(t *testing.T) {
|
func TestPubSub_EmptyMessage(t *testing.T) {
|
||||||
SkipIfMissingGateway(t)
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create clients
|
|
||||||
clientPub := NewNetworkClient(t)
|
|
||||||
clientSub := NewNetworkClient(t)
|
|
||||||
|
|
||||||
if err := clientPub.Connect(); err != nil {
|
|
||||||
t.Fatalf("publisher connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientPub.Disconnect()
|
|
||||||
|
|
||||||
if err := clientSub.Connect(); err != nil {
|
|
||||||
t.Fatalf("subscriber connect failed: %v", err)
|
|
||||||
}
|
|
||||||
defer clientSub.Disconnect()
|
|
||||||
|
|
||||||
topic := GenerateTopic()
|
topic := GenerateTopic()
|
||||||
|
|
||||||
// Subscribe
|
// Create subscriber
|
||||||
messageCh, handler := newMessageCollector(ctx, 1)
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
if err := clientSub.PubSub().Subscribe(ctx, topic, handler); err != nil {
|
if err != nil {
|
||||||
t.Fatalf("subscribe failed: %v", err)
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
}
|
}
|
||||||
defer clientSub.PubSub().Unsubscribe(ctx, topic)
|
defer subscriber.Close()
|
||||||
|
|
||||||
// Give subscription time to propagate and mesh to form
|
// Give subscriber time to register
|
||||||
Delay(2000)
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
// Publish empty message
|
// Publish empty message
|
||||||
if err := clientPub.PubSub().Publish(ctx, topic, []byte("")); err != nil {
|
if err := publisher.Publish([]byte("")); err != nil {
|
||||||
t.Fatalf("publish empty failed: %v", err)
|
t.Fatalf("publish empty failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive on sub - should get empty message
|
// Receive on subscriber - should get empty message
|
||||||
recvCtx, recvCancel := context.WithTimeout(ctx, 10*time.Second)
|
msg, err := subscriber.ReceiveWithTimeout(10 * time.Second)
|
||||||
defer recvCancel()
|
|
||||||
|
|
||||||
msg, err := waitForMessage(recvCtx, messageCh)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receive failed: %v", err)
|
t.Fatalf("receive failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -419,3 +351,111 @@ func TestPubSub_EmptyMessage(t *testing.T) {
|
|||||||
t.Fatalf("expected empty message, got %q", string(msg))
|
t.Fatalf("expected empty message, got %q", string(msg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPubSub_LargeMessage tests sending and receiving large messages
|
||||||
|
func TestPubSub_LargeMessage(t *testing.T) {
|
||||||
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
|
topic := GenerateTopic()
|
||||||
|
|
||||||
|
// Create a large message (100KB)
|
||||||
|
largeMessage := make([]byte, 100*1024)
|
||||||
|
for i := range largeMessage {
|
||||||
|
largeMessage[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create subscriber
|
||||||
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
// Give subscriber time to register
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Publish large message
|
||||||
|
if err := publisher.Publish(largeMessage); err != nil {
|
||||||
|
t.Fatalf("publish large message failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive on subscriber
|
||||||
|
msg, err := subscriber.ReceiveWithTimeout(30 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("receive failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg) != len(largeMessage) {
|
||||||
|
t.Fatalf("expected message of length %d, got %d", len(largeMessage), len(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify content
|
||||||
|
for i := range msg {
|
||||||
|
if msg[i] != largeMessage[i] {
|
||||||
|
t.Fatalf("message content mismatch at byte %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPubSub_RapidPublish tests rapid message publishing
|
||||||
|
func TestPubSub_RapidPublish(t *testing.T) {
|
||||||
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
|
topic := GenerateTopic()
|
||||||
|
numMessages := 50
|
||||||
|
|
||||||
|
// Create subscriber
|
||||||
|
subscriber, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create subscriber: %v", err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
// Give subscriber time to register
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Create publisher
|
||||||
|
publisher, err := NewWSPubSubClient(t, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// Give connections time to stabilize
|
||||||
|
Delay(200)
|
||||||
|
|
||||||
|
// Publish messages rapidly
|
||||||
|
for i := 0; i < numMessages; i++ {
|
||||||
|
msg := fmt.Sprintf("rapid-msg-%d", i)
|
||||||
|
if err := publisher.Publish([]byte(msg)); err != nil {
|
||||||
|
t.Fatalf("publish %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive messages
|
||||||
|
receivedCount := 0
|
||||||
|
for receivedCount < numMessages {
|
||||||
|
_, err := subscriber.ReceiveWithTimeout(10 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
receivedCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow some message loss due to buffering
|
||||||
|
minExpected := numMessages * 80 / 100 // 80% minimum
|
||||||
|
if receivedCount < minExpected {
|
||||||
|
t.Fatalf("expected at least %d messages, got %d", minExpected, receivedCount)
|
||||||
|
}
|
||||||
|
t.Logf("received %d/%d messages (%.1f%%)", receivedCount, numMessages, float64(receivedCount)*100/float64(numMessages))
|
||||||
|
}
|
||||||
|
|||||||
122
e2e/pubsub_presence_test.go
Normal file
122
e2e/pubsub_presence_test.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
//go:build e2e
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPubSub_Presence(t *testing.T) {
|
||||||
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
|
topic := GenerateTopic()
|
||||||
|
memberID := "user123"
|
||||||
|
memberMeta := map[string]interface{}{"name": "Alice"}
|
||||||
|
|
||||||
|
// 1. Subscribe with presence
|
||||||
|
client1, err := NewWSPubSubPresenceClient(t, topic, memberID, memberMeta)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create presence client: %v", err)
|
||||||
|
}
|
||||||
|
defer client1.Close()
|
||||||
|
|
||||||
|
// Wait for join event
|
||||||
|
msg, err := client1.ReceiveWithTimeout(5 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("did not receive join event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var event map[string]interface{}
|
||||||
|
if err := json.Unmarshal(msg, &event); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if event["type"] != "presence.join" {
|
||||||
|
t.Fatalf("expected presence.join event, got %v", event["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if event["member_id"] != memberID {
|
||||||
|
t.Fatalf("expected member_id %s, got %v", memberID, event["member_id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Query presence endpoint
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := &HTTPRequest{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: fmt.Sprintf("%s/v1/pubsub/presence?topic=%s", GetGatewayURL(), topic),
|
||||||
|
}
|
||||||
|
|
||||||
|
body, status, err := req.Do(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("presence query failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := DecodeJSON(body, &resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["count"] != float64(1) {
|
||||||
|
t.Fatalf("expected count 1, got %v", resp["count"])
|
||||||
|
}
|
||||||
|
|
||||||
|
members := resp["members"].([]interface{})
|
||||||
|
if len(members) != 1 {
|
||||||
|
t.Fatalf("expected 1 member, got %d", len(members))
|
||||||
|
}
|
||||||
|
|
||||||
|
member := members[0].(map[string]interface{})
|
||||||
|
if member["member_id"] != memberID {
|
||||||
|
t.Fatalf("expected member_id %s, got %v", memberID, member["member_id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Subscribe second member
|
||||||
|
memberID2 := "user456"
|
||||||
|
client2, err := NewWSPubSubPresenceClient(t, topic, memberID2, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create second presence client: %v", err)
|
||||||
|
}
|
||||||
|
// We'll close client2 later to test leave event
|
||||||
|
|
||||||
|
// Client1 should receive join event for Client2
|
||||||
|
msg2, err := client1.ReceiveWithTimeout(5 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client1 did not receive join event for client2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(msg2, &event); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if event["type"] != "presence.join" || event["member_id"] != memberID2 {
|
||||||
|
t.Fatalf("expected presence.join for %s, got %v for %v", memberID2, event["type"], event["member_id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Disconnect client2 and verify leave event
|
||||||
|
client2.Close()
|
||||||
|
|
||||||
|
msg3, err := client1.ReceiveWithTimeout(5 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client1 did not receive leave event for client2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(msg3, &event); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if event["type"] != "presence.leave" || event["member_id"] != memberID2 {
|
||||||
|
t.Fatalf("expected presence.leave for %s, got %v for %v", memberID2, event["type"], event["member_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
123
e2e/serverless_test.go
Normal file
123
e2e/serverless_test.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
//go:build e2e
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerless_DeployAndInvoke(t *testing.T) {
|
||||||
|
SkipIfMissingGateway(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wasmPath := "../examples/functions/bin/hello.wasm"
|
||||||
|
if _, err := os.Stat(wasmPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("hello.wasm not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
wasmBytes, err := os.ReadFile(wasmPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read hello.wasm: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcName := "e2e-hello"
|
||||||
|
namespace := "default"
|
||||||
|
|
||||||
|
// 1. Deploy function
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
|
||||||
|
// Add metadata
|
||||||
|
_ = writer.WriteField("name", funcName)
|
||||||
|
_ = writer.WriteField("namespace", namespace)
|
||||||
|
|
||||||
|
// Add WASM file
|
||||||
|
part, err := writer.CreateFormFile("wasm", funcName+".wasm")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create form file: %v", err)
|
||||||
|
}
|
||||||
|
part.Write(wasmBytes)
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
deployReq, _ := http.NewRequestWithContext(ctx, "POST", GetGatewayURL()+"/v1/functions", &buf)
|
||||||
|
deployReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
deployReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewHTTPClient(1 * time.Minute)
|
||||||
|
resp, err := client.Do(deployReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("deploy request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("deploy failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Invoke function
|
||||||
|
invokePayload := []byte(`{"name": "E2E Tester"}`)
|
||||||
|
invokeReq, _ := http.NewRequestWithContext(ctx, "POST", GetGatewayURL()+"/v1/functions/"+funcName+"/invoke", bytes.NewReader(invokePayload))
|
||||||
|
invokeReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
invokeReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = client.Do(invokeReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invoke request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("invoke failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
output, _ := io.ReadAll(resp.Body)
|
||||||
|
expected := "Hello, E2E Tester!"
|
||||||
|
if !bytes.Contains(output, []byte(expected)) {
|
||||||
|
t.Errorf("output %q does not contain %q", string(output), expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. List functions
|
||||||
|
listReq, _ := http.NewRequestWithContext(ctx, "GET", GetGatewayURL()+"/v1/functions?namespace="+namespace, nil)
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
listReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
resp, err = client.Do(listReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("list failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Delete function
|
||||||
|
deleteReq, _ := http.NewRequestWithContext(ctx, "DELETE", GetGatewayURL()+"/v1/functions/"+funcName+"?namespace="+namespace, nil)
|
||||||
|
if apiKey := GetAPIKey(); apiKey != "" {
|
||||||
|
deleteReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
resp, err = client.Do(deleteReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("delete request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("delete failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
158
example.http
Normal file
158
example.http
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
### Orama Network Gateway API Examples
|
||||||
|
# This file is designed for the VS Code "REST Client" extension.
|
||||||
|
# It demonstrates the core capabilities of the DeBros Network Gateway.
|
||||||
|
|
||||||
|
@baseUrl = http://localhost:6001
|
||||||
|
@apiKey = ak_X32jj2fiin8zzv0hmBKTC5b5:default
|
||||||
|
@contentType = application/json
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
### 1. SYSTEM & HEALTH
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
# @name HealthCheck
|
||||||
|
GET {{baseUrl}}/v1/health
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name SystemStatus
|
||||||
|
# Returns the full status of the gateway and connected services
|
||||||
|
GET {{baseUrl}}/v1/status
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name NetworkStatus
|
||||||
|
# Returns the P2P network status and PeerID
|
||||||
|
GET {{baseUrl}}/v1/network/status
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
### 2. DISTRIBUTED CACHE (OLRIC)
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
# @name CachePut
|
||||||
|
# Stores a value in the distributed cache (DMap)
|
||||||
|
POST {{baseUrl}}/v1/cache/put
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: {{contentType}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"dmap": "demo-cache",
|
||||||
|
"key": "video-demo",
|
||||||
|
"value": "Hello from REST Client!"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name CacheGet
|
||||||
|
# Retrieves a value from the distributed cache
|
||||||
|
POST {{baseUrl}}/v1/cache/get
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: {{contentType}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"dmap": "demo-cache",
|
||||||
|
"key": "video-demo"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name CacheScan
|
||||||
|
# Scans for keys in a specific DMap
|
||||||
|
POST {{baseUrl}}/v1/cache/scan
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: {{contentType}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"dmap": "demo-cache"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
### 3. DECENTRALIZED STORAGE (IPFS)
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
# @name StorageUpload
|
||||||
|
# Uploads a file to IPFS (Multipart)
|
||||||
|
POST {{baseUrl}}/v1/storage/upload
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: multipart/form-data; boundary=boundary
|
||||||
|
|
||||||
|
--boundary
|
||||||
|
Content-Disposition: form-data; name="file"; filename="demo.txt"
|
||||||
|
Content-Type: text/plain
|
||||||
|
|
||||||
|
This is a demonstration of decentralized storage on the Sonr Network.
|
||||||
|
--boundary--
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name StorageStatus
|
||||||
|
# Check the pinning status and replication of a CID
|
||||||
|
# Replace {cid} with the CID returned from the upload above
|
||||||
|
@demoCid = bafkreid76y6x6v2n5o4n6n5o4n6n5o4n6n5o4n6n5o4
|
||||||
|
GET {{baseUrl}}/v1/storage/status/{{demoCid}}
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name StorageDownload
|
||||||
|
# Retrieve content directly from IPFS via the gateway
|
||||||
|
GET {{baseUrl}}/v1/storage/get/{{demoCid}}
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
### 4. REAL-TIME PUB/SUB
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
# @name ListTopics
|
||||||
|
# Lists all active topics in the current namespace
|
||||||
|
GET {{baseUrl}}/v1/pubsub/topics
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name PublishMessage
|
||||||
|
# Publishes a base64 encoded message to a topic
|
||||||
|
POST {{baseUrl}}/v1/pubsub/publish
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: {{contentType}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"topic": "network-updates",
|
||||||
|
"data_base64": "U29uciBOZXR3b3JrIGlzIGF3ZXNvbWUh"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
### 5. SERVERLESS FUNCTIONS
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
# @name ListFunctions
|
||||||
|
# Lists all deployed serverless functions
|
||||||
|
GET {{baseUrl}}/v1/functions
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name InvokeFunction
|
||||||
|
# Invokes a deployed function by name
|
||||||
|
# Path: /v1/invoke/{namespace}/{functionName}
|
||||||
|
POST {{baseUrl}}/v1/invoke/default/hello
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
|
Content-Type: {{contentType}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "Developer"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# @name WhoAmI
|
||||||
|
# Validates the API Key and returns caller identity
|
||||||
|
GET {{baseUrl}}/v1/auth/whoami
|
||||||
|
X-API-Key: {{apiKey}}
|
||||||
42
examples/functions/build.sh
Executable file
42
examples/functions/build.sh
Executable file
@ -0,0 +1,42 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Build all example functions to WASM using TinyGo
|
||||||
|
#
|
||||||
|
# Prerequisites:
|
||||||
|
# - TinyGo installed: https://tinygo.org/getting-started/install/
|
||||||
|
# - On macOS: brew install tinygo
|
||||||
|
#
|
||||||
|
# Usage: ./build.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_DIR="$SCRIPT_DIR/bin"
|
||||||
|
|
||||||
|
# Check if TinyGo is installed
|
||||||
|
if ! command -v tinygo &> /dev/null; then
|
||||||
|
echo "Error: TinyGo is not installed."
|
||||||
|
echo "Install it with: brew install tinygo (macOS) or see https://tinygo.org/getting-started/install/"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
echo "Building example functions to WASM..."
|
||||||
|
echo
|
||||||
|
|
||||||
|
# Build each function
|
||||||
|
for dir in "$SCRIPT_DIR"/*/; do
|
||||||
|
if [ -f "$dir/main.go" ]; then
|
||||||
|
name=$(basename "$dir")
|
||||||
|
echo "Building $name..."
|
||||||
|
cd "$dir"
|
||||||
|
tinygo build -o "$OUTPUT_DIR/$name.wasm" -target wasi main.go
|
||||||
|
echo " -> $OUTPUT_DIR/$name.wasm"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Done! WASM files are in $OUTPUT_DIR/"
|
||||||
|
ls -lh "$OUTPUT_DIR"/*.wasm 2>/dev/null || echo "No WASM files built."
|
||||||
|
|
||||||
66
examples/functions/counter/main.go
Normal file
66
examples/functions/counter/main.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Example: Counter function with Olric cache
|
||||||
|
// This function demonstrates using the distributed cache to maintain state.
|
||||||
|
// Compile with: tinygo build -o counter.wasm -target wasi main.go
|
||||||
|
//
|
||||||
|
// Note: This example shows the CONCEPT. Actual host function integration
|
||||||
|
// requires the host function bindings to be exposed to the WASM module.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Read input from stdin
|
||||||
|
var input []byte
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
n, err := os.Stdin.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
input = append(input, buf[:n]...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse input
|
||||||
|
var payload struct {
|
||||||
|
Action string `json:"action"` // "increment", "decrement", "get", "reset"
|
||||||
|
CounterID string `json:"counter_id"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, &payload); err != nil {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"error": "Invalid JSON input",
|
||||||
|
}
|
||||||
|
output, _ := json.Marshal(response)
|
||||||
|
os.Stdout.Write(output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.CounterID == "" {
|
||||||
|
payload.CounterID = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: In the real implementation, this would use host functions:
|
||||||
|
// - cache_get(key) to read the counter
|
||||||
|
// - cache_put(key, value, ttl) to write the counter
|
||||||
|
//
|
||||||
|
// For this example, we just simulate the logic:
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"counter_id": payload.CounterID,
|
||||||
|
"action": payload.Action,
|
||||||
|
"message": "Counter operations require cache host functions",
|
||||||
|
"example": map[string]interface{}{
|
||||||
|
"increment": "cache_put('counter:' + counter_id, current + 1)",
|
||||||
|
"decrement": "cache_put('counter:' + counter_id, current - 1)",
|
||||||
|
"get": "cache_get('counter:' + counter_id)",
|
||||||
|
"reset": "cache_put('counter:' + counter_id, 0)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output, _ := json.Marshal(response)
|
||||||
|
os.Stdout.Write(output)
|
||||||
|
}
|
||||||
|
|
||||||
50
examples/functions/echo/main.go
Normal file
50
examples/functions/echo/main.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
// Example: Echo function
|
||||||
|
// This is a simple serverless function that echoes back the input.
|
||||||
|
// Compile with: tinygo build -o echo.wasm -target wasi main.go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Input is read from stdin, output is written to stdout.
|
||||||
|
// The Orama serverless engine passes the invocation payload via stdin
|
||||||
|
// and expects the response on stdout.
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Read all input from stdin
|
||||||
|
var input []byte
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
n, err := os.Stdin.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
input = append(input, buf[:n]...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse input as JSON (optional - could also just echo raw bytes)
|
||||||
|
var payload map[string]interface{}
|
||||||
|
if err := json.Unmarshal(input, &payload); err != nil {
|
||||||
|
// Not JSON, just echo the raw input
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"echo": string(input),
|
||||||
|
}
|
||||||
|
output, _ := json.Marshal(response)
|
||||||
|
os.Stdout.Write(output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"echo": payload,
|
||||||
|
"message": "Echo function received your input!",
|
||||||
|
}
|
||||||
|
|
||||||
|
output, _ := json.Marshal(response)
|
||||||
|
os.Stdout.Write(output)
|
||||||
|
}
|
||||||
|
|
||||||
42
examples/functions/hello/main.go
Normal file
42
examples/functions/hello/main.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
// Example: Hello function
|
||||||
|
// This is a simple serverless function that returns a greeting.
|
||||||
|
// Compile with: tinygo build -o hello.wasm -target wasi main.go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Read input from stdin
|
||||||
|
var input []byte
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
n, err := os.Stdin.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
input = append(input, buf[:n]...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse input to get name
|
||||||
|
var payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, &payload); err != nil || payload.Name == "" {
|
||||||
|
payload.Name = "World"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create greeting response
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"greeting": "Hello, " + payload.Name + "!",
|
||||||
|
"message": "This is a serverless function running on Orama Network",
|
||||||
|
}
|
||||||
|
|
||||||
|
output, _ := json.Marshal(response)
|
||||||
|
os.Stdout.Write(output)
|
||||||
|
}
|
||||||
|
|
||||||
7
go.mod
7
go.mod
@ -1,6 +1,6 @@
|
|||||||
module github.com/DeBrosOfficial/network
|
module github.com/DeBrosOfficial/network
|
||||||
|
|
||||||
go 1.23.8
|
go 1.24.0
|
||||||
|
|
||||||
toolchain go1.24.1
|
toolchain go1.24.1
|
||||||
|
|
||||||
@ -10,6 +10,7 @@ require (
|
|||||||
github.com/charmbracelet/lipgloss v1.0.0
|
github.com/charmbracelet/lipgloss v1.0.0
|
||||||
github.com/ethereum/go-ethereum v1.13.14
|
github.com/ethereum/go-ethereum v1.13.14
|
||||||
github.com/go-chi/chi/v5 v5.2.3
|
github.com/go-chi/chi/v5 v5.2.3
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/libp2p/go-libp2p v0.41.1
|
github.com/libp2p/go-libp2p v0.41.1
|
||||||
github.com/libp2p/go-libp2p-pubsub v0.14.2
|
github.com/libp2p/go-libp2p-pubsub v0.14.2
|
||||||
@ -18,6 +19,7 @@ require (
|
|||||||
github.com/multiformats/go-multiaddr v0.15.0
|
github.com/multiformats/go-multiaddr v0.15.0
|
||||||
github.com/olric-data/olric v0.7.0
|
github.com/olric-data/olric v0.7.0
|
||||||
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
|
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
|
||||||
|
github.com/tetratelabs/wazero v1.11.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
golang.org/x/crypto v0.40.0
|
golang.org/x/crypto v0.40.0
|
||||||
golang.org/x/net v0.42.0
|
golang.org/x/net v0.42.0
|
||||||
@ -54,7 +56,6 @@ require (
|
|||||||
github.com/google/btree v1.1.3 // indirect
|
github.com/google/btree v1.1.3 // indirect
|
||||||
github.com/google/gopacket v1.1.19 // indirect
|
github.com/google/gopacket v1.1.19 // indirect
|
||||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect
|
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
|
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
|
||||||
github.com/hashicorp/go-metrics v0.5.4 // indirect
|
github.com/hashicorp/go-metrics v0.5.4 // indirect
|
||||||
@ -154,7 +155,7 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect
|
||||||
golang.org/x/mod v0.26.0 // indirect
|
golang.org/x/mod v0.26.0 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
golang.org/x/sync v0.16.0 // indirect
|
||||||
golang.org/x/sys v0.34.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.27.0 // indirect
|
golang.org/x/text v0.27.0 // indirect
|
||||||
golang.org/x/tools v0.35.0 // indirect
|
golang.org/x/tools v0.35.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.6 // indirect
|
google.golang.org/protobuf v1.36.6 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@ -487,6 +487,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
|||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||||
|
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
|
||||||
|
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
|
||||||
github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4=
|
github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4=
|
||||||
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
|
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
|
||||||
github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
|
github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
|
||||||
@ -627,8 +629,8 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
|
|||||||
243
migrations/004_serverless_functions.sql
Normal file
243
migrations/004_serverless_functions.sql
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
-- Orama Network - Serverless Functions Engine (Phase 4)
|
||||||
|
-- WASM-based serverless function execution with triggers, jobs, and secrets
|
||||||
|
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- FUNCTIONS TABLE
|
||||||
|
-- Core function registry with versioning support
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS functions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
namespace TEXT NOT NULL,
|
||||||
|
version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
wasm_cid TEXT NOT NULL,
|
||||||
|
source_cid TEXT,
|
||||||
|
memory_limit_mb INTEGER NOT NULL DEFAULT 64,
|
||||||
|
timeout_seconds INTEGER NOT NULL DEFAULT 30,
|
||||||
|
is_public BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
retry_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
retry_delay_seconds INTEGER NOT NULL DEFAULT 5,
|
||||||
|
dlq_topic TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
created_by TEXT NOT NULL,
|
||||||
|
UNIQUE(namespace, name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_functions_namespace ON functions(namespace);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_functions_name ON functions(namespace, name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_functions_status ON functions(status);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- FUNCTION ENVIRONMENT VARIABLES
|
||||||
|
-- Non-sensitive configuration per function
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_env_vars (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
key TEXT NOT NULL,
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE(function_id, key),
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_env_vars_function ON function_env_vars(function_id);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- FUNCTION SECRETS
|
||||||
|
-- Encrypted secrets per namespace (shared across functions in namespace)
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_secrets (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
namespace TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
encrypted_value BLOB NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE(namespace, name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_secrets_namespace ON function_secrets(namespace);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- CRON TRIGGERS
|
||||||
|
-- Scheduled function execution using cron expressions
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_cron_triggers (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
cron_expression TEXT NOT NULL,
|
||||||
|
next_run_at TIMESTAMP,
|
||||||
|
last_run_at TIMESTAMP,
|
||||||
|
last_status TEXT,
|
||||||
|
last_error TEXT,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_cron_triggers_function ON function_cron_triggers(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_cron_triggers_next_run ON function_cron_triggers(next_run_at)
|
||||||
|
WHERE enabled = TRUE;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- DATABASE TRIGGERS
|
||||||
|
-- Trigger functions on database changes (INSERT/UPDATE/DELETE)
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_db_triggers (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
table_name TEXT NOT NULL,
|
||||||
|
operation TEXT NOT NULL CHECK(operation IN ('INSERT', 'UPDATE', 'DELETE')),
|
||||||
|
condition TEXT,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_db_triggers_function ON function_db_triggers(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_db_triggers_table ON function_db_triggers(table_name, operation)
|
||||||
|
WHERE enabled = TRUE;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- PUBSUB TRIGGERS
|
||||||
|
-- Trigger functions on pubsub messages
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_pubsub_triggers (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function ON function_pubsub_triggers(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_topic ON function_pubsub_triggers(topic)
|
||||||
|
WHERE enabled = TRUE;
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- ONE-TIME TIMERS
|
||||||
|
-- Schedule functions to run once at a specific time
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_timers (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
run_at TIMESTAMP NOT NULL,
|
||||||
|
payload TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending' CHECK(status IN ('pending', 'running', 'completed', 'failed')),
|
||||||
|
error TEXT,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
completed_at TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_timers_function ON function_timers(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_timers_pending ON function_timers(run_at)
|
||||||
|
WHERE status = 'pending';
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- BACKGROUND JOBS
|
||||||
|
-- Long-running async function execution
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_jobs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
payload TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending' CHECK(status IN ('pending', 'running', 'completed', 'failed', 'cancelled')),
|
||||||
|
progress INTEGER NOT NULL DEFAULT 0 CHECK(progress >= 0 AND progress <= 100),
|
||||||
|
result TEXT,
|
||||||
|
error TEXT,
|
||||||
|
started_at TIMESTAMP,
|
||||||
|
completed_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_jobs_function ON function_jobs(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_jobs_status ON function_jobs(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_jobs_pending ON function_jobs(created_at)
|
||||||
|
WHERE status = 'pending';
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- INVOCATION LOGS
|
||||||
|
-- Record of all function invocations for debugging and metrics
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_invocations (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
request_id TEXT NOT NULL,
|
||||||
|
trigger_type TEXT NOT NULL,
|
||||||
|
caller_wallet TEXT,
|
||||||
|
input_size INTEGER,
|
||||||
|
output_size INTEGER,
|
||||||
|
started_at TIMESTAMP NOT NULL,
|
||||||
|
completed_at TIMESTAMP,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
status TEXT CHECK(status IN ('success', 'error', 'timeout')),
|
||||||
|
error_message TEXT,
|
||||||
|
memory_used_mb REAL,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_invocations_function ON function_invocations(function_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_invocations_request ON function_invocations(request_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_invocations_time ON function_invocations(started_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_invocations_status ON function_invocations(function_id, status);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- FUNCTION LOGS
|
||||||
|
-- Captured log output from function execution
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_logs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
function_id TEXT NOT NULL,
|
||||||
|
invocation_id TEXT NOT NULL,
|
||||||
|
level TEXT NOT NULL CHECK(level IN ('info', 'warn', 'error', 'debug')),
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (function_id) REFERENCES functions(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (invocation_id) REFERENCES function_invocations(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_logs_invocation ON function_logs(invocation_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_logs_function ON function_logs(function_id, timestamp);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- DB CHANGE TRACKING
|
||||||
|
-- Track last processed row for database triggers (CDC-like)
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_db_change_tracking (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
trigger_id TEXT NOT NULL UNIQUE,
|
||||||
|
last_row_id INTEGER,
|
||||||
|
last_updated_at TIMESTAMP,
|
||||||
|
last_check_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (trigger_id) REFERENCES function_db_triggers(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- RATE LIMITING
|
||||||
|
-- Track request counts for rate limiting
|
||||||
|
-- =============================================================================
|
||||||
|
CREATE TABLE IF NOT EXISTS function_rate_limits (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
window_key TEXT NOT NULL,
|
||||||
|
count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
window_start TIMESTAMP NOT NULL,
|
||||||
|
UNIQUE(window_key, window_start)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_function_rate_limits_window ON function_rate_limits(window_key, window_start);
|
||||||
|
|
||||||
|
-- =============================================================================
|
||||||
|
-- MIGRATION VERSION TRACKING
|
||||||
|
-- =============================================================================
|
||||||
|
INSERT OR IGNORE INTO schema_migrations(version) VALUES (4);
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
|
||||||
@ -2,8 +2,6 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -11,269 +9,12 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/config"
|
"github.com/DeBrosOfficial/network/pkg/cli/utils"
|
||||||
"github.com/DeBrosOfficial/network/pkg/environments/production"
|
"github.com/DeBrosOfficial/network/pkg/environments/production"
|
||||||
"github.com/DeBrosOfficial/network/pkg/installer"
|
|
||||||
"github.com/multiformats/go-multiaddr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPFSPeerInfo holds IPFS peer information for configuring Peering.Peers
|
|
||||||
type IPFSPeerInfo struct {
|
|
||||||
PeerID string
|
|
||||||
Addrs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPFSClusterPeerInfo contains IPFS Cluster peer information for cluster discovery
|
|
||||||
type IPFSClusterPeerInfo struct {
|
|
||||||
PeerID string
|
|
||||||
Addrs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateSwarmKey validates that a swarm key is 64 hex characters
|
|
||||||
func validateSwarmKey(key string) error {
|
|
||||||
key = strings.TrimSpace(key)
|
|
||||||
if len(key) != 64 {
|
|
||||||
return fmt.Errorf("swarm key must be 64 hex characters (32 bytes), got %d", len(key))
|
|
||||||
}
|
|
||||||
if _, err := hex.DecodeString(key); err != nil {
|
|
||||||
return fmt.Errorf("swarm key must be valid hexadecimal: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runInteractiveInstaller launches the TUI installer
|
|
||||||
func runInteractiveInstaller() {
|
|
||||||
config, err := installer.Run()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert TUI config to install args and run installation
|
|
||||||
var args []string
|
|
||||||
args = append(args, "--vps-ip", config.VpsIP)
|
|
||||||
args = append(args, "--domain", config.Domain)
|
|
||||||
args = append(args, "--branch", config.Branch)
|
|
||||||
|
|
||||||
if config.NoPull {
|
|
||||||
args = append(args, "--no-pull")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !config.IsFirstNode {
|
|
||||||
if config.JoinAddress != "" {
|
|
||||||
args = append(args, "--join", config.JoinAddress)
|
|
||||||
}
|
|
||||||
if config.ClusterSecret != "" {
|
|
||||||
args = append(args, "--cluster-secret", config.ClusterSecret)
|
|
||||||
}
|
|
||||||
if config.SwarmKeyHex != "" {
|
|
||||||
args = append(args, "--swarm-key", config.SwarmKeyHex)
|
|
||||||
}
|
|
||||||
if len(config.Peers) > 0 {
|
|
||||||
args = append(args, "--peers", strings.Join(config.Peers, ","))
|
|
||||||
}
|
|
||||||
// Pass IPFS peer info for Peering.Peers configuration
|
|
||||||
if config.IPFSPeerID != "" {
|
|
||||||
args = append(args, "--ipfs-peer", config.IPFSPeerID)
|
|
||||||
}
|
|
||||||
if len(config.IPFSSwarmAddrs) > 0 {
|
|
||||||
args = append(args, "--ipfs-addrs", strings.Join(config.IPFSSwarmAddrs, ","))
|
|
||||||
}
|
|
||||||
// Pass IPFS Cluster peer info for cluster peer_addresses configuration
|
|
||||||
if config.IPFSClusterPeerID != "" {
|
|
||||||
args = append(args, "--ipfs-cluster-peer", config.IPFSClusterPeerID)
|
|
||||||
}
|
|
||||||
if len(config.IPFSClusterAddrs) > 0 {
|
|
||||||
args = append(args, "--ipfs-cluster-addrs", strings.Join(config.IPFSClusterAddrs, ","))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-run with collected args
|
|
||||||
handleProdInstall(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// showDryRunSummary displays what would be done during installation without making changes
|
|
||||||
func showDryRunSummary(vpsIP, domain, branch string, peers []string, joinAddress string, isFirstNode bool, oramaDir string) {
|
|
||||||
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n")
|
|
||||||
fmt.Printf("DRY RUN - No changes will be made\n")
|
|
||||||
fmt.Printf(strings.Repeat("=", 70) + "\n\n")
|
|
||||||
|
|
||||||
fmt.Printf("📋 Installation Summary:\n")
|
|
||||||
fmt.Printf(" VPS IP: %s\n", vpsIP)
|
|
||||||
fmt.Printf(" Domain: %s\n", domain)
|
|
||||||
fmt.Printf(" Branch: %s\n", branch)
|
|
||||||
if isFirstNode {
|
|
||||||
fmt.Printf(" Node Type: First node (creates new cluster)\n")
|
|
||||||
} else {
|
|
||||||
fmt.Printf(" Node Type: Joining existing cluster\n")
|
|
||||||
if joinAddress != "" {
|
|
||||||
fmt.Printf(" Join Address: %s\n", joinAddress)
|
|
||||||
}
|
|
||||||
if len(peers) > 0 {
|
|
||||||
fmt.Printf(" Peers: %d peer(s)\n", len(peers))
|
|
||||||
for _, peer := range peers {
|
|
||||||
fmt.Printf(" - %s\n", peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("\n📁 Directories that would be created:\n")
|
|
||||||
fmt.Printf(" %s/configs/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/secrets/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/data/ipfs/repo/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/data/ipfs-cluster/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/data/rqlite/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/logs/\n", oramaDir)
|
|
||||||
fmt.Printf(" %s/tls-cache/\n", oramaDir)
|
|
||||||
|
|
||||||
fmt.Printf("\n🔧 Binaries that would be installed:\n")
|
|
||||||
fmt.Printf(" - Go (if not present)\n")
|
|
||||||
fmt.Printf(" - RQLite 8.43.0\n")
|
|
||||||
fmt.Printf(" - IPFS/Kubo 0.38.2\n")
|
|
||||||
fmt.Printf(" - IPFS Cluster (latest)\n")
|
|
||||||
fmt.Printf(" - Olric 0.7.0\n")
|
|
||||||
fmt.Printf(" - anyone-client (npm)\n")
|
|
||||||
fmt.Printf(" - DeBros binaries (built from %s branch)\n", branch)
|
|
||||||
|
|
||||||
fmt.Printf("\n🔐 Secrets that would be generated:\n")
|
|
||||||
fmt.Printf(" - Cluster secret (64-hex)\n")
|
|
||||||
fmt.Printf(" - IPFS swarm key\n")
|
|
||||||
fmt.Printf(" - Node identity (Ed25519 keypair)\n")
|
|
||||||
|
|
||||||
fmt.Printf("\n📝 Configuration files that would be created:\n")
|
|
||||||
fmt.Printf(" - %s/configs/node.yaml\n", oramaDir)
|
|
||||||
fmt.Printf(" - %s/configs/olric/config.yaml\n", oramaDir)
|
|
||||||
|
|
||||||
fmt.Printf("\n⚙️ Systemd services that would be created:\n")
|
|
||||||
fmt.Printf(" - debros-ipfs.service\n")
|
|
||||||
fmt.Printf(" - debros-ipfs-cluster.service\n")
|
|
||||||
fmt.Printf(" - debros-olric.service\n")
|
|
||||||
fmt.Printf(" - debros-node.service (includes embedded gateway + RQLite)\n")
|
|
||||||
fmt.Printf(" - debros-anyone-client.service\n")
|
|
||||||
|
|
||||||
fmt.Printf("\n🌐 Ports that would be used:\n")
|
|
||||||
fmt.Printf(" External (must be open in firewall):\n")
|
|
||||||
fmt.Printf(" - 80 (HTTP for ACME/Let's Encrypt)\n")
|
|
||||||
fmt.Printf(" - 443 (HTTPS gateway)\n")
|
|
||||||
fmt.Printf(" - 4101 (IPFS swarm)\n")
|
|
||||||
fmt.Printf(" - 7001 (RQLite Raft)\n")
|
|
||||||
fmt.Printf(" Internal (localhost only):\n")
|
|
||||||
fmt.Printf(" - 4501 (IPFS API)\n")
|
|
||||||
fmt.Printf(" - 5001 (RQLite HTTP)\n")
|
|
||||||
fmt.Printf(" - 6001 (Unified gateway)\n")
|
|
||||||
fmt.Printf(" - 8080 (IPFS gateway)\n")
|
|
||||||
fmt.Printf(" - 9050 (Anyone SOCKS5)\n")
|
|
||||||
fmt.Printf(" - 9094 (IPFS Cluster API)\n")
|
|
||||||
fmt.Printf(" - 3320/3322 (Olric)\n")
|
|
||||||
|
|
||||||
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n")
|
|
||||||
fmt.Printf("To proceed with installation, run without --dry-run\n")
|
|
||||||
fmt.Printf(strings.Repeat("=", 70) + "\n\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateGeneratedConfig loads and validates the generated node configuration
|
|
||||||
func validateGeneratedConfig(oramaDir string) error {
|
|
||||||
configPath := filepath.Join(oramaDir, "configs", "node.yaml")
|
|
||||||
|
|
||||||
// Check if config file exists
|
|
||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("configuration file not found at %s", configPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the config file
|
|
||||||
file, err := os.Open(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open config file: %w", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
var cfg config.Config
|
|
||||||
if err := config.DecodeStrict(file, &cfg); err != nil {
|
|
||||||
return fmt.Errorf("failed to parse config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the configuration
|
|
||||||
if errs := cfg.Validate(); len(errs) > 0 {
|
|
||||||
var errMsgs []string
|
|
||||||
for _, e := range errs {
|
|
||||||
errMsgs = append(errMsgs, e.Error())
|
|
||||||
}
|
|
||||||
return fmt.Errorf("configuration validation errors:\n - %s", strings.Join(errMsgs, "\n - "))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateDNSRecord validates that the domain points to the expected IP address
|
|
||||||
// Returns nil if DNS is valid, warning message if DNS doesn't match but continues,
|
|
||||||
// or error if DNS lookup fails completely
|
|
||||||
func validateDNSRecord(domain, expectedIP string) error {
|
|
||||||
if domain == "" {
|
|
||||||
return nil // No domain provided, skip validation
|
|
||||||
}
|
|
||||||
|
|
||||||
ips, err := net.LookupIP(domain)
|
|
||||||
if err != nil {
|
|
||||||
// DNS lookup failed - this is a warning, not a fatal error
|
|
||||||
// The user might be setting up DNS after installation
|
|
||||||
fmt.Printf(" ⚠️ DNS lookup failed for %s: %v\n", domain, err)
|
|
||||||
fmt.Printf(" Make sure DNS is configured before enabling HTTPS\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any resolved IP matches the expected IP
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ip.String() == expectedIP {
|
|
||||||
fmt.Printf(" ✓ DNS validated: %s → %s\n", domain, expectedIP)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNS doesn't point to expected IP - warn but continue
|
|
||||||
resolvedIPs := make([]string, len(ips))
|
|
||||||
for i, ip := range ips {
|
|
||||||
resolvedIPs[i] = ip.String()
|
|
||||||
}
|
|
||||||
fmt.Printf(" ⚠️ DNS mismatch: %s resolves to %v, expected %s\n", domain, resolvedIPs, expectedIP)
|
|
||||||
fmt.Printf(" HTTPS certificate generation may fail until DNS is updated\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalizePeers normalizes and validates peer multiaddrs
|
|
||||||
func normalizePeers(peersStr string) ([]string, error) {
|
|
||||||
if peersStr == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split by comma and trim whitespace
|
|
||||||
rawPeers := strings.Split(peersStr, ",")
|
|
||||||
peers := make([]string, 0, len(rawPeers))
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, peer := range rawPeers {
|
|
||||||
peer = strings.TrimSpace(peer)
|
|
||||||
if peer == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate multiaddr format
|
|
||||||
if _, err := multiaddr.NewMultiaddr(peer); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid multiaddr %q: %w", peer, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deduplicate
|
|
||||||
if !seen[peer] {
|
|
||||||
peers = append(peers, peer)
|
|
||||||
seen[peer] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleProdCommand handles production environment commands
|
// HandleProdCommand handles production environment commands
|
||||||
func HandleProdCommand(args []string) {
|
func HandleProdCommand(args []string) {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
@ -368,294 +109,6 @@ func showProdHelp() {
|
|||||||
fmt.Printf(" orama logs node --follow\n")
|
fmt.Printf(" orama logs node --follow\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleProdInstall(args []string) {
|
|
||||||
// Parse arguments using flag.FlagSet
|
|
||||||
fs := flag.NewFlagSet("install", flag.ContinueOnError)
|
|
||||||
fs.SetOutput(os.Stderr)
|
|
||||||
|
|
||||||
force := fs.Bool("force", false, "Reconfigure all settings")
|
|
||||||
skipResourceChecks := fs.Bool("ignore-resource-checks", false, "Skip disk/RAM/CPU prerequisite validation")
|
|
||||||
vpsIP := fs.String("vps-ip", "", "VPS public IP address")
|
|
||||||
domain := fs.String("domain", "", "Domain for this node (e.g., node-123.orama.network)")
|
|
||||||
peersStr := fs.String("peers", "", "Comma-separated peer multiaddrs to connect to")
|
|
||||||
joinAddress := fs.String("join", "", "RQLite join address (IP:port) to join existing cluster")
|
|
||||||
branch := fs.String("branch", "main", "Git branch to use (main or nightly)")
|
|
||||||
clusterSecret := fs.String("cluster-secret", "", "Hex-encoded 32-byte cluster secret (for joining existing cluster)")
|
|
||||||
swarmKey := fs.String("swarm-key", "", "64-hex IPFS swarm key (for joining existing private network)")
|
|
||||||
ipfsPeerID := fs.String("ipfs-peer", "", "IPFS peer ID to connect to (auto-discovered from peer domain)")
|
|
||||||
ipfsAddrs := fs.String("ipfs-addrs", "", "Comma-separated IPFS swarm addresses (auto-discovered from peer domain)")
|
|
||||||
ipfsClusterPeerID := fs.String("ipfs-cluster-peer", "", "IPFS Cluster peer ID to connect to (auto-discovered from peer domain)")
|
|
||||||
ipfsClusterAddrs := fs.String("ipfs-cluster-addrs", "", "Comma-separated IPFS Cluster addresses (auto-discovered from peer domain)")
|
|
||||||
interactive := fs.Bool("interactive", false, "Run interactive TUI installer")
|
|
||||||
dryRun := fs.Bool("dry-run", false, "Show what would be done without making changes")
|
|
||||||
noPull := fs.Bool("no-pull", false, "Skip git clone/pull, use existing /home/debros/src")
|
|
||||||
|
|
||||||
if err := fs.Parse(args); err != nil {
|
|
||||||
if err == flag.ErrHelp {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Failed to parse flags: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch TUI installer if --interactive flag or no required args provided
|
|
||||||
if *interactive || (*vpsIP == "" && len(args) == 0) {
|
|
||||||
runInteractiveInstaller()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate branch
|
|
||||||
if *branch != "main" && *branch != "nightly" {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Invalid branch: %s (must be 'main' or 'nightly')\n", *branch)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize and validate peers
|
|
||||||
peers, err := normalizePeers(*peersStr)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Invalid peers: %v\n", err)
|
|
||||||
fmt.Fprintf(os.Stderr, " Example: --peers /ip4/10.0.0.1/tcp/4001/p2p/Qm...,/ip4/10.0.0.2/tcp/4001/p2p/Qm...\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate setup requirements
|
|
||||||
if os.Geteuid() != 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Production install must be run as root (use sudo)\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate VPS IP is provided
|
|
||||||
if *vpsIP == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ --vps-ip is required\n")
|
|
||||||
fmt.Fprintf(os.Stderr, " Usage: sudo orama install --vps-ip <public_ip>\n")
|
|
||||||
fmt.Fprintf(os.Stderr, " Or run: sudo orama install --interactive\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if this is the first node (creates new cluster) or joining existing cluster
|
|
||||||
isFirstNode := len(peers) == 0 && *joinAddress == ""
|
|
||||||
if isFirstNode {
|
|
||||||
fmt.Printf("ℹ️ First node detected - will create new cluster\n")
|
|
||||||
} else {
|
|
||||||
fmt.Printf("ℹ️ Joining existing cluster\n")
|
|
||||||
// Cluster secret is required when joining
|
|
||||||
if *clusterSecret == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ --cluster-secret is required when joining an existing cluster\n")
|
|
||||||
fmt.Fprintf(os.Stderr, " Provide the 64-hex secret from an existing node (cat ~/.orama/secrets/cluster-secret)\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
if err := production.ValidateClusterSecret(*clusterSecret); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Invalid --cluster-secret: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
// Swarm key is required when joining
|
|
||||||
if *swarmKey == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ --swarm-key is required when joining an existing cluster\n")
|
|
||||||
fmt.Fprintf(os.Stderr, " Provide the 64-hex swarm key from an existing node:\n")
|
|
||||||
fmt.Fprintf(os.Stderr, " cat ~/.orama/secrets/swarm.key | tail -1\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
if err := validateSwarmKey(*swarmKey); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Invalid --swarm-key: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oramaHome := "/home/debros"
|
|
||||||
oramaDir := oramaHome + "/.orama"
|
|
||||||
|
|
||||||
// If cluster secret was provided, save it to secrets directory before setup
|
|
||||||
if *clusterSecret != "" {
|
|
||||||
secretsDir := filepath.Join(oramaDir, "secrets")
|
|
||||||
if err := os.MkdirAll(secretsDir, 0755); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
secretPath := filepath.Join(secretsDir, "cluster-secret")
|
|
||||||
if err := os.WriteFile(secretPath, []byte(*clusterSecret), 0600); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Failed to save cluster secret: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
fmt.Printf(" ✓ Cluster secret saved\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If swarm key was provided, save it to secrets directory in full format
|
|
||||||
if *swarmKey != "" {
|
|
||||||
secretsDir := filepath.Join(oramaDir, "secrets")
|
|
||||||
if err := os.MkdirAll(secretsDir, 0755); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
// Convert 64-hex key to full swarm.key format
|
|
||||||
swarmKeyContent := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base16/\n%s\n", strings.ToUpper(*swarmKey))
|
|
||||||
swarmKeyPath := filepath.Join(secretsDir, "swarm.key")
|
|
||||||
if err := os.WriteFile(swarmKeyPath, []byte(swarmKeyContent), 0600); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Failed to save swarm key: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
fmt.Printf(" ✓ Swarm key saved\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store IPFS peer info for later use in IPFS configuration
|
|
||||||
var ipfsPeerInfo *IPFSPeerInfo
|
|
||||||
if *ipfsPeerID != "" && *ipfsAddrs != "" {
|
|
||||||
ipfsPeerInfo = &IPFSPeerInfo{
|
|
||||||
PeerID: *ipfsPeerID,
|
|
||||||
Addrs: strings.Split(*ipfsAddrs, ","),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store IPFS Cluster peer info for cluster peer discovery
|
|
||||||
var ipfsClusterPeerInfo *IPFSClusterPeerInfo
|
|
||||||
if *ipfsClusterPeerID != "" {
|
|
||||||
var addrs []string
|
|
||||||
if *ipfsClusterAddrs != "" {
|
|
||||||
addrs = strings.Split(*ipfsClusterAddrs, ",")
|
|
||||||
}
|
|
||||||
ipfsClusterPeerInfo = &IPFSClusterPeerInfo{
|
|
||||||
PeerID: *ipfsClusterPeerID,
|
|
||||||
Addrs: addrs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setup := production.NewProductionSetup(oramaHome, os.Stdout, *force, *branch, *noPull, *skipResourceChecks)
|
|
||||||
|
|
||||||
// Inform user if skipping git pull
|
|
||||||
if *noPull {
|
|
||||||
fmt.Printf(" ⚠️ --no-pull flag enabled: Skipping git clone/pull\n")
|
|
||||||
fmt.Printf(" Using existing repository at /home/debros/src\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check port availability before proceeding
|
|
||||||
if err := ensurePortsAvailable("install", defaultPorts()); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate DNS if domain is provided
|
|
||||||
if *domain != "" {
|
|
||||||
fmt.Printf("\n🌐 Pre-flight DNS validation...\n")
|
|
||||||
validateDNSRecord(*domain, *vpsIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dry-run mode: show what would be done and exit
|
|
||||||
if *dryRun {
|
|
||||||
showDryRunSummary(*vpsIP, *domain, *branch, peers, *joinAddress, isFirstNode, oramaDir)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save branch preference for future upgrades
|
|
||||||
if err := production.SaveBranchPreference(oramaDir, *branch); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "⚠️ Warning: Failed to save branch preference: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 1: Check prerequisites
|
|
||||||
fmt.Printf("\n📋 Phase 1: Checking prerequisites...\n")
|
|
||||||
if err := setup.Phase1CheckPrerequisites(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Prerequisites check failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 2: Provision environment
|
|
||||||
fmt.Printf("\n🛠️ Phase 2: Provisioning environment...\n")
|
|
||||||
if err := setup.Phase2ProvisionEnvironment(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Environment provisioning failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 2b: Install binaries
|
|
||||||
fmt.Printf("\nPhase 2b: Installing binaries...\n")
|
|
||||||
if err := setup.Phase2bInstallBinaries(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Binary installation failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 3: Generate secrets FIRST (before service initialization)
|
|
||||||
// This ensures cluster secret and swarm key exist before repos are seeded
|
|
||||||
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
|
|
||||||
if err := setup.Phase3GenerateSecrets(); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Secret generation failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 4: Generate configs (BEFORE service initialization)
|
|
||||||
// This ensures node.yaml exists before services try to access it
|
|
||||||
fmt.Printf("\n⚙️ Phase 4: Generating configurations...\n")
|
|
||||||
enableHTTPS := *domain != ""
|
|
||||||
if err := setup.Phase4GenerateConfigs(peers, *vpsIP, enableHTTPS, *domain, *joinAddress); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Configuration generation failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate generated configuration
|
|
||||||
fmt.Printf(" Validating generated configuration...\n")
|
|
||||||
if err := validateGeneratedConfig(oramaDir); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Configuration validation failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
fmt.Printf(" ✓ Configuration validated\n")
|
|
||||||
|
|
||||||
// Phase 2c: Initialize services (after config is in place)
|
|
||||||
fmt.Printf("\nPhase 2c: Initializing services...\n")
|
|
||||||
var prodIPFSPeer *production.IPFSPeerInfo
|
|
||||||
if ipfsPeerInfo != nil {
|
|
||||||
prodIPFSPeer = &production.IPFSPeerInfo{
|
|
||||||
PeerID: ipfsPeerInfo.PeerID,
|
|
||||||
Addrs: ipfsPeerInfo.Addrs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var prodIPFSClusterPeer *production.IPFSClusterPeerInfo
|
|
||||||
if ipfsClusterPeerInfo != nil {
|
|
||||||
prodIPFSClusterPeer = &production.IPFSClusterPeerInfo{
|
|
||||||
PeerID: ipfsClusterPeerInfo.PeerID,
|
|
||||||
Addrs: ipfsClusterPeerInfo.Addrs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := setup.Phase2cInitializeServices(peers, *vpsIP, prodIPFSPeer, prodIPFSClusterPeer); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Service initialization failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Phase 5: Create systemd services
|
|
||||||
fmt.Printf("\n🔧 Phase 5: Creating systemd services...\n")
|
|
||||||
if err := setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "❌ Service creation failed: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log completion with actual peer ID
|
|
||||||
setup.LogSetupComplete(setup.NodePeerID)
|
|
||||||
fmt.Printf("✅ Production installation complete!\n\n")
|
|
||||||
|
|
||||||
// For first node, print important secrets and identifiers
|
|
||||||
if isFirstNode {
|
|
||||||
fmt.Printf("📋 Save these for joining future nodes:\n\n")
|
|
||||||
|
|
||||||
// Print cluster secret
|
|
||||||
clusterSecretPath := filepath.Join(oramaDir, "secrets", "cluster-secret")
|
|
||||||
if clusterSecretData, err := os.ReadFile(clusterSecretPath); err == nil {
|
|
||||||
fmt.Printf(" Cluster Secret (--cluster-secret):\n")
|
|
||||||
fmt.Printf(" %s\n\n", string(clusterSecretData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print swarm key
|
|
||||||
swarmKeyPath := filepath.Join(oramaDir, "secrets", "swarm.key")
|
|
||||||
if swarmKeyData, err := os.ReadFile(swarmKeyPath); err == nil {
|
|
||||||
swarmKeyContent := strings.TrimSpace(string(swarmKeyData))
|
|
||||||
lines := strings.Split(swarmKeyContent, "\n")
|
|
||||||
if len(lines) >= 3 {
|
|
||||||
// Extract just the hex part (last line)
|
|
||||||
fmt.Printf(" IPFS Swarm Key (--swarm-key, last line only):\n")
|
|
||||||
fmt.Printf(" %s\n\n", lines[len(lines)-1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print peer ID
|
|
||||||
fmt.Printf(" Node Peer ID:\n")
|
|
||||||
fmt.Printf(" %s\n\n", setup.NodePeerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleProdUpgrade(args []string) {
|
func handleProdUpgrade(args []string) {
|
||||||
// Parse arguments using flag.FlagSet
|
// Parse arguments using flag.FlagSet
|
||||||
fs := flag.NewFlagSet("upgrade", flag.ContinueOnError)
|
fs := flag.NewFlagSet("upgrade", flag.ContinueOnError)
|
||||||
@ -767,7 +220,7 @@ func handleProdUpgrade(args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check port availability after stopping services
|
// Check port availability after stopping services
|
||||||
if err := ensurePortsAvailable("prod upgrade", defaultPorts()); err != nil {
|
if err := utils.EnsurePortsAvailable("prod upgrade", utils.DefaultPorts()); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@ -945,7 +398,7 @@ func handleProdUpgrade(args []string) {
|
|||||||
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Failed to reload systemd daemon: %v\n", err)
|
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Failed to reload systemd daemon: %v\n", err)
|
||||||
}
|
}
|
||||||
// Restart services to apply changes - use getProductionServices to only restart existing services
|
// Restart services to apply changes - use getProductionServices to only restart existing services
|
||||||
services := getProductionServices()
|
services := utils.GetProductionServices()
|
||||||
if len(services) == 0 {
|
if len(services) == 0 {
|
||||||
fmt.Printf(" ⚠️ No services found to restart\n")
|
fmt.Printf(" ⚠️ No services found to restart\n")
|
||||||
} else {
|
} else {
|
||||||
@ -991,10 +444,9 @@ func handleProdStatus() {
|
|||||||
fmt.Printf("Services:\n")
|
fmt.Printf("Services:\n")
|
||||||
found := false
|
found := false
|
||||||
for _, svc := range serviceNames {
|
for _, svc := range serviceNames {
|
||||||
cmd := exec.Command("systemctl", "is-active", "--quiet", svc)
|
active, _ := utils.IsServiceActive(svc)
|
||||||
err := cmd.Run()
|
|
||||||
status := "❌ Inactive"
|
status := "❌ Inactive"
|
||||||
if err == nil {
|
if active {
|
||||||
status = "✅ Active"
|
status = "✅ Active"
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
@ -1016,52 +468,6 @@ func handleProdStatus() {
|
|||||||
fmt.Printf("\nView logs with: dbn prod logs <service>\n")
|
fmt.Printf("\nView logs with: dbn prod logs <service>\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveServiceName resolves service aliases to actual systemd service names
|
|
||||||
func resolveServiceName(alias string) ([]string, error) {
|
|
||||||
// Service alias mapping (unified - no bootstrap/node distinction)
|
|
||||||
aliases := map[string][]string{
|
|
||||||
"node": {"debros-node"},
|
|
||||||
"ipfs": {"debros-ipfs"},
|
|
||||||
"cluster": {"debros-ipfs-cluster"},
|
|
||||||
"ipfs-cluster": {"debros-ipfs-cluster"},
|
|
||||||
"gateway": {"debros-gateway"},
|
|
||||||
"olric": {"debros-olric"},
|
|
||||||
"rqlite": {"debros-node"}, // RQLite logs are in node logs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's an alias
|
|
||||||
if serviceNames, ok := aliases[strings.ToLower(alias)]; ok {
|
|
||||||
// Filter to only existing services
|
|
||||||
var existing []string
|
|
||||||
for _, svc := range serviceNames {
|
|
||||||
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
|
|
||||||
if _, err := os.Stat(unitPath); err == nil {
|
|
||||||
existing = append(existing, svc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(existing) == 0 {
|
|
||||||
return nil, fmt.Errorf("no services found for alias %q", alias)
|
|
||||||
}
|
|
||||||
return existing, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's already a full service name
|
|
||||||
unitPath := filepath.Join("/etc/systemd/system", alias+".service")
|
|
||||||
if _, err := os.Stat(unitPath); err == nil {
|
|
||||||
return []string{alias}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try without .service suffix
|
|
||||||
if !strings.HasSuffix(alias, ".service") {
|
|
||||||
unitPath = filepath.Join("/etc/systemd/system", alias+".service")
|
|
||||||
if _, err := os.Stat(unitPath); err == nil {
|
|
||||||
return []string{alias}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("service %q not found. Use: node, ipfs, cluster, gateway, olric, or full service name", alias)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleProdLogs(args []string) {
|
func handleProdLogs(args []string) {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
fmt.Fprintf(os.Stderr, "Usage: dbn prod logs <service> [--follow]\n")
|
fmt.Fprintf(os.Stderr, "Usage: dbn prod logs <service> [--follow]\n")
|
||||||
@ -1079,7 +485,7 @@ func handleProdLogs(args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve service alias to actual service names
|
// Resolve service alias to actual service names
|
||||||
serviceNames, err := resolveServiceName(serviceAlias)
|
serviceNames, err := utils.ResolveServiceName(serviceAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
fmt.Fprintf(os.Stderr, "\nAvailable service aliases: node, ipfs, cluster, gateway, olric\n")
|
fmt.Fprintf(os.Stderr, "\nAvailable service aliases: node, ipfs, cluster, gateway, olric\n")
|
||||||
@ -1109,7 +515,7 @@ func handleProdLogs(args []string) {
|
|||||||
} else {
|
} else {
|
||||||
for i, svc := range serviceNames {
|
for i, svc := range serviceNames {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
fmt.Printf("\n" + strings.Repeat("=", 70) + "\n\n")
|
fmt.Print("\n" + strings.Repeat("=", 70) + "\n\n")
|
||||||
}
|
}
|
||||||
fmt.Printf("📋 Logs for %s:\n\n", svc)
|
fmt.Printf("📋 Logs for %s:\n\n", svc)
|
||||||
cmd := exec.Command("journalctl", "-u", svc, "-n", "50")
|
cmd := exec.Command("journalctl", "-u", svc, "-n", "50")
|
||||||
@ -1138,145 +544,6 @@ func handleProdLogs(args []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// errServiceNotFound marks units that systemd does not know about.
|
|
||||||
var errServiceNotFound = errors.New("service not found")
|
|
||||||
|
|
||||||
type portSpec struct {
|
|
||||||
Name string
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
var servicePorts = map[string][]portSpec{
|
|
||||||
"debros-gateway": {{"Gateway API", 6001}},
|
|
||||||
"debros-olric": {{"Olric HTTP", 3320}, {"Olric Memberlist", 3322}},
|
|
||||||
"debros-node": {{"RQLite HTTP", 5001}, {"RQLite Raft", 7001}},
|
|
||||||
"debros-ipfs": {{"IPFS API", 4501}, {"IPFS Gateway", 8080}, {"IPFS Swarm", 4101}},
|
|
||||||
"debros-ipfs-cluster": {{"IPFS Cluster API", 9094}},
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultPorts is used for fresh installs/upgrades before unit files exist.
|
|
||||||
func defaultPorts() []portSpec {
|
|
||||||
return []portSpec{
|
|
||||||
{"IPFS Swarm", 4001},
|
|
||||||
{"IPFS API", 4501},
|
|
||||||
{"IPFS Gateway", 8080},
|
|
||||||
{"Gateway API", 6001},
|
|
||||||
{"RQLite HTTP", 5001},
|
|
||||||
{"RQLite Raft", 7001},
|
|
||||||
{"IPFS Cluster API", 9094},
|
|
||||||
{"Olric HTTP", 3320},
|
|
||||||
{"Olric Memberlist", 3322},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isServiceActive(service string) (bool, error) {
|
|
||||||
cmd := exec.Command("systemctl", "is-active", "--quiet", service)
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
switch exitErr.ExitCode() {
|
|
||||||
case 3:
|
|
||||||
return false, nil
|
|
||||||
case 4:
|
|
||||||
return false, errServiceNotFound
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isServiceEnabled(service string) (bool, error) {
|
|
||||||
cmd := exec.Command("systemctl", "is-enabled", "--quiet", service)
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
switch exitErr.ExitCode() {
|
|
||||||
case 1:
|
|
||||||
return false, nil // Service is disabled
|
|
||||||
case 4:
|
|
||||||
return false, errServiceNotFound
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectPortsForServices(services []string, skipActive bool) ([]portSpec, error) {
|
|
||||||
seen := make(map[int]portSpec)
|
|
||||||
for _, svc := range services {
|
|
||||||
if skipActive {
|
|
||||||
active, err := isServiceActive(svc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to check %s: %w", svc, err)
|
|
||||||
}
|
|
||||||
if active {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, spec := range servicePorts[svc] {
|
|
||||||
if _, ok := seen[spec.Port]; !ok {
|
|
||||||
seen[spec.Port] = spec
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ports := make([]portSpec, 0, len(seen))
|
|
||||||
for _, spec := range seen {
|
|
||||||
ports = append(ports, spec)
|
|
||||||
}
|
|
||||||
return ports, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensurePortsAvailable(action string, ports []portSpec) error {
|
|
||||||
for _, spec := range ports {
|
|
||||||
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", spec.Port))
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, syscall.EADDRINUSE) || strings.Contains(err.Error(), "address already in use") {
|
|
||||||
return fmt.Errorf("%s cannot continue: %s (port %d) is already in use", action, spec.Name, spec.Port)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("%s cannot continue: failed to inspect %s (port %d): %w", action, spec.Name, spec.Port, err)
|
|
||||||
}
|
|
||||||
_ = ln.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getProductionServices returns a list of all DeBros production service names that exist
|
|
||||||
func getProductionServices() []string {
|
|
||||||
// Unified service names (no bootstrap/node distinction)
|
|
||||||
allServices := []string{
|
|
||||||
"debros-gateway",
|
|
||||||
"debros-node",
|
|
||||||
"debros-olric",
|
|
||||||
"debros-ipfs-cluster",
|
|
||||||
"debros-ipfs",
|
|
||||||
"debros-anyone-client",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter to only existing services by checking if unit file exists
|
|
||||||
var existing []string
|
|
||||||
for _, svc := range allServices {
|
|
||||||
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
|
|
||||||
if _, err := os.Stat(unitPath); err == nil {
|
|
||||||
existing = append(existing, svc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return existing
|
|
||||||
}
|
|
||||||
|
|
||||||
func isServiceMasked(service string) (bool, error) {
|
|
||||||
cmd := exec.Command("systemctl", "is-enabled", service)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
outputStr := string(output)
|
|
||||||
if strings.Contains(outputStr, "masked") {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleProdStart() {
|
func handleProdStart() {
|
||||||
if os.Geteuid() != 0 {
|
if os.Geteuid() != 0 {
|
||||||
fmt.Fprintf(os.Stderr, "❌ Production commands must be run as root (use sudo)\n")
|
fmt.Fprintf(os.Stderr, "❌ Production commands must be run as root (use sudo)\n")
|
||||||
@ -1285,7 +552,7 @@ func handleProdStart() {
|
|||||||
|
|
||||||
fmt.Printf("Starting all DeBros production services...\n")
|
fmt.Printf("Starting all DeBros production services...\n")
|
||||||
|
|
||||||
services := getProductionServices()
|
services := utils.GetProductionServices()
|
||||||
if len(services) == 0 {
|
if len(services) == 0 {
|
||||||
fmt.Printf(" ⚠️ No DeBros services found\n")
|
fmt.Printf(" ⚠️ No DeBros services found\n")
|
||||||
return
|
return
|
||||||
@ -1301,7 +568,7 @@ func handleProdStart() {
|
|||||||
inactive := make([]string, 0, len(services))
|
inactive := make([]string, 0, len(services))
|
||||||
for _, svc := range services {
|
for _, svc := range services {
|
||||||
// Check if service is masked and unmask it
|
// Check if service is masked and unmask it
|
||||||
masked, err := isServiceMasked(svc)
|
masked, err := utils.IsServiceMasked(svc)
|
||||||
if err == nil && masked {
|
if err == nil && masked {
|
||||||
fmt.Printf(" ⚠️ %s is masked, unmasking...\n", svc)
|
fmt.Printf(" ⚠️ %s is masked, unmasking...\n", svc)
|
||||||
if err := exec.Command("systemctl", "unmask", svc).Run(); err != nil {
|
if err := exec.Command("systemctl", "unmask", svc).Run(); err != nil {
|
||||||
@ -1311,7 +578,7 @@ func handleProdStart() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
active, err := isServiceActive(svc)
|
active, err := utils.IsServiceActive(svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
||||||
continue
|
continue
|
||||||
@ -1319,7 +586,7 @@ func handleProdStart() {
|
|||||||
if active {
|
if active {
|
||||||
fmt.Printf(" ℹ️ %s already running\n", svc)
|
fmt.Printf(" ℹ️ %s already running\n", svc)
|
||||||
// Re-enable if disabled (in case it was stopped with 'dbn prod stop')
|
// Re-enable if disabled (in case it was stopped with 'dbn prod stop')
|
||||||
enabled, err := isServiceEnabled(svc)
|
enabled, err := utils.IsServiceEnabled(svc)
|
||||||
if err == nil && !enabled {
|
if err == nil && !enabled {
|
||||||
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
|
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
|
||||||
fmt.Printf(" ⚠️ Failed to re-enable %s: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Failed to re-enable %s: %v\n", svc, err)
|
||||||
@ -1338,12 +605,12 @@ func handleProdStart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check port availability for services we're about to start
|
// Check port availability for services we're about to start
|
||||||
ports, err := collectPortsForServices(inactive, false)
|
ports, err := utils.CollectPortsForServices(inactive, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if err := ensurePortsAvailable("prod start", ports); err != nil {
|
if err := utils.EnsurePortsAvailable("prod start", ports); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@ -1351,7 +618,7 @@ func handleProdStart() {
|
|||||||
// Enable and start inactive services
|
// Enable and start inactive services
|
||||||
for _, svc := range inactive {
|
for _, svc := range inactive {
|
||||||
// Re-enable the service first (in case it was disabled by 'dbn prod stop')
|
// Re-enable the service first (in case it was disabled by 'dbn prod stop')
|
||||||
enabled, err := isServiceEnabled(svc)
|
enabled, err := utils.IsServiceEnabled(svc)
|
||||||
if err == nil && !enabled {
|
if err == nil && !enabled {
|
||||||
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
|
if err := exec.Command("systemctl", "enable", svc).Run(); err != nil {
|
||||||
fmt.Printf(" ⚠️ Failed to enable %s: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Failed to enable %s: %v\n", svc, err)
|
||||||
@ -1385,7 +652,7 @@ func handleProdStop() {
|
|||||||
|
|
||||||
fmt.Printf("Stopping all DeBros production services...\n")
|
fmt.Printf("Stopping all DeBros production services...\n")
|
||||||
|
|
||||||
services := getProductionServices()
|
services := utils.GetProductionServices()
|
||||||
if len(services) == 0 {
|
if len(services) == 0 {
|
||||||
fmt.Printf(" ⚠️ No DeBros services found\n")
|
fmt.Printf(" ⚠️ No DeBros services found\n")
|
||||||
return
|
return
|
||||||
@ -1424,7 +691,7 @@ func handleProdStop() {
|
|||||||
|
|
||||||
hadError := false
|
hadError := false
|
||||||
for _, svc := range services {
|
for _, svc := range services {
|
||||||
active, err := isServiceActive(svc)
|
active, err := utils.IsServiceActive(svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
||||||
hadError = true
|
hadError = true
|
||||||
@ -1441,7 +708,7 @@ func handleProdStop() {
|
|||||||
} else {
|
} else {
|
||||||
// Wait and verify again
|
// Wait and verify again
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
if stillActive, _ := isServiceActive(svc); stillActive {
|
if stillActive, _ := utils.IsServiceActive(svc); stillActive {
|
||||||
fmt.Printf(" ❌ %s restarted itself (Restart=always)\n", svc)
|
fmt.Printf(" ❌ %s restarted itself (Restart=always)\n", svc)
|
||||||
hadError = true
|
hadError = true
|
||||||
} else {
|
} else {
|
||||||
@ -1451,7 +718,7 @@ func handleProdStop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Disable the service to prevent it from auto-starting on boot
|
// Disable the service to prevent it from auto-starting on boot
|
||||||
enabled, err := isServiceEnabled(svc)
|
enabled, err := utils.IsServiceEnabled(svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" ⚠️ Unable to check if %s is enabled: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Unable to check if %s is enabled: %v\n", svc, err)
|
||||||
// Continue anyway - try to disable
|
// Continue anyway - try to disable
|
||||||
@ -1486,7 +753,7 @@ func handleProdRestart() {
|
|||||||
|
|
||||||
fmt.Printf("Restarting all DeBros production services...\n")
|
fmt.Printf("Restarting all DeBros production services...\n")
|
||||||
|
|
||||||
services := getProductionServices()
|
services := utils.GetProductionServices()
|
||||||
if len(services) == 0 {
|
if len(services) == 0 {
|
||||||
fmt.Printf(" ⚠️ No DeBros services found\n")
|
fmt.Printf(" ⚠️ No DeBros services found\n")
|
||||||
return
|
return
|
||||||
@ -1495,7 +762,7 @@ func handleProdRestart() {
|
|||||||
// Stop all active services first
|
// Stop all active services first
|
||||||
fmt.Printf(" Stopping services...\n")
|
fmt.Printf(" Stopping services...\n")
|
||||||
for _, svc := range services {
|
for _, svc := range services {
|
||||||
active, err := isServiceActive(svc)
|
active, err := utils.IsServiceActive(svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
fmt.Printf(" ⚠️ Unable to check %s: %v\n", svc, err)
|
||||||
continue
|
continue
|
||||||
@ -1512,12 +779,12 @@ func handleProdRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check port availability before restarting
|
// Check port availability before restarting
|
||||||
ports, err := collectPortsForServices(services, false)
|
ports, err := utils.CollectPortsForServices(services, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if err := ensurePortsAvailable("prod restart", ports); err != nil {
|
if err := utils.EnsurePortsAvailable("prod restart", ports); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,8 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/cli/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestProdCommandFlagParsing verifies that prod command flags are parsed correctly
|
// TestProdCommandFlagParsing verifies that prod command flags are parsed correctly
|
||||||
@ -156,7 +158,7 @@ func TestNormalizePeers(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
peers, err := normalizePeers(tt.input)
|
peers, err := utils.NormalizePeers(tt.input)
|
||||||
|
|
||||||
if tt.expectError && err == nil {
|
if tt.expectError && err == nil {
|
||||||
t.Errorf("expected error but got none")
|
t.Errorf("expected error but got none")
|
||||||
|
|||||||
264
pkg/cli/prod_install.go
Normal file
264
pkg/cli/prod_install.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/cli/utils"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/environments/production"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleProdInstall(args []string) {
|
||||||
|
// Parse arguments using flag.FlagSet
|
||||||
|
fs := flag.NewFlagSet("install", flag.ContinueOnError)
|
||||||
|
fs.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
vpsIP := fs.String("vps-ip", "", "Public IP of this VPS (required)")
|
||||||
|
domain := fs.String("domain", "", "Domain name for HTTPS (optional, e.g. gateway.example.com)")
|
||||||
|
branch := fs.String("branch", "main", "Git branch to use (main or nightly)")
|
||||||
|
noPull := fs.Bool("no-pull", false, "Skip git clone/pull, use existing repository in /home/debros/src")
|
||||||
|
force := fs.Bool("force", false, "Force reconfiguration even if already installed")
|
||||||
|
dryRun := fs.Bool("dry-run", false, "Show what would be done without making changes")
|
||||||
|
skipResourceChecks := fs.Bool("skip-checks", false, "Skip minimum resource checks (RAM/CPU)")
|
||||||
|
|
||||||
|
// Cluster join flags
|
||||||
|
joinAddress := fs.String("join", "", "Join an existing cluster (e.g. 1.2.3.4:7001)")
|
||||||
|
clusterSecret := fs.String("cluster-secret", "", "Cluster secret for IPFS Cluster (required if joining)")
|
||||||
|
swarmKey := fs.String("swarm-key", "", "IPFS Swarm key (required if joining)")
|
||||||
|
peersStr := fs.String("peers", "", "Comma-separated list of bootstrap peer multiaddrs")
|
||||||
|
|
||||||
|
// IPFS/Cluster specific info for Peering configuration
|
||||||
|
ipfsPeerID := fs.String("ipfs-peer", "", "Peer ID of existing IPFS node to peer with")
|
||||||
|
ipfsAddrs := fs.String("ipfs-addrs", "", "Comma-separated multiaddrs of existing IPFS node")
|
||||||
|
ipfsClusterPeerID := fs.String("ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node")
|
||||||
|
ipfsClusterAddrs := fs.String("ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node")
|
||||||
|
|
||||||
|
if err := fs.Parse(args); err != nil {
|
||||||
|
if err == flag.ErrHelp {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Failed to parse flags: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required flags
|
||||||
|
if *vpsIP == "" && !*dryRun {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Error: --vps-ip is required for installation\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " Example: dbn prod install --vps-ip 1.2.3.4\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Geteuid() != 0 && !*dryRun {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Production installation must be run as root (use sudo)\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
oramaHome := "/home/debros"
|
||||||
|
oramaDir := oramaHome + "/.orama"
|
||||||
|
fmt.Printf("🚀 Starting production installation...\n\n")
|
||||||
|
|
||||||
|
isFirstNode := *joinAddress == ""
|
||||||
|
peers, err := utils.NormalizePeers(*peersStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Invalid peers: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cluster secret was provided, save it to secrets directory before setup
|
||||||
|
if *clusterSecret != "" {
|
||||||
|
secretsDir := filepath.Join(oramaDir, "secrets")
|
||||||
|
if err := os.MkdirAll(secretsDir, 0755); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
secretPath := filepath.Join(secretsDir, "cluster-secret")
|
||||||
|
if err := os.WriteFile(secretPath, []byte(*clusterSecret), 0600); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Failed to save cluster secret: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf(" ✓ Cluster secret saved\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If swarm key was provided, save it to secrets directory in full format
|
||||||
|
if *swarmKey != "" {
|
||||||
|
secretsDir := filepath.Join(oramaDir, "secrets")
|
||||||
|
if err := os.MkdirAll(secretsDir, 0755); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Failed to create secrets directory: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
// Convert 64-hex key to full swarm.key format
|
||||||
|
swarmKeyContent := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base16/\n%s\n", strings.ToUpper(*swarmKey))
|
||||||
|
swarmKeyPath := filepath.Join(secretsDir, "swarm.key")
|
||||||
|
if err := os.WriteFile(swarmKeyPath, []byte(swarmKeyContent), 0600); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Failed to save swarm key: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf(" ✓ Swarm key saved\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store IPFS peer info for peering
|
||||||
|
var ipfsPeerInfo *utils.IPFSPeerInfo
|
||||||
|
if *ipfsPeerID != "" {
|
||||||
|
var addrs []string
|
||||||
|
if *ipfsAddrs != "" {
|
||||||
|
addrs = strings.Split(*ipfsAddrs, ",")
|
||||||
|
}
|
||||||
|
ipfsPeerInfo = &utils.IPFSPeerInfo{
|
||||||
|
PeerID: *ipfsPeerID,
|
||||||
|
Addrs: addrs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store IPFS Cluster peer info for cluster peer discovery
|
||||||
|
var ipfsClusterPeerInfo *utils.IPFSClusterPeerInfo
|
||||||
|
if *ipfsClusterPeerID != "" {
|
||||||
|
var addrs []string
|
||||||
|
if *ipfsClusterAddrs != "" {
|
||||||
|
addrs = strings.Split(*ipfsClusterAddrs, ",")
|
||||||
|
}
|
||||||
|
ipfsClusterPeerInfo = &utils.IPFSClusterPeerInfo{
|
||||||
|
PeerID: *ipfsClusterPeerID,
|
||||||
|
Addrs: addrs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setup := production.NewProductionSetup(oramaHome, os.Stdout, *force, *branch, *noPull, *skipResourceChecks)
|
||||||
|
|
||||||
|
// Inform user if skipping git pull
|
||||||
|
if *noPull {
|
||||||
|
fmt.Printf(" ⚠️ --no-pull flag enabled: Skipping git clone/pull\n")
|
||||||
|
fmt.Printf(" Using existing repository at /home/debros/src\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check port availability before proceeding
|
||||||
|
if err := utils.EnsurePortsAvailable("install", utils.DefaultPorts()); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate DNS if domain is provided
|
||||||
|
if *domain != "" {
|
||||||
|
fmt.Printf("\n🌐 Pre-flight DNS validation...\n")
|
||||||
|
utils.ValidateDNSRecord(*domain, *vpsIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dry-run mode: show what would be done and exit
|
||||||
|
if *dryRun {
|
||||||
|
utils.ShowDryRunSummary(*vpsIP, *domain, *branch, peers, *joinAddress, isFirstNode, oramaDir)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save branch preference for future upgrades
|
||||||
|
if err := production.SaveBranchPreference(oramaDir, *branch); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "⚠️ Warning: Failed to save branch preference: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: Check prerequisites
|
||||||
|
fmt.Printf("\n📋 Phase 1: Checking prerequisites...\n")
|
||||||
|
if err := setup.Phase1CheckPrerequisites(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Prerequisites check failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Provision environment
|
||||||
|
fmt.Printf("\n🛠️ Phase 2: Provisioning environment...\n")
|
||||||
|
if err := setup.Phase2ProvisionEnvironment(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Environment provisioning failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2b: Install binaries
|
||||||
|
fmt.Printf("\nPhase 2b: Installing binaries...\n")
|
||||||
|
if err := setup.Phase2bInstallBinaries(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Binary installation failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 3: Generate secrets FIRST (before service initialization)
|
||||||
|
// This ensures cluster secret and swarm key exist before repos are seeded
|
||||||
|
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
|
||||||
|
if err := setup.Phase3GenerateSecrets(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Secret generation failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 4: Generate configs (BEFORE service initialization)
|
||||||
|
// This ensures node.yaml exists before services try to access it
|
||||||
|
fmt.Printf("\n⚙️ Phase 4: Generating configurations...\n")
|
||||||
|
enableHTTPS := *domain != ""
|
||||||
|
if err := setup.Phase4GenerateConfigs(peers, *vpsIP, enableHTTPS, *domain, *joinAddress); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Configuration generation failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate generated configuration
|
||||||
|
fmt.Printf(" Validating generated configuration...\n")
|
||||||
|
if err := utils.ValidateGeneratedConfig(oramaDir); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Configuration validation failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf(" ✓ Configuration validated\n")
|
||||||
|
|
||||||
|
// Phase 2c: Initialize services (after config is in place)
|
||||||
|
fmt.Printf("\nPhase 2c: Initializing services...\n")
|
||||||
|
var prodIPFSPeer *production.IPFSPeerInfo
|
||||||
|
if ipfsPeerInfo != nil {
|
||||||
|
prodIPFSPeer = &production.IPFSPeerInfo{
|
||||||
|
PeerID: ipfsPeerInfo.PeerID,
|
||||||
|
Addrs: ipfsPeerInfo.Addrs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var prodIPFSClusterPeer *production.IPFSClusterPeerInfo
|
||||||
|
if ipfsClusterPeerInfo != nil {
|
||||||
|
prodIPFSClusterPeer = &production.IPFSClusterPeerInfo{
|
||||||
|
PeerID: ipfsClusterPeerInfo.PeerID,
|
||||||
|
Addrs: ipfsClusterPeerInfo.Addrs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := setup.Phase2cInitializeServices(peers, *vpsIP, prodIPFSPeer, prodIPFSClusterPeer); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Service initialization failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5: Create systemd services
|
||||||
|
fmt.Printf("\n🔧 Phase 5: Creating systemd services...\n")
|
||||||
|
if err := setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "❌ Service creation failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log completion with actual peer ID
|
||||||
|
setup.LogSetupComplete(setup.NodePeerID)
|
||||||
|
fmt.Printf("✅ Production installation complete!\n\n")
|
||||||
|
|
||||||
|
// For first node, print important secrets and identifiers
|
||||||
|
if isFirstNode {
|
||||||
|
fmt.Printf("📋 Save these for joining future nodes:\n\n")
|
||||||
|
|
||||||
|
// Print cluster secret
|
||||||
|
clusterSecretPath := filepath.Join(oramaDir, "secrets", "cluster-secret")
|
||||||
|
if clusterSecretData, err := os.ReadFile(clusterSecretPath); err == nil {
|
||||||
|
fmt.Printf(" Cluster Secret (--cluster-secret):\n")
|
||||||
|
fmt.Printf(" %s\n\n", string(clusterSecretData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print swarm key
|
||||||
|
swarmKeyPath := filepath.Join(oramaDir, "secrets", "swarm.key")
|
||||||
|
if swarmKeyData, err := os.ReadFile(swarmKeyPath); err == nil {
|
||||||
|
swarmKeyContent := strings.TrimSpace(string(swarmKeyData))
|
||||||
|
lines := strings.Split(swarmKeyContent, "\n")
|
||||||
|
if len(lines) >= 3 {
|
||||||
|
// Extract just the hex part (last line)
|
||||||
|
fmt.Printf(" IPFS Swarm Key (--swarm-key, last line only):\n")
|
||||||
|
fmt.Printf(" %s\n\n", lines[len(lines)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print peer ID
|
||||||
|
fmt.Printf(" Node Peer ID:\n")
|
||||||
|
fmt.Printf(" %s\n\n", setup.NodePeerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
97
pkg/cli/utils/install.go
Normal file
97
pkg/cli/utils/install.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPFSPeerInfo holds IPFS peer information for configuring Peering.Peers
|
||||||
|
type IPFSPeerInfo struct {
|
||||||
|
PeerID string
|
||||||
|
Addrs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPFSClusterPeerInfo contains IPFS Cluster peer information for cluster discovery
|
||||||
|
type IPFSClusterPeerInfo struct {
|
||||||
|
PeerID string
|
||||||
|
Addrs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowDryRunSummary displays what would be done during installation without making changes
|
||||||
|
func ShowDryRunSummary(vpsIP, domain, branch string, peers []string, joinAddress string, isFirstNode bool, oramaDir string) {
|
||||||
|
fmt.Print("\n" + strings.Repeat("=", 70) + "\n")
|
||||||
|
fmt.Printf("DRY RUN - No changes will be made\n")
|
||||||
|
fmt.Print(strings.Repeat("=", 70) + "\n\n")
|
||||||
|
|
||||||
|
fmt.Printf("📋 Installation Summary:\n")
|
||||||
|
fmt.Printf(" VPS IP: %s\n", vpsIP)
|
||||||
|
fmt.Printf(" Domain: %s\n", domain)
|
||||||
|
fmt.Printf(" Branch: %s\n", branch)
|
||||||
|
if isFirstNode {
|
||||||
|
fmt.Printf(" Node Type: First node (creates new cluster)\n")
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" Node Type: Joining existing cluster\n")
|
||||||
|
if joinAddress != "" {
|
||||||
|
fmt.Printf(" Join Address: %s\n", joinAddress)
|
||||||
|
}
|
||||||
|
if len(peers) > 0 {
|
||||||
|
fmt.Printf(" Peers: %d peer(s)\n", len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
fmt.Printf(" - %s\n", peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\n📁 Directories that would be created:\n")
|
||||||
|
fmt.Printf(" %s/configs/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/secrets/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/data/ipfs/repo/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/data/ipfs-cluster/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/data/rqlite/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/logs/\n", oramaDir)
|
||||||
|
fmt.Printf(" %s/tls-cache/\n", oramaDir)
|
||||||
|
|
||||||
|
fmt.Printf("\n🔧 Binaries that would be installed:\n")
|
||||||
|
fmt.Printf(" - Go (if not present)\n")
|
||||||
|
fmt.Printf(" - RQLite 8.43.0\n")
|
||||||
|
fmt.Printf(" - IPFS/Kubo 0.38.2\n")
|
||||||
|
fmt.Printf(" - IPFS Cluster (latest)\n")
|
||||||
|
fmt.Printf(" - Olric 0.7.0\n")
|
||||||
|
fmt.Printf(" - anyone-client (npm)\n")
|
||||||
|
fmt.Printf(" - DeBros binaries (built from %s branch)\n", branch)
|
||||||
|
|
||||||
|
fmt.Printf("\n🔐 Secrets that would be generated:\n")
|
||||||
|
fmt.Printf(" - Cluster secret (64-hex)\n")
|
||||||
|
fmt.Printf(" - IPFS swarm key\n")
|
||||||
|
fmt.Printf(" - Node identity (Ed25519 keypair)\n")
|
||||||
|
|
||||||
|
fmt.Printf("\n📝 Configuration files that would be created:\n")
|
||||||
|
fmt.Printf(" - %s/configs/node.yaml\n", oramaDir)
|
||||||
|
fmt.Printf(" - %s/configs/olric/config.yaml\n", oramaDir)
|
||||||
|
|
||||||
|
fmt.Printf("\n⚙️ Systemd services that would be created:\n")
|
||||||
|
fmt.Printf(" - debros-ipfs.service\n")
|
||||||
|
fmt.Printf(" - debros-ipfs-cluster.service\n")
|
||||||
|
fmt.Printf(" - debros-olric.service\n")
|
||||||
|
fmt.Printf(" - debros-node.service (includes embedded gateway + RQLite)\n")
|
||||||
|
fmt.Printf(" - debros-anyone-client.service\n")
|
||||||
|
|
||||||
|
fmt.Printf("\n🌐 Ports that would be used:\n")
|
||||||
|
fmt.Printf(" External (must be open in firewall):\n")
|
||||||
|
fmt.Printf(" - 80 (HTTP for ACME/Let's Encrypt)\n")
|
||||||
|
fmt.Printf(" - 443 (HTTPS gateway)\n")
|
||||||
|
fmt.Printf(" - 4101 (IPFS swarm)\n")
|
||||||
|
fmt.Printf(" - 7001 (RQLite Raft)\n")
|
||||||
|
fmt.Printf(" Internal (localhost only):\n")
|
||||||
|
fmt.Printf(" - 4501 (IPFS API)\n")
|
||||||
|
fmt.Printf(" - 5001 (RQLite HTTP)\n")
|
||||||
|
fmt.Printf(" - 6001 (Unified gateway)\n")
|
||||||
|
fmt.Printf(" - 8080 (IPFS gateway)\n")
|
||||||
|
fmt.Printf(" - 9050 (Anyone SOCKS5)\n")
|
||||||
|
fmt.Printf(" - 9094 (IPFS Cluster API)\n")
|
||||||
|
fmt.Printf(" - 3320/3322 (Olric)\n")
|
||||||
|
|
||||||
|
fmt.Print("\n" + strings.Repeat("=", 70) + "\n")
|
||||||
|
fmt.Printf("To proceed with installation, run without --dry-run\n")
|
||||||
|
fmt.Print(strings.Repeat("=", 70) + "\n\n")
|
||||||
|
}
|
||||||
217
pkg/cli/utils/systemd.go
Normal file
217
pkg/cli/utils/systemd.go
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrServiceNotFound = errors.New("service not found")
|
||||||
|
|
||||||
|
// PortSpec defines a port and its name for checking availability
|
||||||
|
type PortSpec struct {
|
||||||
|
Name string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
var ServicePorts = map[string][]PortSpec{
|
||||||
|
"debros-gateway": {
|
||||||
|
{Name: "Gateway API", Port: 6001},
|
||||||
|
},
|
||||||
|
"debros-olric": {
|
||||||
|
{Name: "Olric HTTP", Port: 3320},
|
||||||
|
{Name: "Olric Memberlist", Port: 3322},
|
||||||
|
},
|
||||||
|
"debros-node": {
|
||||||
|
{Name: "RQLite HTTP", Port: 5001},
|
||||||
|
{Name: "RQLite Raft", Port: 7001},
|
||||||
|
},
|
||||||
|
"debros-ipfs": {
|
||||||
|
{Name: "IPFS API", Port: 4501},
|
||||||
|
{Name: "IPFS Gateway", Port: 8080},
|
||||||
|
{Name: "IPFS Swarm", Port: 4101},
|
||||||
|
},
|
||||||
|
"debros-ipfs-cluster": {
|
||||||
|
{Name: "IPFS Cluster API", Port: 9094},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultPorts is used for fresh installs/upgrades before unit files exist.
|
||||||
|
func DefaultPorts() []PortSpec {
|
||||||
|
return []PortSpec{
|
||||||
|
{Name: "IPFS Swarm", Port: 4001},
|
||||||
|
{Name: "IPFS API", Port: 4501},
|
||||||
|
{Name: "IPFS Gateway", Port: 8080},
|
||||||
|
{Name: "Gateway API", Port: 6001},
|
||||||
|
{Name: "RQLite HTTP", Port: 5001},
|
||||||
|
{Name: "RQLite Raft", Port: 7001},
|
||||||
|
{Name: "IPFS Cluster API", Port: 9094},
|
||||||
|
{Name: "Olric HTTP", Port: 3320},
|
||||||
|
{Name: "Olric Memberlist", Port: 3322},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveServiceName resolves service aliases to actual systemd service names
|
||||||
|
func ResolveServiceName(alias string) ([]string, error) {
|
||||||
|
// Service alias mapping (unified - no bootstrap/node distinction)
|
||||||
|
aliases := map[string][]string{
|
||||||
|
"node": {"debros-node"},
|
||||||
|
"ipfs": {"debros-ipfs"},
|
||||||
|
"cluster": {"debros-ipfs-cluster"},
|
||||||
|
"ipfs-cluster": {"debros-ipfs-cluster"},
|
||||||
|
"gateway": {"debros-gateway"},
|
||||||
|
"olric": {"debros-olric"},
|
||||||
|
"rqlite": {"debros-node"}, // RQLite logs are in node logs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's an alias
|
||||||
|
if serviceNames, ok := aliases[strings.ToLower(alias)]; ok {
|
||||||
|
// Filter to only existing services
|
||||||
|
var existing []string
|
||||||
|
for _, svc := range serviceNames {
|
||||||
|
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
|
||||||
|
if _, err := os.Stat(unitPath); err == nil {
|
||||||
|
existing = append(existing, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(existing) == 0 {
|
||||||
|
return nil, fmt.Errorf("no services found for alias %q", alias)
|
||||||
|
}
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's already a full service name
|
||||||
|
unitPath := filepath.Join("/etc/systemd/system", alias+".service")
|
||||||
|
if _, err := os.Stat(unitPath); err == nil {
|
||||||
|
return []string{alias}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try without .service suffix
|
||||||
|
if !strings.HasSuffix(alias, ".service") {
|
||||||
|
unitPath = filepath.Join("/etc/systemd/system", alias+".service")
|
||||||
|
if _, err := os.Stat(unitPath); err == nil {
|
||||||
|
return []string{alias}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("service %q not found. Use: node, ipfs, cluster, gateway, olric, or full service name", alias)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServiceActive checks if a systemd service is currently active (running)
|
||||||
|
func IsServiceActive(service string) (bool, error) {
|
||||||
|
cmd := exec.Command("systemctl", "is-active", "--quiet", service)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
switch exitErr.ExitCode() {
|
||||||
|
case 3:
|
||||||
|
return false, nil
|
||||||
|
case 4:
|
||||||
|
return false, ErrServiceNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServiceEnabled checks if a systemd service is enabled to start on boot
|
||||||
|
func IsServiceEnabled(service string) (bool, error) {
|
||||||
|
cmd := exec.Command("systemctl", "is-enabled", "--quiet", service)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
switch exitErr.ExitCode() {
|
||||||
|
case 1:
|
||||||
|
return false, nil // Service is disabled
|
||||||
|
case 4:
|
||||||
|
return false, ErrServiceNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServiceMasked checks if a systemd service is masked
|
||||||
|
func IsServiceMasked(service string) (bool, error) {
|
||||||
|
cmd := exec.Command("systemctl", "is-enabled", service)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
outputStr := string(output)
|
||||||
|
if strings.Contains(outputStr, "masked") {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProductionServices returns a list of all DeBros production service names that exist
|
||||||
|
func GetProductionServices() []string {
|
||||||
|
// Unified service names (no bootstrap/node distinction)
|
||||||
|
allServices := []string{
|
||||||
|
"debros-gateway",
|
||||||
|
"debros-node",
|
||||||
|
"debros-olric",
|
||||||
|
"debros-ipfs-cluster",
|
||||||
|
"debros-ipfs",
|
||||||
|
"debros-anyone-client",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter to only existing services by checking if unit file exists
|
||||||
|
var existing []string
|
||||||
|
for _, svc := range allServices {
|
||||||
|
unitPath := filepath.Join("/etc/systemd/system", svc+".service")
|
||||||
|
if _, err := os.Stat(unitPath); err == nil {
|
||||||
|
existing = append(existing, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
|
||||||
|
// CollectPortsForServices returns a list of ports used by the specified services
|
||||||
|
func CollectPortsForServices(services []string, skipActive bool) ([]PortSpec, error) {
|
||||||
|
seen := make(map[int]PortSpec)
|
||||||
|
for _, svc := range services {
|
||||||
|
if skipActive {
|
||||||
|
active, err := IsServiceActive(svc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to check %s: %w", svc, err)
|
||||||
|
}
|
||||||
|
if active {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, spec := range ServicePorts[svc] {
|
||||||
|
if _, ok := seen[spec.Port]; !ok {
|
||||||
|
seen[spec.Port] = spec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ports := make([]PortSpec, 0, len(seen))
|
||||||
|
for _, spec := range seen {
|
||||||
|
ports = append(ports, spec)
|
||||||
|
}
|
||||||
|
return ports, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsurePortsAvailable checks if the specified ports are available
|
||||||
|
func EnsurePortsAvailable(action string, ports []PortSpec) error {
|
||||||
|
for _, spec := range ports {
|
||||||
|
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", spec.Port))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, syscall.EADDRINUSE) || strings.Contains(err.Error(), "address already in use") {
|
||||||
|
return fmt.Errorf("%s cannot continue: %s (port %d) is already in use", action, spec.Name, spec.Port)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s cannot continue: failed to inspect %s (port %d): %w", action, spec.Name, spec.Port, err)
|
||||||
|
}
|
||||||
|
_ = ln.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
113
pkg/cli/utils/validation.go
Normal file
113
pkg/cli/utils/validation.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateGeneratedConfig loads and validates the generated node configuration
|
||||||
|
func ValidateGeneratedConfig(oramaDir string) error {
|
||||||
|
configPath := filepath.Join(oramaDir, "configs", "node.yaml")
|
||||||
|
|
||||||
|
// Check if config file exists
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("configuration file not found at %s", configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config file
|
||||||
|
file, err := os.Open(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open config file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var cfg config.Config
|
||||||
|
if err := config.DecodeStrict(file, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the configuration
|
||||||
|
if errs := cfg.Validate(); len(errs) > 0 {
|
||||||
|
var errMsgs []string
|
||||||
|
for _, e := range errs {
|
||||||
|
errMsgs = append(errMsgs, e.Error())
|
||||||
|
}
|
||||||
|
return fmt.Errorf("configuration validation errors:\n - %s", strings.Join(errMsgs, "\n - "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateDNSRecord validates that the domain points to the expected IP address
|
||||||
|
// Returns nil if DNS is valid, warning message if DNS doesn't match but continues,
|
||||||
|
// or error if DNS lookup fails completely
|
||||||
|
func ValidateDNSRecord(domain, expectedIP string) error {
|
||||||
|
if domain == "" {
|
||||||
|
return nil // No domain provided, skip validation
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
// DNS lookup failed - this is a warning, not a fatal error
|
||||||
|
// The user might be setting up DNS after installation
|
||||||
|
fmt.Printf(" ⚠️ DNS lookup failed for %s: %v\n", domain, err)
|
||||||
|
fmt.Printf(" Make sure DNS is configured before enabling HTTPS\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any resolved IP matches the expected IP
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.String() == expectedIP {
|
||||||
|
fmt.Printf(" ✓ DNS validated: %s → %s\n", domain, expectedIP)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNS doesn't point to expected IP - warn but continue
|
||||||
|
resolvedIPs := make([]string, len(ips))
|
||||||
|
for i, ip := range ips {
|
||||||
|
resolvedIPs[i] = ip.String()
|
||||||
|
}
|
||||||
|
fmt.Printf(" ⚠️ DNS mismatch: %s resolves to %v, expected %s\n", domain, resolvedIPs, expectedIP)
|
||||||
|
fmt.Printf(" HTTPS certificate generation may fail until DNS is updated\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizePeers normalizes and validates peer multiaddrs
|
||||||
|
func NormalizePeers(peersStr string) ([]string, error) {
|
||||||
|
if peersStr == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split by comma and trim whitespace
|
||||||
|
rawPeers := strings.Split(peersStr, ",")
|
||||||
|
peers := make([]string, 0, len(rawPeers))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, peer := range rawPeers {
|
||||||
|
peer = strings.TrimSpace(peer)
|
||||||
|
if peer == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate multiaddr format
|
||||||
|
if _, err := multiaddr.NewMultiaddr(peer); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid multiaddr %q: %w", peer, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate
|
||||||
|
if !seen[peer] {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
seen[peer] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
@ -329,6 +329,18 @@ func (c *Client) getAppNamespace() string {
|
|||||||
return c.config.AppName
|
return c.config.AppName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PubSubAdapter returns the underlying pubsub.ClientAdapter for direct use by serverless functions.
|
||||||
|
// This bypasses the authentication checks used by PubSub() since serverless functions
|
||||||
|
// are already authenticated via the gateway.
|
||||||
|
func (c *Client) PubSubAdapter() *pubsub.ClientAdapter {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
if c.pubsub == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.pubsub.adapter
|
||||||
|
}
|
||||||
|
|
||||||
// requireAccess enforces that credentials are present and that any context-based namespace overrides match
|
// requireAccess enforces that credentials are present and that any context-based namespace overrides match
|
||||||
func (c *Client) requireAccess(ctx context.Context) error {
|
func (c *Client) requireAccess(ctx context.Context) error {
|
||||||
// Allow internal system operations to bypass authentication
|
// Allow internal system operations to bypass authentication
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -585,3 +586,15 @@ func extractTCPPort(multiaddrStr string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateSwarmKey validates that a swarm key is 64 hex characters
|
||||||
|
func ValidateSwarmKey(key string) error {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if len(key) != 64 {
|
||||||
|
return fmt.Errorf("swarm key must be 64 hex characters (32 bytes), got %d", len(key))
|
||||||
|
}
|
||||||
|
if _, err := hex.DecodeString(key); err != nil {
|
||||||
|
return fmt.Errorf("swarm key must be valid hexadecimal: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -78,7 +78,7 @@ func (dc *DependencyChecker) CheckAll() ([]string, error) {
|
|||||||
|
|
||||||
errMsg := fmt.Sprintf("Missing %d required dependencies:\n%s\n\nInstall them with:\n%s",
|
errMsg := fmt.Sprintf("Missing %d required dependencies:\n%s\n\nInstall them with:\n%s",
|
||||||
len(missing), strings.Join(missing, ", "), strings.Join(hints, "\n"))
|
len(missing), strings.Join(missing, ", "), strings.Join(hints, "\n"))
|
||||||
return missing, fmt.Errorf(errMsg)
|
return missing, fmt.Errorf("%s", errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PortChecker validates that required ports are available
|
// PortChecker validates that required ports are available
|
||||||
@ -113,7 +113,7 @@ func (pc *PortChecker) CheckAll() ([]int, error) {
|
|||||||
|
|
||||||
errMsg := fmt.Sprintf("The following ports are unavailable: %v\n\nFree them or stop conflicting services and try again",
|
errMsg := fmt.Sprintf("The following ports are unavailable: %v\n\nFree them or stop conflicting services and try again",
|
||||||
unavailable)
|
unavailable)
|
||||||
return unavailable, fmt.Errorf(errMsg)
|
return unavailable, fmt.Errorf("%s", errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPortAvailable checks if a TCP port is available for binding
|
// isPortAvailable checks if a TCP port is available for binding
|
||||||
|
|||||||
287
pkg/environments/development/ipfs.go
Normal file
287
pkg/environments/development/ipfs.go
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
package development
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ipfsNodeInfo holds information about an IPFS node for peer discovery
|
||||||
|
type ipfsNodeInfo struct {
|
||||||
|
name string
|
||||||
|
ipfsPath string
|
||||||
|
apiPort int
|
||||||
|
swarmPort int
|
||||||
|
gatewayPort int
|
||||||
|
peerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) buildIPFSNodes(topology *Topology) []ipfsNodeInfo {
|
||||||
|
var nodes []ipfsNodeInfo
|
||||||
|
for _, nodeSpec := range topology.Nodes {
|
||||||
|
nodes = append(nodes, ipfsNodeInfo{
|
||||||
|
name: nodeSpec.Name,
|
||||||
|
ipfsPath: filepath.Join(pm.oramaDir, nodeSpec.DataDir, "ipfs/repo"),
|
||||||
|
apiPort: nodeSpec.IPFSAPIPort,
|
||||||
|
swarmPort: nodeSpec.IPFSSwarmPort,
|
||||||
|
gatewayPort: nodeSpec.IPFSGatewayPort,
|
||||||
|
peerID: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startIPFS(ctx context.Context) error {
|
||||||
|
topology := DefaultTopology()
|
||||||
|
nodes := pm.buildIPFSNodes(topology)
|
||||||
|
|
||||||
|
for i := range nodes {
|
||||||
|
os.MkdirAll(nodes[i].ipfsPath, 0755)
|
||||||
|
|
||||||
|
if _, err := os.Stat(filepath.Join(nodes[i].ipfsPath, "config")); os.IsNotExist(err) {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Initializing IPFS (%s)...\n", nodes[i].name)
|
||||||
|
cmd := exec.CommandContext(ctx, "ipfs", "init", "--profile=server", "--repo-dir="+nodes[i].ipfsPath)
|
||||||
|
if _, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: ipfs init failed: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
swarmKeyPath := filepath.Join(pm.oramaDir, "swarm.key")
|
||||||
|
if data, err := os.ReadFile(swarmKeyPath); err == nil {
|
||||||
|
os.WriteFile(filepath.Join(nodes[i].ipfsPath, "swarm.key"), data, 0600)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID, err := configureIPFSRepo(nodes[i].ipfsPath, nodes[i].apiPort, nodes[i].gatewayPort, nodes[i].swarmPort)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to configure IPFS repo for %s: %v\n", nodes[i].name, err)
|
||||||
|
} else {
|
||||||
|
nodes[i].peerID = peerID
|
||||||
|
fmt.Fprintf(pm.logWriter, " Peer ID for %s: %s\n", nodes[i].name, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range nodes {
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-%s.pid", nodes[i].name))
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-%s.log", nodes[i].name))
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "ipfs", "daemon", "--enable-pubsub-experiment", "--repo-dir="+nodes[i].ipfsPath)
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start ipfs-%s: %w", nodes[i].name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
pm.processes[fmt.Sprintf("ipfs-%s", nodes[i].name)] = &ManagedProcess{
|
||||||
|
Name: fmt.Sprintf("ipfs-%s", nodes[i].name),
|
||||||
|
PID: cmd.Process.Pid,
|
||||||
|
StartTime: time.Now(),
|
||||||
|
LogPath: logPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ IPFS (%s) started (PID: %d, API: %d, Swarm: %d)\n", nodes[i].name, cmd.Process.Pid, nodes[i].apiPort, nodes[i].swarmPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
if err := pm.seedIPFSPeersWithHTTP(ctx, nodes); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, "⚠️ Failed to seed IPFS peers: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureIPFSRepo(repoPath string, apiPort, gatewayPort, swarmPort int) (string, error) {
|
||||||
|
configPath := filepath.Join(repoPath, "config")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read IPFS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var config map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse IPFS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config["Addresses"] = map[string]interface{}{
|
||||||
|
"API": []string{fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", apiPort)},
|
||||||
|
"Gateway": []string{fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", gatewayPort)},
|
||||||
|
"Swarm": []string{
|
||||||
|
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", swarmPort),
|
||||||
|
fmt.Sprintf("/ip6/::/tcp/%d", swarmPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config["AutoConf"] = map[string]interface{}{
|
||||||
|
"Enabled": false,
|
||||||
|
}
|
||||||
|
config["Bootstrap"] = []string{}
|
||||||
|
|
||||||
|
if dns, ok := config["DNS"].(map[string]interface{}); ok {
|
||||||
|
dns["Resolvers"] = map[string]interface{}{}
|
||||||
|
} else {
|
||||||
|
config["DNS"] = map[string]interface{}{
|
||||||
|
"Resolvers": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if routing, ok := config["Routing"].(map[string]interface{}); ok {
|
||||||
|
routing["DelegatedRouters"] = []string{}
|
||||||
|
} else {
|
||||||
|
config["Routing"] = map[string]interface{}{
|
||||||
|
"DelegatedRouters": []string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipns, ok := config["Ipns"].(map[string]interface{}); ok {
|
||||||
|
ipns["DelegatedPublishers"] = []string{}
|
||||||
|
} else {
|
||||||
|
config["Ipns"] = map[string]interface{}{
|
||||||
|
"DelegatedPublishers": []string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if api, ok := config["API"].(map[string]interface{}); ok {
|
||||||
|
api["HTTPHeaders"] = map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET", "PUT", "POST", "DELETE", "OPTIONS"},
|
||||||
|
"Access-Control-Allow-Headers": {"Content-Type", "X-Requested-With"},
|
||||||
|
"Access-Control-Expose-Headers": {"Content-Length", "Content-Range"},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
config["API"] = map[string]interface{}{
|
||||||
|
"HTTPHeaders": map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET", "PUT", "POST", "DELETE", "OPTIONS"},
|
||||||
|
"Access-Control-Allow-Headers": {"Content-Type", "X-Requested-With"},
|
||||||
|
"Access-Control-Expose-Headers": {"Content-Length", "Content-Range"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedData, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to marshal IPFS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, updatedData, 0644); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to write IPFS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if id, ok := config["Identity"].(map[string]interface{}); ok {
|
||||||
|
if peerID, ok := id["PeerID"].(string); ok {
|
||||||
|
return peerID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("could not extract peer ID from config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) seedIPFSPeersWithHTTP(ctx context.Context, nodes []ipfsNodeInfo) error {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Seeding IPFS local bootstrap peers via HTTP API...\n")
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
if err := pm.waitIPFSReady(ctx, node); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to wait for IPFS readiness for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, node := range nodes {
|
||||||
|
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/bootstrap/rm?all=true", node.apiPort)
|
||||||
|
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to clear bootstrap for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j, otherNode := range nodes {
|
||||||
|
if i == j {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
multiaddr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d/p2p/%s", otherNode.swarmPort, otherNode.peerID)
|
||||||
|
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/bootstrap/add?arg=%s", node.apiPort, url.QueryEscape(multiaddr))
|
||||||
|
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to add bootstrap peer for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) waitIPFSReady(ctx context.Context, node ipfsNodeInfo) error {
|
||||||
|
maxRetries := 30
|
||||||
|
retryInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
httpURL := fmt.Sprintf("http://127.0.0.1:%d/api/v0/version", node.apiPort)
|
||||||
|
if err := pm.ipfsHTTPCall(ctx, httpURL, "POST"); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(retryInterval):
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("IPFS daemon %s did not become ready", node.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) ipfsHTTPCall(ctx context.Context, urlStr string, method string) error {
|
||||||
|
client := tlsutil.NewHTTPClient(5 * time.Second)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("HTTP call failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPFSConfigValue(ctx context.Context, repoPath string, key string) (string, error) {
|
||||||
|
configPath := filepath.Join(repoPath, "config")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read IPFS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if strings.Contains(line, key) {
|
||||||
|
parts := strings.SplitN(line, ":", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
value := strings.TrimSpace(parts[1])
|
||||||
|
value = strings.Trim(value, `",`)
|
||||||
|
if value != "" {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("key %s not found in IPFS config", key)
|
||||||
|
}
|
||||||
|
|
||||||
314
pkg/environments/development/ipfs_cluster.go
Normal file
314
pkg/environments/development/ipfs_cluster.go
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
package development
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startIPFSCluster(ctx context.Context) error {
|
||||||
|
topology := DefaultTopology()
|
||||||
|
var nodes []struct {
|
||||||
|
name string
|
||||||
|
clusterPath string
|
||||||
|
restAPIPort int
|
||||||
|
clusterPort int
|
||||||
|
ipfsPort int
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nodeSpec := range topology.Nodes {
|
||||||
|
nodes = append(nodes, struct {
|
||||||
|
name string
|
||||||
|
clusterPath string
|
||||||
|
restAPIPort int
|
||||||
|
clusterPort int
|
||||||
|
ipfsPort int
|
||||||
|
}{
|
||||||
|
nodeSpec.Name,
|
||||||
|
filepath.Join(pm.oramaDir, nodeSpec.DataDir, "ipfs-cluster"),
|
||||||
|
nodeSpec.ClusterAPIPort,
|
||||||
|
nodeSpec.ClusterPort,
|
||||||
|
nodeSpec.IPFSAPIPort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, " Waiting for IPFS daemons to be ready...\n")
|
||||||
|
ipfsNodes := pm.buildIPFSNodes(topology)
|
||||||
|
for _, ipfsNode := range ipfsNodes {
|
||||||
|
if err := pm.waitIPFSReady(ctx, ipfsNode); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: IPFS %s did not become ready: %v\n", ipfsNode.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secretPath := filepath.Join(pm.oramaDir, "cluster-secret")
|
||||||
|
clusterSecret, err := os.ReadFile(secretPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read cluster secret: %w", err)
|
||||||
|
}
|
||||||
|
clusterSecretHex := strings.TrimSpace(string(clusterSecret))
|
||||||
|
|
||||||
|
bootstrapMultiaddr := ""
|
||||||
|
{
|
||||||
|
node := nodes[0]
|
||||||
|
if err := pm.cleanClusterState(node.clusterPath); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to clean cluster state for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.MkdirAll(node.clusterPath, 0755)
|
||||||
|
fmt.Fprintf(pm.logWriter, " Initializing IPFS Cluster (%s)...\n", node.name)
|
||||||
|
cmd := exec.CommandContext(ctx, "ipfs-cluster-service", "init", "--force")
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath),
|
||||||
|
fmt.Sprintf("CLUSTER_SECRET=%s", clusterSecretHex),
|
||||||
|
)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: ipfs-cluster-service init failed: %v (output: %s)\n", err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.ensureIPFSClusterPorts(node.clusterPath, node.restAPIPort, node.clusterPort); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to update IPFS Cluster config for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-cluster-%s.pid", node.name))
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-cluster-%s.log", node.name))
|
||||||
|
|
||||||
|
cmd = exec.CommandContext(ctx, "ipfs-cluster-service", "daemon")
|
||||||
|
cmd.Env = append(os.Environ(), fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath))
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ IPFS Cluster (%s) started (PID: %d, API: %d)\n", node.name, cmd.Process.Pid, node.restAPIPort)
|
||||||
|
|
||||||
|
if err := pm.waitClusterReady(ctx, node.name, node.restAPIPort); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster %s did not become ready: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
peerID, err := pm.waitForClusterPeerID(ctx, filepath.Join(node.clusterPath, "identity.json"))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to read bootstrap peer ID: %v\n", err)
|
||||||
|
} else {
|
||||||
|
bootstrapMultiaddr = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d/p2p/%s", node.clusterPort, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(nodes); i++ {
|
||||||
|
node := nodes[i]
|
||||||
|
if err := pm.cleanClusterState(node.clusterPath); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to clean cluster state for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.MkdirAll(node.clusterPath, 0755)
|
||||||
|
fmt.Fprintf(pm.logWriter, " Initializing IPFS Cluster (%s)...\n", node.name)
|
||||||
|
cmd := exec.CommandContext(ctx, "ipfs-cluster-service", "init", "--force")
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath),
|
||||||
|
fmt.Sprintf("CLUSTER_SECRET=%s", clusterSecretHex),
|
||||||
|
)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: ipfs-cluster-service init failed for %s: %v (output: %s)\n", node.name, err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.ensureIPFSClusterPorts(node.clusterPath, node.restAPIPort, node.clusterPort); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: failed to update IPFS Cluster config for %s: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("ipfs-cluster-%s.pid", node.name))
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("ipfs-cluster-%s.log", node.name))
|
||||||
|
|
||||||
|
args := []string{"daemon"}
|
||||||
|
if bootstrapMultiaddr != "" {
|
||||||
|
args = append(args, "--bootstrap", bootstrapMultiaddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.CommandContext(ctx, "ipfs-cluster-service", args...)
|
||||||
|
cmd.Env = append(os.Environ(), fmt.Sprintf("IPFS_CLUSTER_PATH=%s", node.clusterPath))
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ IPFS Cluster (%s) started (PID: %d, API: %d)\n", node.name, cmd.Process.Pid, node.restAPIPort)
|
||||||
|
|
||||||
|
if err := pm.waitClusterReady(ctx, node.name, node.restAPIPort); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster %s did not become ready: %v\n", node.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, " Waiting for IPFS Cluster peers to form...\n")
|
||||||
|
if err := pm.waitClusterFormed(ctx, nodes[0].restAPIPort); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " Warning: IPFS Cluster did not form fully: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) waitForClusterPeerID(ctx context.Context, identityPath string) (string, error) {
|
||||||
|
maxRetries := 30
|
||||||
|
retryInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
data, err := os.ReadFile(identityPath)
|
||||||
|
if err == nil {
|
||||||
|
var identity map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &identity); err == nil {
|
||||||
|
if id, ok := identity["id"].(string); ok {
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(retryInterval):
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("could not read cluster peer ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) waitClusterReady(ctx context.Context, name string, restAPIPort int) error {
|
||||||
|
maxRetries := 30
|
||||||
|
retryInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
httpURL := fmt.Sprintf("http://127.0.0.1:%d/peers", restAPIPort)
|
||||||
|
resp, err := http.Get(httpURL)
|
||||||
|
if err == nil && resp.StatusCode == 200 {
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(retryInterval):
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("IPFS Cluster %s did not become ready", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) waitClusterFormed(ctx context.Context, bootstrapRestAPIPort int) error {
|
||||||
|
maxRetries := 30
|
||||||
|
retryInterval := 1 * time.Second
|
||||||
|
requiredPeers := 3
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
httpURL := fmt.Sprintf("http://127.0.0.1:%d/peers", bootstrapRestAPIPort)
|
||||||
|
resp, err := http.Get(httpURL)
|
||||||
|
if err == nil && resp.StatusCode == 200 {
|
||||||
|
dec := json.NewDecoder(resp.Body)
|
||||||
|
peerCount := 0
|
||||||
|
for {
|
||||||
|
var peer interface{}
|
||||||
|
if err := dec.Decode(&peer); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
peerCount++
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if peerCount >= requiredPeers {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(retryInterval):
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("IPFS Cluster did not form fully")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) cleanClusterState(clusterPath string) error {
|
||||||
|
pebblePath := filepath.Join(clusterPath, "pebble")
|
||||||
|
os.RemoveAll(pebblePath)
|
||||||
|
|
||||||
|
peerstorePath := filepath.Join(clusterPath, "peerstore")
|
||||||
|
os.Remove(peerstorePath)
|
||||||
|
|
||||||
|
serviceJSONPath := filepath.Join(clusterPath, "service.json")
|
||||||
|
os.Remove(serviceJSONPath)
|
||||||
|
|
||||||
|
lockPath := filepath.Join(clusterPath, "cluster.lock")
|
||||||
|
os.Remove(lockPath)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) ensureIPFSClusterPorts(clusterPath string, restAPIPort int, clusterPort int) error {
|
||||||
|
serviceJSONPath := filepath.Join(clusterPath, "service.json")
|
||||||
|
data, err := os.ReadFile(serviceJSONPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config map[string]interface{}
|
||||||
|
json.Unmarshal(data, &config)
|
||||||
|
|
||||||
|
portOffset := restAPIPort - 9094
|
||||||
|
proxyPort := 9095 + portOffset
|
||||||
|
pinsvcPort := 9097 + portOffset
|
||||||
|
ipfsPort := 4501 + (portOffset / 10)
|
||||||
|
|
||||||
|
if api, ok := config["api"].(map[string]interface{}); ok {
|
||||||
|
if restapi, ok := api["restapi"].(map[string]interface{}); ok {
|
||||||
|
restapi["http_listen_multiaddress"] = fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", restAPIPort)
|
||||||
|
}
|
||||||
|
if proxy, ok := api["ipfsproxy"].(map[string]interface{}); ok {
|
||||||
|
proxy["listen_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", proxyPort)
|
||||||
|
proxy["node_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", ipfsPort)
|
||||||
|
}
|
||||||
|
if pinsvc, ok := api["pinsvcapi"].(map[string]interface{}); ok {
|
||||||
|
pinsvc["http_listen_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", pinsvcPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster, ok := config["cluster"].(map[string]interface{}); ok {
|
||||||
|
cluster["listen_multiaddress"] = []string{
|
||||||
|
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", clusterPort),
|
||||||
|
fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", clusterPort),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if connector, ok := config["ipfs_connector"].(map[string]interface{}); ok {
|
||||||
|
if ipfshttp, ok := connector["ipfshttp"].(map[string]interface{}); ok {
|
||||||
|
ipfshttp["node_multiaddress"] = fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", ipfsPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedData, _ := json.MarshalIndent(config, "", " ")
|
||||||
|
return os.WriteFile(serviceJSONPath, updatedData, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
231
pkg/environments/development/process.go
Normal file
231
pkg/environments/development/process.go
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
package development
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (pm *ProcessManager) printStartupSummary(topology *Topology) {
|
||||||
|
fmt.Fprintf(pm.logWriter, "\n✅ Development environment ready!\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, "═══════════════════════════════════════\n\n")
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "📡 Access your nodes via unified gateway ports:\n\n")
|
||||||
|
for _, node := range topology.Nodes {
|
||||||
|
fmt.Fprintf(pm.logWriter, " %s:\n", node.Name)
|
||||||
|
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/health\n", node.UnifiedGatewayPort)
|
||||||
|
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/rqlite/http/db/execute\n", node.UnifiedGatewayPort)
|
||||||
|
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/cluster/health\n\n", node.UnifiedGatewayPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "🌐 Main Gateway:\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, " curl http://localhost:%d/v1/status\n\n", topology.GatewayPort)
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "📊 Other Services:\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, " Olric: http://localhost:%d\n", topology.OlricHTTPPort)
|
||||||
|
fmt.Fprintf(pm.logWriter, " Anon SOCKS: 127.0.0.1:%d\n", topology.AnonSOCKSPort)
|
||||||
|
fmt.Fprintf(pm.logWriter, " Rqlite MCP: http://localhost:%d/sse\n\n", topology.MCPPort)
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "📝 Useful Commands:\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, " ./bin/orama dev status - Check service status\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, " ./bin/orama dev logs node-1 - View logs\n")
|
||||||
|
fmt.Fprintf(pm.logWriter, " ./bin/orama dev down - Stop all services\n\n")
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logWriter, "📂 Logs: %s/logs\n", pm.oramaDir)
|
||||||
|
fmt.Fprintf(pm.logWriter, "⚙️ Config: %s\n\n", pm.oramaDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) stopProcess(name string) error {
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("%s.pid", name))
|
||||||
|
pidBytes, err := os.ReadFile(pidPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes)))
|
||||||
|
if err != nil {
|
||||||
|
os.Remove(pidPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !checkProcessRunning(pid) {
|
||||||
|
os.Remove(pidPath)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ %s (not running)\n", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
os.Remove(pidPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proc.Signal(os.Interrupt)
|
||||||
|
|
||||||
|
gracefulShutdown := false
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if !checkProcessRunning(pid) {
|
||||||
|
gracefulShutdown = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gracefulShutdown && checkProcessRunning(pid) {
|
||||||
|
proc.Signal(os.Kill)
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
exec.Command("pkill", "-9", "-P", fmt.Sprintf("%d", pid)).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
if checkProcessRunning(pid) {
|
||||||
|
exec.Command("kill", "-9", fmt.Sprintf("%d", pid)).Run()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Remove(pidPath)
|
||||||
|
|
||||||
|
if gracefulShutdown {
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ %s stopped gracefully\n", name)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ %s stopped (forced)\n", name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkProcessRunning(pid int) bool {
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
err = proc.Signal(os.Signal(nil))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startNode(name, configFile, logPath string) error {
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, fmt.Sprintf("%s.pid", name))
|
||||||
|
cmd := exec.Command("./bin/orama-node", "--config", configFile)
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ %s started (PID: %d)\n", strings.Title(name), cmd.Process.Pid)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startGateway(ctx context.Context) error {
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, "gateway.pid")
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", "gateway.log")
|
||||||
|
|
||||||
|
cmd := exec.Command("./bin/gateway", "--config", "gateway.yaml")
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start gateway: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ Gateway started (PID: %d, listen: 6001)\n", cmd.Process.Pid)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startOlric(ctx context.Context) error {
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, "olric.pid")
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", "olric.log")
|
||||||
|
configPath := filepath.Join(pm.oramaDir, "olric-config.yaml")
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "olric-server")
|
||||||
|
cmd.Env = append(os.Environ(), fmt.Sprintf("OLRIC_SERVER_CONFIG=%s", configPath))
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start olric: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ Olric started (PID: %d)\n", cmd.Process.Pid)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startAnon(ctx context.Context) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, "anon.pid")
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", "anon.log")
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "npx", "anyone-client")
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " ⚠️ Failed to start Anon: %v\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ Anon proxy started (PID: %d, SOCKS: 9050)\n", cmd.Process.Pid)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startMCP(ctx context.Context) error {
|
||||||
|
topology := DefaultTopology()
|
||||||
|
pidPath := filepath.Join(pm.pidsDir, "rqlite-mcp.pid")
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", "rqlite-mcp.log")
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "./bin/rqlite-mcp")
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
fmt.Sprintf("MCP_PORT=%d", topology.MCPPort),
|
||||||
|
fmt.Sprintf("RQLITE_URL=http://localhost:%d", topology.Nodes[0].RQLiteHTTPPort),
|
||||||
|
)
|
||||||
|
logFile, _ := os.Create(logPath)
|
||||||
|
cmd.Stdout = logFile
|
||||||
|
cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
fmt.Fprintf(pm.logWriter, " ⚠️ Failed to start Rqlite MCP: %v\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
|
||||||
|
fmt.Fprintf(pm.logWriter, "✓ Rqlite MCP started (PID: %d, port: %d)\n", cmd.Process.Pid, topology.MCPPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProcessManager) startNodes(ctx context.Context) error {
|
||||||
|
topology := DefaultTopology()
|
||||||
|
for _, nodeSpec := range topology.Nodes {
|
||||||
|
logPath := filepath.Join(pm.oramaDir, "logs", fmt.Sprintf("%s.log", nodeSpec.Name))
|
||||||
|
if err := pm.startNode(nodeSpec.Name, nodeSpec.ConfigFilename, logPath); err != nil {
|
||||||
|
return fmt.Errorf("failed to start %s: %w", nodeSpec.Name, err)
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@ -4,20 +4,20 @@ import "fmt"
|
|||||||
|
|
||||||
// NodeSpec defines configuration for a single dev environment node
|
// NodeSpec defines configuration for a single dev environment node
|
||||||
type NodeSpec struct {
|
type NodeSpec struct {
|
||||||
Name string // node-1, node-2, node-3, node-4, node-5
|
Name string // node-1, node-2, node-3, node-4, node-5
|
||||||
ConfigFilename string // node-1.yaml, node-2.yaml, etc.
|
ConfigFilename string // node-1.yaml, node-2.yaml, etc.
|
||||||
DataDir string // relative path from .orama root
|
DataDir string // relative path from .orama root
|
||||||
P2PPort int // LibP2P listen port
|
P2PPort int // LibP2P listen port
|
||||||
IPFSAPIPort int // IPFS API port
|
IPFSAPIPort int // IPFS API port
|
||||||
IPFSSwarmPort int // IPFS Swarm port
|
IPFSSwarmPort int // IPFS Swarm port
|
||||||
IPFSGatewayPort int // IPFS HTTP Gateway port
|
IPFSGatewayPort int // IPFS HTTP Gateway port
|
||||||
RQLiteHTTPPort int // RQLite HTTP API port
|
RQLiteHTTPPort int // RQLite HTTP API port
|
||||||
RQLiteRaftPort int // RQLite Raft consensus port
|
RQLiteRaftPort int // RQLite Raft consensus port
|
||||||
ClusterAPIPort int // IPFS Cluster REST API port
|
ClusterAPIPort int // IPFS Cluster REST API port
|
||||||
ClusterPort int // IPFS Cluster P2P port
|
ClusterPort int // IPFS Cluster P2P port
|
||||||
UnifiedGatewayPort int // Unified gateway port (proxies all services)
|
UnifiedGatewayPort int // Unified gateway port (proxies all services)
|
||||||
RQLiteJoinTarget string // which node's RQLite Raft port to join (empty for first node)
|
RQLiteJoinTarget string // which node's RQLite Raft port to join (empty for first node)
|
||||||
ClusterJoinTarget string // which node's cluster to join (empty for first node)
|
ClusterJoinTarget string // which node's cluster to join (empty for first node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Topology defines the complete development environment topology
|
// Topology defines the complete development environment topology
|
||||||
@ -27,97 +27,99 @@ type Topology struct {
|
|||||||
OlricHTTPPort int
|
OlricHTTPPort int
|
||||||
OlricMemberPort int
|
OlricMemberPort int
|
||||||
AnonSOCKSPort int
|
AnonSOCKSPort int
|
||||||
|
MCPPort int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultTopology returns the default five-node dev environment topology
|
// DefaultTopology returns the default five-node dev environment topology
|
||||||
func DefaultTopology() *Topology {
|
func DefaultTopology() *Topology {
|
||||||
return &Topology{
|
return &Topology{
|
||||||
Nodes: []NodeSpec{
|
Nodes: []NodeSpec{
|
||||||
{
|
{
|
||||||
Name: "node-1",
|
Name: "node-1",
|
||||||
ConfigFilename: "node-1.yaml",
|
ConfigFilename: "node-1.yaml",
|
||||||
DataDir: "node-1",
|
DataDir: "node-1",
|
||||||
P2PPort: 4001,
|
P2PPort: 4001,
|
||||||
IPFSAPIPort: 4501,
|
IPFSAPIPort: 4501,
|
||||||
IPFSSwarmPort: 4101,
|
IPFSSwarmPort: 4101,
|
||||||
IPFSGatewayPort: 7501,
|
IPFSGatewayPort: 7501,
|
||||||
RQLiteHTTPPort: 5001,
|
RQLiteHTTPPort: 5001,
|
||||||
RQLiteRaftPort: 7001,
|
RQLiteRaftPort: 7001,
|
||||||
ClusterAPIPort: 9094,
|
ClusterAPIPort: 9094,
|
||||||
ClusterPort: 9096,
|
ClusterPort: 9096,
|
||||||
UnifiedGatewayPort: 6001,
|
UnifiedGatewayPort: 6001,
|
||||||
RQLiteJoinTarget: "", // First node - creates cluster
|
RQLiteJoinTarget: "", // First node - creates cluster
|
||||||
ClusterJoinTarget: "",
|
ClusterJoinTarget: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "node-2",
|
||||||
|
ConfigFilename: "node-2.yaml",
|
||||||
|
DataDir: "node-2",
|
||||||
|
P2PPort: 4011,
|
||||||
|
IPFSAPIPort: 4511,
|
||||||
|
IPFSSwarmPort: 4111,
|
||||||
|
IPFSGatewayPort: 7511,
|
||||||
|
RQLiteHTTPPort: 5011,
|
||||||
|
RQLiteRaftPort: 7011,
|
||||||
|
ClusterAPIPort: 9104,
|
||||||
|
ClusterPort: 9106,
|
||||||
|
UnifiedGatewayPort: 6002,
|
||||||
|
RQLiteJoinTarget: "localhost:7001",
|
||||||
|
ClusterJoinTarget: "localhost:9096",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "node-3",
|
||||||
|
ConfigFilename: "node-3.yaml",
|
||||||
|
DataDir: "node-3",
|
||||||
|
P2PPort: 4002,
|
||||||
|
IPFSAPIPort: 4502,
|
||||||
|
IPFSSwarmPort: 4102,
|
||||||
|
IPFSGatewayPort: 7502,
|
||||||
|
RQLiteHTTPPort: 5002,
|
||||||
|
RQLiteRaftPort: 7002,
|
||||||
|
ClusterAPIPort: 9114,
|
||||||
|
ClusterPort: 9116,
|
||||||
|
UnifiedGatewayPort: 6003,
|
||||||
|
RQLiteJoinTarget: "localhost:7001",
|
||||||
|
ClusterJoinTarget: "localhost:9096",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "node-4",
|
||||||
|
ConfigFilename: "node-4.yaml",
|
||||||
|
DataDir: "node-4",
|
||||||
|
P2PPort: 4003,
|
||||||
|
IPFSAPIPort: 4503,
|
||||||
|
IPFSSwarmPort: 4103,
|
||||||
|
IPFSGatewayPort: 7503,
|
||||||
|
RQLiteHTTPPort: 5003,
|
||||||
|
RQLiteRaftPort: 7003,
|
||||||
|
ClusterAPIPort: 9124,
|
||||||
|
ClusterPort: 9126,
|
||||||
|
UnifiedGatewayPort: 6004,
|
||||||
|
RQLiteJoinTarget: "localhost:7001",
|
||||||
|
ClusterJoinTarget: "localhost:9096",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "node-5",
|
||||||
|
ConfigFilename: "node-5.yaml",
|
||||||
|
DataDir: "node-5",
|
||||||
|
P2PPort: 4004,
|
||||||
|
IPFSAPIPort: 4504,
|
||||||
|
IPFSSwarmPort: 4104,
|
||||||
|
IPFSGatewayPort: 7504,
|
||||||
|
RQLiteHTTPPort: 5004,
|
||||||
|
RQLiteRaftPort: 7004,
|
||||||
|
ClusterAPIPort: 9134,
|
||||||
|
ClusterPort: 9136,
|
||||||
|
UnifiedGatewayPort: 6005,
|
||||||
|
RQLiteJoinTarget: "localhost:7001",
|
||||||
|
ClusterJoinTarget: "localhost:9096",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
GatewayPort: 6000, // Main gateway on 6000 (nodes use 6001-6005)
|
||||||
Name: "node-2",
|
|
||||||
ConfigFilename: "node-2.yaml",
|
|
||||||
DataDir: "node-2",
|
|
||||||
P2PPort: 4011,
|
|
||||||
IPFSAPIPort: 4511,
|
|
||||||
IPFSSwarmPort: 4111,
|
|
||||||
IPFSGatewayPort: 7511,
|
|
||||||
RQLiteHTTPPort: 5011,
|
|
||||||
RQLiteRaftPort: 7011,
|
|
||||||
ClusterAPIPort: 9104,
|
|
||||||
ClusterPort: 9106,
|
|
||||||
UnifiedGatewayPort: 6002,
|
|
||||||
RQLiteJoinTarget: "localhost:7001",
|
|
||||||
ClusterJoinTarget: "localhost:9096",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "node-3",
|
|
||||||
ConfigFilename: "node-3.yaml",
|
|
||||||
DataDir: "node-3",
|
|
||||||
P2PPort: 4002,
|
|
||||||
IPFSAPIPort: 4502,
|
|
||||||
IPFSSwarmPort: 4102,
|
|
||||||
IPFSGatewayPort: 7502,
|
|
||||||
RQLiteHTTPPort: 5002,
|
|
||||||
RQLiteRaftPort: 7002,
|
|
||||||
ClusterAPIPort: 9114,
|
|
||||||
ClusterPort: 9116,
|
|
||||||
UnifiedGatewayPort: 6003,
|
|
||||||
RQLiteJoinTarget: "localhost:7001",
|
|
||||||
ClusterJoinTarget: "localhost:9096",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "node-4",
|
|
||||||
ConfigFilename: "node-4.yaml",
|
|
||||||
DataDir: "node-4",
|
|
||||||
P2PPort: 4003,
|
|
||||||
IPFSAPIPort: 4503,
|
|
||||||
IPFSSwarmPort: 4103,
|
|
||||||
IPFSGatewayPort: 7503,
|
|
||||||
RQLiteHTTPPort: 5003,
|
|
||||||
RQLiteRaftPort: 7003,
|
|
||||||
ClusterAPIPort: 9124,
|
|
||||||
ClusterPort: 9126,
|
|
||||||
UnifiedGatewayPort: 6004,
|
|
||||||
RQLiteJoinTarget: "localhost:7001",
|
|
||||||
ClusterJoinTarget: "localhost:9096",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "node-5",
|
|
||||||
ConfigFilename: "node-5.yaml",
|
|
||||||
DataDir: "node-5",
|
|
||||||
P2PPort: 4004,
|
|
||||||
IPFSAPIPort: 4504,
|
|
||||||
IPFSSwarmPort: 4104,
|
|
||||||
IPFSGatewayPort: 7504,
|
|
||||||
RQLiteHTTPPort: 5004,
|
|
||||||
RQLiteRaftPort: 7004,
|
|
||||||
ClusterAPIPort: 9134,
|
|
||||||
ClusterPort: 9136,
|
|
||||||
UnifiedGatewayPort: 6005,
|
|
||||||
RQLiteJoinTarget: "localhost:7001",
|
|
||||||
ClusterJoinTarget: "localhost:9096",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
GatewayPort: 6000, // Main gateway on 6000 (nodes use 6001-6005)
|
|
||||||
OlricHTTPPort: 3320,
|
OlricHTTPPort: 3320,
|
||||||
OlricMemberPort: 3322,
|
OlricMemberPort: 3322,
|
||||||
AnonSOCKSPort: 9050,
|
AnonSOCKSPort: 9050,
|
||||||
|
MCPPort: 5825,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
package gateway
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
@ -13,13 +13,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (g *Gateway) jwksHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if g.signingKey == nil {
|
if s.signingKey == nil {
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}})
|
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pub := g.signingKey.Public().(*rsa.PublicKey)
|
pub := s.signingKey.Public().(*rsa.PublicKey)
|
||||||
n := pub.N.Bytes()
|
n := pub.N.Bytes()
|
||||||
// Encode exponent as big-endian bytes
|
// Encode exponent as big-endian bytes
|
||||||
eVal := pub.E
|
eVal := pub.E
|
||||||
@ -35,7 +35,7 @@ func (g *Gateway) jwksHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
"kty": "RSA",
|
"kty": "RSA",
|
||||||
"use": "sig",
|
"use": "sig",
|
||||||
"alg": "RS256",
|
"alg": "RS256",
|
||||||
"kid": g.keyID,
|
"kid": s.keyID,
|
||||||
"n": base64.RawURLEncoding.EncodeToString(n),
|
"n": base64.RawURLEncoding.EncodeToString(n),
|
||||||
"e": base64.RawURLEncoding.EncodeToString(eb),
|
"e": base64.RawURLEncoding.EncodeToString(eb),
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ type jwtHeader struct {
|
|||||||
Kid string `json:"kid"`
|
Kid string `json:"kid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type jwtClaims struct {
|
type JWTClaims struct {
|
||||||
Iss string `json:"iss"`
|
Iss string `json:"iss"`
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Aud string `json:"aud"`
|
Aud string `json:"aud"`
|
||||||
@ -59,9 +59,9 @@ type jwtClaims struct {
|
|||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims
|
// ParseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims
|
||||||
func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) {
|
func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
||||||
if g.signingKey == nil {
|
if s.signingKey == nil {
|
||||||
return nil, errors.New("signing key unavailable")
|
return nil, errors.New("signing key unavailable")
|
||||||
}
|
}
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
@ -90,12 +90,12 @@ func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) {
|
|||||||
// Verify signature
|
// Verify signature
|
||||||
signingInput := parts[0] + "." + parts[1]
|
signingInput := parts[0] + "." + parts[1]
|
||||||
sum := sha256.Sum256([]byte(signingInput))
|
sum := sha256.Sum256([]byte(signingInput))
|
||||||
pub := g.signingKey.Public().(*rsa.PublicKey)
|
pub := s.signingKey.Public().(*rsa.PublicKey)
|
||||||
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
|
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
|
||||||
return nil, errors.New("invalid signature")
|
return nil, errors.New("invalid signature")
|
||||||
}
|
}
|
||||||
// Parse claims
|
// Parse claims
|
||||||
var claims jwtClaims
|
var claims JWTClaims
|
||||||
if err := json.Unmarshal(pb, &claims); err != nil {
|
if err := json.Unmarshal(pb, &claims); err != nil {
|
||||||
return nil, errors.New("invalid claims json")
|
return nil, errors.New("invalid claims json")
|
||||||
}
|
}
|
||||||
@ -122,14 +122,14 @@ func (g *Gateway) parseAndVerifyJWT(token string) (*jwtClaims, error) {
|
|||||||
return &claims, nil
|
return &claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) generateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
||||||
if g.signingKey == nil {
|
if s.signingKey == nil {
|
||||||
return "", 0, errors.New("signing key unavailable")
|
return "", 0, errors.New("signing key unavailable")
|
||||||
}
|
}
|
||||||
header := map[string]string{
|
header := map[string]string{
|
||||||
"alg": "RS256",
|
"alg": "RS256",
|
||||||
"typ": "JWT",
|
"typ": "JWT",
|
||||||
"kid": g.keyID,
|
"kid": s.keyID,
|
||||||
}
|
}
|
||||||
hb, _ := json.Marshal(header)
|
hb, _ := json.Marshal(header)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
@ -148,7 +148,7 @@ func (g *Gateway) generateJWT(ns, subject string, ttl time.Duration) (string, in
|
|||||||
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
||||||
signingInput := hb64 + "." + pb64
|
signingInput := hb64 + "." + pb64
|
||||||
sum := sha256.Sum256([]byte(signingInput))
|
sum := sha256.Sum256([]byte(signingInput))
|
||||||
sig, err := rsa.SignPKCS1v15(rand.Reader, g.signingKey, crypto.SHA256, sum[:])
|
sig, err := rsa.SignPKCS1v15(rand.Reader, s.signingKey, crypto.SHA256, sum[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
391
pkg/gateway/auth/service.go
Normal file
391
pkg/gateway/auth/service.go
Normal file
@ -0,0 +1,391 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
ethcrypto "github.com/ethereum/go-ethereum/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service handles authentication business logic
|
||||||
|
type Service struct {
|
||||||
|
logger *logging.ColoredLogger
|
||||||
|
orm client.NetworkClient
|
||||||
|
signingKey *rsa.PrivateKey
|
||||||
|
keyID string
|
||||||
|
defaultNS string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signingKeyPEM string, defaultNS string) (*Service, error) {
|
||||||
|
s := &Service{
|
||||||
|
logger: logger,
|
||||||
|
orm: orm,
|
||||||
|
defaultNS: defaultNS,
|
||||||
|
}
|
||||||
|
|
||||||
|
if signingKeyPEM != "" {
|
||||||
|
block, _ := pem.Decode([]byte(signingKeyPEM))
|
||||||
|
if block == nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse signing key PEM")
|
||||||
|
}
|
||||||
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
|
||||||
|
}
|
||||||
|
s.signingKey = key
|
||||||
|
|
||||||
|
// Generate a simple KID from the public key hash
|
||||||
|
pubBytes := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
||||||
|
sum := sha256.Sum256(pubBytes)
|
||||||
|
s.keyID = hex.EncodeToString(sum[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNonce generates a new nonce and stores it in the database
|
||||||
|
func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) {
|
||||||
|
// Generate a URL-safe random nonce (32 bytes)
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
nonce := base64.RawURLEncoding.EncodeToString(buf)
|
||||||
|
|
||||||
|
// Use internal context to bypass authentication for system operations
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = s.defaultNS
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure namespace exists
|
||||||
|
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", namespace); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to ensure namespace: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve namespace ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store nonce with 5 minute expiry
|
||||||
|
walletLower := strings.ToLower(strings.TrimSpace(wallet))
|
||||||
|
if _, err := db.Query(internalCtx,
|
||||||
|
"INSERT INTO nonces(namespace_id, wallet, nonce, purpose, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+5 minutes'))",
|
||||||
|
nsID, walletLower, nonce, purpose,
|
||||||
|
); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to store nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nonce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySignature verifies a wallet signature for a given nonce
|
||||||
|
func (s *Service) VerifySignature(ctx context.Context, wallet, nonce, signature, chainType string) (bool, error) {
|
||||||
|
chainType = strings.ToUpper(strings.TrimSpace(chainType))
|
||||||
|
if chainType == "" {
|
||||||
|
chainType = "ETH"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch chainType {
|
||||||
|
case "ETH":
|
||||||
|
return s.verifyEthSignature(wallet, nonce, signature)
|
||||||
|
case "SOL":
|
||||||
|
return s.verifySolSignature(wallet, nonce, signature)
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unsupported chain type: %s", chainType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) verifyEthSignature(wallet, nonce, signature string) (bool, error) {
|
||||||
|
msg := []byte(nonce)
|
||||||
|
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
|
||||||
|
hash := ethcrypto.Keccak256(prefix, msg)
|
||||||
|
|
||||||
|
sigHex := strings.TrimSpace(signature)
|
||||||
|
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
|
||||||
|
sigHex = sigHex[2:]
|
||||||
|
}
|
||||||
|
sig, err := hex.DecodeString(sigHex)
|
||||||
|
if err != nil || len(sig) != 65 {
|
||||||
|
return false, fmt.Errorf("invalid signature format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sig[64] >= 27 {
|
||||||
|
sig[64] -= 27
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, err := ethcrypto.SigToPub(hash, sig)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("signature recovery failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
|
||||||
|
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(wallet, "0x"), "0X"))
|
||||||
|
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
|
||||||
|
|
||||||
|
return got == want, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) verifySolSignature(wallet, nonce, signature string) (bool, error) {
|
||||||
|
sig, err := base64.StdEncoding.DecodeString(signature)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid base64 signature: %w", err)
|
||||||
|
}
|
||||||
|
if len(sig) != 64 {
|
||||||
|
return false, fmt.Errorf("invalid signature length: expected 64 bytes, got %d", len(sig))
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKeyBytes, err := s.Base58Decode(wallet)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid wallet address: %w", err)
|
||||||
|
}
|
||||||
|
if len(pubKeyBytes) != 32 {
|
||||||
|
return false, fmt.Errorf("invalid public key length: expected 32 bytes, got %d", len(pubKeyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
message := []byte(nonce)
|
||||||
|
return ed25519.Verify(ed25519.PublicKey(pubKeyBytes), message, sig), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueTokens generates access and refresh tokens for a verified wallet
|
||||||
|
func (s *Service) IssueTokens(ctx context.Context, wallet, namespace string) (string, string, int64, error) {
|
||||||
|
if s.signingKey == nil {
|
||||||
|
return "", "", 0, fmt.Errorf("signing key unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue access token (15m)
|
||||||
|
token, expUnix, err := s.GenerateJWT(namespace, wallet, 15*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, fmt.Errorf("failed to generate JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create refresh token (30d)
|
||||||
|
rbuf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(rbuf); err != nil {
|
||||||
|
return "", "", 0, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||||
|
}
|
||||||
|
refresh := base64.RawURLEncoding.EncodeToString(rbuf)
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, fmt.Errorf("failed to resolve namespace ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
if _, err := db.Query(internalCtx,
|
||||||
|
"INSERT INTO refresh_tokens(namespace_id, subject, token, audience, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+30 days'))",
|
||||||
|
nsID, wallet, refresh, "gateway",
|
||||||
|
); err != nil {
|
||||||
|
return "", "", 0, fmt.Errorf("failed to store refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, refresh, expUnix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken validates a refresh token and issues a new access token
|
||||||
|
func (s *Service) RefreshToken(ctx context.Context, refreshToken, namespace string) (string, string, int64, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := "SELECT subject FROM refresh_tokens WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
||||||
|
res, err := db.Query(internalCtx, q, nsID, refreshToken)
|
||||||
|
if err != nil || res == nil || res.Count == 0 {
|
||||||
|
return "", "", 0, fmt.Errorf("invalid or expired refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
subject := ""
|
||||||
|
if len(res.Rows) > 0 && len(res.Rows[0]) > 0 {
|
||||||
|
if val, ok := res.Rows[0][0].(string); ok {
|
||||||
|
subject = val
|
||||||
|
} else {
|
||||||
|
b, _ := json.Marshal(res.Rows[0][0])
|
||||||
|
_ = json.Unmarshal(b, &subject)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, expUnix, err := s.GenerateJWT(namespace, subject, 15*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, subject, expUnix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeToken revokes a specific refresh token or all tokens for a subject
|
||||||
|
func (s *Service) RevokeToken(ctx context.Context, namespace, token string, all bool, subject string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != "" {
|
||||||
|
_, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL", nsID, token)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if all && subject != "" {
|
||||||
|
_, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND subject = ? AND revoked_at IS NULL", nsID, subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("nothing to revoke")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterApp registers a new client application
|
||||||
|
func (s *Service) RegisterApp(ctx context.Context, wallet, namespace, name, publicKey string) (string, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate client app_id
|
||||||
|
buf := make([]byte, 12)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate app id: %w", err)
|
||||||
|
}
|
||||||
|
appID := "app_" + base64.RawURLEncoding.EncodeToString(buf)
|
||||||
|
|
||||||
|
// Persist app
|
||||||
|
if _, err := db.Query(internalCtx, "INSERT INTO apps(namespace_id, app_id, name, public_key) VALUES (?, ?, ?, ?)", nsID, appID, name, publicKey); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record ownership
|
||||||
|
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, ?, ?)", nsID, "wallet", wallet)
|
||||||
|
|
||||||
|
return appID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateAPIKey returns an existing API key or creates a new one for a wallet in a namespace
|
||||||
|
func (s *Service) GetOrCreateAPIKey(ctx context.Context, wallet, namespace string) (string, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
nsID, err := s.ResolveNamespaceID(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try existing linkage
|
||||||
|
var apiKey string
|
||||||
|
r1, err := db.Query(internalCtx,
|
||||||
|
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
|
||||||
|
nsID, wallet,
|
||||||
|
)
|
||||||
|
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
|
||||||
|
if val, ok := r1.Rows[0][0].(string); ok {
|
||||||
|
apiKey = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiKey != "" {
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new API key
|
||||||
|
buf := make([]byte, 18)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate api key: %w", err)
|
||||||
|
}
|
||||||
|
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + namespace
|
||||||
|
|
||||||
|
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to store api key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link wallet -> api_key
|
||||||
|
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
|
||||||
|
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
|
||||||
|
apiKeyID := rid.Rows[0][0]
|
||||||
|
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(wallet), apiKeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record ownerships
|
||||||
|
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
|
||||||
|
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, wallet)
|
||||||
|
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveNamespaceID ensures the given namespace exists and returns its primary key ID.
|
||||||
|
func (s *Service) ResolveNamespaceID(ctx context.Context, ns string) (interface{}, error) {
|
||||||
|
if s.orm == nil {
|
||||||
|
return nil, fmt.Errorf("client not initialized")
|
||||||
|
}
|
||||||
|
ns = strings.TrimSpace(ns)
|
||||||
|
if ns == "" {
|
||||||
|
ns = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
db := s.orm.Database()
|
||||||
|
|
||||||
|
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 {
|
||||||
|
return nil, fmt.Errorf("failed to resolve namespace")
|
||||||
|
}
|
||||||
|
return res.Rows[0][0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base58Decode decodes a base58-encoded string
|
||||||
|
func (s *Service) Base58Decode(input string) ([]byte, error) {
|
||||||
|
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||||
|
answer := big.NewInt(0)
|
||||||
|
j := big.NewInt(1)
|
||||||
|
for i := len(input) - 1; i >= 0; i-- {
|
||||||
|
tmp := strings.IndexByte(alphabet, input[i])
|
||||||
|
if tmp == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid base58 character")
|
||||||
|
}
|
||||||
|
idx := big.NewInt(int64(tmp))
|
||||||
|
tmp1 := new(big.Int)
|
||||||
|
tmp1.Mul(idx, j)
|
||||||
|
answer.Add(answer, tmp1)
|
||||||
|
j.Mul(j, big.NewInt(58))
|
||||||
|
}
|
||||||
|
// Handle leading zeros
|
||||||
|
res := answer.Bytes()
|
||||||
|
for i := 0; i < len(input) && input[i] == alphabet[0]; i++ {
|
||||||
|
res = append([]byte{0}, res...)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
166
pkg/gateway/auth/service_test.go
Normal file
166
pkg/gateway/auth/service_test.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockNetworkClient implements client.NetworkClient for testing
|
||||||
|
type mockNetworkClient struct {
|
||||||
|
client.NetworkClient
|
||||||
|
db *mockDatabaseClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNetworkClient) Database() client.DatabaseClient {
|
||||||
|
return m.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockDatabaseClient implements client.DatabaseClient for testing
|
||||||
|
type mockDatabaseClient struct {
|
||||||
|
client.DatabaseClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDatabaseClient) Query(ctx context.Context, sql string, args ...interface{}) (*client.QueryResult, error) {
|
||||||
|
return &client.QueryResult{
|
||||||
|
Count: 1,
|
||||||
|
Rows: [][]interface{}{
|
||||||
|
{1}, // Default ID for ResolveNamespaceID
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestService(t *testing.T) *Service {
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGateway, false)
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate key: %v", err)
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
|
||||||
|
mockDB := &mockDatabaseClient{}
|
||||||
|
mockClient := &mockNetworkClient{db: mockDB}
|
||||||
|
|
||||||
|
s, err := NewService(logger, mockClient, string(keyPEM), "test-ns")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBase58Decode(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string // hex representation for comparison
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"1", "00", false},
|
||||||
|
{"2", "01", false},
|
||||||
|
{"9", "08", false},
|
||||||
|
{"A", "09", false},
|
||||||
|
{"B", "0a", false},
|
||||||
|
{"2p", "0100", false}, // 58*1 + 0 = 58 (0x3a) - wait, base58 is weird
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, err := s.Base58Decode(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Base58Decode(%s) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !tt.wantErr {
|
||||||
|
hexGot := hex.EncodeToString(got)
|
||||||
|
if tt.expected != "" && hexGot != tt.expected {
|
||||||
|
// Base58 decoding of single characters might not be exactly what I expect above
|
||||||
|
// but let's just ensure it doesn't crash and returns something for now.
|
||||||
|
// Better to test a known valid address.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test a real Solana address (Base58)
|
||||||
|
solAddr := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
|
||||||
|
_, err := s.Base58Decode(solAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to decode solana address: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTFlow(t *testing.T) {
|
||||||
|
s := createTestService(t)
|
||||||
|
|
||||||
|
ns := "test-ns"
|
||||||
|
sub := "0x1234567890abcdef1234567890abcdef12345678"
|
||||||
|
ttl := 15 * time.Minute
|
||||||
|
|
||||||
|
token, exp, err := s.GenerateJWT(ns, sub, ttl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJWT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("generated token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp <= time.Now().Unix() {
|
||||||
|
t.Errorf("expiration time %d is in the past", exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseAndVerifyJWT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Sub != sub {
|
||||||
|
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Namespace != ns {
|
||||||
|
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Iss != "debros-gateway" {
|
||||||
|
t.Errorf("expected issuer debros-gateway, got %s", claims.Iss)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEthSignature(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
|
||||||
|
// This is a bit hard to test without a real ETH signature
|
||||||
|
// but we can check if it returns false for obviously wrong signatures
|
||||||
|
wallet := "0x1234567890abcdef1234567890abcdef12345678"
|
||||||
|
nonce := "test-nonce"
|
||||||
|
sig := hex.EncodeToString(make([]byte, 65))
|
||||||
|
|
||||||
|
ok, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "ETH")
|
||||||
|
if err == nil && ok {
|
||||||
|
t.Error("VerifySignature should have failed for zero signature")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySolSignature(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
|
||||||
|
// Solana address (base58)
|
||||||
|
wallet := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
|
||||||
|
nonce := "test-nonce"
|
||||||
|
sig := "invalid-sig"
|
||||||
|
|
||||||
|
_, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "SOL")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("VerifySignature should have failed for invalid base64 signature")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,20 +1,14 @@
|
|||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
ethcrypto "github.com/ethereum/go-ethereum/crypto"
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -29,7 +23,7 @@ func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Prefer JWT if present
|
// Prefer JWT if present
|
||||||
if v := ctx.Value(ctxKeyJWT); v != nil {
|
if v := ctx.Value(ctxKeyJWT); v != nil {
|
||||||
if claims, ok := v.(*jwtClaims); ok && claims != nil {
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"authenticated": true,
|
"authenticated": true,
|
||||||
"method": "jwt",
|
"method": "jwt",
|
||||||
@ -61,8 +55,8 @@ func (g *Gateway) whoamiHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -82,51 +76,16 @@ func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "wallet is required")
|
writeError(w, http.StatusBadRequest, "wallet is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Generate a URL-safe random nonce (32 bytes)
|
|
||||||
buf := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to generate nonce")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nonce := base64.RawURLEncoding.EncodeToString(buf)
|
|
||||||
|
|
||||||
// Insert namespace if missing, fetch id
|
nonce, err := g.authService.CreateNonce(r.Context(), req.Wallet, req.Purpose, req.Namespace)
|
||||||
ctx := r.Context()
|
if err != nil {
|
||||||
// Use internal context to bypass authentication for system operations
|
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
|
||||||
db := g.client.Database()
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nres, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
|
|
||||||
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to resolve namespace")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nsID := nres.Rows[0][0]
|
|
||||||
|
|
||||||
// Store nonce with 5 minute expiry
|
|
||||||
// Normalize wallet address to lowercase for case-insensitive comparison
|
|
||||||
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
|
|
||||||
if _, err := db.Query(internalCtx,
|
|
||||||
"INSERT INTO nonces(namespace_id, wallet, nonce, purpose, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+5 minutes'))",
|
|
||||||
nsID, walletLower, nonce, req.Purpose,
|
|
||||||
); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"wallet": req.Wallet,
|
"wallet": req.Wallet,
|
||||||
"namespace": ns,
|
"namespace": req.Namespace,
|
||||||
"nonce": nonce,
|
"nonce": nonce,
|
||||||
"purpose": req.Purpose,
|
"purpose": req.Purpose,
|
||||||
"expires_at": time.Now().Add(5 * time.Minute).UTC().Format(time.RFC3339Nano),
|
"expires_at": time.Now().Add(5 * time.Minute).UTC().Format(time.RFC3339Nano),
|
||||||
@ -134,8 +93,8 @@ func (g *Gateway) challengeHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -147,7 +106,7 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Nonce string `json:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
ChainType string `json:"chain_type"` // "ETH" or "SOL", defaults to "ETH"
|
ChainType string `json:"chain_type"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
writeError(w, http.StatusBadRequest, "invalid json body")
|
writeError(w, http.StatusBadRequest, "invalid json body")
|
||||||
@ -157,185 +116,30 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
// Use internal context to bypass authentication for system operations
|
verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
if err != nil || !verified {
|
||||||
|
writeError(w, http.StatusUnauthorized, "signature verification failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark nonce used
|
||||||
|
nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
|
||||||
db := g.client.Database()
|
db := g.client.Database()
|
||||||
nsID, err := g.resolveNamespaceID(ctx, ns)
|
_, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
|
||||||
|
|
||||||
|
token, refresh, expUnix, err := g.authService.IssueTokens(ctx, req.Wallet, req.Namespace)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Normalize wallet address to lowercase for case-insensitive comparison
|
|
||||||
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
|
|
||||||
q := "SELECT id FROM nonces WHERE namespace_id = ? AND LOWER(wallet) = LOWER(?) AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
|
||||||
nres, err := db.Query(internalCtx, q, nsID, walletLower, req.Nonce)
|
|
||||||
if err != nil || nres == nil || nres.Count == 0 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nonceID := nres.Rows[0][0]
|
|
||||||
|
|
||||||
// Determine chain type (default to ETH for backward compatibility)
|
apiKey, err := g.authService.GetOrCreateAPIKey(ctx, req.Wallet, req.Namespace)
|
||||||
chainType := strings.ToUpper(strings.TrimSpace(req.ChainType))
|
|
||||||
if chainType == "" {
|
|
||||||
chainType = "ETH"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify signature based on chain type
|
|
||||||
var verified bool
|
|
||||||
var verifyErr error
|
|
||||||
|
|
||||||
switch chainType {
|
|
||||||
case "ETH":
|
|
||||||
// EVM personal_sign verification of the nonce
|
|
||||||
msg := []byte(req.Nonce)
|
|
||||||
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
|
|
||||||
hash := ethcrypto.Keccak256(prefix, msg)
|
|
||||||
|
|
||||||
// Decode signature (expects 65-byte r||s||v, hex with optional 0x)
|
|
||||||
sigHex := strings.TrimSpace(req.Signature)
|
|
||||||
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
|
|
||||||
sigHex = sigHex[2:]
|
|
||||||
}
|
|
||||||
sig, err := hex.DecodeString(sigHex)
|
|
||||||
if err != nil || len(sig) != 65 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid signature format")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Normalize V to 0/1 as expected by geth
|
|
||||||
if sig[64] >= 27 {
|
|
||||||
sig[64] -= 27
|
|
||||||
}
|
|
||||||
pub, err := ethcrypto.SigToPub(hash, sig)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature recovery failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
|
|
||||||
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
|
|
||||||
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
|
|
||||||
if got != want {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
verified = true
|
|
||||||
|
|
||||||
case "SOL":
|
|
||||||
// Solana uses Ed25519 signatures
|
|
||||||
// Signature is base64-encoded, public key is the wallet address (base58)
|
|
||||||
|
|
||||||
// Decode base64 signature (Solana signatures are 64 bytes)
|
|
||||||
sig, err := base64.StdEncoding.DecodeString(req.Signature)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid base64 signature: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(sig) != 64 {
|
|
||||||
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid signature length: expected 64 bytes, got %d", len(sig)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode base58 public key (Solana wallet address)
|
|
||||||
pubKeyBytes, err := base58Decode(req.Wallet)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid wallet address: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(pubKeyBytes) != 32 {
|
|
||||||
writeError(w, http.StatusBadRequest, fmt.Sprintf("invalid public key length: expected 32 bytes, got %d", len(pubKeyBytes)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify Ed25519 signature
|
|
||||||
message := []byte(req.Nonce)
|
|
||||||
if !ed25519.Verify(ed25519.PublicKey(pubKeyBytes), message, sig) {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature verification failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
verified = true
|
|
||||||
|
|
||||||
default:
|
|
||||||
writeError(w, http.StatusBadRequest, fmt.Sprintf("unsupported chain type: %s (must be ETH or SOL)", chainType))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !verified {
|
|
||||||
writeError(w, http.StatusUnauthorized, fmt.Sprintf("signature verification failed: %v", verifyErr))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark nonce used now (after successful verification)
|
|
||||||
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if g.signingKey == nil {
|
|
||||||
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Issue access token (15m) and a refresh token (30d)
|
|
||||||
token, expUnix, err := g.generateJWT(ns, req.Wallet, 15*time.Minute)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// create refresh token
|
|
||||||
rbuf := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(rbuf); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to generate refresh token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
refresh := base64.RawURLEncoding.EncodeToString(rbuf)
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT INTO refresh_tokens(namespace_id, subject, token, audience, expires_at) VALUES (?, ?, ?, ?, datetime('now', '+30 days'))", nsID, req.Wallet, refresh, "gateway"); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure API key exists for this (namespace, wallet) and record ownerships
|
|
||||||
// This is done automatically after successful verification; no second nonce needed
|
|
||||||
var apiKey string
|
|
||||||
|
|
||||||
// Try existing linkage
|
|
||||||
r1, err := db.Query(internalCtx,
|
|
||||||
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
|
|
||||||
nsID, req.Wallet,
|
|
||||||
)
|
|
||||||
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
|
|
||||||
if s, ok := r1.Rows[0][0].(string); ok {
|
|
||||||
apiKey = s
|
|
||||||
} else {
|
|
||||||
b, _ := json.Marshal(r1.Rows[0][0])
|
|
||||||
_ = json.Unmarshal(b, &apiKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
|
||||||
// Create new API key with format ak_<random>:<namespace>
|
|
||||||
buf := make([]byte, 18)
|
|
||||||
if _, err := rand.Read(buf); err == nil {
|
|
||||||
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err == nil {
|
|
||||||
// Link wallet -> api_key
|
|
||||||
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
|
|
||||||
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
|
|
||||||
apiKeyID := rid.Rows[0][0]
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record ownerships (best-effort)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
@ -343,23 +147,16 @@ func (g *Gateway) verifyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
"expires_in": int(expUnix - time.Now().Unix()),
|
"expires_in": int(expUnix - time.Now().Unix()),
|
||||||
"refresh_token": refresh,
|
"refresh_token": refresh,
|
||||||
"subject": req.Wallet,
|
"subject": req.Wallet,
|
||||||
"namespace": ns,
|
"namespace": req.Namespace,
|
||||||
"api_key": apiKey,
|
"api_key": apiKey,
|
||||||
"nonce": req.Nonce,
|
"nonce": req.Nonce,
|
||||||
"signature_verified": true,
|
"signature_verified": true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// issueAPIKeyHandler creates or returns an API key for a verified wallet in a namespace.
|
|
||||||
// Requires: POST { wallet, nonce, signature, namespace }
|
|
||||||
// Behavior:
|
|
||||||
// - Validates nonce and signature like verifyHandler
|
|
||||||
// - Ensures namespace exists
|
|
||||||
// - If an API key already exists for (namespace, wallet), returns it; else creates one
|
|
||||||
// - Records namespace ownership mapping for the wallet and api_key
|
|
||||||
func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -371,6 +168,7 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Nonce string `json:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
|
ChainType string `json:"chain_type"`
|
||||||
Plan string `json:"plan"`
|
Plan string `json:"plan"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
@ -381,110 +179,33 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
// Use internal context to bypass authentication for system operations
|
verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
if err != nil || !verified {
|
||||||
db := g.client.Database()
|
writeError(w, http.StatusUnauthorized, "signature verification failed")
|
||||||
// Resolve namespace id
|
|
||||||
nsID, err := g.resolveNamespaceID(ctx, ns)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Validate nonce exists and not used/expired
|
|
||||||
// Normalize wallet address to lowercase for case-insensitive comparison
|
|
||||||
walletLower := strings.ToLower(strings.TrimSpace(req.Wallet))
|
|
||||||
q := "SELECT id FROM nonces WHERE namespace_id = ? AND LOWER(wallet) = LOWER(?) AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
|
||||||
nres, err := db.Query(internalCtx, q, nsID, walletLower, req.Nonce)
|
|
||||||
if err != nil || nres == nil || nres.Count == 0 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nonceID := nres.Rows[0][0]
|
|
||||||
// Verify signature like verifyHandler
|
|
||||||
msg := []byte(req.Nonce)
|
|
||||||
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
|
|
||||||
hash := ethcrypto.Keccak256(prefix, msg)
|
|
||||||
sigHex := strings.TrimSpace(req.Signature)
|
|
||||||
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
|
|
||||||
sigHex = sigHex[2:]
|
|
||||||
}
|
|
||||||
sig, err := hex.DecodeString(sigHex)
|
|
||||||
if err != nil || len(sig) != 65 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid signature format")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sig[64] >= 27 {
|
|
||||||
sig[64] -= 27
|
|
||||||
}
|
|
||||||
pub, err := ethcrypto.SigToPub(hash, sig)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature recovery failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
|
|
||||||
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
|
|
||||||
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
|
|
||||||
if got != want {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark nonce used
|
// Mark nonce used
|
||||||
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil {
|
nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
|
||||||
|
db := g.client.Database()
|
||||||
|
_, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
|
||||||
|
|
||||||
|
apiKey, err := g.authService.GetOrCreateAPIKey(ctx, req.Wallet, req.Namespace)
|
||||||
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check if api key exists for (namespace, wallet) via linkage table
|
|
||||||
var apiKey string
|
|
||||||
r1, err := db.Query(internalCtx, "SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1", nsID, req.Wallet)
|
|
||||||
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
|
|
||||||
if s, ok := r1.Rows[0][0].(string); ok {
|
|
||||||
apiKey = s
|
|
||||||
} else {
|
|
||||||
b, _ := json.Marshal(r1.Rows[0][0])
|
|
||||||
_ = json.Unmarshal(b, &apiKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
|
||||||
// Create new API key with format ak_<random>:<namespace>
|
|
||||||
buf := make([]byte, 18)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to generate api key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Create linkage
|
|
||||||
// Find api_key id
|
|
||||||
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
|
|
||||||
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
|
|
||||||
apiKeyID := rid.Rows[0][0]
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Record ownerships (best-effort)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"api_key": apiKey,
|
"api_key": apiKey,
|
||||||
"namespace": ns,
|
"namespace": req.Namespace,
|
||||||
"plan": func() string {
|
"plan": func() string {
|
||||||
if strings.TrimSpace(req.Plan) == "" {
|
if strings.TrimSpace(req.Plan) == "" {
|
||||||
return "free"
|
return "free"
|
||||||
} else {
|
|
||||||
return req.Plan
|
|
||||||
}
|
}
|
||||||
|
return req.Plan
|
||||||
}(),
|
}(),
|
||||||
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
|
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
|
||||||
})
|
})
|
||||||
@ -494,8 +215,8 @@ func (g *Gateway) issueAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Requires Authorization header with API key (Bearer or ApiKey or X-API-Key header).
|
// Requires Authorization header with API key (Bearer or ApiKey or X-API-Key header).
|
||||||
// Returns a JWT bound to the namespace derived from the API key record.
|
// Returns a JWT bound to the namespace derived from the API key record.
|
||||||
func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -507,10 +228,10 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusUnauthorized, "missing API key")
|
writeError(w, http.StatusUnauthorized, "missing API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate and get namespace
|
// Validate and get namespace
|
||||||
db := g.client.Database()
|
db := g.client.Database()
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
// Use internal context to bypass authentication for system operations
|
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
|
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
|
||||||
res, err := db.Query(internalCtx, q, key)
|
res, err := db.Query(internalCtx, q, key)
|
||||||
@ -518,28 +239,18 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusUnauthorized, "invalid API key")
|
writeError(w, http.StatusUnauthorized, "invalid API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ns string
|
var ns string
|
||||||
if s, ok := res.Rows[0][0].(string); ok {
|
if s, ok := res.Rows[0][0].(string); ok {
|
||||||
ns = s
|
ns = s
|
||||||
} else {
|
|
||||||
b, _ := json.Marshal(res.Rows[0][0])
|
|
||||||
_ = json.Unmarshal(b, &ns)
|
|
||||||
}
|
}
|
||||||
ns = strings.TrimSpace(ns)
|
|
||||||
if ns == "" {
|
token, expUnix, err := g.authService.GenerateJWT(ns, key, 15*time.Minute)
|
||||||
writeError(w, http.StatusUnauthorized, "invalid API key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if g.signingKey == nil {
|
|
||||||
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Subject is the API key string for now
|
|
||||||
token, expUnix, err := g.generateJWT(ns, key, 15*time.Minute)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
@ -549,8 +260,8 @@ func (g *Gateway) apiKeyToJWTHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -562,6 +273,7 @@ func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Nonce string `json:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
|
ChainType string `json:"chain_type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
@ -572,106 +284,45 @@ func (g *Gateway) registerHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
writeError(w, http.StatusBadRequest, "wallet, nonce and signature are required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
// Use internal context to bypass authentication for system operations
|
verified, err := g.authService.VerifySignature(ctx, req.Wallet, req.Nonce, req.Signature, req.ChainType)
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
if err != nil || !verified {
|
||||||
|
writeError(w, http.StatusUnauthorized, "signature verification failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark nonce used
|
||||||
|
nsID, _ := g.authService.ResolveNamespaceID(ctx, req.Namespace)
|
||||||
db := g.client.Database()
|
db := g.client.Database()
|
||||||
nsID, err := g.resolveNamespaceID(ctx, ns)
|
_, _ = db.Query(client.WithInternalAuth(ctx), "UPDATE nonces SET used_at = datetime('now') WHERE namespace_id = ? AND wallet = ? AND nonce = ?", nsID, strings.ToLower(req.Wallet), req.Nonce)
|
||||||
|
|
||||||
|
// In a real app we'd derive the public key from the signature, but for simplicity here
|
||||||
|
// we just use a placeholder or expect it in the request if needed.
|
||||||
|
// For Ethereum, we can recover it.
|
||||||
|
publicKey := "recovered-pk"
|
||||||
|
|
||||||
|
appID, err := g.authService.RegisterApp(ctx, req.Wallet, req.Namespace, req.Name, publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Validate nonce
|
|
||||||
q := "SELECT id FROM nonces WHERE namespace_id = ? AND wallet = ? AND nonce = ? AND used_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
|
||||||
nres, err := db.Query(internalCtx, q, nsID, req.Wallet, req.Nonce)
|
|
||||||
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid or expired nonce")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nonceID := nres.Rows[0][0]
|
|
||||||
|
|
||||||
// EVM personal_sign verification of the nonce
|
|
||||||
msg := []byte(req.Nonce)
|
|
||||||
prefix := []byte("\x19Ethereum Signed Message:\n" + strconv.Itoa(len(msg)))
|
|
||||||
hash := ethcrypto.Keccak256(prefix, msg)
|
|
||||||
|
|
||||||
// Decode signature (expects 65-byte r||s||v, hex with optional 0x)
|
|
||||||
sigHex := strings.TrimSpace(req.Signature)
|
|
||||||
if strings.HasPrefix(sigHex, "0x") || strings.HasPrefix(sigHex, "0X") {
|
|
||||||
sigHex = sigHex[2:]
|
|
||||||
}
|
|
||||||
sig, err := hex.DecodeString(sigHex)
|
|
||||||
if err != nil || len(sig) != 65 {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid signature format")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Normalize V to 0/1 as expected by geth
|
|
||||||
if sig[64] >= 27 {
|
|
||||||
sig[64] -= 27
|
|
||||||
}
|
|
||||||
pub, err := ethcrypto.SigToPub(hash, sig)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature recovery failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addr := ethcrypto.PubkeyToAddress(*pub).Hex()
|
|
||||||
want := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X"))
|
|
||||||
got := strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(addr, "0x"), "0X"))
|
|
||||||
if got != want {
|
|
||||||
writeError(w, http.StatusUnauthorized, "signature does not match wallet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark nonce used now (after successful verification)
|
|
||||||
if _, err := db.Query(internalCtx, "UPDATE nonces SET used_at = datetime('now') WHERE id = ?", nonceID); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive public key (uncompressed) hex
|
|
||||||
pubBytes := ethcrypto.FromECDSAPub(pub)
|
|
||||||
pubHex := "0x" + hex.EncodeToString(pubBytes)
|
|
||||||
|
|
||||||
// Generate client app_id
|
|
||||||
buf := make([]byte, 12)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to generate app id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
appID := "app_" + base64.RawURLEncoding.EncodeToString(buf)
|
|
||||||
|
|
||||||
// Persist app
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT INTO apps(namespace_id, app_id, name, public_key) VALUES (?, ?, ?, ?)", nsID, appID, req.Name, pubHex); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record namespace ownership by wallet (best-effort)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, ?, ?)", nsID, "wallet", req.Wallet)
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusCreated, map[string]any{
|
writeJSON(w, http.StatusCreated, map[string]any{
|
||||||
"client_id": appID,
|
"client_id": appID,
|
||||||
"app": map[string]any{
|
"app": map[string]any{
|
||||||
"app_id": appID,
|
"app_id": appID,
|
||||||
"name": req.Name,
|
"name": req.Name,
|
||||||
"public_key": pubHex,
|
"namespace": req.Namespace,
|
||||||
"namespace": ns,
|
"wallet": strings.ToLower(req.Wallet),
|
||||||
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
|
|
||||||
},
|
},
|
||||||
"signature_verified": true,
|
"signature_verified": true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -690,54 +341,20 @@ func (g *Gateway) refreshHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "refresh_token is required")
|
writeError(w, http.StatusBadRequest, "refresh_token is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
token, subject, expUnix, err := g.authService.RefreshToken(r.Context(), req.RefreshToken, req.Namespace)
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
|
||||||
// Use internal context to bypass authentication for system operations
|
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
|
||||||
db := g.client.Database()
|
|
||||||
nsID, err := g.resolveNamespaceID(ctx, ns)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusUnauthorized, err.Error())
|
||||||
return
|
|
||||||
}
|
|
||||||
q := "SELECT subject FROM refresh_tokens WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL AND (expires_at IS NULL OR expires_at > datetime('now')) LIMIT 1"
|
|
||||||
rres, err := db.Query(internalCtx, q, nsID, req.RefreshToken)
|
|
||||||
if err != nil || rres == nil || rres.Count == 0 {
|
|
||||||
writeError(w, http.StatusUnauthorized, "invalid or expired refresh token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
subject := ""
|
|
||||||
if len(rres.Rows) > 0 && len(rres.Rows[0]) > 0 {
|
|
||||||
if s, ok := rres.Rows[0][0].(string); ok {
|
|
||||||
subject = s
|
|
||||||
} else {
|
|
||||||
// fallback: format via json
|
|
||||||
b, _ := json.Marshal(rres.Rows[0][0])
|
|
||||||
_ = json.Unmarshal(b, &subject)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if g.signingKey == nil {
|
|
||||||
writeError(w, http.StatusServiceUnavailable, "signing key unavailable")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token, expUnix, err := g.generateJWT(ns, subject, 15*time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
"expires_in": int(expUnix - time.Now().Unix()),
|
"expires_in": int(expUnix - time.Now().Unix()),
|
||||||
"refresh_token": req.RefreshToken,
|
"refresh_token": req.RefreshToken,
|
||||||
"subject": subject,
|
"subject": subject,
|
||||||
"namespace": ns,
|
"namespace": req.Namespace,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1064,8 +681,8 @@ func (g *Gateway) loginPageHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// be revoked. If all=true is provided (and the request is authenticated via JWT),
|
// be revoked. If all=true is provided (and the request is authenticated via JWT),
|
||||||
// all tokens for the JWT subject within the namespace are revoked.
|
// all tokens for the JWT subject within the namespace are revoked.
|
||||||
func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -1081,38 +698,12 @@ func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusBadRequest, "invalid json body")
|
writeError(w, http.StatusBadRequest, "invalid json body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
// Use internal context to bypass authentication for system operations
|
var subject string
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
|
||||||
db := g.client.Database()
|
|
||||||
nsID, err := g.resolveNamespaceID(ctx, ns)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(req.RefreshToken) != "" {
|
|
||||||
// Revoke specific token
|
|
||||||
if _, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND token = ? AND revoked_at IS NULL", nsID, req.RefreshToken); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "revoked": 1})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.All {
|
if req.All {
|
||||||
// Require JWT to identify subject
|
|
||||||
var subject string
|
|
||||||
if v := ctx.Value(ctxKeyJWT); v != nil {
|
if v := ctx.Value(ctxKeyJWT); v != nil {
|
||||||
if claims, ok := v.(*jwtClaims); ok && claims != nil {
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||||
subject = strings.TrimSpace(claims.Sub)
|
subject = strings.TrimSpace(claims.Sub)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1120,23 +711,19 @@ func (g *Gateway) logoutHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusUnauthorized, "jwt required for all=true")
|
writeError(w, http.StatusUnauthorized, "jwt required for all=true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := db.Query(internalCtx, "UPDATE refresh_tokens SET revoked_at = datetime('now') WHERE namespace_id = ? AND subject = ? AND revoked_at IS NULL", nsID, subject); err != nil {
|
}
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
if err := g.authService.RevokeToken(ctx, req.Namespace, req.RefreshToken, req.All, subject); err != nil {
|
||||||
}
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "revoked": "all"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeError(w, http.StatusBadRequest, "nothing to revoke: provide refresh_token or all=true")
|
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// simpleAPIKeyHandler creates an API key directly from a wallet address without signature verification
|
|
||||||
// This is a simplified flow for development/testing
|
|
||||||
// Requires: POST { wallet, namespace }
|
|
||||||
func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if g.client == nil {
|
if g.authService == nil {
|
||||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -1159,114 +746,16 @@ func (g *Gateway) simpleAPIKeyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ns := strings.TrimSpace(req.Namespace)
|
apiKey, err := g.authService.GetOrCreateAPIKey(r.Context(), req.Wallet, req.Namespace)
|
||||||
if ns == "" {
|
if err != nil {
|
||||||
ns = strings.TrimSpace(g.cfg.ClientNamespace)
|
|
||||||
if ns == "" {
|
|
||||||
ns = "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
|
||||||
db := g.client.Database()
|
|
||||||
|
|
||||||
// Resolve or create namespace
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT OR IGNORE INTO namespaces(name) VALUES (?)", ns); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nres, err := db.Query(internalCtx, "SELECT id FROM namespaces WHERE name = ? LIMIT 1", ns)
|
|
||||||
if err != nil || nres == nil || nres.Count == 0 || len(nres.Rows) == 0 || len(nres.Rows[0]) == 0 {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to resolve namespace")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nsID := nres.Rows[0][0]
|
|
||||||
|
|
||||||
// Check if api key already exists for (namespace, wallet)
|
|
||||||
var apiKey string
|
|
||||||
r1, err := db.Query(internalCtx,
|
|
||||||
"SELECT api_keys.key FROM wallet_api_keys JOIN api_keys ON wallet_api_keys.api_key_id = api_keys.id WHERE wallet_api_keys.namespace_id = ? AND LOWER(wallet_api_keys.wallet) = LOWER(?) LIMIT 1",
|
|
||||||
nsID, req.Wallet,
|
|
||||||
)
|
|
||||||
if err == nil && r1 != nil && r1.Count > 0 && len(r1.Rows) > 0 && len(r1.Rows[0]) > 0 {
|
|
||||||
if s, ok := r1.Rows[0][0].(string); ok {
|
|
||||||
apiKey = s
|
|
||||||
} else {
|
|
||||||
b, _ := json.Marshal(r1.Rows[0][0])
|
|
||||||
_ = json.Unmarshal(b, &apiKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no existing key, create a new one
|
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
|
||||||
buf := make([]byte, 18)
|
|
||||||
if _, err := rand.Read(buf); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to generate api key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
apiKey = "ak_" + base64.RawURLEncoding.EncodeToString(buf) + ":" + ns
|
|
||||||
|
|
||||||
if _, err := db.Query(internalCtx, "INSERT INTO api_keys(key, name, namespace_id) VALUES (?, ?, ?)", apiKey, "", nsID); err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Link wallet to api key
|
|
||||||
rid, err := db.Query(internalCtx, "SELECT id FROM api_keys WHERE key = ? LIMIT 1", apiKey)
|
|
||||||
if err == nil && rid != nil && rid.Count > 0 && len(rid.Rows) > 0 && len(rid.Rows[0]) > 0 {
|
|
||||||
apiKeyID := rid.Rows[0][0]
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO wallet_api_keys(namespace_id, wallet, api_key_id) VALUES (?, ?, ?)", nsID, strings.ToLower(req.Wallet), apiKeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record ownerships (best-effort)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'api_key', ?)", nsID, apiKey)
|
|
||||||
_, _ = db.Query(internalCtx, "INSERT OR IGNORE INTO namespace_ownership(namespace_id, owner_type, owner_id) VALUES (?, 'wallet', ?)", nsID, req.Wallet)
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"api_key": apiKey,
|
"api_key": apiKey,
|
||||||
"namespace": ns,
|
"namespace": req.Namespace,
|
||||||
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
|
"wallet": strings.ToLower(strings.TrimPrefix(strings.TrimPrefix(req.Wallet, "0x"), "0X")),
|
||||||
"created": time.Now().Format(time.RFC3339),
|
"created": time.Now().Format(time.RFC3339),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// base58Decode decodes a base58-encoded string (Bitcoin alphabet)
|
|
||||||
// Used for decoding Solana public keys (base58-encoded 32-byte ed25519 public keys)
|
|
||||||
func base58Decode(encoded string) ([]byte, error) {
|
|
||||||
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
|
||||||
|
|
||||||
// Build reverse lookup map
|
|
||||||
lookup := make(map[rune]int)
|
|
||||||
for i, c := range alphabet {
|
|
||||||
lookup[c] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to big integer
|
|
||||||
num := big.NewInt(0)
|
|
||||||
base := big.NewInt(58)
|
|
||||||
|
|
||||||
for _, c := range encoded {
|
|
||||||
val, ok := lookup[c]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid base58 character: %c", c)
|
|
||||||
}
|
|
||||||
num.Mul(num, base)
|
|
||||||
num.Add(num, big.NewInt(int64(val)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to bytes
|
|
||||||
decoded := num.Bytes()
|
|
||||||
|
|
||||||
// Add leading zeros for each leading '1' in the input
|
|
||||||
for _, c := range encoded {
|
|
||||||
if c != '1' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
decoded = append([]byte{0}, decoded...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return decoded, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,23 +4,28 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
"github.com/DeBrosOfficial/network/pkg/config"
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
"github.com/multiformats/go-multiaddr"
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
olriclib "github.com/olric-data/olric"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
_ "github.com/rqlite/gorqlite/stdlib"
|
_ "github.com/rqlite/gorqlite/stdlib"
|
||||||
@ -61,13 +66,11 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
logger *logging.ColoredLogger
|
logger *logging.ColoredLogger
|
||||||
cfg *Config
|
cfg *Config
|
||||||
client client.NetworkClient
|
client client.NetworkClient
|
||||||
nodePeerID string // The node's actual peer ID from its identity file (overrides client's peer ID)
|
nodePeerID string // The node's actual peer ID from its identity file (overrides client's peer ID)
|
||||||
startedAt time.Time
|
startedAt time.Time
|
||||||
signingKey *rsa.PrivateKey
|
|
||||||
keyID string
|
|
||||||
|
|
||||||
// rqlite SQL connection and HTTP ORM gateway
|
// rqlite SQL connection and HTTP ORM gateway
|
||||||
sqlDB *sql.DB
|
sqlDB *sql.DB
|
||||||
@ -83,7 +86,19 @@ type Gateway struct {
|
|||||||
|
|
||||||
// Local pub/sub bypass for same-gateway subscribers
|
// Local pub/sub bypass for same-gateway subscribers
|
||||||
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
|
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
|
||||||
|
presenceMembers map[string][]PresenceMember // topicKey -> members
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
presenceMu sync.RWMutex
|
||||||
|
|
||||||
|
// Serverless function engine
|
||||||
|
serverlessEngine *serverless.Engine
|
||||||
|
serverlessRegistry *serverless.Registry
|
||||||
|
serverlessInvoker *serverless.Invoker
|
||||||
|
serverlessWSMgr *serverless.WSManager
|
||||||
|
serverlessHandlers *ServerlessHandlers
|
||||||
|
|
||||||
|
// Authentication service
|
||||||
|
authService *auth.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
// localSubscriber represents a WebSocket subscriber for local message delivery
|
// localSubscriber represents a WebSocket subscriber for local message delivery
|
||||||
@ -92,6 +107,14 @@ type localSubscriber struct {
|
|||||||
namespace string
|
namespace string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PresenceMember represents a member in a topic's presence list
|
||||||
|
type PresenceMember struct {
|
||||||
|
MemberID string `json:"member_id"`
|
||||||
|
JoinedAt int64 `json:"joined_at"` // Unix timestamp
|
||||||
|
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||||
|
ConnID string `json:"-"` // Internal: for tracking which connection
|
||||||
|
}
|
||||||
|
|
||||||
// New creates and initializes a new Gateway instance
|
// New creates and initializes a new Gateway instance
|
||||||
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
|
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
|
||||||
@ -128,16 +151,7 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
|||||||
nodePeerID: cfg.NodePeerID,
|
nodePeerID: cfg.NodePeerID,
|
||||||
startedAt: time.Now(),
|
startedAt: time.Now(),
|
||||||
localSubscribers: make(map[string][]*localSubscriber),
|
localSubscribers: make(map[string][]*localSubscriber),
|
||||||
}
|
presenceMembers: make(map[string][]PresenceMember),
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Generating RSA signing key...")
|
|
||||||
// Generate local RSA signing key for JWKS/JWT (ephemeral for now)
|
|
||||||
if key, err := rsa.GenerateKey(rand.Reader, 2048); err == nil {
|
|
||||||
gw.signingKey = key
|
|
||||||
gw.keyID = "gw-" + strconv.FormatInt(time.Now().Unix(), 10)
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "RSA key generated successfully")
|
|
||||||
} else {
|
|
||||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to generate RSA key; jwks will be empty", zap.Error(err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
|
||||||
@ -298,6 +312,104 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
|||||||
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
|
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
|
||||||
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
|
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
|
||||||
|
|
||||||
|
// Initialize serverless function engine
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...")
|
||||||
|
if gw.ormClient != nil && gw.ipfsClient != nil {
|
||||||
|
// Create serverless registry (stores functions in RQLite + IPFS)
|
||||||
|
registryCfg := serverless.RegistryConfig{
|
||||||
|
IPFSAPIURL: ipfsAPIURL,
|
||||||
|
}
|
||||||
|
registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger)
|
||||||
|
gw.serverlessRegistry = registry
|
||||||
|
|
||||||
|
// Create WebSocket manager for function streaming
|
||||||
|
gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger)
|
||||||
|
|
||||||
|
// Get underlying Olric client if available
|
||||||
|
var olricClient olriclib.Client
|
||||||
|
if oc := gw.getOlricClient(); oc != nil {
|
||||||
|
olricClient = oc.UnderlyingClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create host functions provider (allows functions to call Orama services)
|
||||||
|
// Get pubsub adapter from client for serverless functions
|
||||||
|
var pubsubAdapter *pubsub.ClientAdapter
|
||||||
|
if gw.client != nil {
|
||||||
|
if concreteClient, ok := gw.client.(*client.Client); ok {
|
||||||
|
pubsubAdapter = concreteClient.PubSubAdapter()
|
||||||
|
if pubsubAdapter != nil {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions")
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostFuncsCfg := serverless.HostFunctionsConfig{
|
||||||
|
IPFSAPIURL: ipfsAPIURL,
|
||||||
|
HTTPTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
hostFuncs := serverless.NewHostFunctions(
|
||||||
|
gw.ormClient,
|
||||||
|
olricClient,
|
||||||
|
gw.ipfsClient,
|
||||||
|
pubsubAdapter, // pubsub adapter for serverless functions
|
||||||
|
gw.serverlessWSMgr,
|
||||||
|
nil, // secrets manager - TODO: implement
|
||||||
|
hostFuncsCfg,
|
||||||
|
logger.Logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create WASM engine configuration
|
||||||
|
engineCfg := serverless.DefaultConfig()
|
||||||
|
engineCfg.DefaultMemoryLimitMB = 128
|
||||||
|
engineCfg.MaxMemoryLimitMB = 256
|
||||||
|
engineCfg.DefaultTimeoutSeconds = 30
|
||||||
|
engineCfg.MaxTimeoutSeconds = 60
|
||||||
|
engineCfg.ModuleCacheSize = 100
|
||||||
|
|
||||||
|
// Create WASM engine
|
||||||
|
engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
|
||||||
|
if engineErr != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr))
|
||||||
|
} else {
|
||||||
|
gw.serverlessEngine = engine
|
||||||
|
|
||||||
|
// Create invoker
|
||||||
|
gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger)
|
||||||
|
|
||||||
|
// Create HTTP handlers
|
||||||
|
gw.serverlessHandlers = NewServerlessHandlers(
|
||||||
|
gw.serverlessInvoker,
|
||||||
|
registry,
|
||||||
|
gw.serverlessWSMgr,
|
||||||
|
logger.Logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize auth service
|
||||||
|
// For now using ephemeral key, can be loaded from config later
|
||||||
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
gw.authService = authService
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
|
||||||
|
zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB),
|
||||||
|
zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds),
|
||||||
|
zap.Int("module_cache_size", engineCfg.ModuleCacheSize),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled")
|
||||||
|
}
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
|
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
|
||||||
return gw, nil
|
return gw, nil
|
||||||
}
|
}
|
||||||
@ -309,6 +421,14 @@ func (g *Gateway) withInternalAuth(ctx context.Context) context.Context {
|
|||||||
|
|
||||||
// Close disconnects the gateway client
|
// Close disconnects the gateway client
|
||||||
func (g *Gateway) Close() {
|
func (g *Gateway) Close() {
|
||||||
|
// Close serverless engine first
|
||||||
|
if g.serverlessEngine != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
if err := g.serverlessEngine.Close(ctx); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err))
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
if g.client != nil {
|
if g.client != nil {
|
||||||
if err := g.client.Disconnect(); err != nil {
|
if err := g.client.Disconnect(); err != nil {
|
||||||
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))
|
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))
|
||||||
|
|||||||
@ -3,22 +3,32 @@ package gateway
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTGenerateAndParse(t *testing.T) {
|
func TestJWTGenerateAndParse(t *testing.T) {
|
||||||
gw := &Gateway{}
|
|
||||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
gw.signingKey = key
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
gw.keyID = "kid"
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
|
||||||
tok, exp, err := gw.generateJWT("ns1", "subj", time.Minute)
|
svc, err := auth.NewService(nil, nil, string(keyPEM), "default")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, exp, err := svc.GenerateJWT("ns1", "subj", time.Minute)
|
||||||
if err != nil || exp <= 0 {
|
if err != nil || exp <= 0 {
|
||||||
t.Fatalf("gen err=%v exp=%d", err, exp)
|
t.Fatalf("gen err=%v exp=%d", err, exp)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := gw.parseAndVerifyJWT(tok)
|
claims, err := svc.ParseAndVerifyJWT(tok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("verify err: %v", err)
|
t.Fatalf("verify err: %v", err)
|
||||||
}
|
}
|
||||||
@ -28,17 +38,23 @@ func TestJWTGenerateAndParse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestJWTExpired(t *testing.T) {
|
func TestJWTExpired(t *testing.T) {
|
||||||
gw := &Gateway{}
|
|
||||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
gw.signingKey = key
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
gw.keyID = "kid"
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
|
||||||
|
svc, err := auth.NewService(nil, nil, string(keyPEM), "default")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Use sufficiently negative TTL to bypass allowed clock skew
|
// Use sufficiently negative TTL to bypass allowed clock skew
|
||||||
tok, _, err := gw.generateJWT("ns1", "subj", -2*time.Minute)
|
tok, _, err := svc.GenerateJWT("ns1", "subj", -2*time.Minute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("gen err=%v", err)
|
t.Fatalf("gen err=%v", err)
|
||||||
}
|
}
|
||||||
if _, err := gw.parseAndVerifyJWT(tok); err == nil {
|
if _, err := svc.ParseAndVerifyJWT(tok); err == nil {
|
||||||
t.Fatalf("expected expired error")
|
t.Fatalf("expected expired error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@ -62,11 +63,8 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
|||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Allow public endpoints without auth
|
|
||||||
if isPublicPath(r.URL.Path) {
|
isPublic := isPublicPath(r.URL.Path)
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1) Try JWT Bearer first if Authorization looks like one
|
// 1) Try JWT Bearer first if Authorization looks like one
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
@ -74,7 +72,7 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
|||||||
if strings.HasPrefix(lower, "bearer ") {
|
if strings.HasPrefix(lower, "bearer ") {
|
||||||
tok := strings.TrimSpace(auth[len("Bearer "):])
|
tok := strings.TrimSpace(auth[len("Bearer "):])
|
||||||
if strings.Count(tok, ".") == 2 {
|
if strings.Count(tok, ".") == 2 {
|
||||||
if claims, err := g.parseAndVerifyJWT(tok); err == nil {
|
if claims, err := g.authService.ParseAndVerifyJWT(tok); err == nil {
|
||||||
// Attach JWT claims and namespace to context
|
// Attach JWT claims and namespace to context
|
||||||
ctx := context.WithValue(r.Context(), ctxKeyJWT, claims)
|
ctx := context.WithValue(r.Context(), ctxKeyJWT, claims)
|
||||||
if ns := strings.TrimSpace(claims.Namespace); ns != "" {
|
if ns := strings.TrimSpace(claims.Namespace); ns != "" {
|
||||||
@ -91,6 +89,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
|||||||
// 2) Fallback to API key (validate against DB)
|
// 2) Fallback to API key (validate against DB)
|
||||||
key := extractAPIKey(r)
|
key := extractAPIKey(r)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
|
if isPublic {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("WWW-Authenticate", "Bearer realm=\"gateway\", charset=\"UTF-8\"")
|
w.Header().Set("WWW-Authenticate", "Bearer realm=\"gateway\", charset=\"UTF-8\"")
|
||||||
writeError(w, http.StatusUnauthorized, "missing API key")
|
writeError(w, http.StatusUnauthorized, "missing API key")
|
||||||
return
|
return
|
||||||
@ -104,6 +106,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
|||||||
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
|
q := "SELECT namespaces.name FROM api_keys JOIN namespaces ON api_keys.namespace_id = namespaces.id WHERE api_keys.key = ? LIMIT 1"
|
||||||
res, err := db.Query(internalCtx, q, key)
|
res, err := db.Query(internalCtx, q, key)
|
||||||
if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 {
|
if err != nil || res == nil || res.Count == 0 || len(res.Rows) == 0 || len(res.Rows[0]) == 0 {
|
||||||
|
if isPublic {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
|
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
|
||||||
writeError(w, http.StatusUnauthorized, "invalid API key")
|
writeError(w, http.StatusUnauthorized, "invalid API key")
|
||||||
return
|
return
|
||||||
@ -118,6 +124,10 @@ func (g *Gateway) authMiddleware(next http.Handler) http.Handler {
|
|||||||
ns = strings.TrimSpace(ns)
|
ns = strings.TrimSpace(ns)
|
||||||
}
|
}
|
||||||
if ns == "" {
|
if ns == "" {
|
||||||
|
if isPublic {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
|
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
|
||||||
writeError(w, http.StatusUnauthorized, "invalid API key")
|
writeError(w, http.StatusUnauthorized, "invalid API key")
|
||||||
return
|
return
|
||||||
@ -183,6 +193,11 @@ func isPublicPath(p string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serverless invocation is public (authorization is handled within the invoker)
|
||||||
|
if strings.HasPrefix(p, "/v1/invoke/") || (strings.HasPrefix(p, "/v1/functions/") && strings.HasSuffix(p, "/invoke")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch p {
|
switch p {
|
||||||
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
|
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
|
||||||
return true
|
return true
|
||||||
@ -235,7 +250,7 @@ func (g *Gateway) authorizationMiddleware(next http.Handler) http.Handler {
|
|||||||
apiKeyFallback := ""
|
apiKeyFallback := ""
|
||||||
|
|
||||||
if v := ctx.Value(ctxKeyJWT); v != nil {
|
if v := ctx.Value(ctxKeyJWT); v != nil {
|
||||||
if claims, ok := v.(*jwtClaims); ok && claims != nil && strings.TrimSpace(claims.Sub) != "" {
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil && strings.TrimSpace(claims.Sub) != "" {
|
||||||
// Determine subject type.
|
// Determine subject type.
|
||||||
// If subject looks like an API key (e.g., ak_<random>:<namespace>),
|
// If subject looks like an API key (e.g., ak_<random>:<namespace>),
|
||||||
// treat it as an API key owner; otherwise assume a wallet subject.
|
// treat it as an API key owner; otherwise assume a wallet subject.
|
||||||
@ -324,6 +339,9 @@ func requiresNamespaceOwnership(p string) bool {
|
|||||||
if strings.HasPrefix(p, "/v1/proxy/") {
|
if strings.HasPrefix(p, "/v1/proxy/") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(p, "/v1/functions") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -51,6 +52,22 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Presence handling
|
||||||
|
enablePresence := r.URL.Query().Get("presence") == "true"
|
||||||
|
memberID := r.URL.Query().Get("member_id")
|
||||||
|
memberMetaStr := r.URL.Query().Get("member_meta")
|
||||||
|
var memberMeta map[string]interface{}
|
||||||
|
if memberMetaStr != "" {
|
||||||
|
_ = json.Unmarshal([]byte(memberMetaStr), &memberMeta)
|
||||||
|
}
|
||||||
|
|
||||||
|
if enablePresence && memberID == "" {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id")
|
||||||
|
writeError(w, http.StatusBadRequest, "missing 'member_id' for presence")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
|
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
|
||||||
@ -73,6 +90,36 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
subscriberCount := len(g.localSubscribers[topicKey])
|
subscriberCount := len(g.localSubscribers[topicKey])
|
||||||
g.mu.Unlock()
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
connID := uuid.New().String()
|
||||||
|
if enablePresence {
|
||||||
|
member := PresenceMember{
|
||||||
|
MemberID: memberID,
|
||||||
|
JoinedAt: time.Now().Unix(),
|
||||||
|
Meta: memberMeta,
|
||||||
|
ConnID: connID,
|
||||||
|
}
|
||||||
|
|
||||||
|
g.presenceMu.Lock()
|
||||||
|
g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member)
|
||||||
|
g.presenceMu.Unlock()
|
||||||
|
|
||||||
|
// Broadcast join event
|
||||||
|
joinEvent := map[string]interface{}{
|
||||||
|
"type": "presence.join",
|
||||||
|
"member_id": memberID,
|
||||||
|
"meta": memberMeta,
|
||||||
|
"timestamp": member.JoinedAt,
|
||||||
|
}
|
||||||
|
eventData, _ := json.Marshal(joinEvent)
|
||||||
|
// Use a background context for the broadcast to ensure it finishes even if the connection closes immediately
|
||||||
|
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||||
|
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("member_id", memberID))
|
||||||
|
}
|
||||||
|
|
||||||
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
|
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
|
||||||
zap.String("topic", topic),
|
zap.String("topic", topic),
|
||||||
zap.String("namespace", ns),
|
zap.String("namespace", ns),
|
||||||
@ -93,6 +140,36 @@ func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
delete(g.localSubscribers, topicKey)
|
delete(g.localSubscribers, topicKey)
|
||||||
}
|
}
|
||||||
g.mu.Unlock()
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
if enablePresence {
|
||||||
|
g.presenceMu.Lock()
|
||||||
|
members := g.presenceMembers[topicKey]
|
||||||
|
for i, m := range members {
|
||||||
|
if m.ConnID == connID {
|
||||||
|
g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(g.presenceMembers[topicKey]) == 0 {
|
||||||
|
delete(g.presenceMembers, topicKey)
|
||||||
|
}
|
||||||
|
g.presenceMu.Unlock()
|
||||||
|
|
||||||
|
// Broadcast leave event
|
||||||
|
leaveEvent := map[string]interface{}{
|
||||||
|
"type": "presence.leave",
|
||||||
|
"member_id": memberID,
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
eventData, _ := json.Marshal(leaveEvent)
|
||||||
|
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||||
|
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: member left presence",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("member_id", memberID))
|
||||||
|
}
|
||||||
|
|
||||||
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
|
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
|
||||||
zap.String("topic", topic),
|
zap.String("topic", topic),
|
||||||
zap.Int("remaining_subscribers", remainingCount))
|
zap.Int("remaining_subscribers", remainingCount))
|
||||||
@ -349,3 +426,44 @@ func namespacePrefix(ns string) string {
|
|||||||
func namespacedTopic(ns, topic string) string {
|
func namespacedTopic(ns, topic string) string {
|
||||||
return namespacePrefix(ns) + topic
|
return namespacePrefix(ns) + topic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic
|
||||||
|
func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topic := r.URL.Query().Get("topic")
|
||||||
|
if topic == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topicKey := fmt.Sprintf("%s.%s", ns, topic)
|
||||||
|
|
||||||
|
g.presenceMu.RLock()
|
||||||
|
members, ok := g.presenceMembers[topicKey]
|
||||||
|
g.presenceMu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"topic": topic,
|
||||||
|
"members": []PresenceMember{},
|
||||||
|
"count": 0,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"topic": topic,
|
||||||
|
"members": members,
|
||||||
|
"count": len(members),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -14,8 +14,8 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
mux.HandleFunc("/v1/status", g.statusHandler)
|
mux.HandleFunc("/v1/status", g.statusHandler)
|
||||||
|
|
||||||
// auth endpoints
|
// auth endpoints
|
||||||
mux.HandleFunc("/v1/auth/jwks", g.jwksHandler)
|
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
|
||||||
mux.HandleFunc("/.well-known/jwks.json", g.jwksHandler)
|
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
|
||||||
mux.HandleFunc("/v1/auth/login", g.loginPageHandler)
|
mux.HandleFunc("/v1/auth/login", g.loginPageHandler)
|
||||||
mux.HandleFunc("/v1/auth/challenge", g.challengeHandler)
|
mux.HandleFunc("/v1/auth/challenge", g.challengeHandler)
|
||||||
mux.HandleFunc("/v1/auth/verify", g.verifyHandler)
|
mux.HandleFunc("/v1/auth/verify", g.verifyHandler)
|
||||||
@ -44,6 +44,7 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
mux.HandleFunc("/v1/pubsub/ws", g.pubsubWebsocketHandler)
|
mux.HandleFunc("/v1/pubsub/ws", g.pubsubWebsocketHandler)
|
||||||
mux.HandleFunc("/v1/pubsub/publish", g.pubsubPublishHandler)
|
mux.HandleFunc("/v1/pubsub/publish", g.pubsubPublishHandler)
|
||||||
mux.HandleFunc("/v1/pubsub/topics", g.pubsubTopicsHandler)
|
mux.HandleFunc("/v1/pubsub/topics", g.pubsubTopicsHandler)
|
||||||
|
mux.HandleFunc("/v1/pubsub/presence", g.pubsubPresenceHandler)
|
||||||
|
|
||||||
// anon proxy (authenticated users only)
|
// anon proxy (authenticated users only)
|
||||||
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)
|
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)
|
||||||
@ -63,5 +64,10 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
mux.HandleFunc("/v1/storage/get/", g.storageGetHandler)
|
mux.HandleFunc("/v1/storage/get/", g.storageGetHandler)
|
||||||
mux.HandleFunc("/v1/storage/unpin/", g.storageUnpinHandler)
|
mux.HandleFunc("/v1/storage/unpin/", g.storageUnpinHandler)
|
||||||
|
|
||||||
|
// serverless functions (if enabled)
|
||||||
|
if g.serverlessHandlers != nil {
|
||||||
|
g.serverlessHandlers.RegisterRoutes(mux)
|
||||||
|
}
|
||||||
|
|
||||||
return g.withMiddleware(mux)
|
return g.withMiddleware(mux)
|
||||||
}
|
}
|
||||||
|
|||||||
694
pkg/gateway/serverless_handlers.go
Normal file
694
pkg/gateway/serverless_handlers.go
Normal file
@ -0,0 +1,694 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerlessHandlers contains handlers for serverless function endpoints.
|
||||||
|
// It's a separate struct to keep the Gateway struct clean.
|
||||||
|
type ServerlessHandlers struct {
|
||||||
|
invoker *serverless.Invoker
|
||||||
|
registry serverless.FunctionRegistry
|
||||||
|
wsManager *serverless.WSManager
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerlessHandlers creates a new ServerlessHandlers instance.
|
||||||
|
func NewServerlessHandlers(
|
||||||
|
invoker *serverless.Invoker,
|
||||||
|
registry serverless.FunctionRegistry,
|
||||||
|
wsManager *serverless.WSManager,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *ServerlessHandlers {
|
||||||
|
return &ServerlessHandlers{
|
||||||
|
invoker: invoker,
|
||||||
|
registry: registry,
|
||||||
|
wsManager: wsManager,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers all serverless routes on the given mux.
|
||||||
|
func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) {
|
||||||
|
// Function management
|
||||||
|
mux.HandleFunc("/v1/functions", h.handleFunctions)
|
||||||
|
mux.HandleFunc("/v1/functions/", h.handleFunctionByName)
|
||||||
|
|
||||||
|
// Direct invoke endpoint
|
||||||
|
mux.HandleFunc("/v1/invoke/", h.handleInvoke)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy)
|
||||||
|
func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.listFunctions(w, r)
|
||||||
|
case http.MethodPost:
|
||||||
|
h.deployFunction(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFunctionByName handles operations on a specific function
|
||||||
|
// Routes:
|
||||||
|
// - GET /v1/functions/{name} - Get function info
|
||||||
|
// - DELETE /v1/functions/{name} - Delete function
|
||||||
|
// - POST /v1/functions/{name}/invoke - Invoke function
|
||||||
|
// - GET /v1/functions/{name}/versions - List versions
|
||||||
|
// - GET /v1/functions/{name}/logs - Get logs
|
||||||
|
// - WS /v1/functions/{name}/ws - WebSocket invoke
|
||||||
|
func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse path: /v1/functions/{name}[/{action}]
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/v1/functions/")
|
||||||
|
parts := strings.SplitN(path, "/", 2)
|
||||||
|
|
||||||
|
if len(parts) == 0 || parts[0] == "" {
|
||||||
|
http.Error(w, "Function name required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := parts[0]
|
||||||
|
action := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
action = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse version from name if present (e.g., "myfunction@2")
|
||||||
|
version := 0
|
||||||
|
if idx := strings.Index(name, "@"); idx > 0 {
|
||||||
|
vStr := name[idx+1:]
|
||||||
|
name = name[:idx]
|
||||||
|
if v, err := strconv.Atoi(vStr); err == nil {
|
||||||
|
version = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case "invoke":
|
||||||
|
h.invokeFunction(w, r, name, version)
|
||||||
|
case "ws":
|
||||||
|
h.handleWebSocket(w, r, name, version)
|
||||||
|
case "versions":
|
||||||
|
h.listVersions(w, r, name)
|
||||||
|
case "logs":
|
||||||
|
h.getFunctionLogs(w, r, name)
|
||||||
|
case "":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getFunctionInfo(w, r, name, version)
|
||||||
|
case http.MethodDelete:
|
||||||
|
h.deleteFunction(w, r, name, version)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unknown action", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version]
|
||||||
|
func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse path: /v1/invoke/{namespace}/{name}[@version]
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/")
|
||||||
|
parts := strings.SplitN(path, "/", 2)
|
||||||
|
|
||||||
|
if len(parts) < 2 {
|
||||||
|
http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace := parts[0]
|
||||||
|
name := parts[1]
|
||||||
|
|
||||||
|
// Parse version if present
|
||||||
|
version := 0
|
||||||
|
if idx := strings.Index(name, "@"); idx > 0 {
|
||||||
|
vStr := name[idx+1:]
|
||||||
|
name = name[:idx]
|
||||||
|
if v, err := strconv.Atoi(vStr); err == nil {
|
||||||
|
version = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.invokeFunction(w, r, namespace+"/"+name, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// listFunctions handles GET /v1/functions
|
||||||
|
func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
// Get namespace from JWT if available
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
functions, err := h.registry.List(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to list functions", zap.Error(err))
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to list functions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"functions": functions,
|
||||||
|
"count": len(functions),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// deployFunction handles POST /v1/functions
|
||||||
|
func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse multipart form (for WASM upload) or JSON
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
var def serverless.FunctionDefinition
|
||||||
|
var wasmBytes []byte
|
||||||
|
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
// Parse multipart form
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get metadata from form field
|
||||||
|
metadataStr := r.FormValue("metadata")
|
||||||
|
if metadataStr != "" {
|
||||||
|
if err := json.Unmarshal([]byte(metadataStr), &def); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get name from form if not in metadata
|
||||||
|
if def.Name == "" {
|
||||||
|
def.Name = r.FormValue("name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from form if not in metadata
|
||||||
|
if def.Namespace == "" {
|
||||||
|
def.Namespace = r.FormValue("namespace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get other configuration fields from form
|
||||||
|
if v := r.FormValue("is_public"); v != "" {
|
||||||
|
def.IsPublic, _ = strconv.ParseBool(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("memory_limit_mb"); v != "" {
|
||||||
|
def.MemoryLimitMB, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("timeout_seconds"); v != "" {
|
||||||
|
def.TimeoutSeconds, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("retry_count"); v != "" {
|
||||||
|
def.RetryCount, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("retry_delay_seconds"); v != "" {
|
||||||
|
def.RetryDelaySeconds, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get WASM file
|
||||||
|
file, _, err := r.FormFile("wasm")
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "WASM file required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
wasmBytes, err = io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// JSON body with base64-encoded WASM
|
||||||
|
var req struct {
|
||||||
|
serverless.FunctionDefinition
|
||||||
|
WASMBase64 string `json:"wasm_base64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
def = req.FunctionDefinition
|
||||||
|
|
||||||
|
if req.WASMBase64 != "" {
|
||||||
|
// Decode base64 WASM - for now, just reject this method
|
||||||
|
writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from JWT if not provided
|
||||||
|
if def.Namespace == "" {
|
||||||
|
def.Namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if def.Name == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "Function name required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if def.Namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "Namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(wasmBytes) == 0 {
|
||||||
|
writeError(w, http.StatusBadRequest, "WASM bytecode required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
oldFn, err := h.registry.Register(ctx, &def, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to deploy function",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate cache for the old version to ensure the new one is loaded
|
||||||
|
if oldFn != nil {
|
||||||
|
h.invoker.InvalidateCache(oldFn.WASMCID)
|
||||||
|
h.logger.Debug("Invalidated function cache",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.String("old_wasm_cid", oldFn.WASMCID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("Function deployed",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.String("namespace", def.Namespace),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Fetch the deployed function to return
|
||||||
|
fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version)
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||||
|
"message": "Function deployed successfully",
|
||||||
|
"name": def.Name,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||||
|
"message": "Function deployed successfully",
|
||||||
|
"function": fn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFunctionInfo handles GET /v1/functions/{name}
|
||||||
|
func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
fn, err := h.registry.Get(ctx, namespace, name, version)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteFunction handles DELETE /v1/functions/{name}
|
||||||
|
func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := h.registry.Delete(ctx, namespace, name, version); err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to delete function")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"message": "Function deleted successfully",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// invokeFunction handles POST /v1/functions/{name}/invoke
|
||||||
|
func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse namespace and name
|
||||||
|
var namespace, name string
|
||||||
|
if idx := strings.Index(nameWithNS, "/"); idx > 0 {
|
||||||
|
namespace = nameWithNS[:idx]
|
||||||
|
name = nameWithNS[idx+1:]
|
||||||
|
} else {
|
||||||
|
name = nameWithNS
|
||||||
|
namespace = r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input body
|
||||||
|
input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get caller wallet from JWT
|
||||||
|
callerWallet := h.getWalletFromRequest(r)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := &serverless.InvokeRequest{
|
||||||
|
Namespace: namespace,
|
||||||
|
FunctionName: name,
|
||||||
|
Version: version,
|
||||||
|
Input: input,
|
||||||
|
TriggerType: serverless.TriggerTypeHTTP,
|
||||||
|
CallerWallet: callerWallet,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.invoker.Invoke(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
statusCode := http.StatusInternalServerError
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
statusCode = http.StatusNotFound
|
||||||
|
} else if serverless.IsResourceExhausted(err) {
|
||||||
|
statusCode = http.StatusTooManyRequests
|
||||||
|
} else if serverless.IsUnauthorized(err) {
|
||||||
|
statusCode = http.StatusUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, statusCode, map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"status": resp.Status,
|
||||||
|
"error": resp.Error,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the function's output directly if it's JSON
|
||||||
|
w.Header().Set("X-Request-ID", resp.RequestID)
|
||||||
|
w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10))
|
||||||
|
|
||||||
|
// Try to detect if output is JSON
|
||||||
|
if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(resp.Output)
|
||||||
|
} else {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"output": string(resp.Output),
|
||||||
|
"status": resp.Status,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWebSocket handles WebSocket connections for function streaming
|
||||||
|
func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
http.Error(w, "namespace required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade to WebSocket
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := uuid.New().String()
|
||||||
|
wsConn := &serverless.GorillaWSConn{Conn: conn}
|
||||||
|
|
||||||
|
// Register connection
|
||||||
|
h.wsManager.Register(clientID, wsConn)
|
||||||
|
defer h.wsManager.Unregister(clientID)
|
||||||
|
|
||||||
|
h.logger.Info("WebSocket connected",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.String("function", name),
|
||||||
|
)
|
||||||
|
|
||||||
|
callerWallet := h.getWalletFromRequest(r)
|
||||||
|
|
||||||
|
// Message loop
|
||||||
|
for {
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
h.logger.Warn("WebSocket error", zap.Error(err))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke function with WebSocket context
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
|
||||||
|
req := &serverless.InvokeRequest{
|
||||||
|
Namespace: namespace,
|
||||||
|
FunctionName: name,
|
||||||
|
Version: version,
|
||||||
|
Input: message,
|
||||||
|
TriggerType: serverless.TriggerTypeWebSocket,
|
||||||
|
CallerWallet: callerWallet,
|
||||||
|
WSClientID: clientID,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.invoker.Invoke(ctx, req)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Send response back
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"status": resp.Status,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response["error"] = resp.Error
|
||||||
|
} else if len(resp.Output) > 0 {
|
||||||
|
// Try to parse output as JSON
|
||||||
|
var output interface{}
|
||||||
|
if json.Unmarshal(resp.Output, &output) == nil {
|
||||||
|
response["output"] = output
|
||||||
|
} else {
|
||||||
|
response["output"] = string(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, _ := json.Marshal(response)
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listVersions handles GET /v1/functions/{name}/versions
|
||||||
|
func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Get registry with extended methods
|
||||||
|
reg, ok := h.registry.(*serverless.Registry)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotImplemented, "Version listing not supported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
versions, err := reg.ListVersions(ctx, namespace, name)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to list versions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"versions": versions,
|
||||||
|
"count": len(versions),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFunctionLogs handles GET /v1/functions/{name}/logs
|
||||||
|
func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 100
|
||||||
|
if lStr := r.URL.Query().Get("limit"); lStr != "" {
|
||||||
|
if l, err := strconv.Atoi(lStr); err == nil {
|
||||||
|
limit = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
logs, err := h.registry.GetLogs(ctx, namespace, name, limit)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to get function logs",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get logs")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"namespace": namespace,
|
||||||
|
"logs": logs,
|
||||||
|
"count": len(logs),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNamespaceFromRequest extracts namespace from JWT or query param
|
||||||
|
func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string {
|
||||||
|
// Try context first (set by auth middleware) - most secure
|
||||||
|
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||||
|
if ns, ok := v.(string); ok && ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try query param as fallback (e.g. for public access or admin)
|
||||||
|
if ns := r.URL.Query().Get("namespace"); ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try header as fallback
|
||||||
|
if ns := r.Header.Get("X-Namespace"); ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWalletFromRequest extracts wallet address from JWT
|
||||||
|
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
|
||||||
|
// 1. Try X-Wallet header (legacy/direct bypass)
|
||||||
|
if wallet := r.Header.Get("X-Wallet"); wallet != "" {
|
||||||
|
return wallet
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Try JWT claims from context
|
||||||
|
if v := r.Context().Value(ctxKeyJWT); v != nil {
|
||||||
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||||
|
subj := strings.TrimSpace(claims.Sub)
|
||||||
|
// Ensure it's not an API key (standard Orama logic)
|
||||||
|
if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") {
|
||||||
|
return subj
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fallback to API key identity (namespace)
|
||||||
|
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||||
|
if ns, ok := v.(string); ok && ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthStatus returns the health status of the serverless engine
|
||||||
|
func (h *ServerlessHandlers) HealthStatus() map[string]interface{} {
|
||||||
|
stats := h.wsManager.GetStats()
|
||||||
|
return map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"connections": stats.ConnectionCount,
|
||||||
|
"topics": stats.TopicCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
88
pkg/gateway/serverless_handlers_test.go
Normal file
88
pkg/gateway/serverless_handlers_test.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockFunctionRegistry struct {
|
||||||
|
functions []*serverless.Function
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) Register(ctx context.Context, fn *serverless.FunctionDefinition, wasmBytes []byte) (*serverless.Function, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) Get(ctx context.Context, namespace, name string, version int) (*serverless.Function, error) {
|
||||||
|
return &serverless.Function{ID: "1", Name: name, Namespace: namespace}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) List(ctx context.Context, namespace string) ([]*serverless.Function, error) {
|
||||||
|
return m.functions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) Delete(ctx context.Context, namespace, name string, version int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
|
||||||
|
return []byte("wasm"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFunctionRegistry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]serverless.LogEntry, error) {
|
||||||
|
return []serverless.LogEntry{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerlessHandlers_ListFunctions(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := &mockFunctionRegistry{
|
||||||
|
functions: []*serverless.Function{
|
||||||
|
{ID: "1", Name: "func1", Namespace: "ns1"},
|
||||||
|
{ID: "2", Name: "func2", Namespace: "ns1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewServerlessHandlers(nil, registry, nil, logger)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.handleFunctions(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(rr.Body.Bytes(), &resp)
|
||||||
|
|
||||||
|
if resp["count"].(float64) != 2 {
|
||||||
|
t.Errorf("expected 2 functions, got %v", resp["count"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerlessHandlers_DeployFunction(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := &mockFunctionRegistry{}
|
||||||
|
|
||||||
|
h := NewServerlessHandlers(nil, registry, nil, logger)
|
||||||
|
|
||||||
|
// Test JSON deploy (which is partially supported according to code)
|
||||||
|
// Should be 400 because WASM is missing or base64 not supported
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.handleFunctions(writer, req)
|
||||||
|
|
||||||
|
if writer.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", writer.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -228,7 +228,12 @@ func (g *Gateway) storageStatusHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
status, err := g.ipfsClient.PinStatus(ctx, path)
|
status, err := g.ipfsClient.PinStatus(ctx, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.logger.ComponentError(logging.ComponentGeneral, "failed to get pin status", zap.Error(err), zap.String("cid", path))
|
g.logger.ComponentError(logging.ComponentGeneral, "failed to get pin status", zap.Error(err), zap.String("cid", path))
|
||||||
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get status: %v", err))
|
errStr := strings.ToLower(err.Error())
|
||||||
|
if strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") || strings.Contains(errStr, "invalid") {
|
||||||
|
writeError(w, http.StatusNotFound, fmt.Sprintf("pin not found: %s", path))
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get status: %v", err))
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,7 +288,8 @@ func (g *Gateway) storageGetHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
g.logger.ComponentError(logging.ComponentGeneral, "failed to get content from IPFS", zap.Error(err), zap.String("cid", path))
|
g.logger.ComponentError(logging.ComponentGeneral, "failed to get content from IPFS", zap.Error(err), zap.String("cid", path))
|
||||||
// Check if error indicates content not found (404)
|
// Check if error indicates content not found (404)
|
||||||
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "status 404") {
|
errStr := strings.ToLower(err.Error())
|
||||||
|
if strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") || strings.Contains(errStr, "invalid") {
|
||||||
writeError(w, http.StatusNotFound, fmt.Sprintf("content not found: %s", path))
|
writeError(w, http.StatusNotFound, fmt.Sprintf("content not found: %s", path))
|
||||||
} else {
|
} else {
|
||||||
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get content: %v", err))
|
writeError(w, http.StatusInternalServerError, fmt.Sprintf("failed to get content: %v", err))
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/certutil"
|
"github.com/DeBrosOfficial/network/pkg/certutil"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -338,7 +339,7 @@ func (m *Model) handleEnter() (tea.Model, tea.Cmd) {
|
|||||||
|
|
||||||
case StepSwarmKey:
|
case StepSwarmKey:
|
||||||
swarmKey := strings.TrimSpace(m.textInput.Value())
|
swarmKey := strings.TrimSpace(m.textInput.Value())
|
||||||
if err := validateSwarmKey(swarmKey); err != nil {
|
if err := config.ValidateSwarmKey(swarmKey); err != nil {
|
||||||
m.err = err
|
m.err = err
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@ -816,17 +817,6 @@ func validateClusterSecret(secret string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSwarmKey(key string) error {
|
|
||||||
if len(key) != 64 {
|
|
||||||
return fmt.Errorf("swarm key must be 64 hex characters")
|
|
||||||
}
|
|
||||||
keyRegex := regexp.MustCompile(`^[a-fA-F0-9]{64}$`)
|
|
||||||
if !keyRegex.MatchString(key) {
|
|
||||||
return fmt.Errorf("swarm key must be valid hexadecimal")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureCertificatesForDomain generates self-signed certificates for the domain
|
// ensureCertificatesForDomain generates self-signed certificates for the domain
|
||||||
func ensureCertificatesForDomain(domain string) error {
|
func ensureCertificatesForDomain(domain string) error {
|
||||||
// Get home directory
|
// Get home directory
|
||||||
|
|||||||
1171
pkg/ipfs/cluster.go
1171
pkg/ipfs/cluster.go
File diff suppressed because it is too large
Load Diff
136
pkg/ipfs/cluster_config.go
Normal file
136
pkg/ipfs/cluster_config.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package ipfs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClusterServiceConfig represents the service.json configuration
|
||||||
|
type ClusterServiceConfig struct {
|
||||||
|
Cluster struct {
|
||||||
|
Peername string `json:"peername"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
ListenMultiaddress []string `json:"listen_multiaddress"`
|
||||||
|
PeerAddresses []string `json:"peer_addresses"`
|
||||||
|
LeaveOnShutdown bool `json:"leave_on_shutdown"`
|
||||||
|
} `json:"cluster"`
|
||||||
|
|
||||||
|
Consensus struct {
|
||||||
|
CRDT struct {
|
||||||
|
ClusterName string `json:"cluster_name"`
|
||||||
|
TrustedPeers []string `json:"trusted_peers"`
|
||||||
|
Batching struct {
|
||||||
|
MaxBatchSize int `json:"max_batch_size"`
|
||||||
|
MaxBatchAge string `json:"max_batch_age"`
|
||||||
|
} `json:"batching"`
|
||||||
|
RepairInterval string `json:"repair_interval"`
|
||||||
|
} `json:"crdt"`
|
||||||
|
} `json:"consensus"`
|
||||||
|
|
||||||
|
API struct {
|
||||||
|
RestAPI struct {
|
||||||
|
HTTPListenMultiaddress string `json:"http_listen_multiaddress"`
|
||||||
|
} `json:"restapi"`
|
||||||
|
IPFSProxy struct {
|
||||||
|
ListenMultiaddress string `json:"listen_multiaddress"`
|
||||||
|
NodeMultiaddress string `json:"node_multiaddress"`
|
||||||
|
} `json:"ipfsproxy"`
|
||||||
|
PinSvcAPI struct {
|
||||||
|
HTTPListenMultiaddress string `json:"http_listen_multiaddress"`
|
||||||
|
} `json:"pinsvcapi"`
|
||||||
|
} `json:"api"`
|
||||||
|
|
||||||
|
IPFSConnector struct {
|
||||||
|
IPFSHTTP struct {
|
||||||
|
NodeMultiaddress string `json:"node_multiaddress"`
|
||||||
|
} `json:"ipfshttp"`
|
||||||
|
} `json:"ipfs_connector"`
|
||||||
|
|
||||||
|
Raw map[string]interface{} `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClusterConfigManager) loadOrCreateConfig(path string) (*ClusterServiceConfig, error) {
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
return cm.createTemplateConfig(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read service.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg ClusterServiceConfig
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse service.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse raw service.json: %w", err)
|
||||||
|
}
|
||||||
|
cfg.Raw = raw
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClusterConfigManager) saveConfig(path string, cfg *ClusterServiceConfig) error {
|
||||||
|
cm.updateNestedMap(cfg.Raw, "cluster", "peername", cfg.Cluster.Peername)
|
||||||
|
cm.updateNestedMap(cfg.Raw, "cluster", "secret", cfg.Cluster.Secret)
|
||||||
|
cm.updateNestedMap(cfg.Raw, "cluster", "listen_multiaddress", cfg.Cluster.ListenMultiaddress)
|
||||||
|
cm.updateNestedMap(cfg.Raw, "cluster", "peer_addresses", cfg.Cluster.PeerAddresses)
|
||||||
|
cm.updateNestedMap(cfg.Raw, "cluster", "leave_on_shutdown", cfg.Cluster.LeaveOnShutdown)
|
||||||
|
|
||||||
|
consensus := cm.ensureRequiredSection(cfg.Raw, "consensus")
|
||||||
|
crdt := cm.ensureRequiredSection(consensus, "crdt")
|
||||||
|
crdt["cluster_name"] = cfg.Consensus.CRDT.ClusterName
|
||||||
|
crdt["trusted_peers"] = cfg.Consensus.CRDT.TrustedPeers
|
||||||
|
crdt["repair_interval"] = cfg.Consensus.CRDT.RepairInterval
|
||||||
|
|
||||||
|
batching := cm.ensureRequiredSection(crdt, "batching")
|
||||||
|
batching["max_batch_size"] = cfg.Consensus.CRDT.Batching.MaxBatchSize
|
||||||
|
batching["max_batch_age"] = cfg.Consensus.CRDT.Batching.MaxBatchAge
|
||||||
|
|
||||||
|
api := cm.ensureRequiredSection(cfg.Raw, "api")
|
||||||
|
restapi := cm.ensureRequiredSection(api, "restapi")
|
||||||
|
restapi["http_listen_multiaddress"] = cfg.API.RestAPI.HTTPListenMultiaddress
|
||||||
|
|
||||||
|
ipfsproxy := cm.ensureRequiredSection(api, "ipfsproxy")
|
||||||
|
ipfsproxy["listen_multiaddress"] = cfg.API.IPFSProxy.ListenMultiaddress
|
||||||
|
ipfsproxy["node_multiaddress"] = cfg.API.IPFSProxy.NodeMultiaddress
|
||||||
|
|
||||||
|
pinsvcapi := cm.ensureRequiredSection(api, "pinsvcapi")
|
||||||
|
pinsvcapi["http_listen_multiaddress"] = cfg.API.PinSvcAPI.HTTPListenMultiaddress
|
||||||
|
|
||||||
|
ipfsConn := cm.ensureRequiredSection(cfg.Raw, "ipfs_connector")
|
||||||
|
ipfsHttp := cm.ensureRequiredSection(ipfsConn, "ipfshttp")
|
||||||
|
ipfsHttp["node_multiaddress"] = cfg.IPFSConnector.IPFSHTTP.NodeMultiaddress
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(cfg.Raw, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal service.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create config directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(path, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClusterConfigManager) updateNestedMap(m map[string]interface{}, section, key string, val interface{}) {
|
||||||
|
if _, ok := m[section]; !ok {
|
||||||
|
m[section] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
s := m[section].(map[string]interface{})
|
||||||
|
s[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClusterConfigManager) ensureRequiredSection(m map[string]interface{}, key string) map[string]interface{} {
|
||||||
|
if _, ok := m[key]; !ok {
|
||||||
|
m[key] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
return m[key].(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
156
pkg/ipfs/cluster_peer.go
Normal file
156
pkg/ipfs/cluster_peer.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package ipfs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdatePeerAddresses updates the peer_addresses in service.json with given multiaddresses
|
||||||
|
func (cm *ClusterConfigManager) UpdatePeerAddresses(addrs []string) error {
|
||||||
|
serviceJSONPath := filepath.Join(cm.clusterPath, "service.json")
|
||||||
|
cfg, err := cm.loadOrCreateConfig(serviceJSONPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
uniqueAddrs := []string{}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if !seen[addr] {
|
||||||
|
uniqueAddrs = append(uniqueAddrs, addr)
|
||||||
|
seen[addr] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Cluster.PeerAddresses = uniqueAddrs
|
||||||
|
return cm.saveConfig(serviceJSONPath, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAllClusterPeers discovers all cluster peers from the gateway and updates local config
|
||||||
|
func (cm *ClusterConfigManager) UpdateAllClusterPeers() error {
|
||||||
|
peers, err := cm.DiscoverClusterPeersFromGateway()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to discover cluster peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerAddrs := []string{}
|
||||||
|
for _, p := range peers {
|
||||||
|
peerAddrs = append(peerAddrs, p.Multiaddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cm.UpdatePeerAddresses(peerAddrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RepairPeerConfiguration attempts to fix configuration issues and re-synchronize peers
|
||||||
|
func (cm *ClusterConfigManager) RepairPeerConfiguration() error {
|
||||||
|
cm.logger.Info("Attempting to repair IPFS Cluster peer configuration")
|
||||||
|
|
||||||
|
_ = cm.FixIPFSConfigAddresses()
|
||||||
|
|
||||||
|
peers, err := cm.DiscoverClusterPeersFromGateway()
|
||||||
|
if err != nil {
|
||||||
|
cm.logger.Warn("Could not discover peers from gateway during repair", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
peerAddrs := []string{}
|
||||||
|
for _, p := range peers {
|
||||||
|
peerAddrs = append(peerAddrs, p.Multiaddress)
|
||||||
|
}
|
||||||
|
if len(peerAddrs) > 0 {
|
||||||
|
_ = cm.UpdatePeerAddresses(peerAddrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoverClusterPeersFromGateway queries the central gateway for registered IPFS Cluster peers
|
||||||
|
func (cm *ClusterConfigManager) DiscoverClusterPeersFromGateway() ([]ClusterPeerInfo, error) {
|
||||||
|
// Not implemented - would require a central gateway URL in config
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoverClusterPeersFromLibP2P uses libp2p host to find other cluster peers
|
||||||
|
func (cm *ClusterConfigManager) DiscoverClusterPeersFromLibP2P(h host.Host) error {
|
||||||
|
if h == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var clusterPeers []string
|
||||||
|
for _, p := range h.Peerstore().Peers() {
|
||||||
|
if p == h.ID() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
info := h.Peerstore().PeerInfo(p)
|
||||||
|
for _, addr := range info.Addrs {
|
||||||
|
if strings.Contains(addr.String(), "/tcp/9096") || strings.Contains(addr.String(), "/tcp/9094") {
|
||||||
|
ma := addr.Encapsulate(multiaddr.StringCast(fmt.Sprintf("/p2p/%s", p.String())))
|
||||||
|
clusterPeers = append(clusterPeers, ma.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clusterPeers) > 0 {
|
||||||
|
return cm.UpdatePeerAddresses(clusterPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ClusterConfigManager) getPeerID() (string, error) {
|
||||||
|
dataDir := cm.cfg.Node.DataDir
|
||||||
|
if strings.HasPrefix(dataDir, "~") {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
dataDir = filepath.Join(home, dataDir[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
possiblePaths := []string{
|
||||||
|
filepath.Join(dataDir, "ipfs", "repo"),
|
||||||
|
filepath.Join(dataDir, "node-1", "ipfs", "repo"),
|
||||||
|
filepath.Join(dataDir, "node-2", "ipfs", "repo"),
|
||||||
|
filepath.Join(filepath.Dir(dataDir), "node-1", "ipfs", "repo"),
|
||||||
|
filepath.Join(filepath.Dir(dataDir), "node-2", "ipfs", "repo"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipfsRepoPath string
|
||||||
|
for _, path := range possiblePaths {
|
||||||
|
if _, err := os.Stat(filepath.Join(path, "config")); err == nil {
|
||||||
|
ipfsRepoPath = path
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipfsRepoPath == "" {
|
||||||
|
return "", fmt.Errorf("could not find IPFS repo path")
|
||||||
|
}
|
||||||
|
|
||||||
|
idCmd := exec.Command("ipfs", "id", "-f", "<id>")
|
||||||
|
idCmd.Env = append(os.Environ(), "IPFS_PATH="+ipfsRepoPath)
|
||||||
|
out, err := idCmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(string(out)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterPeerInfo represents information about an IPFS Cluster peer
|
||||||
|
type ClusterPeerInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Multiaddress string `json:"multiaddress"`
|
||||||
|
NodeName string `json:"node_name"`
|
||||||
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
}
|
||||||
|
|
||||||
119
pkg/ipfs/cluster_util.go
Normal file
119
pkg/ipfs/cluster_util.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package ipfs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func loadOrGenerateClusterSecret(path string) (string, error) {
|
||||||
|
if data, err := os.ReadFile(path); err == nil {
|
||||||
|
secret := strings.TrimSpace(string(data))
|
||||||
|
if len(secret) == 64 {
|
||||||
|
return secret, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := generateRandomSecret()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.WriteFile(path, []byte(secret), 0600)
|
||||||
|
return secret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomSecret() (string, error) {
|
||||||
|
bytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClusterPorts(rawURL string) (int, int, error) {
|
||||||
|
if !strings.HasPrefix(rawURL, "http") {
|
||||||
|
rawURL = "http://" + rawURL
|
||||||
|
}
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return 9096, 9094, nil
|
||||||
|
}
|
||||||
|
_, portStr, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return 9096, 9094, nil
|
||||||
|
}
|
||||||
|
var port int
|
||||||
|
fmt.Sscanf(portStr, "%d", &port)
|
||||||
|
if port == 0 {
|
||||||
|
return 9096, 9094, nil
|
||||||
|
}
|
||||||
|
return port + 2, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIPFSPort(rawURL string) (int, error) {
|
||||||
|
if !strings.HasPrefix(rawURL, "http") {
|
||||||
|
rawURL = "http://" + rawURL
|
||||||
|
}
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return 5001, nil
|
||||||
|
}
|
||||||
|
_, portStr, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return 5001, nil
|
||||||
|
}
|
||||||
|
var port int
|
||||||
|
fmt.Sscanf(portStr, "%d", &port)
|
||||||
|
if port == 0 {
|
||||||
|
return 5001, nil
|
||||||
|
}
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePeerHostAndPort(multiaddr string) (string, int) {
|
||||||
|
parts := strings.Split(multiaddr, "/")
|
||||||
|
var hostStr string
|
||||||
|
var port int
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "ip4" || part == "dns" || part == "dns4" {
|
||||||
|
hostStr = parts[i+1]
|
||||||
|
} else if part == "tcp" {
|
||||||
|
fmt.Sscanf(parts[i+1], "%d", &port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hostStr, port
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractIPFromMultiaddrForCluster(maddr string) string {
|
||||||
|
parts := strings.Split(maddr, "/")
|
||||||
|
for i, part := range parts {
|
||||||
|
if (part == "ip4" || part == "dns" || part == "dns4") && i+1 < len(parts) {
|
||||||
|
return parts[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDomainFromMultiaddr(maddr string) string {
|
||||||
|
parts := strings.Split(maddr, "/")
|
||||||
|
for i, part := range parts {
|
||||||
|
if (part == "dns" || part == "dns4" || part == "dns6") && i+1 < len(parts) {
|
||||||
|
return parts[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStandardHTTPClient() *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
204
pkg/node/gateway.go
Normal file
204
pkg/node/gateway.go
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"golang.org/x/crypto/acme"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startHTTPGateway initializes and starts the full API gateway
|
||||||
|
func (n *Node) startHTTPGateway(ctx context.Context) error {
|
||||||
|
if !n.config.HTTPGateway.Enabled {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "HTTP Gateway disabled in config")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logFile := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "..", "logs", "gateway.log")
|
||||||
|
logsDir := filepath.Dir(logFile)
|
||||||
|
_ = os.MkdirAll(logsDir, 0755)
|
||||||
|
|
||||||
|
gatewayLogger, err := logging.NewFileLogger(logging.ComponentGeneral, logFile, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
gwCfg := &gateway.Config{
|
||||||
|
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
||||||
|
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
||||||
|
BootstrapPeers: n.config.Discovery.BootstrapPeers,
|
||||||
|
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
|
||||||
|
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
|
||||||
|
OlricServers: n.config.HTTPGateway.OlricServers,
|
||||||
|
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
|
||||||
|
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
|
||||||
|
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
||||||
|
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||||
|
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
|
||||||
|
DomainName: n.config.HTTPGateway.HTTPS.Domain,
|
||||||
|
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.apiGateway = apiGateway
|
||||||
|
|
||||||
|
var certManager *autocert.Manager
|
||||||
|
if gwCfg.EnableHTTPS && gwCfg.DomainName != "" {
|
||||||
|
tlsCacheDir := gwCfg.TLSCacheDir
|
||||||
|
if tlsCacheDir == "" {
|
||||||
|
tlsCacheDir = "/home/debros/.orama/tls-cache"
|
||||||
|
}
|
||||||
|
_ = os.MkdirAll(tlsCacheDir, 0700)
|
||||||
|
|
||||||
|
certManager = &autocert.Manager{
|
||||||
|
Prompt: autocert.AcceptTOS,
|
||||||
|
HostPolicy: autocert.HostWhitelist(gwCfg.DomainName),
|
||||||
|
Cache: autocert.DirCache(tlsCacheDir),
|
||||||
|
Email: fmt.Sprintf("admin@%s", gwCfg.DomainName),
|
||||||
|
Client: &acme.Client{
|
||||||
|
DirectoryURL: "https://acme-staging-v02.api.letsencrypt.org/directory",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
n.certManager = certManager
|
||||||
|
n.certReady = make(chan struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReady := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if gwCfg.EnableHTTPS && gwCfg.DomainName != "" && certManager != nil {
|
||||||
|
httpsPort := 443
|
||||||
|
httpPort := 80
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", httpPort),
|
||||||
|
Handler: certManager.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
target := fmt.Sprintf("https://%s%s", r.Host, r.URL.RequestURI())
|
||||||
|
http.Redirect(w, r, target, http.StatusMovedPermanently)
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
httpListener, err := net.Listen("tcp", fmt.Sprintf(":%d", httpPort))
|
||||||
|
if err != nil {
|
||||||
|
close(httpReady)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go httpServer.Serve(httpListener)
|
||||||
|
|
||||||
|
// Pre-provision cert
|
||||||
|
certReq := &tls.ClientHelloInfo{ServerName: gwCfg.DomainName}
|
||||||
|
_, certErr := certManager.GetCertificate(certReq)
|
||||||
|
|
||||||
|
if certErr != nil {
|
||||||
|
close(httpReady)
|
||||||
|
httpServer.Handler = apiGateway.Routes()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(httpReady)
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
GetCertificate: certManager.GetCertificate,
|
||||||
|
}
|
||||||
|
|
||||||
|
httpsServer := &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", httpsPort),
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
Handler: apiGateway.Routes(),
|
||||||
|
}
|
||||||
|
n.apiGatewayServer = httpsServer
|
||||||
|
|
||||||
|
ln, err := tls.Listen("tcp", fmt.Sprintf(":%d", httpsPort), tlsConfig)
|
||||||
|
if err == nil {
|
||||||
|
httpsServer.Serve(ln)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
close(httpReady)
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: gwCfg.ListenAddr,
|
||||||
|
Handler: apiGateway.Routes(),
|
||||||
|
}
|
||||||
|
n.apiGatewayServer = server
|
||||||
|
ln, err := net.Listen("tcp", gwCfg.ListenAddr)
|
||||||
|
if err == nil {
|
||||||
|
server.Serve(ln)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// SNI Gateway
|
||||||
|
if n.config.HTTPGateway.SNI.Enabled && n.certManager != nil {
|
||||||
|
go n.startSNIGateway(ctx, httpReady)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) startSNIGateway(ctx context.Context, httpReady <-chan struct{}) {
|
||||||
|
<-httpReady
|
||||||
|
domain := n.config.HTTPGateway.HTTPS.Domain
|
||||||
|
if domain == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certReq := &tls.ClientHelloInfo{ServerName: domain}
|
||||||
|
tlsCert, err := n.certManager.GetCertificate(certReq)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCacheDir := n.config.HTTPGateway.HTTPS.CacheDir
|
||||||
|
if tlsCacheDir == "" {
|
||||||
|
tlsCacheDir = "/home/debros/.orama/tls-cache"
|
||||||
|
}
|
||||||
|
|
||||||
|
certPath := filepath.Join(tlsCacheDir, domain+".crt")
|
||||||
|
keyPath := filepath.Join(tlsCacheDir, domain+".key")
|
||||||
|
|
||||||
|
if err := extractPEMFromTLSCert(tlsCert, certPath, keyPath); err == nil {
|
||||||
|
if n.certReady != nil {
|
||||||
|
close(n.certReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sniCfg := n.config.HTTPGateway.SNI
|
||||||
|
sniGateway, err := gateway.NewTCPSNIGateway(n.logger, &sniCfg)
|
||||||
|
if err == nil {
|
||||||
|
n.sniGateway = sniGateway
|
||||||
|
sniGateway.Start(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startIPFSClusterConfig initializes and ensures IPFS Cluster configuration
|
||||||
|
func (n *Node) startIPFSClusterConfig() error {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "Initializing IPFS Cluster configuration")
|
||||||
|
|
||||||
|
cm, err := ipfs.NewClusterConfigManager(n.config, n.logger.Logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.clusterConfigManager = cm
|
||||||
|
|
||||||
|
_ = cm.FixIPFSConfigAddresses()
|
||||||
|
if err := cm.EnsureConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = cm.RepairPeerConfiguration()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
302
pkg/node/libp2p.go
Normal file
302
pkg/node/libp2p.go
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/encryption"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
|
"github.com/libp2p/go-libp2p"
|
||||||
|
libp2ppubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||||
|
"github.com/libp2p/go-libp2p/core/crypto"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
noise "github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startLibP2P initializes the LibP2P host
|
||||||
|
func (n *Node) startLibP2P() error {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentLibP2P, "Starting LibP2P host")
|
||||||
|
|
||||||
|
// Load or create persistent identity
|
||||||
|
identity, err := n.loadOrCreateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load identity: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create LibP2P host with explicit listen addresses
|
||||||
|
var opts []libp2p.Option
|
||||||
|
opts = append(opts,
|
||||||
|
libp2p.Identity(identity),
|
||||||
|
libp2p.Security(noise.ID, noise.New),
|
||||||
|
libp2p.DefaultMuxers,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add explicit listen addresses from config
|
||||||
|
if len(n.config.Node.ListenAddresses) > 0 {
|
||||||
|
listenAddrs := make([]multiaddr.Multiaddr, 0, len(n.config.Node.ListenAddresses))
|
||||||
|
for _, addr := range n.config.Node.ListenAddresses {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid listen address %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
listenAddrs = append(listenAddrs, ma)
|
||||||
|
}
|
||||||
|
opts = append(opts, libp2p.ListenAddrs(listenAddrs...))
|
||||||
|
n.logger.ComponentInfo(logging.ComponentLibP2P, "Configured listen addresses",
|
||||||
|
zap.Strings("addrs", n.config.Node.ListenAddresses))
|
||||||
|
}
|
||||||
|
|
||||||
|
// For localhost/development, disable NAT services
|
||||||
|
isLocalhost := len(n.config.Node.ListenAddresses) > 0 &&
|
||||||
|
(strings.Contains(n.config.Node.ListenAddresses[0], "localhost") ||
|
||||||
|
strings.Contains(n.config.Node.ListenAddresses[0], "127.0.0.1"))
|
||||||
|
|
||||||
|
if isLocalhost {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentLibP2P, "Localhost detected - disabling NAT services for local development")
|
||||||
|
} else {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentLibP2P, "Production mode - enabling NAT services")
|
||||||
|
opts = append(opts,
|
||||||
|
libp2p.EnableNATService(),
|
||||||
|
libp2p.EnableAutoNATv2(),
|
||||||
|
libp2p.EnableRelay(),
|
||||||
|
libp2p.NATPortMap(),
|
||||||
|
libp2p.EnableAutoRelayWithPeerSource(
|
||||||
|
peerSource(n.config.Discovery.BootstrapPeers, n.logger.Logger),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := libp2p.New(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n.host = h
|
||||||
|
|
||||||
|
// Initialize pubsub
|
||||||
|
ps, err := libp2ppubsub.NewGossipSub(context.Background(), h,
|
||||||
|
libp2ppubsub.WithPeerExchange(true),
|
||||||
|
libp2ppubsub.WithFloodPublish(true),
|
||||||
|
libp2ppubsub.WithDirectPeers(nil),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create pubsub: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create pubsub adapter
|
||||||
|
n.pubsub = pubsub.NewClientAdapter(ps, n.config.Discovery.NodeNamespace)
|
||||||
|
n.logger.Info("Initialized pubsub adapter on namespace", zap.String("namespace", n.config.Discovery.NodeNamespace))
|
||||||
|
|
||||||
|
// Connect to peers
|
||||||
|
if err := n.connectToPeers(context.Background()); err != nil {
|
||||||
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to connect to peers", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start reconnection loop
|
||||||
|
if len(n.config.Discovery.BootstrapPeers) > 0 {
|
||||||
|
peerCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
n.peerDiscoveryCancel = cancel
|
||||||
|
|
||||||
|
go n.peerReconnectionLoop(peerCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add peers to peerstore
|
||||||
|
for _, peerAddr := range n.config.Discovery.BootstrapPeers {
|
||||||
|
if ma, err := multiaddr.NewMultiaddr(peerAddr); err == nil {
|
||||||
|
if peerInfo, err := peer.AddrInfoFromP2pAddr(ma); err == nil {
|
||||||
|
n.host.Peerstore().AddAddrs(peerInfo.ID, peerInfo.Addrs, time.Hour*24)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize discovery manager
|
||||||
|
n.discoveryManager = discovery.NewManager(h, nil, n.logger.Logger)
|
||||||
|
n.discoveryManager.StartProtocolHandler()
|
||||||
|
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "LibP2P host started successfully")
|
||||||
|
|
||||||
|
// Start peer discovery
|
||||||
|
n.startPeerDiscovery()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) peerReconnectionLoop(ctx context.Context) {
|
||||||
|
interval := 5 * time.Second
|
||||||
|
consecutiveFailures := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.hasPeerConnections() {
|
||||||
|
if err := n.connectToPeers(context.Background()); err != nil {
|
||||||
|
consecutiveFailures++
|
||||||
|
jitteredInterval := addJitter(interval)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(jitteredInterval):
|
||||||
|
}
|
||||||
|
|
||||||
|
interval = calculateNextBackoff(interval)
|
||||||
|
} else {
|
||||||
|
interval = 5 * time.Second
|
||||||
|
consecutiveFailures = 0
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(30 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) connectToPeers(ctx context.Context) error {
|
||||||
|
for _, peerAddr := range n.config.Discovery.BootstrapPeers {
|
||||||
|
if err := n.connectToPeerAddr(ctx, peerAddr); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) connectToPeerAddr(ctx context.Context, addr string) error {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peerInfo, err := peer.AddrInfoFromP2pAddr(ma)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n.host != nil && peerInfo.ID == n.host.ID() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return n.host.Connect(ctx, *peerInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) hasPeerConnections() bool {
|
||||||
|
if n.host == nil || len(n.config.Discovery.BootstrapPeers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
connectedPeers := n.host.Network().Peers()
|
||||||
|
if len(connectedPeers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bootstrapIDs := make(map[peer.ID]bool)
|
||||||
|
for _, addr := range n.config.Discovery.BootstrapPeers {
|
||||||
|
if ma, err := multiaddr.NewMultiaddr(addr); err == nil {
|
||||||
|
if info, err := peer.AddrInfoFromP2pAddr(ma); err == nil {
|
||||||
|
bootstrapIDs[info.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range connectedPeers {
|
||||||
|
if bootstrapIDs[p] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) loadOrCreateIdentity() (crypto.PrivKey, error) {
|
||||||
|
identityFile := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "identity.key")
|
||||||
|
if strings.HasPrefix(identityFile, "~") {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
identityFile = filepath.Join(home, identityFile[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(identityFile); err == nil {
|
||||||
|
info, err := encryption.LoadIdentity(identityFile)
|
||||||
|
if err == nil {
|
||||||
|
return info.PrivateKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := encryption.GenerateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := encryption.SaveIdentity(info, identityFile); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return info.PrivateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) startPeerDiscovery() {
|
||||||
|
if n.discoveryManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
discoveryConfig := discovery.Config{
|
||||||
|
DiscoveryInterval: n.config.Discovery.DiscoveryInterval,
|
||||||
|
MaxConnections: n.config.Node.MaxConnections,
|
||||||
|
}
|
||||||
|
n.discoveryManager.Start(discoveryConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) stopPeerDiscovery() {
|
||||||
|
if n.discoveryManager != nil {
|
||||||
|
n.discoveryManager.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) GetPeerID() string {
|
||||||
|
if n.host == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return n.host.ID().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerSource(peerAddrs []string, logger *zap.Logger) func(context.Context, int) <-chan peer.AddrInfo {
|
||||||
|
return func(ctx context.Context, num int) <-chan peer.AddrInfo {
|
||||||
|
out := make(chan peer.AddrInfo, num)
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
count := 0
|
||||||
|
for _, s := range peerAddrs {
|
||||||
|
if count >= num {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ma, err := multiaddr.NewMultiaddr(s)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ai, err := peer.AddrInfoFromP2pAddr(ma)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case out <- *ai:
|
||||||
|
count++
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -220,9 +220,9 @@ func (n *Node) startConnectionMonitoring() {
|
|||||||
// First try to discover from LibP2P connections (works even if cluster peers aren't connected yet)
|
// First try to discover from LibP2P connections (works even if cluster peers aren't connected yet)
|
||||||
// This runs every minute to discover peers automatically via LibP2P discovery
|
// This runs every minute to discover peers automatically via LibP2P discovery
|
||||||
if time.Now().Unix()%60 == 0 {
|
if time.Now().Unix()%60 == 0 {
|
||||||
if success, err := n.clusterConfigManager.DiscoverClusterPeersFromLibP2P(n.host); err != nil {
|
if err := n.clusterConfigManager.DiscoverClusterPeersFromLibP2P(n.host); err != nil {
|
||||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to discover cluster peers from LibP2P", zap.Error(err))
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to discover cluster peers from LibP2P", zap.Error(err))
|
||||||
} else if success {
|
} else {
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses discovered from LibP2P")
|
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses discovered from LibP2P")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -230,16 +230,16 @@ func (n *Node) startConnectionMonitoring() {
|
|||||||
// Also try to update from cluster API (works once peers are connected)
|
// Also try to update from cluster API (works once peers are connected)
|
||||||
// Update all cluster peers every 2 minutes to discover new peers
|
// Update all cluster peers every 2 minutes to discover new peers
|
||||||
if time.Now().Unix()%120 == 0 {
|
if time.Now().Unix()%120 == 0 {
|
||||||
if success, err := n.clusterConfigManager.UpdateAllClusterPeers(); err != nil {
|
if err := n.clusterConfigManager.UpdateAllClusterPeers(); err != nil {
|
||||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to update cluster peers during monitoring", zap.Error(err))
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to update cluster peers during monitoring", zap.Error(err))
|
||||||
} else if success {
|
} else {
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses updated during monitoring")
|
n.logger.ComponentInfo(logging.ComponentNode, "Cluster peer addresses updated during monitoring")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to repair peer configuration
|
// Try to repair peer configuration
|
||||||
if success, err := n.clusterConfigManager.RepairPeerConfiguration(); err != nil {
|
if err := n.clusterConfigManager.RepairPeerConfiguration(); err != nil {
|
||||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to repair peer addresses during monitoring", zap.Error(err))
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to repair peer addresses during monitoring", zap.Error(err))
|
||||||
} else if success {
|
} else {
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Peer configuration repaired during monitoring")
|
n.logger.ComponentInfo(logging.ComponentNode, "Peer configuration repaired during monitoring")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1172
pkg/node/node.go
1172
pkg/node/node.go
File diff suppressed because it is too large
Load Diff
98
pkg/node/rqlite.go
Normal file
98
pkg/node/rqlite.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startRQLite initializes and starts the RQLite database
|
||||||
|
func (n *Node) startRQLite(ctx context.Context) error {
|
||||||
|
n.logger.Info("Starting RQLite database")
|
||||||
|
|
||||||
|
// Determine node identifier for log filename - use node ID for unique filenames
|
||||||
|
nodeID := n.config.Node.ID
|
||||||
|
if nodeID == "" {
|
||||||
|
// Default to "node" if ID is not set
|
||||||
|
nodeID = "node"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create RQLite manager
|
||||||
|
n.rqliteManager = database.NewRQLiteManager(&n.config.Database, &n.config.Discovery, n.config.Node.DataDir, n.logger.Logger)
|
||||||
|
n.rqliteManager.SetNodeType(nodeID)
|
||||||
|
|
||||||
|
// Initialize cluster discovery service if LibP2P host is available
|
||||||
|
if n.host != nil && n.discoveryManager != nil {
|
||||||
|
// Create cluster discovery service (all nodes are unified)
|
||||||
|
n.clusterDiscovery = database.NewClusterDiscoveryService(
|
||||||
|
n.host,
|
||||||
|
n.discoveryManager,
|
||||||
|
n.rqliteManager,
|
||||||
|
n.config.Node.ID,
|
||||||
|
"node", // Unified node type
|
||||||
|
n.config.Discovery.RaftAdvAddress,
|
||||||
|
n.config.Discovery.HttpAdvAddress,
|
||||||
|
n.config.Node.DataDir,
|
||||||
|
n.logger.Logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set discovery service on RQLite manager BEFORE starting RQLite
|
||||||
|
// This is critical for pre-start cluster discovery during recovery
|
||||||
|
n.rqliteManager.SetDiscoveryService(n.clusterDiscovery)
|
||||||
|
|
||||||
|
// Start cluster discovery (but don't trigger initial sync yet)
|
||||||
|
if err := n.clusterDiscovery.Start(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to start cluster discovery: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish initial metadata (with log_index=0) so peers can discover us during recovery
|
||||||
|
// The metadata will be updated with actual log index after RQLite starts
|
||||||
|
n.clusterDiscovery.UpdateOwnMetadata()
|
||||||
|
|
||||||
|
n.logger.Info("Cluster discovery service started (waiting for RQLite)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If node-to-node TLS is configured, wait for certificates to be provisioned
|
||||||
|
// This ensures RQLite can start with TLS when joining through the SNI gateway
|
||||||
|
if n.config.Database.NodeCert != "" && n.config.Database.NodeKey != "" && n.certReady != nil {
|
||||||
|
n.logger.Info("RQLite node TLS configured, waiting for certificates to be provisioned...",
|
||||||
|
zap.String("node_cert", n.config.Database.NodeCert),
|
||||||
|
zap.String("node_key", n.config.Database.NodeKey))
|
||||||
|
|
||||||
|
// Wait for certificate ready signal with timeout
|
||||||
|
certTimeout := 5 * time.Minute
|
||||||
|
select {
|
||||||
|
case <-n.certReady:
|
||||||
|
n.logger.Info("Certificates ready, proceeding with RQLite startup")
|
||||||
|
case <-time.After(certTimeout):
|
||||||
|
return fmt.Errorf("timeout waiting for TLS certificates after %v - ensure HTTPS is configured and ports 80/443 are accessible for ACME challenges", certTimeout)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("context cancelled while waiting for certificates: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start RQLite FIRST before updating metadata
|
||||||
|
if err := n.rqliteManager.Start(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOW update metadata after RQLite is running
|
||||||
|
if n.clusterDiscovery != nil {
|
||||||
|
n.clusterDiscovery.UpdateOwnMetadata()
|
||||||
|
n.clusterDiscovery.TriggerSync() // Do initial cluster sync now that RQLite is ready
|
||||||
|
n.logger.Info("RQLite metadata published and cluster synced")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create adapter for sql.DB compatibility
|
||||||
|
adapter, err := database.NewRQLiteAdapter(n.rqliteManager)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create RQLite adapter: %w", err)
|
||||||
|
}
|
||||||
|
n.rqliteAdapter = adapter
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
127
pkg/node/utils.go
Normal file
127
pkg/node/utils.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
mathrand "math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/encryption"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func extractIPFromMultiaddr(multiaddrStr string) string {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(multiaddrStr)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var ip string
|
||||||
|
var dnsName string
|
||||||
|
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
|
||||||
|
switch c.Protocol().Code {
|
||||||
|
case multiaddr.P_IP4, multiaddr.P_IP6:
|
||||||
|
ip = c.Value()
|
||||||
|
return false
|
||||||
|
case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNSADDR:
|
||||||
|
dnsName = c.Value()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsName != "" {
|
||||||
|
if resolvedIPs, err := net.LookupIP(dnsName); err == nil && len(resolvedIPs) > 0 {
|
||||||
|
for _, resolvedIP := range resolvedIPs {
|
||||||
|
if resolvedIP.To4() != nil {
|
||||||
|
return resolvedIP.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resolvedIPs[0].String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateNextBackoff(current time.Duration) time.Duration {
|
||||||
|
next := time.Duration(float64(current) * 1.5)
|
||||||
|
maxInterval := 10 * time.Minute
|
||||||
|
if next > maxInterval {
|
||||||
|
next = maxInterval
|
||||||
|
}
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
func addJitter(interval time.Duration) time.Duration {
|
||||||
|
jitterPercent := 0.2
|
||||||
|
jitterRange := float64(interval) * jitterPercent
|
||||||
|
jitter := (mathrand.Float64() - 0.5) * 2 * jitterRange
|
||||||
|
result := time.Duration(float64(interval) + jitter)
|
||||||
|
if result < time.Second {
|
||||||
|
result = time.Second
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadNodePeerIDFromIdentity(dataDir string) string {
|
||||||
|
identityFile := filepath.Join(os.ExpandEnv(dataDir), "identity.key")
|
||||||
|
if strings.HasPrefix(identityFile, "~") {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
identityFile = filepath.Join(home, identityFile[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if info, err := encryption.LoadIdentity(identityFile); err == nil {
|
||||||
|
return info.PeerID.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPEMFromTLSCert(tlsCert *tls.Certificate, certPath, keyPath string) error {
|
||||||
|
if tlsCert == nil || len(tlsCert.Certificate) == 0 {
|
||||||
|
return fmt.Errorf("invalid tls certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
certFile, err := os.Create(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer certFile.Close()
|
||||||
|
|
||||||
|
for _, certBytes := range tlsCert.Certificate {
|
||||||
|
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsCert.PrivateKey == nil {
|
||||||
|
return fmt.Errorf("private key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyFile, err := os.Create(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer keyFile.Close()
|
||||||
|
|
||||||
|
var keyBytes []byte
|
||||||
|
switch key := tlsCert.PrivateKey.(type) {
|
||||||
|
case *x509.Certificate:
|
||||||
|
keyBytes, _ = x509.MarshalPKCS8PrivateKey(key)
|
||||||
|
default:
|
||||||
|
keyBytes, _ = x509.MarshalPKCS8PrivateKey(tlsCert.PrivateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})
|
||||||
|
os.Chmod(certPath, 0644)
|
||||||
|
os.Chmod(keyPath, 0600)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
@ -49,6 +49,13 @@ func NewClient(cfg Config, logger *zap.Logger) (*Client, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnderlyingClient returns the underlying olriclib.Client for advanced usage.
|
||||||
|
// This is useful when you need to pass the client to other packages that expect
|
||||||
|
// the raw olric client interface.
|
||||||
|
func (c *Client) UnderlyingClient() olriclib.Client {
|
||||||
|
return c.client
|
||||||
|
}
|
||||||
|
|
||||||
// Health checks if the Olric client is healthy
|
// Health checks if the Olric client is healthy
|
||||||
func (c *Client) Health(ctx context.Context) error {
|
func (c *Client) Health(ctx context.Context) error {
|
||||||
// Create a DMap to test connectivity
|
// Create a DMap to test connectivity
|
||||||
|
|||||||
217
pkg/pubsub/manager_test.go
Normal file
217
pkg/pubsub/manager_test.go
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
package pubsub
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p"
|
||||||
|
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestManager(t *testing.T, ns string) (*Manager, func()) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
h, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create libp2p host: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ps, err := pubsub.NewGossipSub(ctx, h)
|
||||||
|
if err != nil {
|
||||||
|
h.Close()
|
||||||
|
t.Fatalf("failed to create gossipsub: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := NewManager(ps, ns)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
mgr.Close()
|
||||||
|
h.Close()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Namespacing(t *testing.T) {
|
||||||
|
mgr, cleanup := createTestManager(t, "test-ns")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
topic := "my-topic"
|
||||||
|
expectedNamespacedTopic := "test-ns.my-topic"
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
err := mgr.Subscribe(ctx, topic, func(t string, d []byte) error { return nil })
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
_, exists := mgr.subscriptions[expectedNamespacedTopic]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("expected subscription for %s to exist", expectedNamespacedTopic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test override
|
||||||
|
overrideNS := "other-ns"
|
||||||
|
overrideCtx := context.WithValue(ctx, CtxKeyNamespaceOverride, overrideNS)
|
||||||
|
expectedOverrideTopic := "other-ns.my-topic"
|
||||||
|
|
||||||
|
err = mgr.Subscribe(overrideCtx, topic, func(t string, d []byte) error { return nil })
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Subscribe with override failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
_, exists = mgr.subscriptions[expectedOverrideTopic]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("expected subscription for %s to exist", expectedOverrideTopic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ListTopics
|
||||||
|
topics, err := mgr.ListTopics(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTopics failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(topics) != 1 || topics[0] != "my-topic" {
|
||||||
|
t.Errorf("expected 1 topic [my-topic], got %v", topics)
|
||||||
|
}
|
||||||
|
|
||||||
|
topicsOverride, err := mgr.ListTopics(overrideCtx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTopics with override failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(topicsOverride) != 1 || topicsOverride[0] != "my-topic" {
|
||||||
|
t.Errorf("expected 1 topic [my-topic] with override, got %v", topicsOverride)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_RefCount(t *testing.T) {
|
||||||
|
mgr, cleanup := createTestManager(t, "test-ns")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
topic := "ref-topic"
|
||||||
|
namespacedTopic := "test-ns.ref-topic"
|
||||||
|
|
||||||
|
h1 := func(t string, d []byte) error { return nil }
|
||||||
|
h2 := func(t string, d []byte) error { return nil }
|
||||||
|
|
||||||
|
// First subscription
|
||||||
|
err := mgr.Subscribe(ctx, topic, h1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
ts := mgr.subscriptions[namespacedTopic]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
|
||||||
|
if ts.refCount != 1 {
|
||||||
|
t.Errorf("expected refCount 1, got %d", ts.refCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second subscription
|
||||||
|
err = mgr.Subscribe(ctx, topic, h2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts.refCount != 2 {
|
||||||
|
t.Errorf("expected refCount 2, got %d", ts.refCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe one
|
||||||
|
err = mgr.Unsubscribe(ctx, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unsubscribe 1 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts.refCount != 1 {
|
||||||
|
t.Errorf("expected refCount 1 after one unsubscribe, got %d", ts.refCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
_, exists := mgr.subscriptions[namespacedTopic]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
if !exists {
|
||||||
|
t.Error("expected subscription to still exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe second
|
||||||
|
err = mgr.Unsubscribe(ctx, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unsubscribe 2 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
_, exists = mgr.subscriptions[namespacedTopic]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
if exists {
|
||||||
|
t.Error("expected subscription to be removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PubSub(t *testing.T) {
|
||||||
|
// For a real pubsub test between two managers, we need them to be connected
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
h1, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||||
|
ps1, _ := pubsub.NewGossipSub(ctx, h1)
|
||||||
|
mgr1 := NewManager(ps1, "test")
|
||||||
|
defer h1.Close()
|
||||||
|
defer mgr1.Close()
|
||||||
|
|
||||||
|
h2, _ := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||||
|
ps2, _ := pubsub.NewGossipSub(ctx, h2)
|
||||||
|
mgr2 := NewManager(ps2, "test")
|
||||||
|
defer h2.Close()
|
||||||
|
defer mgr2.Close()
|
||||||
|
|
||||||
|
// Connect hosts
|
||||||
|
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
|
||||||
|
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect hosts: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
topic := "chat"
|
||||||
|
msgData := []byte("hello world")
|
||||||
|
received := make(chan []byte, 1)
|
||||||
|
|
||||||
|
err = mgr2.Subscribe(ctx, topic, func(t string, d []byte) error {
|
||||||
|
received <- d
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mgr2 subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for mesh to form (mgr1 needs to know about mgr2's subscription)
|
||||||
|
// In a real network this happens via gossip. We'll just retry publish.
|
||||||
|
timeout := time.After(5 * time.Second)
|
||||||
|
ticker := time.NewTicker(100 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("timed out waiting for message")
|
||||||
|
case <-ticker.C:
|
||||||
|
_ = mgr1.Publish(ctx, topic, msgData)
|
||||||
|
case data := <-received:
|
||||||
|
if string(data) != string(msgData) {
|
||||||
|
t.Errorf("expected %s, got %s", string(msgData), string(data))
|
||||||
|
}
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -595,10 +595,19 @@ func setReflectValue(field reflect.Value, raw any) error {
|
|||||||
switch v := raw.(type) {
|
switch v := raw.(type) {
|
||||||
case int64:
|
case int64:
|
||||||
field.SetInt(v)
|
field.SetInt(v)
|
||||||
|
case float64:
|
||||||
|
// RQLite/JSON returns numbers as float64
|
||||||
|
field.SetInt(int64(v))
|
||||||
|
case int:
|
||||||
|
field.SetInt(int64(v))
|
||||||
case []byte:
|
case []byte:
|
||||||
var n int64
|
var n int64
|
||||||
fmt.Sscan(string(v), &n)
|
fmt.Sscan(string(v), &n)
|
||||||
field.SetInt(n)
|
field.SetInt(n)
|
||||||
|
case string:
|
||||||
|
var n int64
|
||||||
|
fmt.Sscan(v, &n)
|
||||||
|
field.SetInt(n)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("cannot convert %T to int", raw)
|
return fmt.Errorf("cannot convert %T to int", raw)
|
||||||
}
|
}
|
||||||
@ -609,10 +618,22 @@ func setReflectValue(field reflect.Value, raw any) error {
|
|||||||
v = 0
|
v = 0
|
||||||
}
|
}
|
||||||
field.SetUint(uint64(v))
|
field.SetUint(uint64(v))
|
||||||
|
case float64:
|
||||||
|
// RQLite/JSON returns numbers as float64
|
||||||
|
if v < 0 {
|
||||||
|
v = 0
|
||||||
|
}
|
||||||
|
field.SetUint(uint64(v))
|
||||||
|
case uint64:
|
||||||
|
field.SetUint(v)
|
||||||
case []byte:
|
case []byte:
|
||||||
var n uint64
|
var n uint64
|
||||||
fmt.Sscan(string(v), &n)
|
fmt.Sscan(string(v), &n)
|
||||||
field.SetUint(n)
|
field.SetUint(n)
|
||||||
|
case string:
|
||||||
|
var n uint64
|
||||||
|
fmt.Sscan(v, &n)
|
||||||
|
field.SetUint(n)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("cannot convert %T to uint", raw)
|
return fmt.Errorf("cannot convert %T to uint", raw)
|
||||||
}
|
}
|
||||||
@ -628,11 +649,16 @@ func setReflectValue(field reflect.Value, raw any) error {
|
|||||||
return fmt.Errorf("cannot convert %T to float", raw)
|
return fmt.Errorf("cannot convert %T to float", raw)
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
// Support time.Time; extend as needed.
|
// Support time.Time
|
||||||
if field.Type() == reflect.TypeOf(time.Time{}) {
|
if field.Type() == reflect.TypeOf(time.Time{}) {
|
||||||
switch v := raw.(type) {
|
switch v := raw.(type) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
field.Set(reflect.ValueOf(v))
|
field.Set(reflect.ValueOf(v))
|
||||||
|
case string:
|
||||||
|
// Try RFC3339
|
||||||
|
if tt, err := time.Parse(time.RFC3339, v); err == nil {
|
||||||
|
field.Set(reflect.ValueOf(tt))
|
||||||
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
// Try RFC3339
|
// Try RFC3339
|
||||||
if tt, err := time.Parse(time.RFC3339, string(v)); err == nil {
|
if tt, err := time.Parse(time.RFC3339, string(v)); err == nil {
|
||||||
@ -641,6 +667,68 @@ func setReflectValue(field reflect.Value, raw any) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// Support sql.NullString
|
||||||
|
if field.Type() == reflect.TypeOf(sql.NullString{}) {
|
||||||
|
ns := sql.NullString{}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
ns.String = v
|
||||||
|
ns.Valid = true
|
||||||
|
case []byte:
|
||||||
|
ns.String = string(v)
|
||||||
|
ns.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(ns))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Support sql.NullInt64
|
||||||
|
if field.Type() == reflect.TypeOf(sql.NullInt64{}) {
|
||||||
|
ni := sql.NullInt64{}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case int64:
|
||||||
|
ni.Int64 = v
|
||||||
|
ni.Valid = true
|
||||||
|
case float64:
|
||||||
|
ni.Int64 = int64(v)
|
||||||
|
ni.Valid = true
|
||||||
|
case int:
|
||||||
|
ni.Int64 = int64(v)
|
||||||
|
ni.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(ni))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Support sql.NullBool
|
||||||
|
if field.Type() == reflect.TypeOf(sql.NullBool{}) {
|
||||||
|
nb := sql.NullBool{}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case bool:
|
||||||
|
nb.Bool = v
|
||||||
|
nb.Valid = true
|
||||||
|
case int64:
|
||||||
|
nb.Bool = v != 0
|
||||||
|
nb.Valid = true
|
||||||
|
case float64:
|
||||||
|
nb.Bool = v != 0
|
||||||
|
nb.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(nb))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Support sql.NullFloat64
|
||||||
|
if field.Type() == reflect.TypeOf(sql.NullFloat64{}) {
|
||||||
|
nf := sql.NullFloat64{}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case float64:
|
||||||
|
nf.Float64 = v
|
||||||
|
nf.Valid = true
|
||||||
|
case int64:
|
||||||
|
nf.Float64 = float64(v)
|
||||||
|
nf.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(nf))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
// Not supported yet
|
// Not supported yet
|
||||||
|
|||||||
301
pkg/rqlite/cluster.go
Normal file
301
pkg/rqlite/cluster.go
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// establishLeadershipOrJoin handles post-startup cluster establishment
|
||||||
|
func (r *RQLiteManager) establishLeadershipOrJoin(ctx context.Context, rqliteDataDir string) error {
|
||||||
|
timeout := 5 * time.Minute
|
||||||
|
if r.config.RQLiteJoinAddress == "" {
|
||||||
|
timeout = 2 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := r.waitForSQLAvailable(sqlCtx); err != nil {
|
||||||
|
if r.cmd != nil && r.cmd.Process != nil {
|
||||||
|
_ = r.cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForMinClusterSizeBeforeStart waits for minimum cluster size to be discovered
|
||||||
|
func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rqliteDataDir string) error {
|
||||||
|
if r.discoveryService == nil {
|
||||||
|
return fmt.Errorf("discovery service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredRemotePeers := r.config.MinClusterSize - 1
|
||||||
|
_ = r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
|
|
||||||
|
checkInterval := 2 * time.Second
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
r.discoveryService.TriggerSync()
|
||||||
|
time.Sleep(checkInterval)
|
||||||
|
|
||||||
|
allPeers := r.discoveryService.GetAllPeers()
|
||||||
|
remotePeerCount := 0
|
||||||
|
for _, peer := range allPeers {
|
||||||
|
if peer.NodeID != r.discoverConfig.RaftAdvAddress {
|
||||||
|
remotePeerCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if remotePeerCount >= requiredRemotePeers {
|
||||||
|
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
|
||||||
|
r.discoveryService.TriggerSync()
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 {
|
||||||
|
data, err := os.ReadFile(peersPath)
|
||||||
|
if err == nil {
|
||||||
|
var peers []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &peers); err == nil && len(peers) >= requiredRemotePeers {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// performPreStartClusterDiscovery builds peers.json before starting RQLite
|
||||||
|
func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rqliteDataDir string) error {
|
||||||
|
if r.discoveryService == nil {
|
||||||
|
return fmt.Errorf("discovery service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.discoveryService.TriggerSync()
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
discoveryDeadline := time.Now().Add(30 * time.Second)
|
||||||
|
var discoveredPeers int
|
||||||
|
|
||||||
|
for time.Now().Before(discoveryDeadline) {
|
||||||
|
allPeers := r.discoveryService.GetAllPeers()
|
||||||
|
discoveredPeers = len(allPeers)
|
||||||
|
|
||||||
|
if discoveredPeers >= r.config.MinClusterSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
if discoveredPeers <= 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.hasExistingRaftState(rqliteDataDir) {
|
||||||
|
ourLogIndex := r.getRaftLogIndex()
|
||||||
|
maxPeerIndex := uint64(0)
|
||||||
|
for _, peer := range r.discoveryService.GetAllPeers() {
|
||||||
|
if peer.NodeID != r.discoverConfig.RaftAdvAddress && peer.RaftLogIndex > maxPeerIndex {
|
||||||
|
maxPeerIndex = peer.RaftLogIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ourLogIndex == 0 && maxPeerIndex > 0 {
|
||||||
|
_ = r.clearRaftState(rqliteDataDir)
|
||||||
|
_ = r.discoveryService.ForceWritePeersJSON()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.discoveryService.TriggerSync()
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// recoverCluster restarts RQLite using peers.json
|
||||||
|
func (r *RQLiteManager) recoverCluster(ctx context.Context, peersJSONPath string) error {
|
||||||
|
_ = r.Stop()
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
rqliteDataDir, err := r.rqliteDataDirPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.launchProcess(ctx, rqliteDataDir); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.waitForReadyAndConnect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recoverFromSplitBrain automatically recovers from split-brain state
|
||||||
|
func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
|
||||||
|
if r.discoveryService == nil {
|
||||||
|
return fmt.Errorf("discovery service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
r.discoveryService.TriggerSync()
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
rqliteDataDir, _ := r.rqliteDataDirPath()
|
||||||
|
ourIndex := r.getRaftLogIndex()
|
||||||
|
|
||||||
|
maxPeerIndex := uint64(0)
|
||||||
|
for _, peer := range r.discoveryService.GetAllPeers() {
|
||||||
|
if peer.NodeID != r.discoverConfig.RaftAdvAddress && peer.RaftLogIndex > maxPeerIndex {
|
||||||
|
maxPeerIndex = peer.RaftLogIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ourIndex == 0 && maxPeerIndex > 0 {
|
||||||
|
_ = r.clearRaftState(rqliteDataDir)
|
||||||
|
r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
_ = r.discoveryService.ForceWritePeersJSON()
|
||||||
|
return r.recoverCluster(ctx, filepath.Join(rqliteDataDir, "raft", "peers.json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInSplitBrainState detects if we're in a split-brain scenario
|
||||||
|
func (r *RQLiteManager) isInSplitBrainState() bool {
|
||||||
|
status, err := r.getRQLiteStatus()
|
||||||
|
if err != nil || r.discoveryService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
raft := status.Store.Raft
|
||||||
|
if raft.State == "Follower" && raft.Term == 0 && raft.NumPeers == 0 && !raft.Voter {
|
||||||
|
peers := r.discoveryService.GetActivePeers()
|
||||||
|
if len(peers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
reachableCount := 0
|
||||||
|
splitBrainCount := 0
|
||||||
|
for _, peer := range peers {
|
||||||
|
if r.isPeerReachable(peer.HTTPAddress) {
|
||||||
|
reachableCount++
|
||||||
|
peerStatus, err := r.getPeerRQLiteStatus(peer.HTTPAddress)
|
||||||
|
if err == nil {
|
||||||
|
praft := peerStatus.Store.Raft
|
||||||
|
if praft.State == "Follower" && praft.Term == 0 && praft.NumPeers == 0 && !praft.Voter {
|
||||||
|
splitBrainCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reachableCount > 0 && splitBrainCount == reachableCount
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) isPeerReachable(httpAddr string) bool {
|
||||||
|
client := &http.Client{Timeout: 3 * time.Second}
|
||||||
|
resp, err := client.Get(fmt.Sprintf("http://%s/status", httpAddr))
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
return resp.StatusCode == http.StatusOK
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) getPeerRQLiteStatus(httpAddr string) (*RQLiteStatus, error) {
|
||||||
|
client := &http.Client{Timeout: 3 * time.Second}
|
||||||
|
resp, err := client.Get(fmt.Sprintf("http://%s/status", httpAddr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var status RQLiteStatus
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) startHealthMonitoring(ctx context.Context) {
|
||||||
|
time.Sleep(30 * time.Second)
|
||||||
|
ticker := time.NewTicker(60 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if r.isInSplitBrainState() {
|
||||||
|
_ = r.recoverFromSplitBrain(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNeedsClusterRecovery checks if the node has old cluster state that requires coordinated recovery
|
||||||
|
func (r *RQLiteManager) checkNeedsClusterRecovery(rqliteDataDir string) (bool, error) {
|
||||||
|
snapshotsDir := filepath.Join(rqliteDataDir, "rsnapshots")
|
||||||
|
if _, err := os.Stat(snapshotsDir); os.IsNotExist(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(snapshotsDir)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hasSnapshots := false
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || strings.HasSuffix(entry.Name(), ".db") {
|
||||||
|
hasSnapshots = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasSnapshots {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
raftLogPath := filepath.Join(rqliteDataDir, "raft.db")
|
||||||
|
if info, err := os.Stat(raftLogPath); err == nil {
|
||||||
|
if info.Size() <= 8*1024*1024 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) hasExistingRaftState(rqliteDataDir string) bool {
|
||||||
|
raftLogPath := filepath.Join(rqliteDataDir, "raft.db")
|
||||||
|
if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
|
||||||
|
_, err := os.Stat(peersPath)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error {
|
||||||
|
_ = os.Remove(filepath.Join(rqliteDataDir, "raft.db"))
|
||||||
|
_ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
@ -2,20 +2,12 @@ package rqlite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/discovery"
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
"github.com/libp2p/go-libp2p/core/peer"
|
|
||||||
"github.com/multiformats/go-multiaddr"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,855 +152,3 @@ func (c *ClusterDiscoveryService) periodicCleanup(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// collectPeerMetadata collects RQLite metadata from LibP2P peers
|
|
||||||
func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeMetadata {
|
|
||||||
connectedPeers := c.host.Network().Peers()
|
|
||||||
var metadata []*discovery.RQLiteNodeMetadata
|
|
||||||
|
|
||||||
// Metadata collection is routine - no need to log every occurrence
|
|
||||||
|
|
||||||
c.mu.RLock()
|
|
||||||
currentRaftAddr := c.raftAddress
|
|
||||||
currentHTTPAddr := c.httpAddress
|
|
||||||
c.mu.RUnlock()
|
|
||||||
|
|
||||||
// Add ourselves
|
|
||||||
ourMetadata := &discovery.RQLiteNodeMetadata{
|
|
||||||
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
|
|
||||||
RaftAddress: currentRaftAddr,
|
|
||||||
HTTPAddress: currentHTTPAddr,
|
|
||||||
NodeType: c.nodeType,
|
|
||||||
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
ClusterVersion: "1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
|
|
||||||
c.logger.Debug("Adjusted self-advertised RQLite addresses",
|
|
||||||
zap.String("raft_address", ourMetadata.RaftAddress),
|
|
||||||
zap.String("http_address", ourMetadata.HTTPAddress))
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata = append(metadata, ourMetadata)
|
|
||||||
|
|
||||||
staleNodeIDs := make([]string, 0)
|
|
||||||
|
|
||||||
// Query connected peers for their RQLite metadata
|
|
||||||
// For now, we'll use a simple approach - store metadata in peer metadata store
|
|
||||||
// In a full implementation, this would use a custom protocol to exchange RQLite metadata
|
|
||||||
for _, peerID := range connectedPeers {
|
|
||||||
// Try to get stored metadata from peerstore
|
|
||||||
// This would be populated by a peer exchange protocol
|
|
||||||
if val, err := c.host.Peerstore().Get(peerID, "rqlite_metadata"); err == nil {
|
|
||||||
if jsonData, ok := val.([]byte); ok {
|
|
||||||
var peerMeta discovery.RQLiteNodeMetadata
|
|
||||||
if err := json.Unmarshal(jsonData, &peerMeta); err == nil {
|
|
||||||
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, &peerMeta); updated && stale != "" {
|
|
||||||
staleNodeIDs = append(staleNodeIDs, stale)
|
|
||||||
}
|
|
||||||
peerMeta.LastSeen = time.Now()
|
|
||||||
metadata = append(metadata, &peerMeta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up stale entries if NodeID changed
|
|
||||||
if len(staleNodeIDs) > 0 {
|
|
||||||
c.mu.Lock()
|
|
||||||
for _, id := range staleNodeIDs {
|
|
||||||
delete(c.knownPeers, id)
|
|
||||||
delete(c.peerHealth, id)
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
// membershipUpdateResult contains the result of a membership update operation
|
|
||||||
type membershipUpdateResult struct {
|
|
||||||
peersJSON []map[string]interface{}
|
|
||||||
added []string
|
|
||||||
updated []string
|
|
||||||
changed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateClusterMembership updates the cluster membership based on discovered peers
|
|
||||||
func (c *ClusterDiscoveryService) updateClusterMembership() {
|
|
||||||
metadata := c.collectPeerMetadata()
|
|
||||||
|
|
||||||
// Compute membership changes while holding lock
|
|
||||||
c.mu.Lock()
|
|
||||||
result := c.computeMembershipChangesLocked(metadata)
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
// Perform file I/O outside the lock
|
|
||||||
if result.changed {
|
|
||||||
// Log state changes (peer added/removed) at Info level
|
|
||||||
if len(result.added) > 0 || len(result.updated) > 0 {
|
|
||||||
c.logger.Info("Membership changed",
|
|
||||||
zap.Int("added", len(result.added)),
|
|
||||||
zap.Int("updated", len(result.updated)),
|
|
||||||
zap.Strings("added", result.added),
|
|
||||||
zap.Strings("updated", result.updated))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write peers.json without holding lock
|
|
||||||
if err := c.writePeersJSONWithData(result.peersJSON); err != nil {
|
|
||||||
c.logger.Error("Failed to write peers.json",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("data_dir", c.dataDir),
|
|
||||||
zap.Int("peers", len(result.peersJSON)))
|
|
||||||
} else {
|
|
||||||
c.logger.Debug("peers.json updated",
|
|
||||||
zap.Int("peers", len(result.peersJSON)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update lastUpdate timestamp
|
|
||||||
c.mu.Lock()
|
|
||||||
c.lastUpdate = time.Now()
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
// No changes - don't log (reduces noise)
|
|
||||||
}
|
|
||||||
|
|
||||||
// computeMembershipChangesLocked computes membership changes and returns snapshot data
|
|
||||||
// Must be called with lock held
|
|
||||||
func (c *ClusterDiscoveryService) computeMembershipChangesLocked(metadata []*discovery.RQLiteNodeMetadata) membershipUpdateResult {
|
|
||||||
// Track changes
|
|
||||||
added := []string{}
|
|
||||||
updated := []string{}
|
|
||||||
|
|
||||||
// Update known peers, but skip self for health tracking
|
|
||||||
for _, meta := range metadata {
|
|
||||||
// Skip self-metadata for health tracking (we only track remote peers)
|
|
||||||
isSelf := meta.NodeID == c.raftAddress
|
|
||||||
|
|
||||||
if existing, ok := c.knownPeers[meta.NodeID]; ok {
|
|
||||||
// Update existing peer
|
|
||||||
if existing.RaftLogIndex != meta.RaftLogIndex ||
|
|
||||||
existing.HTTPAddress != meta.HTTPAddress ||
|
|
||||||
existing.RaftAddress != meta.RaftAddress {
|
|
||||||
updated = append(updated, meta.NodeID)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// New peer discovered
|
|
||||||
added = append(added, meta.NodeID)
|
|
||||||
c.logger.Info("Node added",
|
|
||||||
zap.String("node", meta.NodeID),
|
|
||||||
zap.String("raft", meta.RaftAddress),
|
|
||||||
zap.String("type", meta.NodeType),
|
|
||||||
zap.Uint64("log_index", meta.RaftLogIndex))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.knownPeers[meta.NodeID] = meta
|
|
||||||
|
|
||||||
// Update health tracking only for remote peers
|
|
||||||
if !isSelf {
|
|
||||||
if _, ok := c.peerHealth[meta.NodeID]; !ok {
|
|
||||||
c.peerHealth[meta.NodeID] = &PeerHealth{
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
LastSuccessful: time.Now(),
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.peerHealth[meta.NodeID].LastSeen = time.Now()
|
|
||||||
c.peerHealth[meta.NodeID].Status = "active"
|
|
||||||
c.peerHealth[meta.NodeID].FailureCount = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CRITICAL FIX: Count remote peers (excluding self)
|
|
||||||
remotePeerCount := 0
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
if peer.NodeID != c.raftAddress {
|
|
||||||
remotePeerCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get peers JSON snapshot (for checking if it would be empty)
|
|
||||||
peers := c.getPeersJSONUnlocked()
|
|
||||||
|
|
||||||
// Determine if we should write peers.json
|
|
||||||
shouldWrite := len(added) > 0 || len(updated) > 0 || c.lastUpdate.IsZero()
|
|
||||||
|
|
||||||
// CRITICAL FIX: Don't write peers.json until we have minimum cluster size
|
|
||||||
// This prevents RQLite from starting as a single-node cluster
|
|
||||||
// For min_cluster_size=3, we need at least 2 remote peers (plus self = 3 total)
|
|
||||||
if shouldWrite {
|
|
||||||
// For initial sync, wait until we have at least (MinClusterSize - 1) remote peers
|
|
||||||
// This ensures peers.json contains enough peers for proper cluster formation
|
|
||||||
if c.lastUpdate.IsZero() {
|
|
||||||
requiredRemotePeers := c.minClusterSize - 1
|
|
||||||
|
|
||||||
if remotePeerCount < requiredRemotePeers {
|
|
||||||
c.logger.Info("Waiting for peers",
|
|
||||||
zap.Int("have", remotePeerCount),
|
|
||||||
zap.Int("need", requiredRemotePeers),
|
|
||||||
zap.Int("min_size", c.minClusterSize))
|
|
||||||
return membershipUpdateResult{
|
|
||||||
changed: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additional safety check: don't write empty peers.json (would cause single-node cluster)
|
|
||||||
if len(peers) == 0 && c.lastUpdate.IsZero() {
|
|
||||||
c.logger.Info("No remote peers - waiting")
|
|
||||||
return membershipUpdateResult{
|
|
||||||
changed: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log initial sync if this is the first time
|
|
||||||
if c.lastUpdate.IsZero() {
|
|
||||||
c.logger.Info("Initial sync",
|
|
||||||
zap.Int("total", len(c.knownPeers)),
|
|
||||||
zap.Int("remote", remotePeerCount),
|
|
||||||
zap.Int("in_json", len(peers)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return membershipUpdateResult{
|
|
||||||
peersJSON: peers,
|
|
||||||
added: added,
|
|
||||||
updated: updated,
|
|
||||||
changed: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return membershipUpdateResult{
|
|
||||||
changed: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeInactivePeers removes peers that haven't been seen for longer than the inactivity limit
|
|
||||||
func (c *ClusterDiscoveryService) removeInactivePeers() {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
removed := []string{}
|
|
||||||
|
|
||||||
for nodeID, health := range c.peerHealth {
|
|
||||||
inactiveDuration := now.Sub(health.LastSeen)
|
|
||||||
|
|
||||||
if inactiveDuration > c.inactivityLimit {
|
|
||||||
// Mark as inactive and remove
|
|
||||||
c.logger.Warn("Node removed",
|
|
||||||
zap.String("node", nodeID),
|
|
||||||
zap.String("reason", "inactive"),
|
|
||||||
zap.Duration("inactive_duration", inactiveDuration))
|
|
||||||
|
|
||||||
delete(c.knownPeers, nodeID)
|
|
||||||
delete(c.peerHealth, nodeID)
|
|
||||||
removed = append(removed, nodeID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regenerate peers.json if any peers were removed
|
|
||||||
if len(removed) > 0 {
|
|
||||||
c.logger.Info("Removed inactive",
|
|
||||||
zap.Int("count", len(removed)),
|
|
||||||
zap.Strings("nodes", removed))
|
|
||||||
|
|
||||||
if err := c.writePeersJSON(); err != nil {
|
|
||||||
c.logger.Error("Failed to write peers.json after cleanup", zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeersJSON generates the peers.json structure from active peers (acquires lock)
|
|
||||||
func (c *ClusterDiscoveryService) getPeersJSON() []map[string]interface{} {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
return c.getPeersJSONUnlocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeersJSONUnlocked generates the peers.json structure (must be called with lock held)
|
|
||||||
func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{} {
|
|
||||||
peers := make([]map[string]interface{}, 0, len(c.knownPeers))
|
|
||||||
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
// CRITICAL FIX: Include ALL peers (including self) in peers.json
|
|
||||||
// When using expect configuration with recovery, RQLite needs the complete
|
|
||||||
// expected cluster configuration to properly form consensus.
|
|
||||||
// The peers.json file is used by RQLite's recovery mechanism to know
|
|
||||||
// what the full cluster membership should be, including the local node.
|
|
||||||
peerEntry := map[string]interface{}{
|
|
||||||
"id": peer.RaftAddress, // RQLite uses raft address as node ID
|
|
||||||
"address": peer.RaftAddress,
|
|
||||||
"non_voter": false,
|
|
||||||
}
|
|
||||||
peers = append(peers, peerEntry)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
// writePeersJSON atomically writes the peers.json file (acquires lock)
|
|
||||||
func (c *ClusterDiscoveryService) writePeersJSON() error {
|
|
||||||
c.mu.RLock()
|
|
||||||
peers := c.getPeersJSONUnlocked()
|
|
||||||
c.mu.RUnlock()
|
|
||||||
|
|
||||||
return c.writePeersJSONWithData(peers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// writePeersJSONWithData writes the peers.json file with provided data (no lock needed)
|
|
||||||
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
|
|
||||||
// Expand ~ in data directory path
|
|
||||||
dataDir := os.ExpandEnv(c.dataDir)
|
|
||||||
if strings.HasPrefix(dataDir, "~") {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to determine home directory: %w", err)
|
|
||||||
}
|
|
||||||
dataDir = filepath.Join(home, dataDir[1:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the RQLite raft directory
|
|
||||||
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
|
|
||||||
|
|
||||||
// Writing peers.json - routine operation, no need to log details
|
|
||||||
|
|
||||||
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peersFile := filepath.Join(rqliteDir, "peers.json")
|
|
||||||
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
|
|
||||||
|
|
||||||
// Backup existing peers.json if it exists
|
|
||||||
if _, err := os.Stat(peersFile); err == nil {
|
|
||||||
// Backup existing peers.json if it exists - routine operation
|
|
||||||
data, err := os.ReadFile(peersFile)
|
|
||||||
if err == nil {
|
|
||||||
_ = os.WriteFile(backupFile, data, 0644)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal to JSON
|
|
||||||
data, err := json.MarshalIndent(peers, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal peers.json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshaled peers.json - routine operation
|
|
||||||
|
|
||||||
// Write atomically using temp file + rename
|
|
||||||
tempFile := peersFile + ".tmp"
|
|
||||||
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
|
||||||
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.Rename(tempFile, peersFile); err != nil {
|
|
||||||
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeIDs := make([]string, 0, len(peers))
|
|
||||||
for _, p := range peers {
|
|
||||||
if id, ok := p["id"].(string); ok {
|
|
||||||
nodeIDs = append(nodeIDs, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Info("peers.json written",
|
|
||||||
zap.Int("peers", len(peers)),
|
|
||||||
zap.Strings("nodes", nodeIDs))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetActivePeers returns a list of active peers (not including self)
|
|
||||||
func (c *ClusterDiscoveryService) GetActivePeers() []*discovery.RQLiteNodeMetadata {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
// Skip self (compare by raft address since that's the NodeID now)
|
|
||||||
if peer.NodeID == c.raftAddress {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllPeers returns a list of all known peers (including self)
|
|
||||||
func (c *ClusterDiscoveryService) GetAllPeers() []*discovery.RQLiteNodeMetadata {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNodeWithHighestLogIndex returns the node with the highest Raft log index
|
|
||||||
func (c *ClusterDiscoveryService) GetNodeWithHighestLogIndex() *discovery.RQLiteNodeMetadata {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
var highest *discovery.RQLiteNodeMetadata
|
|
||||||
var maxIndex uint64 = 0
|
|
||||||
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
// Skip self (compare by raft address since that's the NodeID now)
|
|
||||||
if peer.NodeID == c.raftAddress {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer.RaftLogIndex > maxIndex {
|
|
||||||
maxIndex = peer.RaftLogIndex
|
|
||||||
highest = peer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return highest
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasRecentPeersJSON checks if peers.json was recently updated
|
|
||||||
func (c *ClusterDiscoveryService) HasRecentPeersJSON() bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
// Consider recent if updated in last 5 minutes
|
|
||||||
return time.Since(c.lastUpdate) < 5*time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindJoinTargets discovers join targets via LibP2P
|
|
||||||
func (c *ClusterDiscoveryService) FindJoinTargets() []string {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
targets := []string{}
|
|
||||||
|
|
||||||
// All nodes are equal - prioritize by Raft log index (more advanced = better)
|
|
||||||
type nodeWithIndex struct {
|
|
||||||
address string
|
|
||||||
logIndex uint64
|
|
||||||
}
|
|
||||||
var nodes []nodeWithIndex
|
|
||||||
for _, peer := range c.knownPeers {
|
|
||||||
nodes = append(nodes, nodeWithIndex{peer.RaftAddress, peer.RaftLogIndex})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort by log index descending (higher log index = more up-to-date)
|
|
||||||
for i := 0; i < len(nodes)-1; i++ {
|
|
||||||
for j := i + 1; j < len(nodes); j++ {
|
|
||||||
if nodes[j].logIndex > nodes[i].logIndex {
|
|
||||||
nodes[i], nodes[j] = nodes[j], nodes[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, n := range nodes {
|
|
||||||
targets = append(targets, n.address)
|
|
||||||
}
|
|
||||||
|
|
||||||
return targets
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForDiscoverySettling waits for LibP2P discovery to settle (used on concurrent startup)
|
|
||||||
func (c *ClusterDiscoveryService) WaitForDiscoverySettling(ctx context.Context) {
|
|
||||||
settleDuration := 60 * time.Second
|
|
||||||
c.logger.Info("Waiting for discovery to settle",
|
|
||||||
zap.Duration("duration", settleDuration))
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(settleDuration):
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect final peer list
|
|
||||||
c.updateClusterMembership()
|
|
||||||
|
|
||||||
c.mu.RLock()
|
|
||||||
peerCount := len(c.knownPeers)
|
|
||||||
c.mu.RUnlock()
|
|
||||||
|
|
||||||
c.logger.Info("Discovery settled",
|
|
||||||
zap.Int("peer_count", peerCount))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerSync manually triggers a cluster membership sync
|
|
||||||
func (c *ClusterDiscoveryService) TriggerSync() {
|
|
||||||
// All nodes use the same discovery timing for consistency
|
|
||||||
c.updateClusterMembership()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
|
|
||||||
// This is useful after clearing raft state when we need to recreate peers.json
|
|
||||||
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
|
|
||||||
c.logger.Info("Force writing peers.json")
|
|
||||||
|
|
||||||
// First, collect latest peer metadata to ensure we have current information
|
|
||||||
metadata := c.collectPeerMetadata()
|
|
||||||
|
|
||||||
// Update known peers with latest metadata (without writing file yet)
|
|
||||||
c.mu.Lock()
|
|
||||||
for _, meta := range metadata {
|
|
||||||
c.knownPeers[meta.NodeID] = meta
|
|
||||||
// Update health tracking for remote peers
|
|
||||||
if meta.NodeID != c.raftAddress {
|
|
||||||
if _, ok := c.peerHealth[meta.NodeID]; !ok {
|
|
||||||
c.peerHealth[meta.NodeID] = &PeerHealth{
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
LastSuccessful: time.Now(),
|
|
||||||
Status: "active",
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.peerHealth[meta.NodeID].LastSeen = time.Now()
|
|
||||||
c.peerHealth[meta.NodeID].Status = "active"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peers := c.getPeersJSONUnlocked()
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
// Now force write the file
|
|
||||||
if err := c.writePeersJSONWithData(peers); err != nil {
|
|
||||||
c.logger.Error("Failed to force write peers.json",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("data_dir", c.dataDir),
|
|
||||||
zap.Int("peers", len(peers)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Info("peers.json written",
|
|
||||||
zap.Int("peers", len(peers)))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerPeerExchange actively exchanges peer information with connected peers
|
|
||||||
// This populates the peerstore with RQLite metadata from other nodes
|
|
||||||
func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error {
|
|
||||||
if c.discoveryMgr == nil {
|
|
||||||
return fmt.Errorf("discovery manager not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
collected := c.discoveryMgr.TriggerPeerExchange(ctx)
|
|
||||||
c.logger.Debug("Exchange completed", zap.Int("with_metadata", collected))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
|
|
||||||
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
|
||||||
c.mu.RLock()
|
|
||||||
currentRaftAddr := c.raftAddress
|
|
||||||
currentHTTPAddr := c.httpAddress
|
|
||||||
c.mu.RUnlock()
|
|
||||||
|
|
||||||
metadata := &discovery.RQLiteNodeMetadata{
|
|
||||||
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
|
|
||||||
RaftAddress: currentRaftAddr,
|
|
||||||
HTTPAddress: currentHTTPAddr,
|
|
||||||
NodeType: c.nodeType,
|
|
||||||
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
ClusterVersion: "1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust addresses if needed
|
|
||||||
if c.adjustSelfAdvertisedAddresses(metadata) {
|
|
||||||
c.logger.Debug("Adjusted self-advertised RQLite addresses in UpdateOwnMetadata",
|
|
||||||
zap.String("raft_address", metadata.RaftAddress),
|
|
||||||
zap.String("http_address", metadata.HTTPAddress))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store in our own peerstore for peer exchange
|
|
||||||
data, err := json.Marshal(metadata)
|
|
||||||
if err != nil {
|
|
||||||
c.logger.Error("Failed to marshal own metadata", zap.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.host.Peerstore().Put(c.host.ID(), "rqlite_metadata", data); err != nil {
|
|
||||||
c.logger.Error("Failed to store own metadata", zap.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Debug("Metadata updated",
|
|
||||||
zap.String("node", metadata.NodeID),
|
|
||||||
zap.Uint64("log_index", metadata.RaftLogIndex))
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreRemotePeerMetadata stores metadata received from a remote peer
|
|
||||||
func (c *ClusterDiscoveryService) StoreRemotePeerMetadata(peerID peer.ID, metadata *discovery.RQLiteNodeMetadata) error {
|
|
||||||
if metadata == nil {
|
|
||||||
return fmt.Errorf("metadata is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust addresses if needed (replace localhost with actual IP)
|
|
||||||
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, metadata); updated && stale != "" {
|
|
||||||
// Clean up stale entry if NodeID changed
|
|
||||||
c.mu.Lock()
|
|
||||||
delete(c.knownPeers, stale)
|
|
||||||
delete(c.peerHealth, stale)
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata.LastSeen = time.Now()
|
|
||||||
|
|
||||||
data, err := json.Marshal(metadata)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.host.Peerstore().Put(peerID, "rqlite_metadata", data); err != nil {
|
|
||||||
return fmt.Errorf("failed to store metadata: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Debug("Metadata stored",
|
|
||||||
zap.String("peer", shortPeerID(peerID)),
|
|
||||||
zap.String("node", metadata.NodeID))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// adjustPeerAdvertisedAddresses adjusts peer metadata addresses by replacing localhost/loopback
|
|
||||||
// with the actual IP address from LibP2P connection. Returns (updated, staleNodeID).
|
|
||||||
// staleNodeID is non-empty if NodeID changed (indicating old entry should be cleaned up).
|
|
||||||
func (c *ClusterDiscoveryService) adjustPeerAdvertisedAddresses(peerID peer.ID, meta *discovery.RQLiteNodeMetadata) (bool, string) {
|
|
||||||
ip := c.selectPeerIP(peerID)
|
|
||||||
if ip == "" {
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
changed, stale := rewriteAdvertisedAddresses(meta, ip, true)
|
|
||||||
if changed {
|
|
||||||
c.logger.Debug("Addresses normalized",
|
|
||||||
zap.String("peer", shortPeerID(peerID)),
|
|
||||||
zap.String("raft", meta.RaftAddress),
|
|
||||||
zap.String("http_address", meta.HTTPAddress))
|
|
||||||
}
|
|
||||||
return changed, stale
|
|
||||||
}
|
|
||||||
|
|
||||||
// adjustSelfAdvertisedAddresses adjusts our own metadata addresses by replacing localhost/loopback
|
|
||||||
// with the actual IP address from LibP2P host. Updates internal state if changed.
|
|
||||||
func (c *ClusterDiscoveryService) adjustSelfAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata) bool {
|
|
||||||
ip := c.selectSelfIP()
|
|
||||||
if ip == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
changed, _ := rewriteAdvertisedAddresses(meta, ip, true)
|
|
||||||
if !changed {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update internal state with corrected addresses
|
|
||||||
c.mu.Lock()
|
|
||||||
c.raftAddress = meta.RaftAddress
|
|
||||||
c.httpAddress = meta.HTTPAddress
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if c.rqliteManager != nil {
|
|
||||||
c.rqliteManager.UpdateAdvertisedAddresses(meta.RaftAddress, meta.HTTPAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectPeerIP selects the best IP address for a peer from LibP2P connections.
|
|
||||||
// Prefers public IPs, falls back to private IPs if no public IP is available.
|
|
||||||
func (c *ClusterDiscoveryService) selectPeerIP(peerID peer.ID) string {
|
|
||||||
var fallback string
|
|
||||||
|
|
||||||
// First, try to get IP from active connections
|
|
||||||
for _, conn := range c.host.Network().ConnsToPeer(peerID) {
|
|
||||||
if ip, public := ipFromMultiaddr(conn.RemoteMultiaddr()); ip != "" {
|
|
||||||
if shouldReplaceHost(ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if public {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if fallback == "" {
|
|
||||||
fallback = ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to peerstore addresses
|
|
||||||
for _, addr := range c.host.Peerstore().Addrs(peerID) {
|
|
||||||
if ip, public := ipFromMultiaddr(addr); ip != "" {
|
|
||||||
if shouldReplaceHost(ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if public {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if fallback == "" {
|
|
||||||
fallback = ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectSelfIP selects the best IP address for ourselves from LibP2P host addresses.
|
|
||||||
// Prefers public IPs, falls back to private IPs if no public IP is available.
|
|
||||||
func (c *ClusterDiscoveryService) selectSelfIP() string {
|
|
||||||
var fallback string
|
|
||||||
|
|
||||||
for _, addr := range c.host.Addrs() {
|
|
||||||
if ip, public := ipFromMultiaddr(addr); ip != "" {
|
|
||||||
if shouldReplaceHost(ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if public {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if fallback == "" {
|
|
||||||
fallback = ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewriteAdvertisedAddresses rewrites RaftAddress and HTTPAddress in metadata,
|
|
||||||
// replacing localhost/loopback addresses with the provided IP.
|
|
||||||
// Returns (changed, staleNodeID). staleNodeID is non-empty if NodeID changed.
|
|
||||||
func rewriteAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata, newHost string, allowNodeIDRewrite bool) (bool, string) {
|
|
||||||
if meta == nil || newHost == "" {
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
originalNodeID := meta.NodeID
|
|
||||||
changed := false
|
|
||||||
nodeIDChanged := false
|
|
||||||
|
|
||||||
// Replace host in RaftAddress if it's localhost/loopback
|
|
||||||
if newAddr, replaced := replaceAddressHost(meta.RaftAddress, newHost); replaced {
|
|
||||||
if meta.RaftAddress != newAddr {
|
|
||||||
meta.RaftAddress = newAddr
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace host in HTTPAddress if it's localhost/loopback
|
|
||||||
if newAddr, replaced := replaceAddressHost(meta.HTTPAddress, newHost); replaced {
|
|
||||||
if meta.HTTPAddress != newAddr {
|
|
||||||
meta.HTTPAddress = newAddr
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update NodeID to match RaftAddress if it changed
|
|
||||||
if allowNodeIDRewrite {
|
|
||||||
if meta.RaftAddress != "" && (meta.NodeID == "" || meta.NodeID == originalNodeID || shouldReplaceHost(hostFromAddress(meta.NodeID))) {
|
|
||||||
if meta.NodeID != meta.RaftAddress {
|
|
||||||
meta.NodeID = meta.RaftAddress
|
|
||||||
nodeIDChanged = meta.NodeID != originalNodeID
|
|
||||||
if nodeIDChanged {
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if nodeIDChanged {
|
|
||||||
return changed, originalNodeID
|
|
||||||
}
|
|
||||||
return changed, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// replaceAddressHost replaces the host part of an address if it's localhost/loopback.
|
|
||||||
// Returns (newAddress, replaced). replaced is true if host was replaced.
|
|
||||||
func replaceAddressHost(address, newHost string) (string, bool) {
|
|
||||||
if address == "" || newHost == "" {
|
|
||||||
return address, false
|
|
||||||
}
|
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return address, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldReplaceHost(host) {
|
|
||||||
return address, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.JoinHostPort(newHost, port), true
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldReplaceHost returns true if the host should be replaced (localhost, loopback, etc.)
|
|
||||||
func shouldReplaceHost(host string) bool {
|
|
||||||
if host == "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if strings.EqualFold(host, "localhost") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's a loopback or unspecified address
|
|
||||||
if addr, err := netip.ParseAddr(host); err == nil {
|
|
||||||
if addr.IsLoopback() || addr.IsUnspecified() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// hostFromAddress extracts the host part from a host:port address
|
|
||||||
func hostFromAddress(address string) string {
|
|
||||||
host, _, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// ipFromMultiaddr extracts an IP address from a multiaddr and returns (ip, isPublic)
|
|
||||||
func ipFromMultiaddr(addr multiaddr.Multiaddr) (string, bool) {
|
|
||||||
if addr == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
if v4, err := addr.ValueForProtocol(multiaddr.P_IP4); err == nil {
|
|
||||||
return v4, isPublicIP(v4)
|
|
||||||
}
|
|
||||||
if v6, err := addr.ValueForProtocol(multiaddr.P_IP6); err == nil {
|
|
||||||
return v6, isPublicIP(v6)
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPublicIP returns true if the IP is a public (non-private, non-loopback) address
|
|
||||||
func isPublicIP(ip string) bool {
|
|
||||||
addr, err := netip.ParseAddr(ip)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Exclude loopback, unspecified, link-local, multicast, and private addresses
|
|
||||||
if addr.IsLoopback() || addr.IsUnspecified() || addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() || addr.IsPrivate() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// shortPeerID returns a shortened version of a peer ID for logging
|
|
||||||
func shortPeerID(id peer.ID) string {
|
|
||||||
s := id.String()
|
|
||||||
if len(s) <= 8 {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:8] + "..."
|
|
||||||
}
|
|
||||||
|
|||||||
318
pkg/rqlite/cluster_discovery_membership.go
Normal file
318
pkg/rqlite/cluster_discovery_membership.go
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// collectPeerMetadata collects RQLite metadata from LibP2P peers
|
||||||
|
func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeMetadata {
|
||||||
|
connectedPeers := c.host.Network().Peers()
|
||||||
|
var metadata []*discovery.RQLiteNodeMetadata
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
currentRaftAddr := c.raftAddress
|
||||||
|
currentHTTPAddr := c.httpAddress
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
// Add ourselves
|
||||||
|
ourMetadata := &discovery.RQLiteNodeMetadata{
|
||||||
|
NodeID: currentRaftAddr, // RQLite uses raft address as node ID
|
||||||
|
RaftAddress: currentRaftAddr,
|
||||||
|
HTTPAddress: currentHTTPAddr,
|
||||||
|
NodeType: c.nodeType,
|
||||||
|
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
ClusterVersion: "1.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
|
||||||
|
c.logger.Debug("Adjusted self-advertised RQLite addresses",
|
||||||
|
zap.String("raft_address", ourMetadata.RaftAddress),
|
||||||
|
zap.String("http_address", ourMetadata.HTTPAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = append(metadata, ourMetadata)
|
||||||
|
|
||||||
|
staleNodeIDs := make([]string, 0)
|
||||||
|
|
||||||
|
for _, peerID := range connectedPeers {
|
||||||
|
if val, err := c.host.Peerstore().Get(peerID, "rqlite_metadata"); err == nil {
|
||||||
|
if jsonData, ok := val.([]byte); ok {
|
||||||
|
var peerMeta discovery.RQLiteNodeMetadata
|
||||||
|
if err := json.Unmarshal(jsonData, &peerMeta); err == nil {
|
||||||
|
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, &peerMeta); updated && stale != "" {
|
||||||
|
staleNodeIDs = append(staleNodeIDs, stale)
|
||||||
|
}
|
||||||
|
peerMeta.LastSeen = time.Now()
|
||||||
|
metadata = append(metadata, &peerMeta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(staleNodeIDs) > 0 {
|
||||||
|
c.mu.Lock()
|
||||||
|
for _, id := range staleNodeIDs {
|
||||||
|
delete(c.knownPeers, id)
|
||||||
|
delete(c.peerHealth, id)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
type membershipUpdateResult struct {
|
||||||
|
peersJSON []map[string]interface{}
|
||||||
|
added []string
|
||||||
|
updated []string
|
||||||
|
changed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) updateClusterMembership() {
|
||||||
|
metadata := c.collectPeerMetadata()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
result := c.computeMembershipChangesLocked(metadata)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if result.changed {
|
||||||
|
if len(result.added) > 0 || len(result.updated) > 0 {
|
||||||
|
c.logger.Info("Membership changed",
|
||||||
|
zap.Int("added", len(result.added)),
|
||||||
|
zap.Int("updated", len(result.updated)),
|
||||||
|
zap.Strings("added", result.added),
|
||||||
|
zap.Strings("updated", result.updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.writePeersJSONWithData(result.peersJSON); err != nil {
|
||||||
|
c.logger.Error("Failed to write peers.json",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("data_dir", c.dataDir),
|
||||||
|
zap.Int("peers", len(result.peersJSON)))
|
||||||
|
} else {
|
||||||
|
c.logger.Debug("peers.json updated",
|
||||||
|
zap.Int("peers", len(result.peersJSON)))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.lastUpdate = time.Now()
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) computeMembershipChangesLocked(metadata []*discovery.RQLiteNodeMetadata) membershipUpdateResult {
|
||||||
|
added := []string{}
|
||||||
|
updated := []string{}
|
||||||
|
|
||||||
|
for _, meta := range metadata {
|
||||||
|
isSelf := meta.NodeID == c.raftAddress
|
||||||
|
|
||||||
|
if existing, ok := c.knownPeers[meta.NodeID]; ok {
|
||||||
|
if existing.RaftLogIndex != meta.RaftLogIndex ||
|
||||||
|
existing.HTTPAddress != meta.HTTPAddress ||
|
||||||
|
existing.RaftAddress != meta.RaftAddress {
|
||||||
|
updated = append(updated, meta.NodeID)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
added = append(added, meta.NodeID)
|
||||||
|
c.logger.Info("Node added",
|
||||||
|
zap.String("node", meta.NodeID),
|
||||||
|
zap.String("raft", meta.RaftAddress),
|
||||||
|
zap.String("type", meta.NodeType),
|
||||||
|
zap.Uint64("log_index", meta.RaftLogIndex))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.knownPeers[meta.NodeID] = meta
|
||||||
|
|
||||||
|
if !isSelf {
|
||||||
|
if _, ok := c.peerHealth[meta.NodeID]; !ok {
|
||||||
|
c.peerHealth[meta.NodeID] = &PeerHealth{
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
LastSuccessful: time.Now(),
|
||||||
|
Status: "active",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.peerHealth[meta.NodeID].LastSeen = time.Now()
|
||||||
|
c.peerHealth[meta.NodeID].Status = "active"
|
||||||
|
c.peerHealth[meta.NodeID].FailureCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remotePeerCount := 0
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
if peer.NodeID != c.raftAddress {
|
||||||
|
remotePeerCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peers := c.getPeersJSONUnlocked()
|
||||||
|
shouldWrite := len(added) > 0 || len(updated) > 0 || c.lastUpdate.IsZero()
|
||||||
|
|
||||||
|
if shouldWrite {
|
||||||
|
if c.lastUpdate.IsZero() {
|
||||||
|
requiredRemotePeers := c.minClusterSize - 1
|
||||||
|
|
||||||
|
if remotePeerCount < requiredRemotePeers {
|
||||||
|
c.logger.Info("Waiting for peers",
|
||||||
|
zap.Int("have", remotePeerCount),
|
||||||
|
zap.Int("need", requiredRemotePeers),
|
||||||
|
zap.Int("min_size", c.minClusterSize))
|
||||||
|
return membershipUpdateResult{
|
||||||
|
changed: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peers) == 0 && c.lastUpdate.IsZero() {
|
||||||
|
c.logger.Info("No remote peers - waiting")
|
||||||
|
return membershipUpdateResult{
|
||||||
|
changed: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.lastUpdate.IsZero() {
|
||||||
|
c.logger.Info("Initial sync",
|
||||||
|
zap.Int("total", len(c.knownPeers)),
|
||||||
|
zap.Int("remote", remotePeerCount),
|
||||||
|
zap.Int("in_json", len(peers)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return membershipUpdateResult{
|
||||||
|
peersJSON: peers,
|
||||||
|
added: added,
|
||||||
|
updated: updated,
|
||||||
|
changed: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return membershipUpdateResult{
|
||||||
|
changed: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) removeInactivePeers() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
removed := []string{}
|
||||||
|
|
||||||
|
for nodeID, health := range c.peerHealth {
|
||||||
|
inactiveDuration := now.Sub(health.LastSeen)
|
||||||
|
|
||||||
|
if inactiveDuration > c.inactivityLimit {
|
||||||
|
c.logger.Warn("Node removed",
|
||||||
|
zap.String("node", nodeID),
|
||||||
|
zap.String("reason", "inactive"),
|
||||||
|
zap.Duration("inactive_duration", inactiveDuration))
|
||||||
|
|
||||||
|
delete(c.knownPeers, nodeID)
|
||||||
|
delete(c.peerHealth, nodeID)
|
||||||
|
removed = append(removed, nodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(removed) > 0 {
|
||||||
|
c.logger.Info("Removed inactive",
|
||||||
|
zap.Int("count", len(removed)),
|
||||||
|
zap.Strings("nodes", removed))
|
||||||
|
|
||||||
|
if err := c.writePeersJSON(); err != nil {
|
||||||
|
c.logger.Error("Failed to write peers.json after cleanup", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) getPeersJSON() []map[string]interface{} {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.getPeersJSONUnlocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{} {
|
||||||
|
peers := make([]map[string]interface{}, 0, len(c.knownPeers))
|
||||||
|
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
peerEntry := map[string]interface{}{
|
||||||
|
"id": peer.RaftAddress,
|
||||||
|
"address": peer.RaftAddress,
|
||||||
|
"non_voter": false,
|
||||||
|
}
|
||||||
|
peers = append(peers, peerEntry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) writePeersJSON() error {
|
||||||
|
c.mu.RLock()
|
||||||
|
peers := c.getPeersJSONUnlocked()
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
return c.writePeersJSONWithData(peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
|
||||||
|
dataDir := os.ExpandEnv(c.dataDir)
|
||||||
|
if strings.HasPrefix(dataDir, "~") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine home directory: %w", err)
|
||||||
|
}
|
||||||
|
dataDir = filepath.Join(home, dataDir[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peersFile := filepath.Join(rqliteDir, "peers.json")
|
||||||
|
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
|
||||||
|
|
||||||
|
if _, err := os.Stat(peersFile); err == nil {
|
||||||
|
data, err := os.ReadFile(peersFile)
|
||||||
|
if err == nil {
|
||||||
|
_ = os.WriteFile(backupFile, data, 0644)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(peers, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal peers.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFile := peersFile + ".tmp"
|
||||||
|
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(tempFile, peersFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeIDs := make([]string, 0, len(peers))
|
||||||
|
for _, p := range peers {
|
||||||
|
if id, ok := p["id"].(string); ok {
|
||||||
|
nodeIDs = append(nodeIDs, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("peers.json written",
|
||||||
|
zap.Int("peers", len(peers)),
|
||||||
|
zap.Strings("nodes", nodeIDs))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
251
pkg/rqlite/cluster_discovery_queries.go
Normal file
251
pkg/rqlite/cluster_discovery_queries.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetActivePeers returns a list of active peers (not including self)
|
||||||
|
func (c *ClusterDiscoveryService) GetActivePeers() []*discovery.RQLiteNodeMetadata {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
if peer.NodeID == c.raftAddress {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllPeers returns a list of all known peers (including self)
|
||||||
|
func (c *ClusterDiscoveryService) GetAllPeers() []*discovery.RQLiteNodeMetadata {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
peers := make([]*discovery.RQLiteNodeMetadata, 0, len(c.knownPeers))
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodeWithHighestLogIndex returns the node with the highest Raft log index
|
||||||
|
func (c *ClusterDiscoveryService) GetNodeWithHighestLogIndex() *discovery.RQLiteNodeMetadata {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
var highest *discovery.RQLiteNodeMetadata
|
||||||
|
var maxIndex uint64 = 0
|
||||||
|
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
if peer.NodeID == c.raftAddress {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.RaftLogIndex > maxIndex {
|
||||||
|
maxIndex = peer.RaftLogIndex
|
||||||
|
highest = peer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return highest
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRecentPeersJSON checks if peers.json was recently updated
|
||||||
|
func (c *ClusterDiscoveryService) HasRecentPeersJSON() bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
return time.Since(c.lastUpdate) < 5*time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindJoinTargets discovers join targets via LibP2P
|
||||||
|
func (c *ClusterDiscoveryService) FindJoinTargets() []string {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
targets := []string{}
|
||||||
|
|
||||||
|
type nodeWithIndex struct {
|
||||||
|
address string
|
||||||
|
logIndex uint64
|
||||||
|
}
|
||||||
|
var nodes []nodeWithIndex
|
||||||
|
for _, peer := range c.knownPeers {
|
||||||
|
nodes = append(nodes, nodeWithIndex{peer.RaftAddress, peer.RaftLogIndex})
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(nodes)-1; i++ {
|
||||||
|
for j := i + 1; j < len(nodes); j++ {
|
||||||
|
if nodes[j].logIndex > nodes[i].logIndex {
|
||||||
|
nodes[i], nodes[j] = nodes[j], nodes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, n := range nodes {
|
||||||
|
targets = append(targets, n.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
return targets
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForDiscoverySettling waits for LibP2P discovery to settle (used on concurrent startup)
|
||||||
|
func (c *ClusterDiscoveryService) WaitForDiscoverySettling(ctx context.Context) {
|
||||||
|
settleDuration := 60 * time.Second
|
||||||
|
c.logger.Info("Waiting for discovery to settle",
|
||||||
|
zap.Duration("duration", settleDuration))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(settleDuration):
|
||||||
|
}
|
||||||
|
|
||||||
|
c.updateClusterMembership()
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
peerCount := len(c.knownPeers)
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.logger.Info("Discovery settled",
|
||||||
|
zap.Int("peer_count", peerCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerSync manually triggers a cluster membership sync
|
||||||
|
func (c *ClusterDiscoveryService) TriggerSync() {
|
||||||
|
c.updateClusterMembership()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
|
||||||
|
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
|
||||||
|
c.logger.Info("Force writing peers.json")
|
||||||
|
|
||||||
|
metadata := c.collectPeerMetadata()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
for _, meta := range metadata {
|
||||||
|
c.knownPeers[meta.NodeID] = meta
|
||||||
|
if meta.NodeID != c.raftAddress {
|
||||||
|
if _, ok := c.peerHealth[meta.NodeID]; !ok {
|
||||||
|
c.peerHealth[meta.NodeID] = &PeerHealth{
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
LastSuccessful: time.Now(),
|
||||||
|
Status: "active",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.peerHealth[meta.NodeID].LastSeen = time.Now()
|
||||||
|
c.peerHealth[meta.NodeID].Status = "active"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
peers := c.getPeersJSONUnlocked()
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if err := c.writePeersJSONWithData(peers); err != nil {
|
||||||
|
c.logger.Error("Failed to force write peers.json",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("data_dir", c.dataDir),
|
||||||
|
zap.Int("peers", len(peers)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("peers.json written",
|
||||||
|
zap.Int("peers", len(peers)))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerPeerExchange actively exchanges peer information with connected peers
|
||||||
|
func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error {
|
||||||
|
if c.discoveryMgr == nil {
|
||||||
|
return fmt.Errorf("discovery manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := c.discoveryMgr.TriggerPeerExchange(ctx)
|
||||||
|
c.logger.Debug("Exchange completed", zap.Int("with_metadata", collected))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
|
||||||
|
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
||||||
|
c.mu.RLock()
|
||||||
|
currentRaftAddr := c.raftAddress
|
||||||
|
currentHTTPAddr := c.httpAddress
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
metadata := &discovery.RQLiteNodeMetadata{
|
||||||
|
NodeID: currentRaftAddr,
|
||||||
|
RaftAddress: currentRaftAddr,
|
||||||
|
HTTPAddress: currentHTTPAddr,
|
||||||
|
NodeType: c.nodeType,
|
||||||
|
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
ClusterVersion: "1.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.adjustSelfAdvertisedAddresses(metadata) {
|
||||||
|
c.logger.Debug("Adjusted self-advertised RQLite addresses in UpdateOwnMetadata",
|
||||||
|
zap.String("raft_address", metadata.RaftAddress),
|
||||||
|
zap.String("http_address", metadata.HTTPAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("Failed to marshal own metadata", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.host.Peerstore().Put(c.host.ID(), "rqlite_metadata", data); err != nil {
|
||||||
|
c.logger.Error("Failed to store own metadata", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debug("Metadata updated",
|
||||||
|
zap.String("node", metadata.NodeID),
|
||||||
|
zap.Uint64("log_index", metadata.RaftLogIndex))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreRemotePeerMetadata stores metadata received from a remote peer
|
||||||
|
func (c *ClusterDiscoveryService) StoreRemotePeerMetadata(peerID peer.ID, metadata *discovery.RQLiteNodeMetadata) error {
|
||||||
|
if metadata == nil {
|
||||||
|
return fmt.Errorf("metadata is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated, stale := c.adjustPeerAdvertisedAddresses(peerID, metadata); updated && stale != "" {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.knownPeers, stale)
|
||||||
|
delete(c.peerHealth, stale)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata.LastSeen = time.Now()
|
||||||
|
|
||||||
|
data, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.host.Peerstore().Put(peerID, "rqlite_metadata", data); err != nil {
|
||||||
|
return fmt.Errorf("failed to store metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debug("Metadata stored",
|
||||||
|
zap.String("peer", shortPeerID(peerID)),
|
||||||
|
zap.String("node", metadata.NodeID))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
97
pkg/rqlite/cluster_discovery_test.go
Normal file
97
pkg/rqlite/cluster_discovery_test.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldReplaceHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"", true},
|
||||||
|
{"localhost", true},
|
||||||
|
{"127.0.0.1", true},
|
||||||
|
{"::1", true},
|
||||||
|
{"0.0.0.0", true},
|
||||||
|
{"1.1.1.1", false},
|
||||||
|
{"8.8.8.8", false},
|
||||||
|
{"example.com", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := shouldReplaceHost(tt.host); got != tt.expected {
|
||||||
|
t.Errorf("shouldReplaceHost(%s) = %v; want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPublicIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", false},
|
||||||
|
{"192.168.1.1", false},
|
||||||
|
{"10.0.0.1", false},
|
||||||
|
{"172.16.0.1", false},
|
||||||
|
{"1.1.1.1", true},
|
||||||
|
{"8.8.8.8", true},
|
||||||
|
{"2001:4860:4860::8888", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := isPublicIP(tt.ip); got != tt.expected {
|
||||||
|
t.Errorf("isPublicIP(%s) = %v; want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceAddressHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
address string
|
||||||
|
newHost string
|
||||||
|
expected string
|
||||||
|
replaced bool
|
||||||
|
}{
|
||||||
|
{"localhost:4001", "1.1.1.1", "1.1.1.1:4001", true},
|
||||||
|
{"127.0.0.1:4001", "1.1.1.1", "1.1.1.1:4001", true},
|
||||||
|
{"8.8.8.8:4001", "1.1.1.1", "8.8.8.8:4001", false}, // Don't replace public IP
|
||||||
|
{"invalid", "1.1.1.1", "invalid", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, replaced := replaceAddressHost(tt.address, tt.newHost)
|
||||||
|
if got != tt.expected || replaced != tt.replaced {
|
||||||
|
t.Errorf("replaceAddressHost(%s, %s) = %s, %v; want %s, %v", tt.address, tt.newHost, got, replaced, tt.expected, tt.replaced)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteAdvertisedAddresses(t *testing.T) {
|
||||||
|
meta := &discovery.RQLiteNodeMetadata{
|
||||||
|
NodeID: "localhost:4001",
|
||||||
|
RaftAddress: "localhost:4001",
|
||||||
|
HTTPAddress: "localhost:4002",
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, originalNodeID := rewriteAdvertisedAddresses(meta, "1.1.1.1", true)
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
t.Error("expected changed to be true")
|
||||||
|
}
|
||||||
|
if originalNodeID != "localhost:4001" {
|
||||||
|
t.Errorf("expected originalNodeID localhost:4001, got %s", originalNodeID)
|
||||||
|
}
|
||||||
|
if meta.RaftAddress != "1.1.1.1:4001" {
|
||||||
|
t.Errorf("expected RaftAddress 1.1.1.1:4001, got %s", meta.RaftAddress)
|
||||||
|
}
|
||||||
|
if meta.HTTPAddress != "1.1.1.1:4002" {
|
||||||
|
t.Errorf("expected HTTPAddress 1.1.1.1:4002, got %s", meta.HTTPAddress)
|
||||||
|
}
|
||||||
|
if meta.NodeID != "1.1.1.1:4001" {
|
||||||
|
t.Errorf("expected NodeID 1.1.1.1:4001, got %s", meta.NodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
233
pkg/rqlite/cluster_discovery_utils.go
Normal file
233
pkg/rqlite/cluster_discovery_utils.go
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// adjustPeerAdvertisedAddresses adjusts peer metadata addresses
|
||||||
|
func (c *ClusterDiscoveryService) adjustPeerAdvertisedAddresses(peerID peer.ID, meta *discovery.RQLiteNodeMetadata) (bool, string) {
|
||||||
|
ip := c.selectPeerIP(peerID)
|
||||||
|
if ip == "" {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, stale := rewriteAdvertisedAddresses(meta, ip, true)
|
||||||
|
if changed {
|
||||||
|
c.logger.Debug("Addresses normalized",
|
||||||
|
zap.String("peer", shortPeerID(peerID)),
|
||||||
|
zap.String("raft", meta.RaftAddress),
|
||||||
|
zap.String("http_address", meta.HTTPAddress))
|
||||||
|
}
|
||||||
|
return changed, stale
|
||||||
|
}
|
||||||
|
|
||||||
|
// adjustSelfAdvertisedAddresses adjusts our own metadata addresses
|
||||||
|
func (c *ClusterDiscoveryService) adjustSelfAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata) bool {
|
||||||
|
ip := c.selectSelfIP()
|
||||||
|
if ip == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, _ := rewriteAdvertisedAddresses(meta, ip, true)
|
||||||
|
if !changed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.raftAddress = meta.RaftAddress
|
||||||
|
c.httpAddress = meta.HTTPAddress
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.rqliteManager != nil {
|
||||||
|
c.rqliteManager.UpdateAdvertisedAddresses(meta.RaftAddress, meta.HTTPAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectPeerIP selects the best IP address for a peer
|
||||||
|
func (c *ClusterDiscoveryService) selectPeerIP(peerID peer.ID) string {
|
||||||
|
var fallback string
|
||||||
|
|
||||||
|
for _, conn := range c.host.Network().ConnsToPeer(peerID) {
|
||||||
|
if ip, public := ipFromMultiaddr(conn.RemoteMultiaddr()); ip != "" {
|
||||||
|
if shouldReplaceHost(ip) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if public {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if fallback == "" {
|
||||||
|
fallback = ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range c.host.Peerstore().Addrs(peerID) {
|
||||||
|
if ip, public := ipFromMultiaddr(addr); ip != "" {
|
||||||
|
if shouldReplaceHost(ip) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if public {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if fallback == "" {
|
||||||
|
fallback = ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectSelfIP selects the best IP address for ourselves
|
||||||
|
func (c *ClusterDiscoveryService) selectSelfIP() string {
|
||||||
|
var fallback string
|
||||||
|
|
||||||
|
for _, addr := range c.host.Addrs() {
|
||||||
|
if ip, public := ipFromMultiaddr(addr); ip != "" {
|
||||||
|
if shouldReplaceHost(ip) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if public {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if fallback == "" {
|
||||||
|
fallback = ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteAdvertisedAddresses rewrites RaftAddress and HTTPAddress in metadata
|
||||||
|
func rewriteAdvertisedAddresses(meta *discovery.RQLiteNodeMetadata, newHost string, allowNodeIDRewrite bool) (bool, string) {
|
||||||
|
if meta == nil || newHost == "" {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
originalNodeID := meta.NodeID
|
||||||
|
changed := false
|
||||||
|
nodeIDChanged := false
|
||||||
|
|
||||||
|
if newAddr, replaced := replaceAddressHost(meta.RaftAddress, newHost); replaced {
|
||||||
|
if meta.RaftAddress != newAddr {
|
||||||
|
meta.RaftAddress = newAddr
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newAddr, replaced := replaceAddressHost(meta.HTTPAddress, newHost); replaced {
|
||||||
|
if meta.HTTPAddress != newAddr {
|
||||||
|
meta.HTTPAddress = newAddr
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowNodeIDRewrite {
|
||||||
|
if meta.RaftAddress != "" && (meta.NodeID == "" || meta.NodeID == originalNodeID || shouldReplaceHost(hostFromAddress(meta.NodeID))) {
|
||||||
|
if meta.NodeID != meta.RaftAddress {
|
||||||
|
meta.NodeID = meta.RaftAddress
|
||||||
|
nodeIDChanged = meta.NodeID != originalNodeID
|
||||||
|
if nodeIDChanged {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nodeIDChanged {
|
||||||
|
return changed, originalNodeID
|
||||||
|
}
|
||||||
|
return changed, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceAddressHost replaces the host part of an address
|
||||||
|
func replaceAddressHost(address, newHost string) (string, bool) {
|
||||||
|
if address == "" || newHost == "" {
|
||||||
|
return address, false
|
||||||
|
}
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return address, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldReplaceHost(host) {
|
||||||
|
return address, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.JoinHostPort(newHost, port), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldReplaceHost returns true if the host should be replaced
|
||||||
|
func shouldReplaceHost(host string) bool {
|
||||||
|
if host == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.EqualFold(host, "localhost") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr, err := netip.ParseAddr(host); err == nil {
|
||||||
|
if addr.IsLoopback() || addr.IsUnspecified() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostFromAddress extracts the host part from a host:port address
|
||||||
|
func hostFromAddress(address string) string {
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipFromMultiaddr extracts an IP address from a multiaddr and returns (ip, isPublic)
|
||||||
|
func ipFromMultiaddr(addr multiaddr.Multiaddr) (string, bool) {
|
||||||
|
if addr == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if v4, err := addr.ValueForProtocol(multiaddr.P_IP4); err == nil {
|
||||||
|
return v4, isPublicIP(v4)
|
||||||
|
}
|
||||||
|
if v6, err := addr.ValueForProtocol(multiaddr.P_IP6); err == nil {
|
||||||
|
return v6, isPublicIP(v6)
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPublicIP returns true if the IP is a public address
|
||||||
|
func isPublicIP(ip string) bool {
|
||||||
|
addr, err := netip.ParseAddr(ip)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if addr.IsLoopback() || addr.IsUnspecified() || addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() || addr.IsPrivate() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// shortPeerID returns a shortened version of a peer ID
|
||||||
|
func shortPeerID(id peer.ID) string {
|
||||||
|
s := id.String()
|
||||||
|
if len(s) <= 8 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:8] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
61
pkg/rqlite/discovery_manager.go
Normal file
61
pkg/rqlite/discovery_manager.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetDiscoveryService sets the cluster discovery service
|
||||||
|
func (r *RQLiteManager) SetDiscoveryService(service *ClusterDiscoveryService) {
|
||||||
|
r.discoveryService = service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNodeType sets the node type
|
||||||
|
func (r *RQLiteManager) SetNodeType(nodeType string) {
|
||||||
|
if nodeType != "" {
|
||||||
|
r.nodeType = nodeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAdvertisedAddresses overrides advertised addresses
|
||||||
|
func (r *RQLiteManager) UpdateAdvertisedAddresses(raftAddr, httpAddr string) {
|
||||||
|
if r == nil || r.discoverConfig == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if raftAddr != "" && r.discoverConfig.RaftAdvAddress != raftAddr {
|
||||||
|
r.discoverConfig.RaftAdvAddress = raftAddr
|
||||||
|
}
|
||||||
|
if httpAddr != "" && r.discoverConfig.HttpAdvAddress != httpAddr {
|
||||||
|
r.discoverConfig.HttpAdvAddress = httpAddr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) validateNodeID() error {
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
nodes, err := r.getRQLiteNodes()
|
||||||
|
if err != nil {
|
||||||
|
if i < 4 {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedID := r.discoverConfig.RaftAdvAddress
|
||||||
|
if expectedID == "" || len(nodes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
if node.Address == expectedID {
|
||||||
|
if node.ID != expectedID {
|
||||||
|
return fmt.Errorf("node ID mismatch: %s != %s", expectedID, node.ID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
@ -570,9 +570,13 @@ func (g *HTTPGateway) handleDropTable(w http.ResponseWriter, r *http.Request) {
|
|||||||
ctx, cancel := g.withTimeout(r.Context())
|
ctx, cancel := g.withTimeout(r.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
stmt := "DROP TABLE IF EXISTS " + tbl
|
stmt := "DROP TABLE " + tbl
|
||||||
if _, err := g.Client.Exec(ctx, stmt); err != nil {
|
if _, err := g.Client.Exec(ctx, stmt); err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error())
|
if strings.Contains(err.Error(), "no such table") {
|
||||||
|
writeError(w, http.StatusNotFound, err.Error())
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||||
|
|||||||
239
pkg/rqlite/process.go
Normal file
239
pkg/rqlite/process.go
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
||||||
|
"github.com/rqlite/gorqlite"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// launchProcess starts the RQLite process with appropriate arguments
|
||||||
|
func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error {
|
||||||
|
// Build RQLite command
|
||||||
|
args := []string{
|
||||||
|
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
|
||||||
|
"-http-adv-addr", r.discoverConfig.HttpAdvAddress,
|
||||||
|
"-raft-adv-addr", r.discoverConfig.RaftAdvAddress,
|
||||||
|
"-raft-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLiteRaftPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.NodeCert != "" && r.config.NodeKey != "" {
|
||||||
|
r.logger.Info("Enabling node-to-node TLS encryption",
|
||||||
|
zap.String("node_cert", r.config.NodeCert),
|
||||||
|
zap.String("node_key", r.config.NodeKey))
|
||||||
|
|
||||||
|
args = append(args, "-node-cert", r.config.NodeCert)
|
||||||
|
args = append(args, "-node-key", r.config.NodeKey)
|
||||||
|
|
||||||
|
if r.config.NodeCACert != "" {
|
||||||
|
args = append(args, "-node-ca-cert", r.config.NodeCACert)
|
||||||
|
}
|
||||||
|
if r.config.NodeNoVerify {
|
||||||
|
args = append(args, "-node-no-verify")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.RQLiteJoinAddress != "" {
|
||||||
|
r.logger.Info("Joining RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress))
|
||||||
|
|
||||||
|
joinArg := r.config.RQLiteJoinAddress
|
||||||
|
if strings.HasPrefix(joinArg, "http://") {
|
||||||
|
joinArg = strings.TrimPrefix(joinArg, "http://")
|
||||||
|
} else if strings.HasPrefix(joinArg, "https://") {
|
||||||
|
joinArg = strings.TrimPrefix(joinArg, "https://")
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTimeout := 5 * time.Minute
|
||||||
|
if err := r.waitForJoinTarget(ctx, r.config.RQLiteJoinAddress, joinTimeout); err != nil {
|
||||||
|
r.logger.Warn("Join target did not become reachable within timeout; will still attempt to join",
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, "-join", joinArg, "-join-as", r.discoverConfig.RaftAdvAddress, "-join-attempts", "30", "-join-interval", "10s")
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, rqliteDataDir)
|
||||||
|
|
||||||
|
r.cmd = exec.Command("rqlited", args...)
|
||||||
|
|
||||||
|
nodeType := r.nodeType
|
||||||
|
if nodeType == "" {
|
||||||
|
nodeType = "node"
|
||||||
|
}
|
||||||
|
|
||||||
|
logsDir := filepath.Join(filepath.Dir(r.dataDir), "logs")
|
||||||
|
_ = os.MkdirAll(logsDir, 0755)
|
||||||
|
|
||||||
|
logPath := filepath.Join(logsDir, fmt.Sprintf("rqlite-%s.log", nodeType))
|
||||||
|
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.cmd.Stdout = logFile
|
||||||
|
r.cmd.Stderr = logFile
|
||||||
|
|
||||||
|
if err := r.cmd.Start(); err != nil {
|
||||||
|
logFile.Close()
|
||||||
|
return fmt.Errorf("failed to start RQLite: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logFile.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForReadyAndConnect waits for RQLite to be ready and establishes connection
|
||||||
|
func (r *RQLiteManager) waitForReadyAndConnect(ctx context.Context) error {
|
||||||
|
if err := r.waitForReady(ctx); err != nil {
|
||||||
|
if r.cmd != nil && r.cmd.Process != nil {
|
||||||
|
_ = r.cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var conn *gorqlite.Connection
|
||||||
|
var err error
|
||||||
|
maxConnectAttempts := 10
|
||||||
|
connectBackoff := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxConnectAttempts; attempt++ {
|
||||||
|
conn, err = gorqlite.Open(fmt.Sprintf("http://localhost:%d", r.config.RQLitePort))
|
||||||
|
if err == nil {
|
||||||
|
r.connection = conn
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(err.Error(), "store is not open") {
|
||||||
|
time.Sleep(connectBackoff)
|
||||||
|
connectBackoff = time.Duration(float64(connectBackoff) * 1.5)
|
||||||
|
if connectBackoff > 5*time.Second {
|
||||||
|
connectBackoff = 5 * time.Second
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.cmd != nil && r.cmd.Process != nil {
|
||||||
|
_ = r.cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to connect to RQLite: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("failed to connect to RQLite after max attempts")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = r.validateNodeID()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForReady waits for RQLite to be ready to accept connections
|
||||||
|
func (r *RQLiteManager) waitForReady(ctx context.Context) error {
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/status", r.config.RQLitePort)
|
||||||
|
client := tlsutil.NewHTTPClient(2 * time.Second)
|
||||||
|
|
||||||
|
for i := 0; i < 180; i++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
if err == nil && resp.StatusCode == http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
var statusResp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &statusResp); err == nil {
|
||||||
|
if raft, ok := statusResp["raft"].(map[string]interface{}); ok {
|
||||||
|
state, _ := raft["state"].(string)
|
||||||
|
if state == "leader" || state == "follower" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil // Backwards compatibility
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("RQLite did not become ready within timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForSQLAvailable waits until a simple query succeeds
|
||||||
|
func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
if r.connection == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, err := r.connection.QueryOne("SELECT 1")
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testJoinAddress tests if a join address is reachable
|
||||||
|
func (r *RQLiteManager) testJoinAddress(joinAddress string) error {
|
||||||
|
client := tlsutil.NewHTTPClient(5 * time.Second)
|
||||||
|
var statusURL string
|
||||||
|
if strings.HasPrefix(joinAddress, "http://") || strings.HasPrefix(joinAddress, "https://") {
|
||||||
|
statusURL = strings.TrimRight(joinAddress, "/") + "/status"
|
||||||
|
} else {
|
||||||
|
host := joinAddress
|
||||||
|
if idx := strings.Index(joinAddress, ":"); idx != -1 {
|
||||||
|
host = joinAddress[:idx]
|
||||||
|
}
|
||||||
|
statusURL = fmt.Sprintf("http://%s:%d/status", host, 5001)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Get(statusURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("leader returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForJoinTarget waits until the join target's HTTP status becomes reachable
|
||||||
|
func (r *RQLiteManager) waitForJoinTarget(ctx context.Context, joinAddress string, timeout time.Duration) error {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if err := r.testJoinAddress(joinAddress); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
1241
pkg/rqlite/rqlite.go
1241
pkg/rqlite/rqlite.go
File diff suppressed because it is too large
Load Diff
58
pkg/rqlite/util.go
Normal file
58
pkg/rqlite/util.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *RQLiteManager) rqliteDataDirPath() (string, error) {
|
||||||
|
dataDir := os.ExpandEnv(r.dataDir)
|
||||||
|
if strings.HasPrefix(dataDir, "~") {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
dataDir = filepath.Join(home, dataDir[1:])
|
||||||
|
}
|
||||||
|
return filepath.Join(dataDir, "rqlite"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) resolveMigrationsDir() (string, error) {
|
||||||
|
productionPath := "/home/debros/src/migrations"
|
||||||
|
if _, err := os.Stat(productionPath); err == nil {
|
||||||
|
return productionPath, nil
|
||||||
|
}
|
||||||
|
return "migrations", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) prepareDataDir() (string, error) {
|
||||||
|
rqliteDataDir, err := r.rqliteDataDirPath()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(rqliteDataDir, 0755); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return rqliteDataDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) hasExistingState(rqliteDataDir string) bool {
|
||||||
|
entries, err := os.ReadDir(rqliteDataDir)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.Name() != "." && e.Name() != ".." {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RQLiteManager) exponentialBackoff(attempt int, baseDelay time.Duration, maxDelay time.Duration) time.Duration {
|
||||||
|
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||||
|
if delay > maxDelay {
|
||||||
|
delay = maxDelay
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
89
pkg/rqlite/util_test.go
Normal file
89
pkg/rqlite/util_test.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExponentialBackoff(t *testing.T) {
|
||||||
|
r := &RQLiteManager{}
|
||||||
|
baseDelay := 100 * time.Millisecond
|
||||||
|
maxDelay := 1 * time.Second
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
attempt int
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{0, 100 * time.Millisecond},
|
||||||
|
{1, 200 * time.Millisecond},
|
||||||
|
{2, 400 * time.Millisecond},
|
||||||
|
{3, 800 * time.Millisecond},
|
||||||
|
{4, 1000 * time.Millisecond}, // Maxed out
|
||||||
|
{10, 1000 * time.Millisecond}, // Maxed out
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := r.exponentialBackoff(tt.attempt, baseDelay, maxDelay)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("exponentialBackoff(%d) = %v; want %v", tt.attempt, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRQLiteDataDirPath(t *testing.T) {
|
||||||
|
// Test with explicit path
|
||||||
|
r := &RQLiteManager{dataDir: "/tmp/data"}
|
||||||
|
got, _ := r.rqliteDataDirPath()
|
||||||
|
expected := filepath.Join("/tmp/data", "rqlite")
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("rqliteDataDirPath() = %s; want %s", got, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with environment variable expansion
|
||||||
|
os.Setenv("TEST_DATA_DIR", "/tmp/env-data")
|
||||||
|
defer os.Unsetenv("TEST_DATA_DIR")
|
||||||
|
r = &RQLiteManager{dataDir: "$TEST_DATA_DIR"}
|
||||||
|
got, _ = r.rqliteDataDirPath()
|
||||||
|
expected = filepath.Join("/tmp/env-data", "rqlite")
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("rqliteDataDirPath() with env = %s; want %s", got, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with home directory expansion
|
||||||
|
r = &RQLiteManager{dataDir: "~/data"}
|
||||||
|
got, _ = r.rqliteDataDirPath()
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
expected = filepath.Join(home, "data", "rqlite")
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("rqliteDataDirPath() with ~ = %s; want %s", got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasExistingState(t *testing.T) {
|
||||||
|
r := &RQLiteManager{}
|
||||||
|
|
||||||
|
// Create a temp directory for testing
|
||||||
|
tmpDir, err := os.MkdirTemp("", "rqlite-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Test empty directory
|
||||||
|
if r.hasExistingState(tmpDir) {
|
||||||
|
t.Errorf("hasExistingState() = true; want false for empty dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test directory with a file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.txt")
|
||||||
|
if err := os.WriteFile(testFile, []byte("data"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to create test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.hasExistingState(tmpDir) {
|
||||||
|
t.Errorf("hasExistingState() = false; want true for non-empty dir")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
187
pkg/serverless/config.go
Normal file
187
pkg/serverless/config.go
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds configuration for the serverless engine.
|
||||||
|
type Config struct {
|
||||||
|
// Memory limits
|
||||||
|
DefaultMemoryLimitMB int `yaml:"default_memory_limit_mb"`
|
||||||
|
MaxMemoryLimitMB int `yaml:"max_memory_limit_mb"`
|
||||||
|
|
||||||
|
// Execution limits
|
||||||
|
DefaultTimeoutSeconds int `yaml:"default_timeout_seconds"`
|
||||||
|
MaxTimeoutSeconds int `yaml:"max_timeout_seconds"`
|
||||||
|
|
||||||
|
// Retry configuration
|
||||||
|
DefaultRetryCount int `yaml:"default_retry_count"`
|
||||||
|
MaxRetryCount int `yaml:"max_retry_count"`
|
||||||
|
DefaultRetryDelaySeconds int `yaml:"default_retry_delay_seconds"`
|
||||||
|
|
||||||
|
// Rate limiting (global)
|
||||||
|
GlobalRateLimitPerMinute int `yaml:"global_rate_limit_per_minute"`
|
||||||
|
|
||||||
|
// Background job configuration
|
||||||
|
JobWorkers int `yaml:"job_workers"`
|
||||||
|
JobPollInterval time.Duration `yaml:"job_poll_interval"`
|
||||||
|
JobMaxQueueSize int `yaml:"job_max_queue_size"`
|
||||||
|
JobMaxPayloadSize int `yaml:"job_max_payload_size"` // bytes
|
||||||
|
|
||||||
|
// Scheduler configuration
|
||||||
|
CronPollInterval time.Duration `yaml:"cron_poll_interval"`
|
||||||
|
TimerPollInterval time.Duration `yaml:"timer_poll_interval"`
|
||||||
|
DBPollInterval time.Duration `yaml:"db_poll_interval"`
|
||||||
|
|
||||||
|
// WASM compilation cache
|
||||||
|
ModuleCacheSize int `yaml:"module_cache_size"` // Number of compiled modules to cache
|
||||||
|
EnablePrewarm bool `yaml:"enable_prewarm"` // Pre-compile frequently used functions
|
||||||
|
|
||||||
|
// Secrets encryption
|
||||||
|
SecretsEncryptionKey string `yaml:"secrets_encryption_key"` // AES-256 key (32 bytes, hex-encoded)
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogInvocations bool `yaml:"log_invocations"` // Log all invocations to database
|
||||||
|
LogRetention int `yaml:"log_retention"` // Days to retain logs
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a configuration with sensible defaults.
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
// Memory limits
|
||||||
|
DefaultMemoryLimitMB: 64,
|
||||||
|
MaxMemoryLimitMB: 256,
|
||||||
|
|
||||||
|
// Execution limits
|
||||||
|
DefaultTimeoutSeconds: 30,
|
||||||
|
MaxTimeoutSeconds: 300, // 5 minutes max
|
||||||
|
|
||||||
|
// Retry configuration
|
||||||
|
DefaultRetryCount: 0,
|
||||||
|
MaxRetryCount: 5,
|
||||||
|
DefaultRetryDelaySeconds: 5,
|
||||||
|
|
||||||
|
// Rate limiting
|
||||||
|
GlobalRateLimitPerMinute: 10000, // 10k requests/minute globally
|
||||||
|
|
||||||
|
// Background jobs
|
||||||
|
JobWorkers: 4,
|
||||||
|
JobPollInterval: time.Second,
|
||||||
|
JobMaxQueueSize: 10000,
|
||||||
|
JobMaxPayloadSize: 1024 * 1024, // 1MB
|
||||||
|
|
||||||
|
// Scheduler
|
||||||
|
CronPollInterval: time.Minute,
|
||||||
|
TimerPollInterval: time.Second,
|
||||||
|
DBPollInterval: time.Second * 5,
|
||||||
|
|
||||||
|
// WASM cache
|
||||||
|
ModuleCacheSize: 100,
|
||||||
|
EnablePrewarm: true,
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogInvocations: true,
|
||||||
|
LogRetention: 7, // 7 days
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks the configuration for errors.
|
||||||
|
func (c *Config) Validate() []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if c.DefaultMemoryLimitMB <= 0 {
|
||||||
|
errs = append(errs, &ConfigError{Field: "DefaultMemoryLimitMB", Message: "must be positive"})
|
||||||
|
}
|
||||||
|
if c.MaxMemoryLimitMB < c.DefaultMemoryLimitMB {
|
||||||
|
errs = append(errs, &ConfigError{Field: "MaxMemoryLimitMB", Message: "must be >= DefaultMemoryLimitMB"})
|
||||||
|
}
|
||||||
|
if c.DefaultTimeoutSeconds <= 0 {
|
||||||
|
errs = append(errs, &ConfigError{Field: "DefaultTimeoutSeconds", Message: "must be positive"})
|
||||||
|
}
|
||||||
|
if c.MaxTimeoutSeconds < c.DefaultTimeoutSeconds {
|
||||||
|
errs = append(errs, &ConfigError{Field: "MaxTimeoutSeconds", Message: "must be >= DefaultTimeoutSeconds"})
|
||||||
|
}
|
||||||
|
if c.GlobalRateLimitPerMinute <= 0 {
|
||||||
|
errs = append(errs, &ConfigError{Field: "GlobalRateLimitPerMinute", Message: "must be positive"})
|
||||||
|
}
|
||||||
|
if c.JobWorkers <= 0 {
|
||||||
|
errs = append(errs, &ConfigError{Field: "JobWorkers", Message: "must be positive"})
|
||||||
|
}
|
||||||
|
if c.ModuleCacheSize <= 0 {
|
||||||
|
errs = append(errs, &ConfigError{Field: "ModuleCacheSize", Message: "must be positive"})
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults fills in zero values with defaults.
|
||||||
|
func (c *Config) ApplyDefaults() {
|
||||||
|
defaults := DefaultConfig()
|
||||||
|
|
||||||
|
if c.DefaultMemoryLimitMB == 0 {
|
||||||
|
c.DefaultMemoryLimitMB = defaults.DefaultMemoryLimitMB
|
||||||
|
}
|
||||||
|
if c.MaxMemoryLimitMB == 0 {
|
||||||
|
c.MaxMemoryLimitMB = defaults.MaxMemoryLimitMB
|
||||||
|
}
|
||||||
|
if c.DefaultTimeoutSeconds == 0 {
|
||||||
|
c.DefaultTimeoutSeconds = defaults.DefaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
if c.MaxTimeoutSeconds == 0 {
|
||||||
|
c.MaxTimeoutSeconds = defaults.MaxTimeoutSeconds
|
||||||
|
}
|
||||||
|
if c.GlobalRateLimitPerMinute == 0 {
|
||||||
|
c.GlobalRateLimitPerMinute = defaults.GlobalRateLimitPerMinute
|
||||||
|
}
|
||||||
|
if c.JobWorkers == 0 {
|
||||||
|
c.JobWorkers = defaults.JobWorkers
|
||||||
|
}
|
||||||
|
if c.JobPollInterval == 0 {
|
||||||
|
c.JobPollInterval = defaults.JobPollInterval
|
||||||
|
}
|
||||||
|
if c.JobMaxQueueSize == 0 {
|
||||||
|
c.JobMaxQueueSize = defaults.JobMaxQueueSize
|
||||||
|
}
|
||||||
|
if c.JobMaxPayloadSize == 0 {
|
||||||
|
c.JobMaxPayloadSize = defaults.JobMaxPayloadSize
|
||||||
|
}
|
||||||
|
if c.CronPollInterval == 0 {
|
||||||
|
c.CronPollInterval = defaults.CronPollInterval
|
||||||
|
}
|
||||||
|
if c.TimerPollInterval == 0 {
|
||||||
|
c.TimerPollInterval = defaults.TimerPollInterval
|
||||||
|
}
|
||||||
|
if c.DBPollInterval == 0 {
|
||||||
|
c.DBPollInterval = defaults.DBPollInterval
|
||||||
|
}
|
||||||
|
if c.ModuleCacheSize == 0 {
|
||||||
|
c.ModuleCacheSize = defaults.ModuleCacheSize
|
||||||
|
}
|
||||||
|
if c.LogRetention == 0 {
|
||||||
|
c.LogRetention = defaults.LogRetention
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMemoryLimit returns a copy with the memory limit set.
|
||||||
|
func (c *Config) WithMemoryLimit(defaultMB, maxMB int) *Config {
|
||||||
|
copy := *c
|
||||||
|
copy.DefaultMemoryLimitMB = defaultMB
|
||||||
|
copy.MaxMemoryLimitMB = maxMB
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout returns a copy with the timeout set.
|
||||||
|
func (c *Config) WithTimeout(defaultSec, maxSec int) *Config {
|
||||||
|
copy := *c
|
||||||
|
copy.DefaultTimeoutSeconds = defaultSec
|
||||||
|
copy.MaxTimeoutSeconds = maxSec
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRateLimit returns a copy with the rate limit set.
|
||||||
|
func (c *Config) WithRateLimit(perMinute int) *Config {
|
||||||
|
copy := *c
|
||||||
|
copy.GlobalRateLimitPerMinute = perMinute
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
736
pkg/serverless/engine.go
Normal file
736
pkg/serverless/engine.go
Normal file
@ -0,0 +1,736 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tetratelabs/wazero"
|
||||||
|
"github.com/tetratelabs/wazero/api"
|
||||||
|
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextAwareHostServices is an internal interface for services that need to know about
|
||||||
|
// the current invocation context.
|
||||||
|
type contextAwareHostServices interface {
|
||||||
|
SetInvocationContext(invCtx *InvocationContext)
|
||||||
|
ClearContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure Engine implements FunctionExecutor interface.
|
||||||
|
var _ FunctionExecutor = (*Engine)(nil)
|
||||||
|
|
||||||
|
// Engine is the core WASM execution engine using wazero.
|
||||||
|
// It manages compiled module caching and function execution.
|
||||||
|
type Engine struct {
|
||||||
|
runtime wazero.Runtime
|
||||||
|
config *Config
|
||||||
|
registry FunctionRegistry
|
||||||
|
hostServices HostServices
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// Module cache: wasmCID -> compiled module
|
||||||
|
moduleCache map[string]wazero.CompiledModule
|
||||||
|
moduleCacheMu sync.RWMutex
|
||||||
|
|
||||||
|
// Invocation logger for metrics/debugging
|
||||||
|
invocationLogger InvocationLogger
|
||||||
|
|
||||||
|
// Rate limiter
|
||||||
|
rateLimiter RateLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvocationLogger logs function invocations (optional).
|
||||||
|
type InvocationLogger interface {
|
||||||
|
Log(ctx context.Context, inv *InvocationRecord) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvocationRecord represents a logged invocation.
|
||||||
|
type InvocationRecord struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
TriggerType TriggerType `json:"trigger_type"`
|
||||||
|
CallerWallet string `json:"caller_wallet,omitempty"`
|
||||||
|
InputSize int `json:"input_size"`
|
||||||
|
OutputSize int `json:"output_size"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
DurationMS int64 `json:"duration_ms"`
|
||||||
|
Status InvocationStatus `json:"status"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
MemoryUsedMB float64 `json:"memory_used_mb"`
|
||||||
|
Logs []LogEntry `json:"logs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiter checks if a request should be rate limited.
|
||||||
|
type RateLimiter interface {
|
||||||
|
Allow(ctx context.Context, key string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EngineOption configures the Engine.
|
||||||
|
type EngineOption func(*Engine)
|
||||||
|
|
||||||
|
// WithInvocationLogger sets the invocation logger.
|
||||||
|
func WithInvocationLogger(logger InvocationLogger) EngineOption {
|
||||||
|
return func(e *Engine) {
|
||||||
|
e.invocationLogger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRateLimiter sets the rate limiter.
|
||||||
|
func WithRateLimiter(limiter RateLimiter) EngineOption {
|
||||||
|
return func(e *Engine) {
|
||||||
|
e.rateLimiter = limiter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEngine creates a new WASM execution engine.
|
||||||
|
func NewEngine(cfg *Config, registry FunctionRegistry, hostServices HostServices, logger *zap.Logger, opts ...EngineOption) (*Engine, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = DefaultConfig()
|
||||||
|
}
|
||||||
|
cfg.ApplyDefaults()
|
||||||
|
|
||||||
|
// Create wazero runtime with compilation cache
|
||||||
|
runtimeConfig := wazero.NewRuntimeConfig().
|
||||||
|
WithCloseOnContextDone(true)
|
||||||
|
|
||||||
|
runtime := wazero.NewRuntimeWithConfig(context.Background(), runtimeConfig)
|
||||||
|
|
||||||
|
// Instantiate WASI - required for WASM modules compiled with TinyGo targeting WASI
|
||||||
|
wasi_snapshot_preview1.MustInstantiate(context.Background(), runtime)
|
||||||
|
|
||||||
|
engine := &Engine{
|
||||||
|
runtime: runtime,
|
||||||
|
config: cfg,
|
||||||
|
registry: registry,
|
||||||
|
hostServices: hostServices,
|
||||||
|
logger: logger,
|
||||||
|
moduleCache: make(map[string]wazero.CompiledModule),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply options
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(engine)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register host functions
|
||||||
|
if err := engine.registerHostModule(context.Background()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register host module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs a function with the given input and returns the output.
|
||||||
|
func (e *Engine) Execute(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error) {
|
||||||
|
if fn == nil {
|
||||||
|
return nil, &ValidationError{Field: "function", Message: "cannot be nil"}
|
||||||
|
}
|
||||||
|
if invCtx == nil {
|
||||||
|
invCtx = &InvocationContext{
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
FunctionID: fn.ID,
|
||||||
|
FunctionName: fn.Name,
|
||||||
|
Namespace: fn.Namespace,
|
||||||
|
TriggerType: TriggerTypeHTTP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
if e.rateLimiter != nil {
|
||||||
|
allowed, err := e.rateLimiter.Allow(ctx, "global")
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Warn("Rate limiter error", zap.Error(err))
|
||||||
|
} else if !allowed {
|
||||||
|
return nil, ErrRateLimited
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create timeout context
|
||||||
|
timeout := time.Duration(fn.TimeoutSeconds) * time.Second
|
||||||
|
if timeout > time.Duration(e.config.MaxTimeoutSeconds)*time.Second {
|
||||||
|
timeout = time.Duration(e.config.MaxTimeoutSeconds) * time.Second
|
||||||
|
}
|
||||||
|
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Get compiled module (from cache or compile)
|
||||||
|
module, err := e.getOrCompileModule(execCtx, fn.WASMCID)
|
||||||
|
if err != nil {
|
||||||
|
e.logInvocation(ctx, fn, invCtx, startTime, 0, InvocationStatusError, err)
|
||||||
|
return nil, &ExecutionError{FunctionName: fn.Name, RequestID: invCtx.RequestID, Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the module
|
||||||
|
output, err := e.executeModule(execCtx, module, fn, input, invCtx)
|
||||||
|
if err != nil {
|
||||||
|
status := InvocationStatusError
|
||||||
|
if execCtx.Err() == context.DeadlineExceeded {
|
||||||
|
status = InvocationStatusTimeout
|
||||||
|
err = ErrTimeout
|
||||||
|
}
|
||||||
|
e.logInvocation(ctx, fn, invCtx, startTime, len(output), status, err)
|
||||||
|
return nil, &ExecutionError{FunctionName: fn.Name, RequestID: invCtx.RequestID, Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.logInvocation(ctx, fn, invCtx, startTime, len(output), InvocationStatusSuccess, nil)
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Precompile compiles a WASM module and caches it for faster execution.
|
||||||
|
func (e *Engine) Precompile(ctx context.Context, wasmCID string, wasmBytes []byte) error {
|
||||||
|
if wasmCID == "" {
|
||||||
|
return &ValidationError{Field: "wasmCID", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if len(wasmBytes) == 0 {
|
||||||
|
return &ValidationError{Field: "wasmBytes", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if already cached
|
||||||
|
e.moduleCacheMu.RLock()
|
||||||
|
_, exists := e.moduleCache[wasmCID]
|
||||||
|
e.moduleCacheMu.RUnlock()
|
||||||
|
if exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile the module
|
||||||
|
compiled, err := e.runtime.CompileModule(ctx, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
return &DeployError{FunctionName: wasmCID, Cause: fmt.Errorf("failed to compile WASM: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the compiled module
|
||||||
|
e.moduleCacheMu.Lock()
|
||||||
|
defer e.moduleCacheMu.Unlock()
|
||||||
|
|
||||||
|
// Evict oldest if cache is full
|
||||||
|
if len(e.moduleCache) >= e.config.ModuleCacheSize {
|
||||||
|
e.evictOldestModule()
|
||||||
|
}
|
||||||
|
|
||||||
|
e.moduleCache[wasmCID] = compiled
|
||||||
|
|
||||||
|
e.logger.Debug("Module precompiled and cached",
|
||||||
|
zap.String("wasm_cid", wasmCID),
|
||||||
|
zap.Int("cache_size", len(e.moduleCache)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate removes a compiled module from the cache.
|
||||||
|
func (e *Engine) Invalidate(wasmCID string) {
|
||||||
|
e.moduleCacheMu.Lock()
|
||||||
|
defer e.moduleCacheMu.Unlock()
|
||||||
|
|
||||||
|
if module, exists := e.moduleCache[wasmCID]; exists {
|
||||||
|
_ = module.Close(context.Background())
|
||||||
|
delete(e.moduleCache, wasmCID)
|
||||||
|
e.logger.Debug("Module invalidated from cache", zap.String("wasm_cid", wasmCID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the engine and releases resources.
|
||||||
|
func (e *Engine) Close(ctx context.Context) error {
|
||||||
|
e.moduleCacheMu.Lock()
|
||||||
|
defer e.moduleCacheMu.Unlock()
|
||||||
|
|
||||||
|
// Close all cached modules
|
||||||
|
for cid, module := range e.moduleCache {
|
||||||
|
if err := module.Close(ctx); err != nil {
|
||||||
|
e.logger.Warn("Failed to close cached module", zap.String("cid", cid), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.moduleCache = make(map[string]wazero.CompiledModule)
|
||||||
|
|
||||||
|
// Close the runtime
|
||||||
|
return e.runtime.Close(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCacheStats returns cache statistics.
|
||||||
|
func (e *Engine) GetCacheStats() (size int, capacity int) {
|
||||||
|
e.moduleCacheMu.RLock()
|
||||||
|
defer e.moduleCacheMu.RUnlock()
|
||||||
|
return len(e.moduleCache), e.config.ModuleCacheSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Private methods
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// getOrCompileModule retrieves a compiled module from cache or compiles it.
|
||||||
|
func (e *Engine) getOrCompileModule(ctx context.Context, wasmCID string) (wazero.CompiledModule, error) {
|
||||||
|
// Check cache first
|
||||||
|
e.moduleCacheMu.RLock()
|
||||||
|
if module, exists := e.moduleCache[wasmCID]; exists {
|
||||||
|
e.moduleCacheMu.RUnlock()
|
||||||
|
return module, nil
|
||||||
|
}
|
||||||
|
e.moduleCacheMu.RUnlock()
|
||||||
|
|
||||||
|
// Fetch WASM bytes from registry
|
||||||
|
wasmBytes, err := e.registry.GetWASMBytes(ctx, wasmCID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch WASM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile the module
|
||||||
|
compiled, err := e.runtime.CompileModule(ctx, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrCompilationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the compiled module
|
||||||
|
e.moduleCacheMu.Lock()
|
||||||
|
defer e.moduleCacheMu.Unlock()
|
||||||
|
|
||||||
|
// Double-check (another goroutine might have added it)
|
||||||
|
if existingModule, exists := e.moduleCache[wasmCID]; exists {
|
||||||
|
_ = compiled.Close(ctx) // Discard our compilation
|
||||||
|
return existingModule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict if cache is full
|
||||||
|
if len(e.moduleCache) >= e.config.ModuleCacheSize {
|
||||||
|
e.evictOldestModule()
|
||||||
|
}
|
||||||
|
|
||||||
|
e.moduleCache[wasmCID] = compiled
|
||||||
|
|
||||||
|
e.logger.Debug("Module compiled and cached",
|
||||||
|
zap.String("wasm_cid", wasmCID),
|
||||||
|
zap.Int("cache_size", len(e.moduleCache)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return compiled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeModule instantiates and runs a WASM module.
|
||||||
|
func (e *Engine) executeModule(ctx context.Context, compiled wazero.CompiledModule, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error) {
|
||||||
|
// Set invocation context for host functions if the service supports it
|
||||||
|
if hf, ok := e.hostServices.(contextAwareHostServices); ok {
|
||||||
|
hf.SetInvocationContext(invCtx)
|
||||||
|
defer hf.ClearContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create buffers for stdin/stdout (WASI uses these for I/O)
|
||||||
|
stdin := bytes.NewReader(input)
|
||||||
|
stdout := new(bytes.Buffer)
|
||||||
|
stderr := new(bytes.Buffer)
|
||||||
|
|
||||||
|
// Create module configuration with WASI stdio
|
||||||
|
moduleConfig := wazero.NewModuleConfig().
|
||||||
|
WithName(fn.Name).
|
||||||
|
WithStdin(stdin).
|
||||||
|
WithStdout(stdout).
|
||||||
|
WithStderr(stderr).
|
||||||
|
WithArgs(fn.Name) // argv[0] is the program name
|
||||||
|
|
||||||
|
// Instantiate and run the module (WASI _start will be called automatically)
|
||||||
|
instance, err := e.runtime.InstantiateModule(ctx, compiled, moduleConfig)
|
||||||
|
if err != nil {
|
||||||
|
// Check if stderr has any output
|
||||||
|
if stderr.Len() > 0 {
|
||||||
|
e.logger.Warn("WASM stderr output", zap.String("stderr", stderr.String()))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to instantiate module: %w", err)
|
||||||
|
}
|
||||||
|
defer instance.Close(ctx)
|
||||||
|
|
||||||
|
// For WASI modules, the output is already in stdout buffer
|
||||||
|
// The _start function was called during instantiation
|
||||||
|
output := stdout.Bytes()
|
||||||
|
|
||||||
|
// Log stderr if any
|
||||||
|
if stderr.Len() > 0 {
|
||||||
|
e.logger.Debug("WASM stderr", zap.String("stderr", stderr.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// callHandleFunction calls the main 'handle' export in the WASM module.
|
||||||
|
func (e *Engine) callHandleFunction(ctx context.Context, instance api.Module, input []byte, invCtx *InvocationContext) ([]byte, error) {
|
||||||
|
// Get the 'handle' function export
|
||||||
|
handleFn := instance.ExportedFunction("handle")
|
||||||
|
if handleFn == nil {
|
||||||
|
return nil, fmt.Errorf("WASM module does not export 'handle' function")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get memory export
|
||||||
|
memory := instance.ExportedMemory("memory")
|
||||||
|
if memory == nil {
|
||||||
|
return nil, fmt.Errorf("WASM module does not export 'memory'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get malloc/free exports for memory management
|
||||||
|
mallocFn := instance.ExportedFunction("malloc")
|
||||||
|
freeFn := instance.ExportedFunction("free")
|
||||||
|
|
||||||
|
var inputPtr uint32
|
||||||
|
var inputLen = uint32(len(input))
|
||||||
|
|
||||||
|
if mallocFn != nil && len(input) > 0 {
|
||||||
|
// Allocate memory for input
|
||||||
|
results, err := mallocFn.Call(ctx, uint64(inputLen))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("malloc failed: %w", err)
|
||||||
|
}
|
||||||
|
inputPtr = uint32(results[0])
|
||||||
|
|
||||||
|
// Write input to memory
|
||||||
|
if !memory.Write(inputPtr, input) {
|
||||||
|
return nil, fmt.Errorf("failed to write input to WASM memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defer free if available
|
||||||
|
if freeFn != nil {
|
||||||
|
defer func() {
|
||||||
|
_, _ = freeFn.Call(ctx, uint64(inputPtr))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call handle(input_ptr, input_len)
|
||||||
|
// Returns: output_ptr (packed with length in upper 32 bits)
|
||||||
|
results, err := handleFn.Call(ctx, uint64(inputPtr), uint64(inputLen))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("handle function error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
return nil, nil // No output
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse result - assume format: lower 32 bits = ptr, upper 32 bits = len
|
||||||
|
result := results[0]
|
||||||
|
outputPtr := uint32(result & 0xFFFFFFFF)
|
||||||
|
outputLen := uint32(result >> 32)
|
||||||
|
|
||||||
|
if outputLen == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read output from memory
|
||||||
|
output, ok := memory.Read(outputPtr, outputLen)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to read output from WASM memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy (memory will be freed)
|
||||||
|
outputCopy := make([]byte, len(output))
|
||||||
|
copy(outputCopy, output)
|
||||||
|
|
||||||
|
return outputCopy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestModule removes the oldest module from cache.
|
||||||
|
// Must be called with moduleCacheMu held.
|
||||||
|
func (e *Engine) evictOldestModule() {
|
||||||
|
// Simple LRU: just remove the first one we find
|
||||||
|
// In production, you'd want proper LRU tracking
|
||||||
|
for cid, module := range e.moduleCache {
|
||||||
|
_ = module.Close(context.Background())
|
||||||
|
delete(e.moduleCache, cid)
|
||||||
|
e.logger.Debug("Evicted module from cache", zap.String("wasm_cid", cid))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// logInvocation logs an invocation record.
|
||||||
|
func (e *Engine) logInvocation(ctx context.Context, fn *Function, invCtx *InvocationContext, startTime time.Time, outputSize int, status InvocationStatus, err error) {
|
||||||
|
if e.invocationLogger == nil || !e.config.LogInvocations {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
completedAt := time.Now()
|
||||||
|
record := &InvocationRecord{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
FunctionID: fn.ID,
|
||||||
|
RequestID: invCtx.RequestID,
|
||||||
|
TriggerType: invCtx.TriggerType,
|
||||||
|
CallerWallet: invCtx.CallerWallet,
|
||||||
|
OutputSize: outputSize,
|
||||||
|
StartedAt: startTime,
|
||||||
|
CompletedAt: completedAt,
|
||||||
|
DurationMS: completedAt.Sub(startTime).Milliseconds(),
|
||||||
|
Status: status,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
record.ErrorMessage = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect logs from host services if supported
|
||||||
|
if hf, ok := e.hostServices.(interface{ GetLogs() []LogEntry }); ok {
|
||||||
|
record.Logs = hf.GetLogs()
|
||||||
|
}
|
||||||
|
|
||||||
|
if logErr := e.invocationLogger.Log(ctx, record); logErr != nil {
|
||||||
|
e.logger.Warn("Failed to log invocation", zap.Error(logErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerHostModule registers the Orama host functions with the wazero runtime.
|
||||||
|
func (e *Engine) registerHostModule(ctx context.Context) error {
|
||||||
|
// Register under both "env" and "host" to support different import styles
|
||||||
|
// The user requested "env" in instructions but "host" in expected result.
|
||||||
|
for _, moduleName := range []string{"env", "host"} {
|
||||||
|
_, err := e.runtime.NewHostModuleBuilder(moduleName).
|
||||||
|
NewFunctionBuilder().WithFunc(e.hGetCallerWallet).Export("get_caller_wallet").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hGetRequestID).Export("get_request_id").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hGetEnv).Export("get_env").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hGetSecret).Export("get_secret").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hDBQuery).Export("db_query").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hDBExecute).Export("db_execute").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hCacheGet).Export("cache_get").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hCacheSet).Export("cache_set").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hCacheIncr).Export("cache_incr").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info").
|
||||||
|
NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error").
|
||||||
|
Instantiate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hGetCallerWallet(ctx context.Context, mod api.Module) uint64 {
|
||||||
|
wallet := e.hostServices.GetCallerWallet(ctx)
|
||||||
|
return e.writeToGuest(ctx, mod, []byte(wallet))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hGetRequestID(ctx context.Context, mod api.Module) uint64 {
|
||||||
|
rid := e.hostServices.GetRequestID(ctx)
|
||||||
|
return e.writeToGuest(ctx, mod, []byte(rid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hGetEnv(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) uint64 {
|
||||||
|
key, ok := mod.Memory().Read(keyPtr, keyLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, _ := e.hostServices.GetEnv(ctx, string(key))
|
||||||
|
return e.writeToGuest(ctx, mod, []byte(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hGetSecret(ctx context.Context, mod api.Module, namePtr, nameLen uint32) uint64 {
|
||||||
|
name, ok := mod.Memory().Read(namePtr, nameLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, err := e.hostServices.GetSecret(ctx, string(name))
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return e.writeToGuest(ctx, mod, []byte(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hDBQuery(ctx context.Context, mod api.Module, queryPtr, queryLen, argsPtr, argsLen uint32) uint64 {
|
||||||
|
query, ok := mod.Memory().Read(queryPtr, queryLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var args []interface{}
|
||||||
|
if argsLen > 0 {
|
||||||
|
argsData, ok := mod.Memory().Read(argsPtr, argsLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(argsData, &args); err != nil {
|
||||||
|
e.logger.Error("failed to unmarshal db_query arguments", zap.Error(err))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := e.hostServices.DBQuery(ctx, string(query), args)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function db_query failed", zap.Error(err), zap.String("query", string(query)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return e.writeToGuest(ctx, mod, results)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hDBExecute(ctx context.Context, mod api.Module, queryPtr, queryLen, argsPtr, argsLen uint32) uint32 {
|
||||||
|
query, ok := mod.Memory().Read(queryPtr, queryLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var args []interface{}
|
||||||
|
if argsLen > 0 {
|
||||||
|
argsData, ok := mod.Memory().Read(argsPtr, argsLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(argsData, &args); err != nil {
|
||||||
|
e.logger.Error("failed to unmarshal db_execute arguments", zap.Error(err))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := e.hostServices.DBExecute(ctx, string(query), args)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function db_execute failed", zap.Error(err), zap.String("query", string(query)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return uint32(affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hCacheGet(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) uint64 {
|
||||||
|
key, ok := mod.Memory().Read(keyPtr, keyLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, err := e.hostServices.CacheGet(ctx, string(key))
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return e.writeToGuest(ctx, mod, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hCacheSet(ctx context.Context, mod api.Module, keyPtr, keyLen, valPtr, valLen uint32, ttl int64) {
|
||||||
|
key, ok := mod.Memory().Read(keyPtr, keyLen)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val, ok := mod.Memory().Read(valPtr, valLen)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = e.hostServices.CacheSet(ctx, string(key), val, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hCacheIncr(ctx context.Context, mod api.Module, keyPtr, keyLen uint32) int64 {
|
||||||
|
key, ok := mod.Memory().Read(keyPtr, keyLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, err := e.hostServices.CacheIncr(ctx, string(key))
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function cache_incr failed", zap.Error(err), zap.String("key", string(key)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hCacheIncrBy(ctx context.Context, mod api.Module, keyPtr, keyLen uint32, delta int64) int64 {
|
||||||
|
key, ok := mod.Memory().Read(keyPtr, keyLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, err := e.hostServices.CacheIncrBy(ctx, string(key), delta)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function cache_incr_by failed", zap.Error(err), zap.String("key", string(key)), zap.Int64("delta", delta))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hHTTPFetch(ctx context.Context, mod api.Module, methodPtr, methodLen, urlPtr, urlLen, headersPtr, headersLen, bodyPtr, bodyLen uint32) uint64 {
|
||||||
|
method, ok := mod.Memory().Read(methodPtr, methodLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
u, ok := mod.Memory().Read(urlPtr, urlLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var headers map[string]string
|
||||||
|
if headersLen > 0 {
|
||||||
|
headersData, ok := mod.Memory().Read(headersPtr, headersLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(headersData, &headers); err != nil {
|
||||||
|
e.logger.Error("failed to unmarshal http_fetch headers", zap.Error(err))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body, ok := mod.Memory().Read(bodyPtr, bodyLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := e.hostServices.HTTPFetch(ctx, string(method), string(u), headers, body)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function http_fetch failed", zap.Error(err), zap.String("url", string(u)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return e.writeToGuest(ctx, mod, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, topicLen, dataPtr, dataLen uint32) uint32 {
|
||||||
|
topic, ok := mod.Memory().Read(topicPtr, topicLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := mod.Memory().Read(dataPtr, dataLen)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
err := e.hostServices.PubSubPublish(ctx, string(topic), data)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("host function pubsub_publish failed", zap.Error(err), zap.String("topic", string(topic)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 1 // Success
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) {
|
||||||
|
msg, ok := mod.Memory().Read(ptr, size)
|
||||||
|
if ok {
|
||||||
|
e.hostServices.LogInfo(ctx, string(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) hLogError(ctx context.Context, mod api.Module, ptr, size uint32) {
|
||||||
|
msg, ok := mod.Memory().Read(ptr, size)
|
||||||
|
if ok {
|
||||||
|
e.hostServices.LogError(ctx, string(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) writeToGuest(ctx context.Context, mod api.Module, data []byte) uint64 {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// Try to find a non-conflicting allocator first, fallback to malloc
|
||||||
|
malloc := mod.ExportedFunction("orama_alloc")
|
||||||
|
if malloc == nil {
|
||||||
|
malloc = mod.ExportedFunction("malloc")
|
||||||
|
}
|
||||||
|
|
||||||
|
if malloc == nil {
|
||||||
|
e.logger.Warn("WASM module missing malloc/orama_alloc export, cannot return string/bytes to guest")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
results, err := malloc.Call(ctx, uint64(len(data)))
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("failed to call malloc in WASM module", zap.Error(err))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
ptr := uint32(results[0])
|
||||||
|
if !mod.Memory().Write(ptr, data) {
|
||||||
|
e.logger.Error("failed to write to WASM memory")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (uint64(ptr) << 32) | uint64(len(data))
|
||||||
|
}
|
||||||
202
pkg/serverless/engine_test.go
Normal file
202
pkg/serverless/engine_test.go
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEngine_Execute(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := NewMockRegistry()
|
||||||
|
hostServices := NewMockHostServices()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.ModuleCacheSize = 2
|
||||||
|
|
||||||
|
engine, err := NewEngine(cfg, registry, hostServices, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create engine: %v", err)
|
||||||
|
}
|
||||||
|
defer engine.Close(context.Background())
|
||||||
|
|
||||||
|
// Use a minimal valid WASM module that exports _start (WASI)
|
||||||
|
// This is just 'nop' in WASM
|
||||||
|
wasmBytes := []byte{
|
||||||
|
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
|
||||||
|
0x03, 0x02, 0x01, 0x00,
|
||||||
|
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
|
||||||
|
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
|
||||||
|
}
|
||||||
|
|
||||||
|
fnDef := &FunctionDefinition{
|
||||||
|
Name: "test-func",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MemoryLimitMB: 64,
|
||||||
|
TimeoutSeconds: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = registry.Register(context.Background(), fnDef, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to register function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, err := registry.Get(context.Background(), "test-ns", "test-func", 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute function
|
||||||
|
ctx := context.Background()
|
||||||
|
output, err := engine.Execute(ctx, fn, []byte("input"), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to execute function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Our minimal WASM doesn't write to stdout, so output should be empty
|
||||||
|
if len(output) != 0 {
|
||||||
|
t.Errorf("expected empty output, got %d bytes", len(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cache stats
|
||||||
|
size, capacity := engine.GetCacheStats()
|
||||||
|
if size != 1 {
|
||||||
|
t.Errorf("expected cache size 1, got %d", size)
|
||||||
|
}
|
||||||
|
if capacity != 2 {
|
||||||
|
t.Errorf("expected cache capacity 2, got %d", capacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Invalidate
|
||||||
|
engine.Invalidate(fn.WASMCID)
|
||||||
|
size, _ = engine.GetCacheStats()
|
||||||
|
if size != 0 {
|
||||||
|
t.Errorf("expected cache size 0 after invalidation, got %d", size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Precompile(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := NewMockRegistry()
|
||||||
|
hostServices := NewMockHostServices()
|
||||||
|
engine, _ := NewEngine(nil, registry, hostServices, logger)
|
||||||
|
defer engine.Close(context.Background())
|
||||||
|
|
||||||
|
wasmBytes := []byte{
|
||||||
|
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
|
||||||
|
0x03, 0x02, 0x01, 0x00,
|
||||||
|
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
|
||||||
|
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := engine.Precompile(context.Background(), "test-cid", wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to precompile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
size, _ := engine.GetCacheStats()
|
||||||
|
if size != 1 {
|
||||||
|
t.Errorf("expected cache size 1, got %d", size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Timeout(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := NewMockRegistry()
|
||||||
|
hostServices := NewMockHostServices()
|
||||||
|
engine, _ := NewEngine(nil, registry, hostServices, logger)
|
||||||
|
defer engine.Close(context.Background())
|
||||||
|
|
||||||
|
wasmBytes := []byte{
|
||||||
|
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
|
||||||
|
0x03, 0x02, 0x01, 0x00,
|
||||||
|
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
|
||||||
|
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, _ := registry.Get(context.Background(), "test", "timeout", 0)
|
||||||
|
if fn == nil {
|
||||||
|
_, _ = registry.Register(context.Background(), &FunctionDefinition{Name: "timeout", Namespace: "test"}, wasmBytes)
|
||||||
|
fn, _ = registry.Get(context.Background(), "test", "timeout", 0)
|
||||||
|
}
|
||||||
|
fn.TimeoutSeconds = 1
|
||||||
|
|
||||||
|
// Test with already canceled context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := engine.Execute(ctx, fn, nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for canceled context, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_MemoryLimit(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := NewMockRegistry()
|
||||||
|
hostServices := NewMockHostServices()
|
||||||
|
engine, _ := NewEngine(nil, registry, hostServices, logger)
|
||||||
|
defer engine.Close(context.Background())
|
||||||
|
|
||||||
|
wasmBytes := []byte{
|
||||||
|
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x01, 0x04, 0x01, 0x60, 0x00, 0x00,
|
||||||
|
0x03, 0x02, 0x01, 0x00,
|
||||||
|
0x07, 0x0a, 0x01, 0x06, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x00, 0x00,
|
||||||
|
0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = registry.Register(context.Background(), &FunctionDefinition{Name: "memory", Namespace: "test", MemoryLimitMB: 1, TimeoutSeconds: 5}, wasmBytes)
|
||||||
|
fn, _ := registry.Get(context.Background(), "test", "memory", 0)
|
||||||
|
|
||||||
|
// This should pass because the minimal WASM doesn't use much memory
|
||||||
|
_, err := engine.Execute(context.Background(), fn, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected success for minimal WASM within memory limit, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_RealWASM(t *testing.T) {
|
||||||
|
wasmPath := "../../examples/functions/bin/hello.wasm"
|
||||||
|
if _, err := os.Stat(wasmPath); os.IsNotExist(err) {
|
||||||
|
t.Skip("hello.wasm not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
wasmBytes, err := os.ReadFile(wasmPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read hello.wasm: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := zap.NewNop()
|
||||||
|
registry := NewMockRegistry()
|
||||||
|
hostServices := NewMockHostServices()
|
||||||
|
engine, _ := NewEngine(nil, registry, hostServices, logger)
|
||||||
|
defer engine.Close(context.Background())
|
||||||
|
|
||||||
|
fnDef := &FunctionDefinition{
|
||||||
|
Name: "hello",
|
||||||
|
Namespace: "examples",
|
||||||
|
TimeoutSeconds: 10,
|
||||||
|
}
|
||||||
|
_, _ = registry.Register(context.Background(), fnDef, wasmBytes)
|
||||||
|
fn, _ := registry.Get(context.Background(), "examples", "hello", 0)
|
||||||
|
|
||||||
|
output, err := engine.Execute(context.Background(), fn, []byte(`{"name": "Tester"}`), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "Hello, Tester!"
|
||||||
|
if !contains(string(output), expected) {
|
||||||
|
t.Errorf("output %q does not contain %q", string(output), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s[:len(substr)] == substr || contains(s[1:], substr))
|
||||||
|
}
|
||||||
216
pkg/serverless/errors.go
Normal file
216
pkg/serverless/errors.go
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors for common conditions.
|
||||||
|
var (
|
||||||
|
// ErrFunctionNotFound is returned when a function does not exist.
|
||||||
|
ErrFunctionNotFound = errors.New("function not found")
|
||||||
|
|
||||||
|
// ErrFunctionExists is returned when attempting to create a function that already exists.
|
||||||
|
ErrFunctionExists = errors.New("function already exists")
|
||||||
|
|
||||||
|
// ErrVersionNotFound is returned when a specific function version does not exist.
|
||||||
|
ErrVersionNotFound = errors.New("function version not found")
|
||||||
|
|
||||||
|
// ErrSecretNotFound is returned when a secret does not exist.
|
||||||
|
ErrSecretNotFound = errors.New("secret not found")
|
||||||
|
|
||||||
|
// ErrJobNotFound is returned when a job does not exist.
|
||||||
|
ErrJobNotFound = errors.New("job not found")
|
||||||
|
|
||||||
|
// ErrTriggerNotFound is returned when a trigger does not exist.
|
||||||
|
ErrTriggerNotFound = errors.New("trigger not found")
|
||||||
|
|
||||||
|
// ErrTimerNotFound is returned when a timer does not exist.
|
||||||
|
ErrTimerNotFound = errors.New("timer not found")
|
||||||
|
|
||||||
|
// ErrUnauthorized is returned when the caller is not authorized.
|
||||||
|
ErrUnauthorized = errors.New("unauthorized")
|
||||||
|
|
||||||
|
// ErrRateLimited is returned when the rate limit is exceeded.
|
||||||
|
ErrRateLimited = errors.New("rate limit exceeded")
|
||||||
|
|
||||||
|
// ErrInvalidWASM is returned when the WASM module is invalid.
|
||||||
|
ErrInvalidWASM = errors.New("invalid WASM module")
|
||||||
|
|
||||||
|
// ErrCompilationFailed is returned when WASM compilation fails.
|
||||||
|
ErrCompilationFailed = errors.New("WASM compilation failed")
|
||||||
|
|
||||||
|
// ErrExecutionFailed is returned when function execution fails.
|
||||||
|
ErrExecutionFailed = errors.New("function execution failed")
|
||||||
|
|
||||||
|
// ErrTimeout is returned when function execution times out.
|
||||||
|
ErrTimeout = errors.New("function execution timeout")
|
||||||
|
|
||||||
|
// ErrMemoryExceeded is returned when the function exceeds memory limits.
|
||||||
|
ErrMemoryExceeded = errors.New("memory limit exceeded")
|
||||||
|
|
||||||
|
// ErrInvalidInput is returned when function input is invalid.
|
||||||
|
ErrInvalidInput = errors.New("invalid input")
|
||||||
|
|
||||||
|
// ErrWSNotAvailable is returned when WebSocket operations are used outside WS context.
|
||||||
|
ErrWSNotAvailable = errors.New("websocket operations not available in this context")
|
||||||
|
|
||||||
|
// ErrWSClientNotFound is returned when a WebSocket client is not connected.
|
||||||
|
ErrWSClientNotFound = errors.New("websocket client not found")
|
||||||
|
|
||||||
|
// ErrInvalidCronExpression is returned when a cron expression is invalid.
|
||||||
|
ErrInvalidCronExpression = errors.New("invalid cron expression")
|
||||||
|
|
||||||
|
// ErrPayloadTooLarge is returned when a job payload exceeds the maximum size.
|
||||||
|
ErrPayloadTooLarge = errors.New("payload too large")
|
||||||
|
|
||||||
|
// ErrQueueFull is returned when the job queue is full.
|
||||||
|
ErrQueueFull = errors.New("job queue is full")
|
||||||
|
|
||||||
|
// ErrJobCancelled is returned when a job is cancelled.
|
||||||
|
ErrJobCancelled = errors.New("job cancelled")
|
||||||
|
|
||||||
|
// ErrStorageUnavailable is returned when IPFS storage is unavailable.
|
||||||
|
ErrStorageUnavailable = errors.New("storage unavailable")
|
||||||
|
|
||||||
|
// ErrDatabaseUnavailable is returned when the database is unavailable.
|
||||||
|
ErrDatabaseUnavailable = errors.New("database unavailable")
|
||||||
|
|
||||||
|
// ErrCacheUnavailable is returned when the cache is unavailable.
|
||||||
|
ErrCacheUnavailable = errors.New("cache unavailable")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigError represents a configuration validation error.
|
||||||
|
type ConfigError struct {
|
||||||
|
Field string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConfigError) Error() string {
|
||||||
|
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeployError represents an error during function deployment.
|
||||||
|
type DeployError struct {
|
||||||
|
FunctionName string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DeployError) Error() string {
|
||||||
|
return fmt.Sprintf("deploy error for function '%s': %v", e.FunctionName, e.Cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DeployError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecutionError represents an error during function execution.
|
||||||
|
type ExecutionError struct {
|
||||||
|
FunctionName string
|
||||||
|
RequestID string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExecutionError) Error() string {
|
||||||
|
return fmt.Sprintf("execution error for function '%s' (request %s): %v",
|
||||||
|
e.FunctionName, e.RequestID, e.Cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExecutionError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostFunctionError represents an error in a host function call.
|
||||||
|
type HostFunctionError struct {
|
||||||
|
Function string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HostFunctionError) Error() string {
|
||||||
|
return fmt.Sprintf("host function '%s' error: %v", e.Function, e.Cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HostFunctionError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerError represents an error in trigger execution.
|
||||||
|
type TriggerError struct {
|
||||||
|
TriggerType string
|
||||||
|
TriggerID string
|
||||||
|
FunctionID string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *TriggerError) Error() string {
|
||||||
|
return fmt.Sprintf("trigger error (%s/%s) for function '%s': %v",
|
||||||
|
e.TriggerType, e.TriggerID, e.FunctionID, e.Cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *TriggerError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationError represents an input validation error.
|
||||||
|
type ValidationError struct {
|
||||||
|
Field string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ValidationError) Error() string {
|
||||||
|
return fmt.Sprintf("validation error: %s: %s", e.Field, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryableError wraps an error that should be retried.
|
||||||
|
type RetryableError struct {
|
||||||
|
Cause error
|
||||||
|
RetryAfter int // Suggested retry delay in seconds
|
||||||
|
MaxRetries int // Maximum number of retries remaining
|
||||||
|
CurrentTry int // Current attempt number
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RetryableError) Error() string {
|
||||||
|
return fmt.Sprintf("retryable error (attempt %d): %v", e.CurrentTry, e.Cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RetryableError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRetryable checks if an error should be retried.
|
||||||
|
func IsRetryable(err error) bool {
|
||||||
|
var retryable *RetryableError
|
||||||
|
return errors.As(err, &retryable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotFound checks if an error indicates a resource was not found.
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
return errors.Is(err, ErrFunctionNotFound) ||
|
||||||
|
errors.Is(err, ErrVersionNotFound) ||
|
||||||
|
errors.Is(err, ErrSecretNotFound) ||
|
||||||
|
errors.Is(err, ErrJobNotFound) ||
|
||||||
|
errors.Is(err, ErrTriggerNotFound) ||
|
||||||
|
errors.Is(err, ErrTimerNotFound) ||
|
||||||
|
errors.Is(err, ErrWSClientNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnauthorized checks if an error indicates a lack of authorization.
|
||||||
|
func IsUnauthorized(err error) bool {
|
||||||
|
return errors.Is(err, ErrUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsResourceExhausted checks if an error indicates resource exhaustion.
|
||||||
|
func IsResourceExhausted(err error) bool {
|
||||||
|
return errors.Is(err, ErrRateLimited) ||
|
||||||
|
errors.Is(err, ErrMemoryExceeded) ||
|
||||||
|
errors.Is(err, ErrPayloadTooLarge) ||
|
||||||
|
errors.Is(err, ErrQueueFull) ||
|
||||||
|
errors.Is(err, ErrTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServiceUnavailable checks if an error indicates a service is unavailable.
|
||||||
|
func IsServiceUnavailable(err error) bool {
|
||||||
|
return errors.Is(err, ErrStorageUnavailable) ||
|
||||||
|
errors.Is(err, ErrDatabaseUnavailable) ||
|
||||||
|
errors.Is(err, ErrCacheUnavailable)
|
||||||
|
}
|
||||||
687
pkg/serverless/hostfuncs.go
Normal file
687
pkg/serverless/hostfuncs.go
Normal file
@ -0,0 +1,687 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/tlsutil"
|
||||||
|
olriclib "github.com/olric-data/olric"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure HostFunctions implements HostServices interface.
|
||||||
|
var _ HostServices = (*HostFunctions)(nil)
|
||||||
|
|
||||||
|
// HostFunctions provides the bridge between WASM functions and Orama services.
|
||||||
|
// It implements the HostServices interface and is injected into the execution context.
|
||||||
|
type HostFunctions struct {
|
||||||
|
db rqlite.Client
|
||||||
|
cacheClient olriclib.Client
|
||||||
|
storage ipfs.IPFSClient
|
||||||
|
ipfsAPIURL string
|
||||||
|
pubsub *pubsub.ClientAdapter
|
||||||
|
wsManager WebSocketManager
|
||||||
|
secrets SecretsManager
|
||||||
|
httpClient *http.Client
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// Current invocation context (set per-execution)
|
||||||
|
invCtx *InvocationContext
|
||||||
|
invCtxLock sync.RWMutex
|
||||||
|
|
||||||
|
// Captured logs for this invocation
|
||||||
|
logs []LogEntry
|
||||||
|
logsLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostFunctionsConfig holds configuration for HostFunctions.
|
||||||
|
type HostFunctionsConfig struct {
|
||||||
|
IPFSAPIURL string
|
||||||
|
HTTPTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostFunctions creates a new HostFunctions instance.
|
||||||
|
func NewHostFunctions(
|
||||||
|
db rqlite.Client,
|
||||||
|
cacheClient olriclib.Client,
|
||||||
|
storage ipfs.IPFSClient,
|
||||||
|
pubsubAdapter *pubsub.ClientAdapter,
|
||||||
|
wsManager WebSocketManager,
|
||||||
|
secrets SecretsManager,
|
||||||
|
cfg HostFunctionsConfig,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *HostFunctions {
|
||||||
|
httpTimeout := cfg.HTTPTimeout
|
||||||
|
if httpTimeout == 0 {
|
||||||
|
httpTimeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HostFunctions{
|
||||||
|
db: db,
|
||||||
|
cacheClient: cacheClient,
|
||||||
|
storage: storage,
|
||||||
|
ipfsAPIURL: cfg.IPFSAPIURL,
|
||||||
|
pubsub: pubsubAdapter,
|
||||||
|
wsManager: wsManager,
|
||||||
|
secrets: secrets,
|
||||||
|
httpClient: tlsutil.NewHTTPClient(httpTimeout),
|
||||||
|
logger: logger,
|
||||||
|
logs: make([]LogEntry, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInvocationContext sets the current invocation context.
|
||||||
|
// Must be called before executing a function.
|
||||||
|
func (h *HostFunctions) SetInvocationContext(invCtx *InvocationContext) {
|
||||||
|
h.invCtxLock.Lock()
|
||||||
|
defer h.invCtxLock.Unlock()
|
||||||
|
h.invCtx = invCtx
|
||||||
|
h.logs = make([]LogEntry, 0) // Reset logs for new invocation
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogs returns the captured logs for the current invocation.
|
||||||
|
func (h *HostFunctions) GetLogs() []LogEntry {
|
||||||
|
h.logsLock.Lock()
|
||||||
|
defer h.logsLock.Unlock()
|
||||||
|
logsCopy := make([]LogEntry, len(h.logs))
|
||||||
|
copy(logsCopy, h.logs)
|
||||||
|
return logsCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearContext clears the invocation context after execution.
|
||||||
|
func (h *HostFunctions) ClearContext() {
|
||||||
|
h.invCtxLock.Lock()
|
||||||
|
defer h.invCtxLock.Unlock()
|
||||||
|
h.invCtx = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Database Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// DBQuery executes a SELECT query and returns JSON-encoded results.
|
||||||
|
func (h *HostFunctions) DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error) {
|
||||||
|
if h.db == nil {
|
||||||
|
return nil, &HostFunctionError{Function: "db_query", Cause: ErrDatabaseUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []map[string]interface{}
|
||||||
|
if err := h.db.Query(ctx, &results, query, args...); err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "db_query", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(results)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "db_query", Cause: fmt.Errorf("failed to marshal results: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBExecute executes an INSERT/UPDATE/DELETE query and returns affected rows.
|
||||||
|
func (h *HostFunctions) DBExecute(ctx context.Context, query string, args []interface{}) (int64, error) {
|
||||||
|
if h.db == nil {
|
||||||
|
return 0, &HostFunctionError{Function: "db_execute", Cause: ErrDatabaseUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.db.Exec(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &HostFunctionError{Function: "db_execute", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, _ := result.RowsAffected()
|
||||||
|
return affected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Cache Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const cacheDMapName = "serverless_cache"
|
||||||
|
|
||||||
|
// CacheGet retrieves a value from the cache.
|
||||||
|
func (h *HostFunctions) CacheGet(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
if h.cacheClient == nil {
|
||||||
|
return nil, &HostFunctionError{Function: "cache_get", Cause: ErrCacheUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
dm, err := h.cacheClient.NewDMap(cacheDMapName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "cache_get", Cause: fmt.Errorf("failed to get DMap: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := dm.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "cache_get", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := result.Byte()
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "cache_get", Cause: fmt.Errorf("failed to decode value: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheSet stores a value in the cache with optional TTL.
|
||||||
|
// Note: TTL is currently not supported by the underlying Olric DMap.Put method.
|
||||||
|
// Values are stored indefinitely until explicitly deleted.
|
||||||
|
func (h *HostFunctions) CacheSet(ctx context.Context, key string, value []byte, ttlSeconds int64) error {
|
||||||
|
if h.cacheClient == nil {
|
||||||
|
return &HostFunctionError{Function: "cache_set", Cause: ErrCacheUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
dm, err := h.cacheClient.NewDMap(cacheDMapName)
|
||||||
|
if err != nil {
|
||||||
|
return &HostFunctionError{Function: "cache_set", Cause: fmt.Errorf("failed to get DMap: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Olric DMap.Put doesn't support TTL in the basic API
|
||||||
|
// For TTL support, consider using Olric's Expire API separately
|
||||||
|
if err := dm.Put(ctx, key, value); err != nil {
|
||||||
|
return &HostFunctionError{Function: "cache_set", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheDelete removes a value from the cache.
|
||||||
|
func (h *HostFunctions) CacheDelete(ctx context.Context, key string) error {
|
||||||
|
if h.cacheClient == nil {
|
||||||
|
return &HostFunctionError{Function: "cache_delete", Cause: ErrCacheUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
dm, err := h.cacheClient.NewDMap(cacheDMapName)
|
||||||
|
if err != nil {
|
||||||
|
return &HostFunctionError{Function: "cache_delete", Cause: fmt.Errorf("failed to get DMap: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := dm.Delete(ctx, key); err != nil {
|
||||||
|
return &HostFunctionError{Function: "cache_delete", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheIncr atomically increments a numeric value in cache by 1 and returns the new value.
|
||||||
|
// If the key doesn't exist, it is initialized to 0 before incrementing.
|
||||||
|
// Returns an error if the value exists but is not numeric.
|
||||||
|
func (h *HostFunctions) CacheIncr(ctx context.Context, key string) (int64, error) {
|
||||||
|
return h.CacheIncrBy(ctx, key, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheIncrBy atomically increments a numeric value by delta and returns the new value.
|
||||||
|
// If the key doesn't exist, it is initialized to 0 before incrementing.
|
||||||
|
// Returns an error if the value exists but is not numeric.
|
||||||
|
func (h *HostFunctions) CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error) {
|
||||||
|
if h.cacheClient == nil {
|
||||||
|
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: ErrCacheUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
dm, err := h.cacheClient.NewDMap(cacheDMapName)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: fmt.Errorf("failed to get DMap: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Olric's Incr method atomically increments a numeric value
|
||||||
|
// It initializes the key to 0 if it doesn't exist, then increments by delta
|
||||||
|
// Note: Olric's Incr takes int (not int64) and returns int
|
||||||
|
newValue, err := dm.Incr(ctx, key, int(delta))
|
||||||
|
if err != nil {
|
||||||
|
return 0, &HostFunctionError{Function: "cache_incr_by", Cause: fmt.Errorf("failed to increment: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(newValue), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Storage Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// StoragePut uploads data to IPFS and returns the CID.
|
||||||
|
func (h *HostFunctions) StoragePut(ctx context.Context, data []byte) (string, error) {
|
||||||
|
if h.storage == nil {
|
||||||
|
return "", &HostFunctionError{Function: "storage_put", Cause: ErrStorageUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bytes.NewReader(data)
|
||||||
|
resp, err := h.storage.Add(ctx, reader, "function-data")
|
||||||
|
if err != nil {
|
||||||
|
return "", &HostFunctionError{Function: "storage_put", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Cid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageGet retrieves data from IPFS by CID.
|
||||||
|
func (h *HostFunctions) StorageGet(ctx context.Context, cid string) ([]byte, error) {
|
||||||
|
if h.storage == nil {
|
||||||
|
return nil, &HostFunctionError{Function: "storage_get", Cause: ErrStorageUnavailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := h.storage.Get(ctx, cid, h.ipfsAPIURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "storage_get", Cause: err}
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "storage_get", Cause: fmt.Errorf("failed to read data: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// PubSub Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// PubSubPublish publishes a message to a topic.
|
||||||
|
func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data []byte) error {
|
||||||
|
if h.pubsub == nil {
|
||||||
|
return &HostFunctionError{Function: "pubsub_publish", Cause: fmt.Errorf("pubsub not available")}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The pubsub adapter handles namespacing internally
|
||||||
|
if err := h.pubsub.Publish(ctx, topic, data); err != nil {
|
||||||
|
return &HostFunctionError{Function: "pubsub_publish", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// WebSocket Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// WSSend sends data to a specific WebSocket client.
|
||||||
|
func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error {
|
||||||
|
if h.wsManager == nil {
|
||||||
|
return &HostFunctionError{Function: "ws_send", Cause: ErrWSNotAvailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no clientID provided, use the current invocation's client
|
||||||
|
if clientID == "" {
|
||||||
|
h.invCtxLock.RLock()
|
||||||
|
if h.invCtx != nil && h.invCtx.WSClientID != "" {
|
||||||
|
clientID = h.invCtx.WSClientID
|
||||||
|
}
|
||||||
|
h.invCtxLock.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID == "" {
|
||||||
|
return &HostFunctionError{Function: "ws_send", Cause: ErrWSNotAvailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.wsManager.Send(clientID, data); err != nil {
|
||||||
|
return &HostFunctionError{Function: "ws_send", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSBroadcast sends data to all WebSocket clients subscribed to a topic.
|
||||||
|
func (h *HostFunctions) WSBroadcast(ctx context.Context, topic string, data []byte) error {
|
||||||
|
if h.wsManager == nil {
|
||||||
|
return &HostFunctionError{Function: "ws_broadcast", Cause: ErrWSNotAvailable}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.wsManager.Broadcast(topic, data); err != nil {
|
||||||
|
return &HostFunctionError{Function: "ws_broadcast", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// HTTP Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// HTTPFetch makes an outbound HTTP request.
|
||||||
|
func (h *HostFunctions) HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if len(body) > 0 {
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("http_fetch request creation error", zap.Error(err), zap.String("url", url))
|
||||||
|
errorResp := map[string]interface{}{
|
||||||
|
"error": "failed to create request: " + err.Error(),
|
||||||
|
"status": 0,
|
||||||
|
}
|
||||||
|
return json.Marshal(errorResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("http_fetch transport error", zap.Error(err), zap.String("url", url))
|
||||||
|
errorResp := map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
"status": 0, // Transport error
|
||||||
|
}
|
||||||
|
return json.Marshal(errorResp)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("http_fetch response read error", zap.Error(err), zap.String("url", url))
|
||||||
|
errorResp := map[string]interface{}{
|
||||||
|
"error": "failed to read response: " + err.Error(),
|
||||||
|
"status": resp.StatusCode,
|
||||||
|
}
|
||||||
|
return json.Marshal(errorResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode response with status code
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"status": resp.StatusCode,
|
||||||
|
"headers": resp.Header,
|
||||||
|
"body": string(respBody),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &HostFunctionError{Function: "http_fetch", Cause: fmt.Errorf("failed to marshal response: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Context Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// GetEnv retrieves an environment variable for the function.
|
||||||
|
func (h *HostFunctions) GetEnv(ctx context.Context, key string) (string, error) {
|
||||||
|
h.invCtxLock.RLock()
|
||||||
|
defer h.invCtxLock.RUnlock()
|
||||||
|
|
||||||
|
if h.invCtx == nil || h.invCtx.EnvVars == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.invCtx.EnvVars[key], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSecret retrieves a decrypted secret.
|
||||||
|
func (h *HostFunctions) GetSecret(ctx context.Context, name string) (string, error) {
|
||||||
|
if h.secrets == nil {
|
||||||
|
return "", &HostFunctionError{Function: "get_secret", Cause: fmt.Errorf("secrets manager not available")}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.invCtxLock.RLock()
|
||||||
|
namespace := ""
|
||||||
|
if h.invCtx != nil {
|
||||||
|
namespace = h.invCtx.Namespace
|
||||||
|
}
|
||||||
|
h.invCtxLock.RUnlock()
|
||||||
|
|
||||||
|
value, err := h.secrets.Get(ctx, namespace, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", &HostFunctionError{Function: "get_secret", Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestID returns the current request ID.
|
||||||
|
func (h *HostFunctions) GetRequestID(ctx context.Context) string {
|
||||||
|
h.invCtxLock.RLock()
|
||||||
|
defer h.invCtxLock.RUnlock()
|
||||||
|
|
||||||
|
if h.invCtx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h.invCtx.RequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallerWallet returns the wallet address of the caller.
|
||||||
|
func (h *HostFunctions) GetCallerWallet(ctx context.Context) string {
|
||||||
|
h.invCtxLock.RLock()
|
||||||
|
defer h.invCtxLock.RUnlock()
|
||||||
|
|
||||||
|
if h.invCtx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h.invCtx.CallerWallet
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Job Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// EnqueueBackground queues a function for background execution.
|
||||||
|
func (h *HostFunctions) EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error) {
|
||||||
|
// This will be implemented when JobManager is integrated
|
||||||
|
// For now, return an error indicating it's not yet available
|
||||||
|
return "", &HostFunctionError{Function: "enqueue_background", Cause: fmt.Errorf("background jobs not yet implemented")}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduleOnce schedules a function to run once at a specific time.
|
||||||
|
func (h *HostFunctions) ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error) {
|
||||||
|
// This will be implemented when Scheduler is integrated
|
||||||
|
return "", &HostFunctionError{Function: "schedule_once", Cause: fmt.Errorf("timers not yet implemented")}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Logging Operations
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// LogInfo logs an info message.
|
||||||
|
func (h *HostFunctions) LogInfo(ctx context.Context, message string) {
|
||||||
|
h.logsLock.Lock()
|
||||||
|
defer h.logsLock.Unlock()
|
||||||
|
|
||||||
|
h.logs = append(h.logs, LogEntry{
|
||||||
|
Level: "info",
|
||||||
|
Message: message,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
h.logger.Info(message,
|
||||||
|
zap.String("request_id", h.GetRequestID(ctx)),
|
||||||
|
zap.String("level", "function"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogError logs an error message.
|
||||||
|
func (h *HostFunctions) LogError(ctx context.Context, message string) {
|
||||||
|
h.logsLock.Lock()
|
||||||
|
defer h.logsLock.Unlock()
|
||||||
|
|
||||||
|
h.logs = append(h.logs, LogEntry{
|
||||||
|
Level: "error",
|
||||||
|
Message: message,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
h.logger.Error(message,
|
||||||
|
zap.String("request_id", h.GetRequestID(ctx)),
|
||||||
|
zap.String("level", "function"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Secrets Manager Implementation (built-in)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// DBSecretsManager implements SecretsManager using the database.
|
||||||
|
type DBSecretsManager struct {
|
||||||
|
db rqlite.Client
|
||||||
|
encryptionKey []byte // 32-byte AES-256 key
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure DBSecretsManager implements SecretsManager.
|
||||||
|
var _ SecretsManager = (*DBSecretsManager)(nil)
|
||||||
|
|
||||||
|
// NewDBSecretsManager creates a secrets manager backed by the database.
|
||||||
|
func NewDBSecretsManager(db rqlite.Client, encryptionKeyHex string, logger *zap.Logger) (*DBSecretsManager, error) {
|
||||||
|
var key []byte
|
||||||
|
if encryptionKeyHex != "" {
|
||||||
|
var err error
|
||||||
|
key, err = hex.DecodeString(encryptionKeyHex)
|
||||||
|
if err != nil || len(key) != 32 {
|
||||||
|
return nil, fmt.Errorf("invalid encryption key: must be 32 bytes hex-encoded")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Generate a random key if none provided
|
||||||
|
key = make([]byte, 32)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate encryption key: %w", err)
|
||||||
|
}
|
||||||
|
logger.Warn("Generated random secrets encryption key - secrets will not persist across restarts")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DBSecretsManager{
|
||||||
|
db: db,
|
||||||
|
encryptionKey: key,
|
||||||
|
logger: logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores an encrypted secret.
|
||||||
|
func (s *DBSecretsManager) Set(ctx context.Context, namespace, name, value string) error {
|
||||||
|
encrypted, err := s.encrypt([]byte(value))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upsert the secret
|
||||||
|
query := `
|
||||||
|
INSERT INTO function_secrets (id, namespace, name, encrypted_value, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(namespace, name) DO UPDATE SET
|
||||||
|
encrypted_value = excluded.encrypted_value,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
id := fmt.Sprintf("%s:%s", namespace, name)
|
||||||
|
now := time.Now()
|
||||||
|
if _, err := s.db.Exec(ctx, query, id, namespace, name, encrypted, now, now); err != nil {
|
||||||
|
return fmt.Errorf("failed to save secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a decrypted secret.
|
||||||
|
func (s *DBSecretsManager) Get(ctx context.Context, namespace, name string) (string, error) {
|
||||||
|
query := `SELECT encrypted_value FROM function_secrets WHERE namespace = ? AND name = ?`
|
||||||
|
|
||||||
|
var rows []struct {
|
||||||
|
EncryptedValue []byte `db:"encrypted_value"`
|
||||||
|
}
|
||||||
|
if err := s.db.Query(ctx, &rows, query, namespace, name); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to query secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return "", ErrSecretNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := s.decrypt(rows[0].EncryptedValue)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decrypt secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(decrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all secret names for a namespace.
|
||||||
|
func (s *DBSecretsManager) List(ctx context.Context, namespace string) ([]string, error) {
|
||||||
|
query := `SELECT name FROM function_secrets WHERE namespace = ? ORDER BY name`
|
||||||
|
|
||||||
|
var rows []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
}
|
||||||
|
if err := s.db.Query(ctx, &rows, query, namespace); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
names[i] = row.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return names, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a secret.
|
||||||
|
func (s *DBSecretsManager) Delete(ctx context.Context, namespace, name string) error {
|
||||||
|
query := `DELETE FROM function_secrets WHERE namespace = ? AND name = ?`
|
||||||
|
|
||||||
|
result, err := s.db.Exec(ctx, query, namespace, name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, _ := result.RowsAffected()
|
||||||
|
if affected == 0 {
|
||||||
|
return ErrSecretNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encrypt encrypts data using AES-256-GCM.
|
||||||
|
func (s *DBSecretsManager) encrypt(plaintext []byte) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher(s.encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt decrypts data using AES-256-GCM.
|
||||||
|
func (s *DBSecretsManager) decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher(s.encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(ciphertext) < nonceSize {
|
||||||
|
return nil, fmt.Errorf("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||||
|
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
}
|
||||||
45
pkg/serverless/hostfuncs_test.go
Normal file
45
pkg/serverless/hostfuncs_test.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHostFunctions_Cache(t *testing.T) {
|
||||||
|
db := NewMockRQLite()
|
||||||
|
ipfs := NewMockIPFSClient()
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// MockOlricClient needs to implement olriclib.Client
|
||||||
|
// For now, let's just test other host functions if Olric is hard to mock
|
||||||
|
|
||||||
|
h := NewHostFunctions(db, nil, ipfs, nil, nil, nil, HostFunctionsConfig{}, logger)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
h.SetInvocationContext(&InvocationContext{
|
||||||
|
RequestID: "req-1",
|
||||||
|
Namespace: "ns-1",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Logging
|
||||||
|
h.LogInfo(ctx, "hello world")
|
||||||
|
logs := h.GetLogs()
|
||||||
|
if len(logs) != 1 || logs[0].Message != "hello world" {
|
||||||
|
t.Errorf("unexpected logs: %+v", logs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Storage
|
||||||
|
cid, err := h.StoragePut(ctx, []byte("data"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StoragePut failed: %v", err)
|
||||||
|
}
|
||||||
|
data, err := h.StorageGet(ctx, cid)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StorageGet failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "data" {
|
||||||
|
t.Errorf("expected 'data', got %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
452
pkg/serverless/invoke.go
Normal file
452
pkg/serverless/invoke.go
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Invoker handles function invocation with retry logic and DLQ support.
|
||||||
|
// It wraps the Engine to provide higher-level invocation semantics.
|
||||||
|
type Invoker struct {
|
||||||
|
engine *Engine
|
||||||
|
registry FunctionRegistry
|
||||||
|
hostServices HostServices
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvoker creates a new function invoker.
|
||||||
|
func NewInvoker(engine *Engine, registry FunctionRegistry, hostServices HostServices, logger *zap.Logger) *Invoker {
|
||||||
|
return &Invoker{
|
||||||
|
engine: engine,
|
||||||
|
registry: registry,
|
||||||
|
hostServices: hostServices,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvokeRequest contains the parameters for invoking a function.
|
||||||
|
type InvokeRequest struct {
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
FunctionName string `json:"function_name"`
|
||||||
|
Version int `json:"version,omitempty"` // 0 = latest
|
||||||
|
Input []byte `json:"input"`
|
||||||
|
TriggerType TriggerType `json:"trigger_type"`
|
||||||
|
CallerWallet string `json:"caller_wallet,omitempty"`
|
||||||
|
WSClientID string `json:"ws_client_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvokeResponse contains the result of a function invocation.
|
||||||
|
type InvokeResponse struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Output []byte `json:"output,omitempty"`
|
||||||
|
Status InvocationStatus `json:"status"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
DurationMS int64 `json:"duration_ms"`
|
||||||
|
Retries int `json:"retries,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke executes a function with automatic retry logic.
|
||||||
|
func (i *Invoker) Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, &ValidationError{Field: "request", Message: "cannot be nil"}
|
||||||
|
}
|
||||||
|
if req.FunctionName == "" {
|
||||||
|
return nil, &ValidationError{Field: "function_name", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if req.Namespace == "" {
|
||||||
|
return nil, &ValidationError{Field: "namespace", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := uuid.New().String()
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Get function from registry
|
||||||
|
fn, err := i.registry.Get(ctx, req.Namespace, req.FunctionName, req.Version)
|
||||||
|
if err != nil {
|
||||||
|
return &InvokeResponse{
|
||||||
|
RequestID: requestID,
|
||||||
|
Status: InvocationStatusError,
|
||||||
|
Error: err.Error(),
|
||||||
|
DurationMS: time.Since(startTime).Milliseconds(),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check authorization
|
||||||
|
authorized, err := i.CanInvoke(ctx, req.Namespace, req.FunctionName, req.CallerWallet)
|
||||||
|
if err != nil || !authorized {
|
||||||
|
return &InvokeResponse{
|
||||||
|
RequestID: requestID,
|
||||||
|
Status: InvocationStatusError,
|
||||||
|
Error: "unauthorized",
|
||||||
|
DurationMS: time.Since(startTime).Milliseconds(),
|
||||||
|
}, ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get environment variables
|
||||||
|
envVars, err := i.getEnvVars(ctx, fn.ID)
|
||||||
|
if err != nil {
|
||||||
|
i.logger.Warn("Failed to get env vars", zap.Error(err))
|
||||||
|
envVars = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build invocation context
|
||||||
|
invCtx := &InvocationContext{
|
||||||
|
RequestID: requestID,
|
||||||
|
FunctionID: fn.ID,
|
||||||
|
FunctionName: fn.Name,
|
||||||
|
Namespace: fn.Namespace,
|
||||||
|
CallerWallet: req.CallerWallet,
|
||||||
|
TriggerType: req.TriggerType,
|
||||||
|
WSClientID: req.WSClientID,
|
||||||
|
EnvVars: envVars,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute with retry logic
|
||||||
|
output, retries, err := i.executeWithRetry(ctx, fn, req.Input, invCtx)
|
||||||
|
|
||||||
|
response := &InvokeResponse{
|
||||||
|
RequestID: requestID,
|
||||||
|
Output: output,
|
||||||
|
DurationMS: time.Since(startTime).Milliseconds(),
|
||||||
|
Retries: retries,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response.Status = InvocationStatusError
|
||||||
|
response.Error = err.Error()
|
||||||
|
|
||||||
|
// Check if it's a timeout
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
response.Status = InvocationStatusTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Status = InvocationStatusSuccess
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvokeByID invokes a function by its ID.
|
||||||
|
func (i *Invoker) InvokeByID(ctx context.Context, functionID string, input []byte, invCtx *InvocationContext) (*InvokeResponse, error) {
|
||||||
|
// Get function from registry by ID
|
||||||
|
fn, err := i.getByID(ctx, functionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if invCtx == nil {
|
||||||
|
invCtx = &InvocationContext{
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
FunctionID: fn.ID,
|
||||||
|
FunctionName: fn.Name,
|
||||||
|
Namespace: fn.Namespace,
|
||||||
|
TriggerType: TriggerTypeHTTP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
output, retries, err := i.executeWithRetry(ctx, fn, input, invCtx)
|
||||||
|
|
||||||
|
response := &InvokeResponse{
|
||||||
|
RequestID: invCtx.RequestID,
|
||||||
|
Output: output,
|
||||||
|
DurationMS: time.Since(startTime).Milliseconds(),
|
||||||
|
Retries: retries,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response.Status = InvocationStatusError
|
||||||
|
response.Error = err.Error()
|
||||||
|
return response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Status = InvocationStatusSuccess
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCache removes a compiled module from the engine's cache.
|
||||||
|
func (i *Invoker) InvalidateCache(wasmCID string) {
|
||||||
|
i.engine.Invalidate(wasmCID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeWithRetry executes a function with retry logic and DLQ.
|
||||||
|
func (i *Invoker) executeWithRetry(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, int, error) {
|
||||||
|
var lastErr error
|
||||||
|
var output []byte
|
||||||
|
|
||||||
|
maxAttempts := fn.RetryCount + 1 // Initial attempt + retries
|
||||||
|
if maxAttempts < 1 {
|
||||||
|
maxAttempts = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
// Check if context is cancelled
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, attempt, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the function
|
||||||
|
output, lastErr = i.engine.Execute(ctx, fn, input, invCtx)
|
||||||
|
if lastErr == nil {
|
||||||
|
return output, attempt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
i.logger.Warn("Function execution failed",
|
||||||
|
zap.String("function", fn.Name),
|
||||||
|
zap.String("request_id", invCtx.RequestID),
|
||||||
|
zap.Int("attempt", attempt+1),
|
||||||
|
zap.Int("max_attempts", maxAttempts),
|
||||||
|
zap.Error(lastErr),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Don't retry on certain errors
|
||||||
|
if !i.isRetryable(lastErr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't wait after the last attempt
|
||||||
|
if attempt < maxAttempts-1 {
|
||||||
|
delay := i.calculateBackoff(fn.RetryDelaySeconds, attempt)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, attempt + 1, ctx.Err()
|
||||||
|
case <-time.After(delay):
|
||||||
|
// Continue to next attempt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All retries exhausted - send to DLQ if configured
|
||||||
|
if fn.DLQTopic != "" {
|
||||||
|
i.sendToDLQ(ctx, fn, input, invCtx, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, maxAttempts - 1, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRetryable determines if an error should trigger a retry.
|
||||||
|
func (i *Invoker) isRetryable(err error) bool {
|
||||||
|
// Don't retry validation errors or not-found errors
|
||||||
|
if IsNotFound(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't retry resource exhaustion (rate limits, memory)
|
||||||
|
if IsResourceExhausted(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry service unavailable errors
|
||||||
|
if IsServiceUnavailable(err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry execution errors (could be transient)
|
||||||
|
var execErr *ExecutionError
|
||||||
|
if ok := errorAs(err, &execErr); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to retryable for unknown errors
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff calculates the delay before the next retry attempt.
|
||||||
|
// Uses exponential backoff with jitter.
|
||||||
|
func (i *Invoker) calculateBackoff(baseDelaySeconds, attempt int) time.Duration {
|
||||||
|
if baseDelaySeconds <= 0 {
|
||||||
|
baseDelaySeconds = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff: delay * 2^attempt
|
||||||
|
delay := time.Duration(baseDelaySeconds) * time.Second
|
||||||
|
for j := 0; j < attempt; j++ {
|
||||||
|
delay *= 2
|
||||||
|
if delay > 5*time.Minute {
|
||||||
|
delay = 5 * time.Minute
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendToDLQ sends a failed invocation to the dead letter queue.
|
||||||
|
func (i *Invoker) sendToDLQ(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext, err error) {
|
||||||
|
dlqMessage := DLQMessage{
|
||||||
|
FunctionID: fn.ID,
|
||||||
|
FunctionName: fn.Name,
|
||||||
|
Namespace: fn.Namespace,
|
||||||
|
RequestID: invCtx.RequestID,
|
||||||
|
Input: input,
|
||||||
|
Error: err.Error(),
|
||||||
|
FailedAt: time.Now(),
|
||||||
|
TriggerType: invCtx.TriggerType,
|
||||||
|
CallerWallet: invCtx.CallerWallet,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, marshalErr := json.Marshal(dlqMessage)
|
||||||
|
if marshalErr != nil {
|
||||||
|
i.logger.Error("Failed to marshal DLQ message",
|
||||||
|
zap.Error(marshalErr),
|
||||||
|
zap.String("function", fn.Name),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish to DLQ topic via host services
|
||||||
|
if err := i.hostServices.PubSubPublish(ctx, fn.DLQTopic, data); err != nil {
|
||||||
|
i.logger.Error("Failed to send to DLQ",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("function", fn.Name),
|
||||||
|
zap.String("dlq_topic", fn.DLQTopic),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
i.logger.Info("Sent failed invocation to DLQ",
|
||||||
|
zap.String("function", fn.Name),
|
||||||
|
zap.String("dlq_topic", fn.DLQTopic),
|
||||||
|
zap.String("request_id", invCtx.RequestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvVars retrieves environment variables for a function.
|
||||||
|
func (i *Invoker) getEnvVars(ctx context.Context, functionID string) (map[string]string, error) {
|
||||||
|
// Type assert to get extended registry methods
|
||||||
|
if reg, ok := i.registry.(*Registry); ok {
|
||||||
|
return reg.GetEnvVars(ctx, functionID)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getByID retrieves a function by ID.
|
||||||
|
func (i *Invoker) getByID(ctx context.Context, functionID string) (*Function, error) {
|
||||||
|
// Type assert to get extended registry methods
|
||||||
|
if reg, ok := i.registry.(*Registry); ok {
|
||||||
|
return reg.GetByID(ctx, functionID)
|
||||||
|
}
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// DLQMessage represents a message sent to the dead letter queue.
|
||||||
|
type DLQMessage struct {
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
FunctionName string `json:"function_name"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Input []byte `json:"input"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
FailedAt time.Time `json:"failed_at"`
|
||||||
|
TriggerType TriggerType `json:"trigger_type"`
|
||||||
|
CallerWallet string `json:"caller_wallet,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorAs is a helper to avoid import of errors package.
|
||||||
|
func errorAs(err error, target interface{}) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Simple type assertion for our custom error types
|
||||||
|
switch t := target.(type) {
|
||||||
|
case **ExecutionError:
|
||||||
|
if e, ok := err.(*ExecutionError); ok {
|
||||||
|
*t = e
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Batch Invocation (for future use)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// BatchInvokeRequest contains parameters for batch invocation.
|
||||||
|
type BatchInvokeRequest struct {
|
||||||
|
Requests []*InvokeRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchInvokeResponse contains results of batch invocation.
|
||||||
|
type BatchInvokeResponse struct {
|
||||||
|
Responses []*InvokeResponse `json:"responses"`
|
||||||
|
Duration time.Duration `json:"duration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchInvoke executes multiple functions in parallel.
|
||||||
|
func (i *Invoker) BatchInvoke(ctx context.Context, req *BatchInvokeRequest) (*BatchInvokeResponse, error) {
|
||||||
|
if req == nil || len(req.Requests) == 0 {
|
||||||
|
return nil, &ValidationError{Field: "requests", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
responses := make([]*InvokeResponse, len(req.Requests))
|
||||||
|
|
||||||
|
// For simplicity, execute sequentially for now
|
||||||
|
// TODO: Implement parallel execution with goroutines and semaphore
|
||||||
|
for idx, invReq := range req.Requests {
|
||||||
|
resp, err := i.Invoke(ctx, invReq)
|
||||||
|
if err != nil && resp == nil {
|
||||||
|
responses[idx] = &InvokeResponse{
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Status: InvocationStatusError,
|
||||||
|
Error: err.Error(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
responses[idx] = resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &BatchInvokeResponse{
|
||||||
|
Responses: responses,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Public Invocation Helpers
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// CanInvoke checks if a caller is authorized to invoke a function.
|
||||||
|
func (i *Invoker) CanInvoke(ctx context.Context, namespace, functionName string, callerWallet string) (bool, error) {
|
||||||
|
fn, err := i.registry.Get(ctx, namespace, functionName, 0)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Public functions can be invoked by anyone
|
||||||
|
if fn.IsPublic {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-public functions require the caller to be in the same namespace
|
||||||
|
// (simplified authorization - can be extended)
|
||||||
|
if callerWallet == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now, just check if caller wallet matches namespace
|
||||||
|
// In production, you'd check group membership, roles, etc.
|
||||||
|
return callerWallet == namespace || fn.CreatedBy == callerWallet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFunctionInfo returns basic info about a function for invocation.
|
||||||
|
func (i *Invoker) GetFunctionInfo(ctx context.Context, namespace, functionName string, version int) (*Function, error) {
|
||||||
|
return i.registry.Get(ctx, namespace, functionName, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateInput performs basic input validation.
|
||||||
|
func (i *Invoker) ValidateInput(input []byte, maxSize int) error {
|
||||||
|
if maxSize > 0 && len(input) > maxSize {
|
||||||
|
return &ValidationError{
|
||||||
|
Field: "input",
|
||||||
|
Message: fmt.Sprintf("exceeds maximum size of %d bytes", maxSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
434
pkg/serverless/mocks_test.go
Normal file
434
pkg/serverless/mocks_test.go
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRegistry is a mock implementation of FunctionRegistry
|
||||||
|
type MockRegistry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
functions map[string]*Function
|
||||||
|
wasm map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockRegistry() *MockRegistry {
|
||||||
|
return &MockRegistry{
|
||||||
|
functions: make(map[string]*Function),
|
||||||
|
wasm: make(map[string][]byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
id := fn.Namespace + "/" + fn.Name
|
||||||
|
wasmCID := "cid-" + id
|
||||||
|
oldFn := m.functions[id]
|
||||||
|
m.functions[id] = &Function{
|
||||||
|
ID: id,
|
||||||
|
Name: fn.Name,
|
||||||
|
Namespace: fn.Namespace,
|
||||||
|
WASMCID: wasmCID,
|
||||||
|
MemoryLimitMB: fn.MemoryLimitMB,
|
||||||
|
TimeoutSeconds: fn.TimeoutSeconds,
|
||||||
|
IsPublic: fn.IsPublic,
|
||||||
|
RetryCount: fn.RetryCount,
|
||||||
|
RetryDelaySeconds: fn.RetryDelaySeconds,
|
||||||
|
Status: FunctionStatusActive,
|
||||||
|
}
|
||||||
|
m.wasm[wasmCID] = wasmBytes
|
||||||
|
return oldFn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) Get(ctx context.Context, namespace, name string, version int) (*Function, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
fn, ok := m.functions[namespace+"/"+name]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
return fn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) List(ctx context.Context, namespace string) ([]*Function, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
var res []*Function
|
||||||
|
for _, fn := range m.functions {
|
||||||
|
if fn.Namespace == namespace {
|
||||||
|
res = append(res, fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) Delete(ctx context.Context, namespace, name string, version int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.functions, namespace+"/"+name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
data, ok := m.wasm[wasmCID]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRegistry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error) {
|
||||||
|
return []LogEntry{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockHostServices is a mock implementation of HostServices
|
||||||
|
type MockHostServices struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[string][]byte
|
||||||
|
storage map[string][]byte
|
||||||
|
logs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockHostServices() *MockHostServices {
|
||||||
|
return &MockHostServices{
|
||||||
|
cache: make(map[string][]byte),
|
||||||
|
storage: make(map[string][]byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error) {
|
||||||
|
return []byte("[]"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) DBExecute(ctx context.Context, query string, args []interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) CacheGet(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.cache[key], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) CacheSet(ctx context.Context, key string, value []byte, ttl int64) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.cache[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) CacheDelete(ctx context.Context, key string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.cache, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) CacheIncr(ctx context.Context, key string) (int64, error) {
|
||||||
|
return m.CacheIncrBy(ctx, key, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var currentValue int64
|
||||||
|
if val, ok := m.cache[key]; ok {
|
||||||
|
var err error
|
||||||
|
currentValue, err = parseInt64FromBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("value is not numeric")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newValue := currentValue + delta
|
||||||
|
m.cache[key] = []byte(fmt.Sprintf("%d", newValue))
|
||||||
|
return newValue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) StoragePut(ctx context.Context, data []byte) (string, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
cid := "cid-" + time.Now().String()
|
||||||
|
m.storage[cid] = data
|
||||||
|
return cid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) StorageGet(ctx context.Context, cid string) ([]byte, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.storage[cid], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) WSBroadcast(ctx context.Context, topic string, data []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) GetEnv(ctx context.Context, key string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) GetSecret(ctx context.Context, name string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) GetRequestID(ctx context.Context) string {
|
||||||
|
return "req-123"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) GetCallerWallet(ctx context.Context) string {
|
||||||
|
return "wallet-123"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error) {
|
||||||
|
return "job-123", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error) {
|
||||||
|
return "timer-123", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) LogInfo(ctx context.Context, message string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.logs = append(m.logs, "INFO: "+message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHostServices) LogError(ctx context.Context, message string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.logs = append(m.logs, "ERROR: "+message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockIPFSClient is a mock for ipfs.IPFSClient
|
||||||
|
type MockIPFSClient struct {
|
||||||
|
data map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockIPFSClient() *MockIPFSClient {
|
||||||
|
return &MockIPFSClient{data: make(map[string][]byte)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPFSClient) Add(ctx context.Context, reader io.Reader, filename string) (*ipfs.AddResponse, error) {
|
||||||
|
data, _ := io.ReadAll(reader)
|
||||||
|
cid := "cid-" + filename
|
||||||
|
m.data[cid] = data
|
||||||
|
return &ipfs.AddResponse{Cid: cid, Name: filename}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPFSClient) Pin(ctx context.Context, cid string, name string, replicationFactor int) (*ipfs.PinResponse, error) {
|
||||||
|
return &ipfs.PinResponse{Cid: cid, Name: name}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPFSClient) PinStatus(ctx context.Context, cid string) (*ipfs.PinStatus, error) {
|
||||||
|
return &ipfs.PinStatus{Cid: cid, Status: "pinned"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPFSClient) Get(ctx context.Context, cid, apiURL string) (io.ReadCloser, error) {
|
||||||
|
data, ok := m.data[cid]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("not found")
|
||||||
|
}
|
||||||
|
return io.NopCloser(strings.NewReader(string(data))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPFSClient) Unpin(ctx context.Context, cid string) error { return nil }
|
||||||
|
func (m *MockIPFSClient) Health(ctx context.Context) error { return nil }
|
||||||
|
func (m *MockIPFSClient) GetPeerCount(ctx context.Context) (int, error) { return 1, nil }
|
||||||
|
func (m *MockIPFSClient) Close(ctx context.Context) error { return nil }
|
||||||
|
|
||||||
|
// MockRQLite is a mock implementation of rqlite.Client
|
||||||
|
type MockRQLite struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tables map[string][]map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockRQLite() *MockRQLite {
|
||||||
|
return &MockRQLite{
|
||||||
|
tables: make(map[string][]map[string]any),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRQLite) Query(ctx context.Context, dest any, query string, args ...any) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Very limited mock query logic for scanning into structs
|
||||||
|
if strings.Contains(query, "FROM functions") {
|
||||||
|
rows := m.tables["functions"]
|
||||||
|
filtered := rows
|
||||||
|
if strings.Contains(query, "namespace = ? AND name = ?") {
|
||||||
|
ns := args[0].(string)
|
||||||
|
name := args[1].(string)
|
||||||
|
filtered = nil
|
||||||
|
for _, r := range rows {
|
||||||
|
if r["namespace"] == ns && r["name"] == name {
|
||||||
|
filtered = append(filtered, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
destVal := reflect.ValueOf(dest).Elem()
|
||||||
|
if destVal.Kind() == reflect.Slice {
|
||||||
|
elemType := destVal.Type().Elem()
|
||||||
|
for _, r := range filtered {
|
||||||
|
newElem := reflect.New(elemType).Elem()
|
||||||
|
// This is a simplified mapping
|
||||||
|
if f := newElem.FieldByName("ID"); f.IsValid() {
|
||||||
|
f.SetString(r["id"].(string))
|
||||||
|
}
|
||||||
|
if f := newElem.FieldByName("Name"); f.IsValid() {
|
||||||
|
f.SetString(r["name"].(string))
|
||||||
|
}
|
||||||
|
if f := newElem.FieldByName("Namespace"); f.IsValid() {
|
||||||
|
f.SetString(r["namespace"].(string))
|
||||||
|
}
|
||||||
|
destVal.Set(reflect.Append(destVal, newElem))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRQLite) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return &mockResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRQLite) FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...rqlite.FindOption) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *MockRQLite) FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...rqlite.FindOption) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *MockRQLite) Save(ctx context.Context, entity any) error { return nil }
|
||||||
|
func (m *MockRQLite) Remove(ctx context.Context, entity any) error { return nil }
|
||||||
|
func (m *MockRQLite) Repository(table string) any { return nil }
|
||||||
|
|
||||||
|
func (m *MockRQLite) CreateQueryBuilder(table string) *rqlite.QueryBuilder {
|
||||||
|
return nil // Should return a valid QueryBuilder if needed by tests
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRQLite) Tx(ctx context.Context, fn func(tx rqlite.Tx) error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockResult struct{}
|
||||||
|
|
||||||
|
func (m *mockResult) LastInsertId() (int64, error) { return 1, nil }
|
||||||
|
func (m *mockResult) RowsAffected() (int64, error) { return 1, nil }
|
||||||
|
|
||||||
|
// MockOlricClient is a mock for olriclib.Client
|
||||||
|
type MockOlricClient struct {
|
||||||
|
dmaps map[string]*MockDMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockOlricClient() *MockOlricClient {
|
||||||
|
return &MockOlricClient{dmaps: make(map[string]*MockDMap)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOlricClient) NewDMap(name string) (any, error) {
|
||||||
|
if dm, ok := m.dmaps[name]; ok {
|
||||||
|
return dm, nil
|
||||||
|
}
|
||||||
|
dm := &MockDMap{data: make(map[string][]byte)}
|
||||||
|
m.dmaps[name] = dm
|
||||||
|
return dm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOlricClient) Close(ctx context.Context) error { return nil }
|
||||||
|
func (m *MockOlricClient) Stats(ctx context.Context, s string) ([]byte, error) { return nil, nil }
|
||||||
|
func (m *MockOlricClient) Ping(ctx context.Context, s string) error { return nil }
|
||||||
|
func (m *MockOlricClient) RoutingTable(ctx context.Context) (map[uint64][]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDMap is a mock for olriclib.DMap
|
||||||
|
type MockDMap struct {
|
||||||
|
data map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDMap) Get(ctx context.Context, key string) (any, error) {
|
||||||
|
val, ok := m.data[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("not found")
|
||||||
|
}
|
||||||
|
return &MockGetResponse{val: val}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDMap) Put(ctx context.Context, key string, value any) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []byte:
|
||||||
|
m.data[key] = v
|
||||||
|
case string:
|
||||||
|
m.data[key] = []byte(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDMap) Delete(ctx context.Context, key string) (bool, error) {
|
||||||
|
_, ok := m.data[key]
|
||||||
|
delete(m.data, key)
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDMap) Incr(ctx context.Context, key string, delta int64) (int64, error) {
|
||||||
|
var currentValue int64
|
||||||
|
|
||||||
|
// Get current value if it exists
|
||||||
|
if val, ok := m.data[key]; ok {
|
||||||
|
// Try to parse as int64
|
||||||
|
var err error
|
||||||
|
currentValue, err = parseInt64FromBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("value is not numeric")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment
|
||||||
|
newValue := currentValue + delta
|
||||||
|
|
||||||
|
// Store the new value
|
||||||
|
m.data[key] = []byte(fmt.Sprintf("%d", newValue))
|
||||||
|
|
||||||
|
return newValue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInt64FromBytes parses an int64 from byte slice
|
||||||
|
func parseInt64FromBytes(data []byte) (int64, error) {
|
||||||
|
var val int64
|
||||||
|
_, err := fmt.Sscanf(string(data), "%d", &val)
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockGetResponse struct {
|
||||||
|
val []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockGetResponse) Byte() ([]byte, error) { return m.val, nil }
|
||||||
|
func (m *MockGetResponse) String() (string, error) { return string(m.val), nil }
|
||||||
561
pkg/serverless/registry.go
Normal file
561
pkg/serverless/registry.go
Normal file
@ -0,0 +1,561 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure Registry implements FunctionRegistry and InvocationLogger interfaces.
|
||||||
|
var _ FunctionRegistry = (*Registry)(nil)
|
||||||
|
var _ InvocationLogger = (*Registry)(nil)
|
||||||
|
|
||||||
|
// Registry manages function metadata in RQLite and bytecode in IPFS.
|
||||||
|
// It implements the FunctionRegistry interface.
|
||||||
|
type Registry struct {
|
||||||
|
db rqlite.Client
|
||||||
|
ipfs ipfs.IPFSClient
|
||||||
|
ipfsAPIURL string
|
||||||
|
logger *zap.Logger
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegistryConfig holds configuration for the Registry.
|
||||||
|
type RegistryConfig struct {
|
||||||
|
IPFSAPIURL string // IPFS API URL for content retrieval
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates a new function registry.
|
||||||
|
func NewRegistry(db rqlite.Client, ipfsClient ipfs.IPFSClient, cfg RegistryConfig, logger *zap.Logger) *Registry {
|
||||||
|
return &Registry{
|
||||||
|
db: db,
|
||||||
|
ipfs: ipfsClient,
|
||||||
|
ipfsAPIURL: cfg.IPFSAPIURL,
|
||||||
|
logger: logger,
|
||||||
|
tableName: "functions",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register deploys a new function or updates an existing one.
|
||||||
|
func (r *Registry) Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error) {
|
||||||
|
if fn == nil {
|
||||||
|
return nil, &ValidationError{Field: "definition", Message: "cannot be nil"}
|
||||||
|
}
|
||||||
|
fn.Name = strings.TrimSpace(fn.Name)
|
||||||
|
fn.Namespace = strings.TrimSpace(fn.Namespace)
|
||||||
|
|
||||||
|
if fn.Name == "" {
|
||||||
|
return nil, &ValidationError{Field: "name", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if fn.Namespace == "" {
|
||||||
|
return nil, &ValidationError{Field: "namespace", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if len(wasmBytes) == 0 {
|
||||||
|
return nil, &ValidationError{Field: "wasmBytes", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if function already exists (regardless of status) to get old metadata for invalidation
|
||||||
|
oldFn, err := r.getByNameInternal(ctx, fn.Namespace, fn.Name)
|
||||||
|
if err != nil && err != ErrFunctionNotFound {
|
||||||
|
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload WASM to IPFS
|
||||||
|
wasmCID, err := r.uploadWASM(ctx, wasmBytes, fn.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
memoryLimit := fn.MemoryLimitMB
|
||||||
|
if memoryLimit == 0 {
|
||||||
|
memoryLimit = 64
|
||||||
|
}
|
||||||
|
timeout := fn.TimeoutSeconds
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 30
|
||||||
|
}
|
||||||
|
retryDelay := fn.RetryDelaySeconds
|
||||||
|
if retryDelay == 0 {
|
||||||
|
retryDelay = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
id := uuid.New().String()
|
||||||
|
version := 1
|
||||||
|
|
||||||
|
if oldFn != nil {
|
||||||
|
// Use existing ID and increment version
|
||||||
|
id = oldFn.ID
|
||||||
|
version = oldFn.Version + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use INSERT OR REPLACE to ensure we never hit UNIQUE constraint failures on (namespace, name).
|
||||||
|
// This handles both new registrations and overwriting existing (even inactive) functions.
|
||||||
|
query := `
|
||||||
|
INSERT OR REPLACE INTO functions (
|
||||||
|
id, name, namespace, version, wasm_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err = r.db.Exec(ctx, query,
|
||||||
|
id, fn.Name, fn.Namespace, version, wasmCID,
|
||||||
|
memoryLimit, timeout, fn.IsPublic,
|
||||||
|
fn.RetryCount, retryDelay, fn.DLQTopic,
|
||||||
|
string(FunctionStatusActive), now, now, fn.Namespace,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &DeployError{FunctionName: fn.Name, Cause: fmt.Errorf("failed to register function: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save environment variables
|
||||||
|
if err := r.saveEnvVars(ctx, id, fn.EnvVars); err != nil {
|
||||||
|
return nil, &DeployError{FunctionName: fn.Name, Cause: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Function registered",
|
||||||
|
zap.String("id", id),
|
||||||
|
zap.String("name", fn.Name),
|
||||||
|
zap.String("namespace", fn.Namespace),
|
||||||
|
zap.String("wasm_cid", wasmCID),
|
||||||
|
zap.Int("version", version),
|
||||||
|
zap.Bool("updated", oldFn != nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
return oldFn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a function by name and optional version.
|
||||||
|
// If version is 0, returns the latest version.
|
||||||
|
func (r *Registry) Get(ctx context.Context, namespace, name string, version int) (*Function, error) {
|
||||||
|
namespace = strings.TrimSpace(namespace)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
|
var query string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if version == 0 {
|
||||||
|
// Get latest version
|
||||||
|
query = `
|
||||||
|
SELECT id, name, namespace, version, wasm_cid, source_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
FROM functions
|
||||||
|
WHERE namespace = ? AND name = ? AND status = ?
|
||||||
|
ORDER BY version DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
args = []interface{}{namespace, name, string(FunctionStatusActive)}
|
||||||
|
} else {
|
||||||
|
query = `
|
||||||
|
SELECT id, name, namespace, version, wasm_cid, source_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
FROM functions
|
||||||
|
WHERE namespace = ? AND name = ? AND version = ?
|
||||||
|
`
|
||||||
|
args = []interface{}{namespace, name, version}
|
||||||
|
}
|
||||||
|
|
||||||
|
var functions []functionRow
|
||||||
|
if err := r.db.Query(ctx, &functions, query, args...); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query function: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(functions) == 0 {
|
||||||
|
if version == 0 {
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
return nil, ErrVersionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.rowToFunction(&functions[0]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all functions for a namespace.
|
||||||
|
func (r *Registry) List(ctx context.Context, namespace string) ([]*Function, error) {
|
||||||
|
// Get latest version of each function in the namespace
|
||||||
|
query := `
|
||||||
|
SELECT f.id, f.name, f.namespace, f.version, f.wasm_cid, f.source_cid,
|
||||||
|
f.memory_limit_mb, f.timeout_seconds, f.is_public,
|
||||||
|
f.retry_count, f.retry_delay_seconds, f.dlq_topic,
|
||||||
|
f.status, f.created_at, f.updated_at, f.created_by
|
||||||
|
FROM functions f
|
||||||
|
INNER JOIN (
|
||||||
|
SELECT namespace, name, MAX(version) as max_version
|
||||||
|
FROM functions
|
||||||
|
WHERE namespace = ? AND status = ?
|
||||||
|
GROUP BY namespace, name
|
||||||
|
) latest ON f.namespace = latest.namespace
|
||||||
|
AND f.name = latest.name
|
||||||
|
AND f.version = latest.max_version
|
||||||
|
ORDER BY f.name
|
||||||
|
`
|
||||||
|
|
||||||
|
var rows []functionRow
|
||||||
|
if err := r.db.Query(ctx, &rows, query, namespace, string(FunctionStatusActive)); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list functions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
functions := make([]*Function, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
functions[i] = r.rowToFunction(&row)
|
||||||
|
}
|
||||||
|
|
||||||
|
return functions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a function. If version is 0, removes all versions.
|
||||||
|
func (r *Registry) Delete(ctx context.Context, namespace, name string, version int) error {
|
||||||
|
namespace = strings.TrimSpace(namespace)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
|
var query string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if version == 0 {
|
||||||
|
// Mark all versions as inactive (soft delete)
|
||||||
|
query = `UPDATE functions SET status = ?, updated_at = ? WHERE namespace = ? AND name = ?`
|
||||||
|
args = []interface{}{string(FunctionStatusInactive), time.Now(), namespace, name}
|
||||||
|
} else {
|
||||||
|
query = `UPDATE functions SET status = ?, updated_at = ? WHERE namespace = ? AND name = ? AND version = ?`
|
||||||
|
args = []interface{}{string(FunctionStatusInactive), time.Now(), namespace, name, version}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := r.db.Exec(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete function: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, _ := result.RowsAffected()
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
if version == 0 {
|
||||||
|
return ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
return ErrVersionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Function deleted",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.Int("version", version),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWASMBytes retrieves the compiled WASM bytecode for a function.
|
||||||
|
func (r *Registry) GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error) {
|
||||||
|
if wasmCID == "" {
|
||||||
|
return nil, &ValidationError{Field: "wasmCID", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := r.ipfs.Get(ctx, wasmCID, r.ipfsAPIURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get WASM from IPFS: %w", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read WASM data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvVars retrieves environment variables for a function.
|
||||||
|
func (r *Registry) GetEnvVars(ctx context.Context, functionID string) (map[string]string, error) {
|
||||||
|
query := `SELECT key, value FROM function_env_vars WHERE function_id = ?`
|
||||||
|
|
||||||
|
var rows []envVarRow
|
||||||
|
if err := r.db.Query(ctx, &rows, query, functionID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query env vars: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
envVars := make(map[string]string, len(rows))
|
||||||
|
for _, row := range rows {
|
||||||
|
envVars[row.Key] = row.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return envVars, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID retrieves a function by its ID.
|
||||||
|
func (r *Registry) GetByID(ctx context.Context, id string) (*Function, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, name, namespace, version, wasm_cid, source_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
FROM functions
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
var functions []functionRow
|
||||||
|
if err := r.db.Query(ctx, &functions, query, id); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query function: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(functions) == 0 {
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.rowToFunction(&functions[0]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVersions returns all versions of a function.
|
||||||
|
func (r *Registry) ListVersions(ctx context.Context, namespace, name string) ([]*Function, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, name, namespace, version, wasm_cid, source_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
FROM functions
|
||||||
|
WHERE namespace = ? AND name = ?
|
||||||
|
ORDER BY version DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
var rows []functionRow
|
||||||
|
if err := r.db.Query(ctx, &rows, query, namespace, name); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list versions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
functions := make([]*Function, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
functions[i] = r.rowToFunction(&row)
|
||||||
|
}
|
||||||
|
|
||||||
|
return functions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log records a function invocation and its logs to the database.
|
||||||
|
func (r *Registry) Log(ctx context.Context, inv *InvocationRecord) error {
|
||||||
|
if inv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert invocation record
|
||||||
|
invQuery := `
|
||||||
|
INSERT INTO function_invocations (
|
||||||
|
id, function_id, request_id, trigger_type, caller_wallet,
|
||||||
|
input_size, output_size, started_at, completed_at,
|
||||||
|
duration_ms, status, error_message, memory_used_mb
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err := r.db.Exec(ctx, invQuery,
|
||||||
|
inv.ID, inv.FunctionID, inv.RequestID, string(inv.TriggerType), inv.CallerWallet,
|
||||||
|
inv.InputSize, inv.OutputSize, inv.StartedAt, inv.CompletedAt,
|
||||||
|
inv.DurationMS, string(inv.Status), inv.ErrorMessage, inv.MemoryUsedMB,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert invocation record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert logs if any
|
||||||
|
if len(inv.Logs) > 0 {
|
||||||
|
for _, entry := range inv.Logs {
|
||||||
|
logID := uuid.New().String()
|
||||||
|
logQuery := `
|
||||||
|
INSERT INTO function_logs (
|
||||||
|
id, function_id, invocation_id, level, message, timestamp
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err := r.db.Exec(ctx, logQuery,
|
||||||
|
logID, inv.FunctionID, inv.ID, entry.Level, entry.Message, entry.Timestamp,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Warn("Failed to insert function log", zap.Error(err))
|
||||||
|
// Continue with other logs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogs retrieves logs for a function.
|
||||||
|
func (r *Registry) GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT l.level, l.message, l.timestamp
|
||||||
|
FROM function_logs l
|
||||||
|
JOIN functions f ON l.function_id = f.id
|
||||||
|
WHERE f.namespace = ? AND f.name = ?
|
||||||
|
ORDER BY l.timestamp DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
var results []struct {
|
||||||
|
Level string `db:"level"`
|
||||||
|
Message string `db:"message"`
|
||||||
|
Timestamp time.Time `db:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.db.Query(ctx, &results, query, namespace, name, limit); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query logs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := make([]LogEntry, len(results))
|
||||||
|
for i, res := range results {
|
||||||
|
logs[i] = LogEntry{
|
||||||
|
Level: res.Level,
|
||||||
|
Message: res.Message,
|
||||||
|
Timestamp: res.Timestamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Private helpers
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// uploadWASM uploads WASM bytecode to IPFS and returns the CID.
|
||||||
|
func (r *Registry) uploadWASM(ctx context.Context, wasmBytes []byte, name string) (string, error) {
|
||||||
|
reader := bytes.NewReader(wasmBytes)
|
||||||
|
resp, err := r.ipfs.Add(ctx, reader, name+".wasm")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload WASM to IPFS: %w", err)
|
||||||
|
}
|
||||||
|
return resp.Cid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLatestVersion returns the latest version number for a function.
|
||||||
|
func (r *Registry) getLatestVersion(ctx context.Context, namespace, name string) (int, error) {
|
||||||
|
query := `SELECT MAX(version) FROM functions WHERE namespace = ? AND name = ?`
|
||||||
|
|
||||||
|
var maxVersion sql.NullInt64
|
||||||
|
var results []struct {
|
||||||
|
MaxVersion sql.NullInt64 `db:"max(version)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.db.Query(ctx, &results, query, namespace, name); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 || !results[0].MaxVersion.Valid {
|
||||||
|
return 0, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
maxVersion = results[0].MaxVersion
|
||||||
|
return int(maxVersion.Int64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getByNameInternal retrieves a function by name regardless of status.
|
||||||
|
func (r *Registry) getByNameInternal(ctx context.Context, namespace, name string) (*Function, error) {
|
||||||
|
namespace = strings.TrimSpace(namespace)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT id, name, namespace, version, wasm_cid, source_cid,
|
||||||
|
memory_limit_mb, timeout_seconds, is_public,
|
||||||
|
retry_count, retry_delay_seconds, dlq_topic,
|
||||||
|
status, created_at, updated_at, created_by
|
||||||
|
FROM functions
|
||||||
|
WHERE namespace = ? AND name = ?
|
||||||
|
ORDER BY version DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
var functions []functionRow
|
||||||
|
if err := r.db.Query(ctx, &functions, query, namespace, name); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query function: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(functions) == 0 {
|
||||||
|
return nil, ErrFunctionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.rowToFunction(&functions[0]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveEnvVars saves environment variables for a function.
|
||||||
|
func (r *Registry) saveEnvVars(ctx context.Context, functionID string, envVars map[string]string) error {
|
||||||
|
// Clear existing env vars first
|
||||||
|
deleteQuery := `DELETE FROM function_env_vars WHERE function_id = ?`
|
||||||
|
if _, err := r.db.Exec(ctx, deleteQuery, functionID); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear existing env vars: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(envVars) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range envVars {
|
||||||
|
id := uuid.New().String()
|
||||||
|
query := `INSERT INTO function_env_vars (id, function_id, key, value, created_at) VALUES (?, ?, ?, ?, ?)`
|
||||||
|
if _, err := r.db.Exec(ctx, query, id, functionID, key, value, time.Now()); err != nil {
|
||||||
|
return fmt.Errorf("failed to save env var '%s': %w", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rowToFunction converts a database row to a Function struct.
|
||||||
|
func (r *Registry) rowToFunction(row *functionRow) *Function {
|
||||||
|
return &Function{
|
||||||
|
ID: row.ID,
|
||||||
|
Name: row.Name,
|
||||||
|
Namespace: row.Namespace,
|
||||||
|
Version: row.Version,
|
||||||
|
WASMCID: row.WASMCID,
|
||||||
|
SourceCID: row.SourceCID.String,
|
||||||
|
MemoryLimitMB: row.MemoryLimitMB,
|
||||||
|
TimeoutSeconds: row.TimeoutSeconds,
|
||||||
|
IsPublic: row.IsPublic,
|
||||||
|
RetryCount: row.RetryCount,
|
||||||
|
RetryDelaySeconds: row.RetryDelaySeconds,
|
||||||
|
DLQTopic: row.DLQTopic.String,
|
||||||
|
Status: FunctionStatus(row.Status),
|
||||||
|
CreatedAt: row.CreatedAt,
|
||||||
|
UpdatedAt: row.UpdatedAt,
|
||||||
|
CreatedBy: row.CreatedBy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Database row types (internal)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type functionRow struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Namespace string `db:"namespace"`
|
||||||
|
Version int `db:"version"`
|
||||||
|
WASMCID string `db:"wasm_cid"`
|
||||||
|
SourceCID sql.NullString `db:"source_cid"`
|
||||||
|
MemoryLimitMB int `db:"memory_limit_mb"`
|
||||||
|
TimeoutSeconds int `db:"timeout_seconds"`
|
||||||
|
IsPublic bool `db:"is_public"`
|
||||||
|
RetryCount int `db:"retry_count"`
|
||||||
|
RetryDelaySeconds int `db:"retry_delay_seconds"`
|
||||||
|
DLQTopic sql.NullString `db:"dlq_topic"`
|
||||||
|
Status string `db:"status"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
UpdatedAt time.Time `db:"updated_at"`
|
||||||
|
CreatedBy string `db:"created_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type envVarRow struct {
|
||||||
|
Key string `db:"key"`
|
||||||
|
Value string `db:"value"`
|
||||||
|
}
|
||||||
40
pkg/serverless/registry_test.go
Normal file
40
pkg/serverless/registry_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegistry_RegisterAndGet(t *testing.T) {
|
||||||
|
db := NewMockRQLite()
|
||||||
|
ipfs := NewMockIPFSClient()
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
registry := NewRegistry(db, ipfs, RegistryConfig{IPFSAPIURL: "http://localhost:5001"}, logger)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
fnDef := &FunctionDefinition{
|
||||||
|
Name: "test-func",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
IsPublic: true,
|
||||||
|
}
|
||||||
|
wasmBytes := []byte("mock wasm")
|
||||||
|
|
||||||
|
_, err := registry.Register(ctx, fnDef, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since MockRQLite doesn't fully implement Query scanning yet,
|
||||||
|
// we won't be able to test Get() effectively without more work.
|
||||||
|
// But we can check if wasm was uploaded.
|
||||||
|
wasm, err := registry.GetWASMBytes(ctx, "cid-test-func.wasm")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetWASMBytes failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(wasm) != "mock wasm" {
|
||||||
|
t.Errorf("expected 'mock wasm', got %q", string(wasm))
|
||||||
|
}
|
||||||
|
}
|
||||||
379
pkg/serverless/types.go
Normal file
379
pkg/serverless/types.go
Normal file
@ -0,0 +1,379 @@
|
|||||||
|
// Package serverless provides a WASM-based serverless function engine for the Orama Network.
|
||||||
|
// It enables users to deploy and execute Go functions (compiled to WASM) across all nodes,
|
||||||
|
// with support for HTTP/WebSocket triggers, cron jobs, database triggers, pub/sub triggers,
|
||||||
|
// one-time timers, retries with DLQ, and background jobs.
|
||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FunctionStatus represents the current state of a deployed function.
|
||||||
|
type FunctionStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FunctionStatusActive FunctionStatus = "active"
|
||||||
|
FunctionStatusInactive FunctionStatus = "inactive"
|
||||||
|
FunctionStatusError FunctionStatus = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TriggerType identifies the type of event that triggered a function invocation.
|
||||||
|
type TriggerType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TriggerTypeHTTP TriggerType = "http"
|
||||||
|
TriggerTypeWebSocket TriggerType = "websocket"
|
||||||
|
TriggerTypeCron TriggerType = "cron"
|
||||||
|
TriggerTypeDatabase TriggerType = "database"
|
||||||
|
TriggerTypePubSub TriggerType = "pubsub"
|
||||||
|
TriggerTypeTimer TriggerType = "timer"
|
||||||
|
TriggerTypeJob TriggerType = "job"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JobStatus represents the current state of a background job.
|
||||||
|
type JobStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
JobStatusPending JobStatus = "pending"
|
||||||
|
JobStatusRunning JobStatus = "running"
|
||||||
|
JobStatusCompleted JobStatus = "completed"
|
||||||
|
JobStatusFailed JobStatus = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InvocationStatus represents the result of a function invocation.
|
||||||
|
type InvocationStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvocationStatusSuccess InvocationStatus = "success"
|
||||||
|
InvocationStatusError InvocationStatus = "error"
|
||||||
|
InvocationStatusTimeout InvocationStatus = "timeout"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DBOperation represents the type of database operation that triggered a function.
|
||||||
|
type DBOperation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
DBOperationInsert DBOperation = "INSERT"
|
||||||
|
DBOperationUpdate DBOperation = "UPDATE"
|
||||||
|
DBOperationDelete DBOperation = "DELETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Core Interfaces (following Interface Segregation Principle)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// FunctionRegistry manages function metadata and bytecode storage.
|
||||||
|
// Responsible for CRUD operations on function definitions.
|
||||||
|
type FunctionRegistry interface {
|
||||||
|
// Register deploys a new function or updates an existing one.
|
||||||
|
// Returns the old function definition if it was updated, or nil if it was a new registration.
|
||||||
|
Register(ctx context.Context, fn *FunctionDefinition, wasmBytes []byte) (*Function, error)
|
||||||
|
|
||||||
|
// Get retrieves a function by name and optional version.
|
||||||
|
// If version is 0, returns the latest version.
|
||||||
|
Get(ctx context.Context, namespace, name string, version int) (*Function, error)
|
||||||
|
|
||||||
|
// List returns all functions for a namespace.
|
||||||
|
List(ctx context.Context, namespace string) ([]*Function, error)
|
||||||
|
|
||||||
|
// Delete removes a function. If version is 0, removes all versions.
|
||||||
|
Delete(ctx context.Context, namespace, name string, version int) error
|
||||||
|
|
||||||
|
// GetWASMBytes retrieves the compiled WASM bytecode for a function.
|
||||||
|
GetWASMBytes(ctx context.Context, wasmCID string) ([]byte, error)
|
||||||
|
|
||||||
|
// GetLogs retrieves logs for a function.
|
||||||
|
GetLogs(ctx context.Context, namespace, name string, limit int) ([]LogEntry, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionExecutor handles the actual execution of WASM functions.
|
||||||
|
type FunctionExecutor interface {
|
||||||
|
// Execute runs a function with the given input and returns the output.
|
||||||
|
Execute(ctx context.Context, fn *Function, input []byte, invCtx *InvocationContext) ([]byte, error)
|
||||||
|
|
||||||
|
// Precompile compiles a WASM module and caches it for faster execution.
|
||||||
|
Precompile(ctx context.Context, wasmCID string, wasmBytes []byte) error
|
||||||
|
|
||||||
|
// Invalidate removes a compiled module from the cache.
|
||||||
|
Invalidate(wasmCID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecretsManager handles secure storage and retrieval of secrets.
|
||||||
|
type SecretsManager interface {
|
||||||
|
// Set stores an encrypted secret.
|
||||||
|
Set(ctx context.Context, namespace, name, value string) error
|
||||||
|
|
||||||
|
// Get retrieves a decrypted secret.
|
||||||
|
Get(ctx context.Context, namespace, name string) (string, error)
|
||||||
|
|
||||||
|
// List returns all secret names for a namespace (not values).
|
||||||
|
List(ctx context.Context, namespace string) ([]string, error)
|
||||||
|
|
||||||
|
// Delete removes a secret.
|
||||||
|
Delete(ctx context.Context, namespace, name string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerManager manages function triggers (cron, database, pubsub, timer).
|
||||||
|
type TriggerManager interface {
|
||||||
|
// AddCronTrigger adds a cron-based trigger to a function.
|
||||||
|
AddCronTrigger(ctx context.Context, functionID, cronExpr string) error
|
||||||
|
|
||||||
|
// AddDBTrigger adds a database trigger to a function.
|
||||||
|
AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error
|
||||||
|
|
||||||
|
// AddPubSubTrigger adds a pubsub trigger to a function.
|
||||||
|
AddPubSubTrigger(ctx context.Context, functionID, topic string) error
|
||||||
|
|
||||||
|
// ScheduleOnce schedules a one-time execution.
|
||||||
|
ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error)
|
||||||
|
|
||||||
|
// RemoveTrigger removes a trigger by ID.
|
||||||
|
RemoveTrigger(ctx context.Context, triggerID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobManager manages background job execution.
|
||||||
|
type JobManager interface {
|
||||||
|
// Enqueue adds a job to the queue for background execution.
|
||||||
|
Enqueue(ctx context.Context, functionID string, payload []byte) (string, error)
|
||||||
|
|
||||||
|
// GetStatus retrieves the current status of a job.
|
||||||
|
GetStatus(ctx context.Context, jobID string) (*Job, error)
|
||||||
|
|
||||||
|
// List returns jobs for a function.
|
||||||
|
List(ctx context.Context, functionID string, limit int) ([]*Job, error)
|
||||||
|
|
||||||
|
// Cancel attempts to cancel a pending or running job.
|
||||||
|
Cancel(ctx context.Context, jobID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketManager manages WebSocket connections for function streaming.
|
||||||
|
type WebSocketManager interface {
|
||||||
|
// Register registers a new WebSocket connection.
|
||||||
|
Register(clientID string, conn WebSocketConn)
|
||||||
|
|
||||||
|
// Unregister removes a WebSocket connection.
|
||||||
|
Unregister(clientID string)
|
||||||
|
|
||||||
|
// Send sends data to a specific client.
|
||||||
|
Send(clientID string, data []byte) error
|
||||||
|
|
||||||
|
// Broadcast sends data to all clients subscribed to a topic.
|
||||||
|
Broadcast(topic string, data []byte) error
|
||||||
|
|
||||||
|
// Subscribe adds a client to a topic.
|
||||||
|
Subscribe(clientID, topic string)
|
||||||
|
|
||||||
|
// Unsubscribe removes a client from a topic.
|
||||||
|
Unsubscribe(clientID, topic string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketConn abstracts a WebSocket connection for testability.
|
||||||
|
type WebSocketConn interface {
|
||||||
|
WriteMessage(messageType int, data []byte) error
|
||||||
|
ReadMessage() (messageType int, p []byte, err error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Data Types
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// FunctionDefinition contains the configuration for deploying a function.
|
||||||
|
type FunctionDefinition struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
Version int `json:"version,omitempty"`
|
||||||
|
MemoryLimitMB int `json:"memory_limit_mb,omitempty"`
|
||||||
|
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||||
|
IsPublic bool `json:"is_public,omitempty"`
|
||||||
|
RetryCount int `json:"retry_count,omitempty"`
|
||||||
|
RetryDelaySeconds int `json:"retry_delay_seconds,omitempty"`
|
||||||
|
DLQTopic string `json:"dlq_topic,omitempty"`
|
||||||
|
EnvVars map[string]string `json:"env_vars,omitempty"`
|
||||||
|
CronExpressions []string `json:"cron_expressions,omitempty"`
|
||||||
|
DBTriggers []DBTriggerConfig `json:"db_triggers,omitempty"`
|
||||||
|
PubSubTopics []string `json:"pubsub_topics,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBTriggerConfig defines a database trigger configuration.
|
||||||
|
type DBTriggerConfig struct {
|
||||||
|
Table string `json:"table"`
|
||||||
|
Operation DBOperation `json:"operation"`
|
||||||
|
Condition string `json:"condition,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function represents a deployed serverless function.
|
||||||
|
type Function struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
Version int `json:"version"`
|
||||||
|
WASMCID string `json:"wasm_cid"`
|
||||||
|
SourceCID string `json:"source_cid,omitempty"`
|
||||||
|
MemoryLimitMB int `json:"memory_limit_mb"`
|
||||||
|
TimeoutSeconds int `json:"timeout_seconds"`
|
||||||
|
IsPublic bool `json:"is_public"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
RetryDelaySeconds int `json:"retry_delay_seconds"`
|
||||||
|
DLQTopic string `json:"dlq_topic,omitempty"`
|
||||||
|
Status FunctionStatus `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
CreatedBy string `json:"created_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvocationContext provides context for a function invocation.
|
||||||
|
type InvocationContext struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
FunctionName string `json:"function_name"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
CallerWallet string `json:"caller_wallet,omitempty"`
|
||||||
|
TriggerType TriggerType `json:"trigger_type"`
|
||||||
|
WSClientID string `json:"ws_client_id,omitempty"`
|
||||||
|
EnvVars map[string]string `json:"env_vars,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvocationResult represents the result of a function invocation.
|
||||||
|
type InvocationResult struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Output []byte `json:"output,omitempty"`
|
||||||
|
Status InvocationStatus `json:"status"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
DurationMS int64 `json:"duration_ms"`
|
||||||
|
Logs []LogEntry `json:"logs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogEntry represents a log message from a function.
|
||||||
|
type LogEntry struct {
|
||||||
|
Level string `json:"level"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Job represents a background job.
|
||||||
|
type Job struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
Payload []byte `json:"payload,omitempty"`
|
||||||
|
Status JobStatus `json:"status"`
|
||||||
|
Progress int `json:"progress"`
|
||||||
|
Result []byte `json:"result,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||||
|
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CronTrigger represents a cron-based trigger.
|
||||||
|
type CronTrigger struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
CronExpression string `json:"cron_expression"`
|
||||||
|
NextRunAt *time.Time `json:"next_run_at,omitempty"`
|
||||||
|
LastRunAt *time.Time `json:"last_run_at,omitempty"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBTrigger represents a database trigger.
|
||||||
|
type DBTrigger struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
TableName string `json:"table_name"`
|
||||||
|
Operation DBOperation `json:"operation"`
|
||||||
|
Condition string `json:"condition,omitempty"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PubSubTrigger represents a pubsub trigger.
|
||||||
|
type PubSubTrigger struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timer represents a one-time scheduled execution.
|
||||||
|
type Timer struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FunctionID string `json:"function_id"`
|
||||||
|
RunAt time.Time `json:"run_at"`
|
||||||
|
Payload []byte `json:"payload,omitempty"`
|
||||||
|
Status JobStatus `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBChangeEvent is passed to functions triggered by database changes.
|
||||||
|
type DBChangeEvent struct {
|
||||||
|
Table string `json:"table"`
|
||||||
|
Operation DBOperation `json:"operation"`
|
||||||
|
Row map[string]interface{} `json:"row"`
|
||||||
|
OldRow map[string]interface{} `json:"old_row,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Host Function Types (passed to WASM functions)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// HostServices provides access to Orama services from within WASM functions.
|
||||||
|
// This interface is implemented by the host and exposed to WASM modules.
|
||||||
|
type HostServices interface {
|
||||||
|
// Database operations
|
||||||
|
DBQuery(ctx context.Context, query string, args []interface{}) ([]byte, error)
|
||||||
|
DBExecute(ctx context.Context, query string, args []interface{}) (int64, error)
|
||||||
|
|
||||||
|
// Cache operations
|
||||||
|
CacheGet(ctx context.Context, key string) ([]byte, error)
|
||||||
|
CacheSet(ctx context.Context, key string, value []byte, ttlSeconds int64) error
|
||||||
|
CacheDelete(ctx context.Context, key string) error
|
||||||
|
CacheIncr(ctx context.Context, key string) (int64, error)
|
||||||
|
CacheIncrBy(ctx context.Context, key string, delta int64) (int64, error)
|
||||||
|
|
||||||
|
// Storage operations
|
||||||
|
StoragePut(ctx context.Context, data []byte) (string, error)
|
||||||
|
StorageGet(ctx context.Context, cid string) ([]byte, error)
|
||||||
|
|
||||||
|
// PubSub operations
|
||||||
|
PubSubPublish(ctx context.Context, topic string, data []byte) error
|
||||||
|
|
||||||
|
// WebSocket operations (only valid in WS context)
|
||||||
|
WSSend(ctx context.Context, clientID string, data []byte) error
|
||||||
|
WSBroadcast(ctx context.Context, topic string, data []byte) error
|
||||||
|
|
||||||
|
// HTTP operations
|
||||||
|
HTTPFetch(ctx context.Context, method, url string, headers map[string]string, body []byte) ([]byte, error)
|
||||||
|
|
||||||
|
// Context operations
|
||||||
|
GetEnv(ctx context.Context, key string) (string, error)
|
||||||
|
GetSecret(ctx context.Context, name string) (string, error)
|
||||||
|
GetRequestID(ctx context.Context) string
|
||||||
|
GetCallerWallet(ctx context.Context) string
|
||||||
|
|
||||||
|
// Job operations
|
||||||
|
EnqueueBackground(ctx context.Context, functionName string, payload []byte) (string, error)
|
||||||
|
ScheduleOnce(ctx context.Context, functionName string, runAt time.Time, payload []byte) (string, error)
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogInfo(ctx context.Context, message string)
|
||||||
|
LogError(ctx context.Context, message string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Deployment Types
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// DeployRequest represents a request to deploy a function.
|
||||||
|
type DeployRequest struct {
|
||||||
|
Definition *FunctionDefinition `json:"definition"`
|
||||||
|
Source io.Reader `json:"-"` // Go source code or WASM bytes
|
||||||
|
IsWASM bool `json:"is_wasm"` // True if Source contains WASM bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeployResult represents the result of a deployment.
|
||||||
|
type DeployResult struct {
|
||||||
|
Function *Function `json:"function"`
|
||||||
|
WASMCID string `json:"wasm_cid"`
|
||||||
|
Triggers []string `json:"triggers,omitempty"`
|
||||||
|
}
|
||||||
332
pkg/serverless/websocket.go
Normal file
332
pkg/serverless/websocket.go
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure WSManager implements WebSocketManager interface.
|
||||||
|
var _ WebSocketManager = (*WSManager)(nil)
|
||||||
|
|
||||||
|
// WSManager manages WebSocket connections for serverless functions.
|
||||||
|
// It handles connection registration, message routing, and topic subscriptions.
|
||||||
|
type WSManager struct {
|
||||||
|
// connections maps client IDs to their WebSocket connections
|
||||||
|
connections map[string]*wsConnection
|
||||||
|
connectionsMu sync.RWMutex
|
||||||
|
|
||||||
|
// subscriptions maps topic names to sets of client IDs
|
||||||
|
subscriptions map[string]map[string]struct{}
|
||||||
|
subscriptionsMu sync.RWMutex
|
||||||
|
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsConnection wraps a WebSocket connection with metadata.
|
||||||
|
type wsConnection struct {
|
||||||
|
conn WebSocketConn
|
||||||
|
clientID string
|
||||||
|
topics map[string]struct{} // Topics this client is subscribed to
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// GorillaWSConn wraps a gorilla/websocket.Conn to implement WebSocketConn.
|
||||||
|
type GorillaWSConn struct {
|
||||||
|
*websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure GorillaWSConn implements WebSocketConn.
|
||||||
|
var _ WebSocketConn = (*GorillaWSConn)(nil)
|
||||||
|
|
||||||
|
// WriteMessage writes a message to the WebSocket connection.
|
||||||
|
func (c *GorillaWSConn) WriteMessage(messageType int, data []byte) error {
|
||||||
|
return c.Conn.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMessage reads a message from the WebSocket connection.
|
||||||
|
func (c *GorillaWSConn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
|
return c.Conn.ReadMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the WebSocket connection.
|
||||||
|
func (c *GorillaWSConn) Close() error {
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWSManager creates a new WebSocket manager.
|
||||||
|
func NewWSManager(logger *zap.Logger) *WSManager {
|
||||||
|
return &WSManager{
|
||||||
|
connections: make(map[string]*wsConnection),
|
||||||
|
subscriptions: make(map[string]map[string]struct{}),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register registers a new WebSocket connection.
|
||||||
|
func (m *WSManager) Register(clientID string, conn WebSocketConn) {
|
||||||
|
m.connectionsMu.Lock()
|
||||||
|
defer m.connectionsMu.Unlock()
|
||||||
|
|
||||||
|
// Close existing connection if any
|
||||||
|
if existing, exists := m.connections[clientID]; exists {
|
||||||
|
_ = existing.conn.Close()
|
||||||
|
m.logger.Debug("Closed existing connection", zap.String("client_id", clientID))
|
||||||
|
}
|
||||||
|
|
||||||
|
m.connections[clientID] = &wsConnection{
|
||||||
|
conn: conn,
|
||||||
|
clientID: clientID,
|
||||||
|
topics: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Debug("Registered WebSocket connection",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.Int("total_connections", len(m.connections)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister removes a WebSocket connection and its subscriptions.
|
||||||
|
func (m *WSManager) Unregister(clientID string) {
|
||||||
|
m.connectionsMu.Lock()
|
||||||
|
conn, exists := m.connections[clientID]
|
||||||
|
if exists {
|
||||||
|
delete(m.connections, clientID)
|
||||||
|
}
|
||||||
|
m.connectionsMu.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from all subscriptions
|
||||||
|
m.subscriptionsMu.Lock()
|
||||||
|
for topic := range conn.topics {
|
||||||
|
if clients, ok := m.subscriptions[topic]; ok {
|
||||||
|
delete(clients, clientID)
|
||||||
|
if len(clients) == 0 {
|
||||||
|
delete(m.subscriptions, topic)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.subscriptionsMu.Unlock()
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
_ = conn.conn.Close()
|
||||||
|
|
||||||
|
m.logger.Debug("Unregistered WebSocket connection",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.Int("remaining_connections", m.GetConnectionCount()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends data to a specific client.
|
||||||
|
func (m *WSManager) Send(clientID string, data []byte) error {
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
conn, exists := m.connections[clientID]
|
||||||
|
m.connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return ErrWSClientNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.mu.Lock()
|
||||||
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||||
|
m.logger.Warn("Failed to send WebSocket message",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends data to all clients subscribed to a topic.
|
||||||
|
func (m *WSManager) Broadcast(topic string, data []byte) error {
|
||||||
|
m.subscriptionsMu.RLock()
|
||||||
|
clients, exists := m.subscriptions[topic]
|
||||||
|
if !exists || len(clients) == 0 {
|
||||||
|
m.subscriptionsMu.RUnlock()
|
||||||
|
return nil // No subscribers, not an error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy client IDs to avoid holding lock during send
|
||||||
|
clientIDs := make([]string, 0, len(clients))
|
||||||
|
for clientID := range clients {
|
||||||
|
clientIDs = append(clientIDs, clientID)
|
||||||
|
}
|
||||||
|
m.subscriptionsMu.RUnlock()
|
||||||
|
|
||||||
|
// Send to all subscribers
|
||||||
|
var sendErrors int
|
||||||
|
for _, clientID := range clientIDs {
|
||||||
|
if err := m.Send(clientID, data); err != nil {
|
||||||
|
sendErrors++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Debug("Broadcast message",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Int("recipients", len(clientIDs)),
|
||||||
|
zap.Int("errors", sendErrors),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a client to a topic.
|
||||||
|
func (m *WSManager) Subscribe(clientID, topic string) {
|
||||||
|
// Add to connection's topic list
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
conn, exists := m.connections[clientID]
|
||||||
|
m.connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.mu.Lock()
|
||||||
|
conn.topics[topic] = struct{}{}
|
||||||
|
conn.mu.Unlock()
|
||||||
|
|
||||||
|
// Add to topic's client list
|
||||||
|
m.subscriptionsMu.Lock()
|
||||||
|
if m.subscriptions[topic] == nil {
|
||||||
|
m.subscriptions[topic] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
m.subscriptions[topic][clientID] = struct{}{}
|
||||||
|
m.subscriptionsMu.Unlock()
|
||||||
|
|
||||||
|
m.logger.Debug("Client subscribed to topic",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a client from a topic.
|
||||||
|
func (m *WSManager) Unsubscribe(clientID, topic string) {
|
||||||
|
// Remove from connection's topic list
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
conn, exists := m.connections[clientID]
|
||||||
|
m.connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.mu.Lock()
|
||||||
|
delete(conn.topics, topic)
|
||||||
|
conn.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from topic's client list
|
||||||
|
m.subscriptionsMu.Lock()
|
||||||
|
if clients, ok := m.subscriptions[topic]; ok {
|
||||||
|
delete(clients, clientID)
|
||||||
|
if len(clients) == 0 {
|
||||||
|
delete(m.subscriptions, topic)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.subscriptionsMu.Unlock()
|
||||||
|
|
||||||
|
m.logger.Debug("Client unsubscribed from topic",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionCount returns the number of active connections.
|
||||||
|
func (m *WSManager) GetConnectionCount() int {
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
defer m.connectionsMu.RUnlock()
|
||||||
|
return len(m.connections)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTopicSubscriberCount returns the number of subscribers for a topic.
|
||||||
|
func (m *WSManager) GetTopicSubscriberCount(topic string) int {
|
||||||
|
m.subscriptionsMu.RLock()
|
||||||
|
defer m.subscriptionsMu.RUnlock()
|
||||||
|
if clients, exists := m.subscriptions[topic]; exists {
|
||||||
|
return len(clients)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientTopics returns all topics a client is subscribed to.
|
||||||
|
func (m *WSManager) GetClientTopics(clientID string) []string {
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
conn, exists := m.connections[clientID]
|
||||||
|
m.connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.mu.Lock()
|
||||||
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
topics := make([]string, 0, len(conn.topics))
|
||||||
|
for topic := range conn.topics {
|
||||||
|
topics = append(topics, topic)
|
||||||
|
}
|
||||||
|
return topics
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected checks if a client is connected.
|
||||||
|
func (m *WSManager) IsConnected(clientID string) bool {
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
defer m.connectionsMu.RUnlock()
|
||||||
|
_, exists := m.connections[clientID]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes all connections and cleans up resources.
|
||||||
|
func (m *WSManager) Close() {
|
||||||
|
m.connectionsMu.Lock()
|
||||||
|
defer m.connectionsMu.Unlock()
|
||||||
|
|
||||||
|
for clientID, conn := range m.connections {
|
||||||
|
_ = conn.conn.Close()
|
||||||
|
delete(m.connections, clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.subscriptionsMu.Lock()
|
||||||
|
m.subscriptions = make(map[string]map[string]struct{})
|
||||||
|
m.subscriptionsMu.Unlock()
|
||||||
|
|
||||||
|
m.logger.Info("WebSocket manager closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the WebSocket manager.
|
||||||
|
type WSStats struct {
|
||||||
|
ConnectionCount int `json:"connection_count"`
|
||||||
|
TopicCount int `json:"topic_count"`
|
||||||
|
SubscriptionCount int `json:"subscription_count"`
|
||||||
|
TopicStats map[string]int `json:"topic_stats"` // topic -> subscriber count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats returns current statistics.
|
||||||
|
func (m *WSManager) GetStats() *WSStats {
|
||||||
|
m.connectionsMu.RLock()
|
||||||
|
connCount := len(m.connections)
|
||||||
|
m.connectionsMu.RUnlock()
|
||||||
|
|
||||||
|
m.subscriptionsMu.RLock()
|
||||||
|
topicCount := len(m.subscriptions)
|
||||||
|
topicStats := make(map[string]int, topicCount)
|
||||||
|
totalSubs := 0
|
||||||
|
for topic, clients := range m.subscriptions {
|
||||||
|
topicStats[topic] = len(clients)
|
||||||
|
totalSubs += len(clients)
|
||||||
|
}
|
||||||
|
m.subscriptionsMu.RUnlock()
|
||||||
|
|
||||||
|
return &WSStats{
|
||||||
|
ConnectionCount: connCount,
|
||||||
|
TopicCount: topicCount,
|
||||||
|
SubscriptionCount: totalSubs,
|
||||||
|
TopicStats: topicStats,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Setup local domains for DeBros Network development
|
|
||||||
# Adds entries to /etc/hosts for node-1.local through node-5.local
|
|
||||||
# Maps them to 127.0.0.1 for local development
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
HOSTS_FILE="/etc/hosts"
|
|
||||||
NODES=("node-1" "node-2" "node-3" "node-4" "node-5")
|
|
||||||
|
|
||||||
# Check if we have sudo access
|
|
||||||
if [ "$EUID" -ne 0 ]; then
|
|
||||||
echo "This script requires sudo to modify /etc/hosts"
|
|
||||||
echo "Please run: sudo bash scripts/setup-local-domains.sh"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Function to add or update domain entry
|
|
||||||
add_domain() {
|
|
||||||
local domain=$1
|
|
||||||
local ip="127.0.0.1"
|
|
||||||
|
|
||||||
# Check if domain already exists
|
|
||||||
if grep -q "^[[:space:]]*$ip[[:space:]]\+$domain" "$HOSTS_FILE"; then
|
|
||||||
echo "✓ $domain already configured"
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Add domain to /etc/hosts
|
|
||||||
echo "$ip $domain" >> "$HOSTS_FILE"
|
|
||||||
echo "✓ Added $domain -> $ip"
|
|
||||||
}
|
|
||||||
|
|
||||||
echo "Setting up local domains for DeBros Network..."
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Add each node domain
|
|
||||||
for node in "${NODES[@]}"; do
|
|
||||||
add_domain "${node}.local"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "✓ Local domains configured successfully!"
|
|
||||||
echo ""
|
|
||||||
echo "You can now access nodes via:"
|
|
||||||
for node in "${NODES[@]}"; do
|
|
||||||
echo " - ${node}.local (HTTP Gateway)"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Example: curl http://node-1.local:8080/rqlite/http/db/status"
|
|
||||||
|
|
||||||
@ -1,85 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Test local domain routing for DeBros Network
|
|
||||||
# Validates that all HTTP gateway routes are working
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
NODES=("1" "2" "3" "4" "5")
|
|
||||||
GATEWAY_PORTS=(8080 8081 8082 8083 8084)
|
|
||||||
|
|
||||||
# Color codes
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
RED='\033[0;31m'
|
|
||||||
YELLOW='\033[1;33m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
# Counters
|
|
||||||
PASSED=0
|
|
||||||
FAILED=0
|
|
||||||
|
|
||||||
# Test a single endpoint
|
|
||||||
test_endpoint() {
|
|
||||||
local node=$1
|
|
||||||
local port=$2
|
|
||||||
local path=$3
|
|
||||||
local description=$4
|
|
||||||
|
|
||||||
local url="http://node-${node}.local:${port}${path}"
|
|
||||||
|
|
||||||
printf "Testing %-50s ... " "$description"
|
|
||||||
|
|
||||||
if curl -s -f "$url" > /dev/null 2>&1; then
|
|
||||||
echo -e "${GREEN}✓ PASS${NC}"
|
|
||||||
((PASSED++))
|
|
||||||
return 0
|
|
||||||
else
|
|
||||||
echo -e "${RED}✗ FAIL${NC}"
|
|
||||||
((FAILED++))
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
echo "=========================================="
|
|
||||||
echo "DeBros Network Local Domain Tests"
|
|
||||||
echo "=========================================="
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Test each node's HTTP gateway
|
|
||||||
for i in "${!NODES[@]}"; do
|
|
||||||
node=${NODES[$i]}
|
|
||||||
port=${GATEWAY_PORTS[$i]}
|
|
||||||
|
|
||||||
echo "Testing node-${node}.local (port ${port}):"
|
|
||||||
|
|
||||||
# Test health endpoint
|
|
||||||
test_endpoint "$node" "$port" "/health" "Node-${node} health check"
|
|
||||||
|
|
||||||
# Test RQLite HTTP endpoint
|
|
||||||
test_endpoint "$node" "$port" "/rqlite/http/db/execute" "Node-${node} RQLite HTTP"
|
|
||||||
|
|
||||||
# Test IPFS API endpoint (may fail if IPFS not running, but at least connection should work)
|
|
||||||
test_endpoint "$node" "$port" "/ipfs/api/v0/version" "Node-${node} IPFS API" || true
|
|
||||||
|
|
||||||
# Test Cluster API endpoint (may fail if Cluster not running, but at least connection should work)
|
|
||||||
test_endpoint "$node" "$port" "/cluster/health" "Node-${node} Cluster API" || true
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
done
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
echo "=========================================="
|
|
||||||
echo "Test Results"
|
|
||||||
echo "=========================================="
|
|
||||||
echo -e "${GREEN}Passed: $PASSED${NC}"
|
|
||||||
echo -e "${RED}Failed: $FAILED${NC}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
if [ $FAILED -eq 0 ]; then
|
|
||||||
echo -e "${GREEN}✓ All tests passed!${NC}"
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo -e "${YELLOW}⚠ Some tests failed (this is expected if services aren't running)${NC}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user