mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 08:36:57 +00:00
Namespace bug fix and fixing bugs on serverless and deployments
This commit is contained in:
parent
8aef779fcd
commit
83804422c4
2
Makefile
2
Makefile
@ -63,7 +63,7 @@ test-e2e-quick:
|
||||
|
||||
.PHONY: build clean test deps tidy fmt vet lint install-hooks redeploy-devnet redeploy-testnet release health
|
||||
|
||||
VERSION := 0.107.2
|
||||
VERSION := 0.108.0
|
||||
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
|
||||
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
|
||||
|
||||
@ -20,6 +20,10 @@ type Credentials struct {
|
||||
LastUsedAt time.Time `json:"last_used_at,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
NamespaceURL string `json:"namespace_url,omitempty"`
|
||||
|
||||
// ProvisioningPollURL is set when namespace cluster is being provisioned.
|
||||
// Used only during the login flow, not persisted.
|
||||
ProvisioningPollURL string `json:"-"`
|
||||
}
|
||||
|
||||
// CredentialStore manages credentials for multiple gateways
|
||||
|
||||
@ -117,6 +117,18 @@ func PerformRootWalletAuthentication(gatewayURL, namespace string) (*Credentials
|
||||
return nil, fmt.Errorf("failed to verify signature: %w", err)
|
||||
}
|
||||
|
||||
// If namespace cluster is being provisioned, poll until ready
|
||||
if creds.ProvisioningPollURL != "" {
|
||||
fmt.Println("⏳ Provisioning namespace cluster...")
|
||||
pollErr := pollNamespaceProvisioning(client, gatewayURL, creds.ProvisioningPollURL)
|
||||
if pollErr != nil {
|
||||
fmt.Printf("⚠️ Provisioning poll failed: %v\n", pollErr)
|
||||
fmt.Println(" Credentials are saved. Cluster may still be provisioning in background.")
|
||||
} else {
|
||||
fmt.Println("✅ Namespace cluster ready!")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\n🎉 Authentication successful!\n")
|
||||
fmt.Printf("🏢 Namespace: %s\n", creds.Namespace)
|
||||
|
||||
@ -184,7 +196,7 @@ func verifySignature(client *http.Client, gatewayURL, wallet, nonce, signature,
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("gateway returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
@ -196,6 +208,9 @@ func verifySignature(client *http.Client, gatewayURL, wallet, nonce, signature,
|
||||
Subject string `json:"subject"`
|
||||
Namespace string `json:"namespace"`
|
||||
APIKey string `json:"api_key"`
|
||||
// Provisioning fields (202 Accepted)
|
||||
Status string `json:"status"`
|
||||
PollURL string `json:"poll_url"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
@ -225,8 +240,51 @@ func verifySignature(client *http.Client, gatewayURL, wallet, nonce, signature,
|
||||
NamespaceURL: namespaceURL,
|
||||
}
|
||||
|
||||
// If 202, namespace cluster is being provisioned — set poll URL
|
||||
if resp.StatusCode == http.StatusAccepted && result.PollURL != "" {
|
||||
creds.ProvisioningPollURL = result.PollURL
|
||||
}
|
||||
|
||||
// Note: result.ExpiresIn is the JWT access token lifetime (15min),
|
||||
// NOT the API key lifetime. Don't set ExpiresAt — the API key is permanent.
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// pollNamespaceProvisioning polls the namespace status endpoint until the cluster is ready.
|
||||
func pollNamespaceProvisioning(client *http.Client, gatewayURL, pollPath string) error {
|
||||
pollURL := gatewayURL + pollPath
|
||||
timeout := time.After(120 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timed out after 120s waiting for namespace cluster")
|
||||
case <-ticker.C:
|
||||
resp, err := client.Get(pollURL)
|
||||
if err != nil {
|
||||
continue // Retry on network error
|
||||
}
|
||||
|
||||
var status struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
decErr := json.NewDecoder(resp.Body).Decode(&status)
|
||||
resp.Body.Close()
|
||||
if decErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ready":
|
||||
return nil
|
||||
case "failed", "error":
|
||||
return fmt.Errorf("namespace provisioning failed")
|
||||
}
|
||||
// "provisioning" or other — keep polling
|
||||
fmt.Print(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
112
pkg/cli/production/report/deployments.go
Normal file
112
pkg/cli/production/report/deployments.go
Normal file
@ -0,0 +1,112 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// collectDeployments discovers deployed applications by querying the local gateway.
|
||||
func collectDeployments() *DeploymentsReport {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
report := &DeploymentsReport{}
|
||||
|
||||
// Query the local gateway for deployment list
|
||||
url := "http://localhost:8080/v1/health"
|
||||
body, err := httpGet(ctx, url)
|
||||
if err != nil {
|
||||
// Gateway not available — fall back to systemd unit discovery
|
||||
return collectDeploymentsFromSystemd()
|
||||
}
|
||||
|
||||
// Check if gateway reports deployment counts in health response
|
||||
var health map[string]interface{}
|
||||
if err := json.Unmarshal(body, &health); err == nil {
|
||||
if deps, ok := health["deployments"].(map[string]interface{}); ok {
|
||||
if v, ok := deps["total"].(float64); ok {
|
||||
report.TotalCount = int(v)
|
||||
}
|
||||
if v, ok := deps["running"].(float64); ok {
|
||||
report.RunningCount = int(v)
|
||||
}
|
||||
if v, ok := deps["failed"].(float64); ok {
|
||||
report.FailedCount = int(v)
|
||||
}
|
||||
return report
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: count deployment systemd units
|
||||
return collectDeploymentsFromSystemd()
|
||||
}
|
||||
|
||||
// collectDeploymentsFromSystemd discovers deployments by listing systemd units.
|
||||
func collectDeploymentsFromSystemd() *DeploymentsReport {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
report := &DeploymentsReport{}
|
||||
|
||||
// List orama-deploy-* units
|
||||
out, err := runCmd(ctx, "systemctl", "list-units", "--type=service", "--no-legend", "--no-pager", "orama-deploy-*")
|
||||
if err != nil {
|
||||
return report
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
report.TotalCount++
|
||||
fields := strings.Fields(line)
|
||||
// systemctl list-units format: UNIT LOAD ACTIVE SUB DESCRIPTION...
|
||||
if len(fields) >= 4 {
|
||||
switch fields[3] {
|
||||
case "running":
|
||||
report.RunningCount++
|
||||
case "failed", "dead":
|
||||
report.FailedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// collectServerless checks if the serverless engine is available via the gateway health endpoint.
|
||||
func collectServerless() *ServerlessReport {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
report := &ServerlessReport{
|
||||
EngineStatus: "unknown",
|
||||
}
|
||||
|
||||
// Check gateway health for serverless subsystem
|
||||
url := "http://localhost:8080/v1/health"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return report
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
report.EngineStatus = "unreachable"
|
||||
return report
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
report.EngineStatus = "healthy"
|
||||
} else {
|
||||
report.EngineStatus = fmt.Sprintf("unhealthy (HTTP %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
return report
|
||||
}
|
||||
@ -40,19 +40,19 @@ func discoverNamespaces() []nsInfo {
|
||||
var result []nsInfo
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// Strategy 1: Glob for orama-deploy-*-rqlite.service files.
|
||||
matches, _ := filepath.Glob("/etc/systemd/system/orama-deploy-*-rqlite.service")
|
||||
// Strategy 1: Glob for orama-namespace-rqlite@*.service files.
|
||||
matches, _ := filepath.Glob("/etc/systemd/system/orama-namespace-rqlite@*.service")
|
||||
for _, path := range matches {
|
||||
base := filepath.Base(path)
|
||||
// Extract namespace name: orama-deploy-<name>-rqlite.service
|
||||
name := strings.TrimPrefix(base, "orama-deploy-")
|
||||
name = strings.TrimSuffix(name, "-rqlite.service")
|
||||
// Extract namespace name: orama-namespace-rqlite@<name>.service
|
||||
name := strings.TrimPrefix(base, "orama-namespace-rqlite@")
|
||||
name = strings.TrimSuffix(name, ".service")
|
||||
if name == "" || seen[name] {
|
||||
continue
|
||||
}
|
||||
seen[name] = true
|
||||
|
||||
portBase := parsePortBaseFromUnit(path)
|
||||
portBase := parsePortFromEnvFile(name)
|
||||
if portBase > 0 {
|
||||
result = append(result, nsInfo{name: name, portBase: portBase})
|
||||
}
|
||||
@ -69,9 +69,7 @@ func discoverNamespaces() []nsInfo {
|
||||
name := entry.Name()
|
||||
seen[name] = true
|
||||
|
||||
// Try to find the port base from a corresponding service unit.
|
||||
unitPath := fmt.Sprintf("/etc/systemd/system/orama-deploy-%s-rqlite.service", name)
|
||||
portBase := parsePortBaseFromUnit(unitPath)
|
||||
portBase := parsePortFromEnvFile(name)
|
||||
if portBase > 0 {
|
||||
result = append(result, nsInfo{name: name, portBase: portBase})
|
||||
}
|
||||
@ -81,58 +79,21 @@ func discoverNamespaces() []nsInfo {
|
||||
return result
|
||||
}
|
||||
|
||||
// parsePortBaseFromUnit reads a systemd unit file and extracts the port base
|
||||
// from ExecStart arguments or environment variables.
|
||||
//
|
||||
// It looks for patterns like:
|
||||
// - "-http-addr localhost:PORT" or "-http-addr 0.0.0.0:PORT" in ExecStart
|
||||
// - "PORT_BASE=NNNN" in environment files
|
||||
// - Any port number that appears to be the RQLite HTTP port (the base port)
|
||||
func parsePortBaseFromUnit(unitPath string) int {
|
||||
data, err := os.ReadFile(unitPath)
|
||||
// parsePortFromEnvFile reads the RQLite env file for a namespace and extracts
|
||||
// the HTTP port from HTTP_ADDR (e.g. "0.0.0.0:14001").
|
||||
func parsePortFromEnvFile(namespace string) int {
|
||||
envPath := fmt.Sprintf("/opt/orama/.orama/data/namespaces/%s/rqlite.env", namespace)
|
||||
data, err := os.ReadFile(envPath)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
content := string(data)
|
||||
|
||||
// Look for -http-addr with a port number in ExecStart line.
|
||||
httpAddrRe := regexp.MustCompile(`-http-addr\s+\S+:(\d+)`)
|
||||
if m := httpAddrRe.FindStringSubmatch(content); len(m) >= 2 {
|
||||
httpAddrRe := regexp.MustCompile(`HTTP_ADDR=\S+:(\d+)`)
|
||||
if m := httpAddrRe.FindStringSubmatch(string(data)); len(m) >= 2 {
|
||||
if port, err := strconv.Atoi(m[1]); err == nil {
|
||||
return port
|
||||
}
|
||||
}
|
||||
|
||||
// Look for a port in -addr or -http flags.
|
||||
addrRe := regexp.MustCompile(`(?:-addr|-http)\s+\S*:(\d+)`)
|
||||
if m := addrRe.FindStringSubmatch(content); len(m) >= 2 {
|
||||
if port, err := strconv.Atoi(m[1]); err == nil {
|
||||
return port
|
||||
}
|
||||
}
|
||||
|
||||
// Look for PORT_BASE environment variable in EnvironmentFile or Environment= directives.
|
||||
portBaseRe := regexp.MustCompile(`PORT_BASE=(\d+)`)
|
||||
if m := portBaseRe.FindStringSubmatch(content); len(m) >= 2 {
|
||||
if port, err := strconv.Atoi(m[1]); err == nil {
|
||||
return port
|
||||
}
|
||||
}
|
||||
|
||||
// Check referenced EnvironmentFile for PORT_BASE.
|
||||
envFileRe := regexp.MustCompile(`EnvironmentFile=(.+)`)
|
||||
if m := envFileRe.FindStringSubmatch(content); len(m) >= 2 {
|
||||
envPath := strings.TrimSpace(m[1])
|
||||
envPath = strings.TrimPrefix(envPath, "-") // optional prefix means "ignore if missing"
|
||||
if envData, err := os.ReadFile(envPath); err == nil {
|
||||
if m2 := portBaseRe.FindStringSubmatch(string(envData)); len(m2) >= 2 {
|
||||
if port, err := strconv.Atoi(m2[1]); err == nil {
|
||||
return port
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@ -102,6 +102,14 @@ func Handle(jsonFlag bool, version string) error {
|
||||
rpt.Namespaces = collectNamespaces()
|
||||
})
|
||||
|
||||
safeGo(&wg, "deployments", func() {
|
||||
rpt.Deployments = collectDeployments()
|
||||
})
|
||||
|
||||
safeGo(&wg, "serverless", func() {
|
||||
rpt.Serverless = collectServerless()
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Populate top-level WireGuard IP from the WireGuard collector result.
|
||||
|
||||
@ -181,10 +181,10 @@ func collectFailedUnits(ctx context.Context) []string {
|
||||
return units
|
||||
}
|
||||
|
||||
// discoverNamespaceServices finds orama-deploy-*.service files in /etc/systemd/system
|
||||
// discoverNamespaceServices finds orama-namespace-*@*.service files in /etc/systemd/system
|
||||
// and returns the service names (without the .service suffix path).
|
||||
func discoverNamespaceServices() []string {
|
||||
matches, err := filepath.Glob("/etc/systemd/system/orama-deploy-*.service")
|
||||
matches, err := filepath.Glob("/etc/systemd/system/orama-namespace-*@*.service")
|
||||
if err != nil || len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -23,7 +23,9 @@ type NodeReport struct {
|
||||
Anyone *AnyoneReport `json:"anyone,omitempty"`
|
||||
Network *NetworkReport `json:"network"`
|
||||
Processes *ProcessReport `json:"processes"`
|
||||
Namespaces []NamespaceReport `json:"namespaces,omitempty"`
|
||||
Namespaces []NamespaceReport `json:"namespaces,omitempty"`
|
||||
Deployments *DeploymentsReport `json:"deployments,omitempty"`
|
||||
Serverless *ServerlessReport `json:"serverless,omitempty"`
|
||||
}
|
||||
|
||||
// --- System ---
|
||||
@ -273,3 +275,19 @@ type NamespaceReport struct {
|
||||
GatewayUp bool `json:"gateway_up"`
|
||||
GatewayStatus int `json:"gateway_status,omitempty"`
|
||||
}
|
||||
|
||||
// --- Deployments ---
|
||||
|
||||
type DeploymentsReport struct {
|
||||
TotalCount int `json:"total_count"`
|
||||
RunningCount int `json:"running_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
StaticCount int `json:"static_count"`
|
||||
}
|
||||
|
||||
// --- Serverless ---
|
||||
|
||||
type ServerlessReport struct {
|
||||
FunctionCount int `json:"function_count"`
|
||||
EngineStatus string `json:"engine_status"`
|
||||
}
|
||||
|
||||
@ -387,6 +387,23 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
}
|
||||
}
|
||||
|
||||
// Create WASM engine configuration (needed before secrets manager)
|
||||
engineCfg := serverless.DefaultConfig()
|
||||
engineCfg.DefaultMemoryLimitMB = 128
|
||||
engineCfg.MaxMemoryLimitMB = 256
|
||||
engineCfg.DefaultTimeoutSeconds = 30
|
||||
engineCfg.MaxTimeoutSeconds = 60
|
||||
engineCfg.ModuleCacheSize = 100
|
||||
|
||||
// Create secrets manager for serverless functions (AES-256-GCM encrypted)
|
||||
var secretsMgr serverless.SecretsManager
|
||||
if smImpl, secretsErr := hostfunctions.NewDBSecretsManager(deps.ORMClient, engineCfg.SecretsEncryptionKey, logger.Logger); secretsErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "Failed to initialize secrets manager; get_secret will be unavailable",
|
||||
zap.Error(secretsErr))
|
||||
} else {
|
||||
secretsMgr = smImpl
|
||||
}
|
||||
|
||||
// Create host functions provider (allows functions to call Orama services)
|
||||
hostFuncsCfg := hostfunctions.HostFunctionsConfig{
|
||||
IPFSAPIURL: cfg.IPFSAPIURL,
|
||||
@ -398,21 +415,17 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
deps.IPFSClient,
|
||||
pubsubAdapter, // pubsub adapter for serverless functions
|
||||
deps.ServerlessWSMgr,
|
||||
nil, // secrets manager - TODO: implement
|
||||
secretsMgr,
|
||||
hostFuncsCfg,
|
||||
logger.Logger,
|
||||
)
|
||||
|
||||
// Create WASM engine configuration
|
||||
engineCfg := serverless.DefaultConfig()
|
||||
engineCfg.DefaultMemoryLimitMB = 128
|
||||
engineCfg.MaxMemoryLimitMB = 256
|
||||
engineCfg.DefaultTimeoutSeconds = 30
|
||||
engineCfg.MaxTimeoutSeconds = 60
|
||||
engineCfg.ModuleCacheSize = 100
|
||||
|
||||
// Create WASM engine
|
||||
engine, err := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
|
||||
// Create WASM engine with rate limiter
|
||||
rateLimiter := serverless.NewTokenBucketLimiter(engineCfg.GlobalRateLimitPerMinute)
|
||||
engine, err := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger,
|
||||
serverless.WithInvocationLogger(registry),
|
||||
serverless.WithRateLimiter(rateLimiter),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize serverless engine: %w", err)
|
||||
}
|
||||
|
||||
@ -264,6 +264,27 @@ func (h *Handlers) PhantomCompleteHandler(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger namespace cluster provisioning if needed (for non-default namespaces)
|
||||
if h.clusterProvisioner != nil && namespace != "default" {
|
||||
_, _, needsProvisioning, checkErr := h.clusterProvisioner.CheckNamespaceCluster(ctx, namespace)
|
||||
if checkErr != nil {
|
||||
_ = checkErr // Log but don't fail auth
|
||||
} else if needsProvisioning {
|
||||
nsIDInt := 0
|
||||
if id, ok := nsID.(int); ok {
|
||||
nsIDInt = id
|
||||
} else if id, ok := nsID.(int64); ok {
|
||||
nsIDInt = int(id)
|
||||
} else if id, ok := nsID.(float64); ok {
|
||||
nsIDInt = int(id)
|
||||
}
|
||||
_, _, provErr := h.clusterProvisioner.ProvisionNamespaceCluster(ctx, nsIDInt, namespace, req.Wallet)
|
||||
if provErr != nil {
|
||||
_ = provErr // Log but don't fail auth — provisioning is async
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Issue API key
|
||||
apiKey, err := h.authService.GetOrCreateAPIKey(ctx, req.Wallet, namespace)
|
||||
if err != nil {
|
||||
|
||||
@ -9,11 +9,13 @@ import (
|
||||
|
||||
// VerifyHandler verifies a wallet signature and issues JWT tokens and an API key.
|
||||
// This completes the authentication flow by validating the signed nonce and returning
|
||||
// access credentials.
|
||||
// access credentials. For non-default namespaces, may trigger cluster provisioning
|
||||
// and return 202 Accepted with credentials + poll URL.
|
||||
//
|
||||
// POST /v1/auth/verify
|
||||
// Request body: VerifyRequest
|
||||
// Response: { "access_token", "token_type", "expires_in", "refresh_token", "subject", "namespace", "api_key", "nonce", "signature_verified" }
|
||||
// Response 200: { "access_token", "token_type", "expires_in", "refresh_token", "subject", "namespace", "api_key", "nonce", "signature_verified" }
|
||||
// Response 202: { "status": "provisioning", "cluster_id", "poll_url", "access_token", "refresh_token", "api_key", ... }
|
||||
func (h *Handlers) VerifyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.authService == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "auth service not initialized")
|
||||
@ -46,6 +48,70 @@ func (h *Handlers) VerifyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
nsID, _ := h.resolveNamespace(ctx, req.Namespace)
|
||||
h.markNonceUsed(ctx, nsID, strings.ToLower(req.Wallet), req.Nonce)
|
||||
|
||||
// Check if namespace cluster provisioning is needed (for non-default namespaces)
|
||||
namespace := strings.TrimSpace(req.Namespace)
|
||||
if namespace == "" {
|
||||
namespace = "default"
|
||||
}
|
||||
|
||||
if h.clusterProvisioner != nil && namespace != "default" {
|
||||
clusterID, status, needsProvisioning, checkErr := h.clusterProvisioner.CheckNamespaceCluster(ctx, namespace)
|
||||
if checkErr != nil {
|
||||
_ = checkErr // Log but don't fail
|
||||
} else if needsProvisioning || status == "provisioning" {
|
||||
// Issue tokens and API key before returning provisioning status
|
||||
token, refresh, expUnix, tokenErr := h.authService.IssueTokens(ctx, req.Wallet, req.Namespace)
|
||||
if tokenErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, tokenErr.Error())
|
||||
return
|
||||
}
|
||||
apiKey, keyErr := h.authService.GetOrCreateAPIKey(ctx, req.Wallet, req.Namespace)
|
||||
if keyErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, keyErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pollURL := ""
|
||||
if needsProvisioning {
|
||||
nsIDInt := 0
|
||||
if id, ok := nsID.(int); ok {
|
||||
nsIDInt = id
|
||||
} else if id, ok := nsID.(int64); ok {
|
||||
nsIDInt = int(id)
|
||||
} else if id, ok := nsID.(float64); ok {
|
||||
nsIDInt = int(id)
|
||||
}
|
||||
|
||||
newClusterID, newPollURL, provErr := h.clusterProvisioner.ProvisionNamespaceCluster(ctx, nsIDInt, namespace, req.Wallet)
|
||||
if provErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to start cluster provisioning")
|
||||
return
|
||||
}
|
||||
clusterID = newClusterID
|
||||
pollURL = newPollURL
|
||||
} else {
|
||||
pollURL = "/v1/namespace/status?id=" + clusterID
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "provisioning",
|
||||
"cluster_id": clusterID,
|
||||
"poll_url": pollURL,
|
||||
"estimated_time_seconds": 60,
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(expUnix - time.Now().Unix()),
|
||||
"refresh_token": refresh,
|
||||
"api_key": apiKey,
|
||||
"namespace": req.Namespace,
|
||||
"subject": req.Wallet,
|
||||
"nonce": req.Nonce,
|
||||
"signature_verified": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, refresh, expUnix, err := h.authService.IssueTokens(ctx, req.Wallet, req.Namespace)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
|
||||
@ -120,7 +120,12 @@ func (h *GoHandler) HandleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Create DNS records (use background context since HTTP context will be cancelled)
|
||||
go h.service.CreateDNSRecords(context.Background(), deployment)
|
||||
go func() {
|
||||
if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil {
|
||||
h.logger.Error("Background DNS creation failed",
|
||||
zap.String("deployment", deployment.Name), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Build response
|
||||
urls := h.service.BuildDeploymentURLs(deployment)
|
||||
|
||||
@ -244,11 +244,15 @@ func (h *ListHandler) HandleDelete(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// 4. Delete subdomain registry
|
||||
subdomainQuery := `DELETE FROM global_deployment_subdomains WHERE deployment_id = ?`
|
||||
_, _ = h.service.db.Exec(ctx, subdomainQuery, deployment.ID)
|
||||
if _, subErr := h.service.db.Exec(ctx, subdomainQuery, deployment.ID); subErr != nil {
|
||||
h.logger.Warn("Failed to delete subdomain registry", zap.String("id", deployment.ID), zap.Error(subErr))
|
||||
}
|
||||
|
||||
// 5. Delete DNS records
|
||||
dnsQuery := `DELETE FROM dns_records WHERE deployment_id = ?`
|
||||
_, _ = h.service.db.Exec(ctx, dnsQuery, deployment.ID)
|
||||
if _, dnsErr := h.service.db.Exec(ctx, dnsQuery, deployment.ID); dnsErr != nil {
|
||||
h.logger.Warn("Failed to delete DNS records", zap.String("id", deployment.ID), zap.Error(dnsErr))
|
||||
}
|
||||
|
||||
// 6. Delete deployment record
|
||||
query := `DELETE FROM deployments WHERE namespace = ? AND name = ?`
|
||||
|
||||
@ -126,7 +126,12 @@ func (h *NextJSHandler) HandleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Create DNS records (use background context since HTTP context will be cancelled)
|
||||
go h.service.CreateDNSRecords(context.Background(), deployment)
|
||||
go func() {
|
||||
if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil {
|
||||
h.logger.Error("Background DNS creation failed",
|
||||
zap.String("deployment", deployment.Name), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Build response
|
||||
urls := h.service.BuildDeploymentURLs(deployment)
|
||||
|
||||
@ -112,7 +112,12 @@ func (h *NodeJSHandler) HandleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Create DNS records (use background context since HTTP context will be cancelled)
|
||||
go h.service.CreateDNSRecords(context.Background(), deployment)
|
||||
go func() {
|
||||
if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil {
|
||||
h.logger.Error("Background DNS creation failed",
|
||||
zap.String("deployment", deployment.Name), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Build response
|
||||
urls := h.service.BuildDeploymentURLs(deployment)
|
||||
|
||||
@ -209,32 +209,43 @@ func (s *DeploymentService) CreateDeployment(ctx context.Context, deployment *de
|
||||
return fmt.Errorf("failed to marshal environment: %w", err)
|
||||
}
|
||||
|
||||
// Insert deployment
|
||||
query := `
|
||||
INSERT INTO deployments (
|
||||
id, namespace, name, type, version, status,
|
||||
content_cid, build_cid, home_node_id, port, subdomain, environment,
|
||||
memory_limit_mb, cpu_limit_percent, disk_limit_mb,
|
||||
health_check_path, health_check_interval, restart_policy, max_restart_count,
|
||||
created_at, updated_at, deployed_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err = s.db.Exec(ctx, query,
|
||||
deployment.ID, deployment.Namespace, deployment.Name, deployment.Type, deployment.Version, deployment.Status,
|
||||
deployment.ContentCID, deployment.BuildCID, deployment.HomeNodeID, deployment.Port, deployment.Subdomain, string(envJSON),
|
||||
deployment.MemoryLimitMB, deployment.CPULimitPercent, deployment.DiskLimitMB,
|
||||
deployment.HealthCheckPath, deployment.HealthCheckInterval, deployment.RestartPolicy, deployment.MaxRestartCount,
|
||||
deployment.CreatedAt, deployment.UpdatedAt, deployment.DeployedBy,
|
||||
)
|
||||
// Insert deployment + record history in a single transaction
|
||||
err = s.db.Tx(ctx, func(tx rqlite.Tx) error {
|
||||
insertQuery := `
|
||||
INSERT INTO deployments (
|
||||
id, namespace, name, type, version, status,
|
||||
content_cid, build_cid, home_node_id, port, subdomain, environment,
|
||||
memory_limit_mb, cpu_limit_percent, disk_limit_mb,
|
||||
health_check_path, health_check_interval, restart_policy, max_restart_count,
|
||||
created_at, updated_at, deployed_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, insertErr := tx.Exec(ctx, insertQuery,
|
||||
deployment.ID, deployment.Namespace, deployment.Name, deployment.Type, deployment.Version, deployment.Status,
|
||||
deployment.ContentCID, deployment.BuildCID, deployment.HomeNodeID, deployment.Port, deployment.Subdomain, string(envJSON),
|
||||
deployment.MemoryLimitMB, deployment.CPULimitPercent, deployment.DiskLimitMB,
|
||||
deployment.HealthCheckPath, deployment.HealthCheckInterval, deployment.RestartPolicy, deployment.MaxRestartCount,
|
||||
deployment.CreatedAt, deployment.UpdatedAt, deployment.DeployedBy,
|
||||
)
|
||||
if insertErr != nil {
|
||||
return insertErr
|
||||
}
|
||||
|
||||
historyQuery := `
|
||||
INSERT INTO deployment_history (id, deployment_id, version, content_cid, build_cid, deployed_at, deployed_by, status)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, histErr := tx.Exec(ctx, historyQuery,
|
||||
uuid.New().String(), deployment.ID, deployment.Version,
|
||||
deployment.ContentCID, deployment.BuildCID,
|
||||
time.Now(), deployment.DeployedBy, "deployed",
|
||||
)
|
||||
return histErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert deployment: %w", err)
|
||||
}
|
||||
|
||||
// Record in history
|
||||
s.recordHistory(ctx, deployment, "deployed")
|
||||
|
||||
// Create replica records
|
||||
if s.replicaManager != nil {
|
||||
s.createDeploymentReplicas(ctx, deployment)
|
||||
|
||||
@ -155,7 +155,12 @@ func (h *StaticDeploymentHandler) HandleUpload(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// Create DNS records (use background context since HTTP context will be cancelled)
|
||||
go h.service.CreateDNSRecords(context.Background(), deployment)
|
||||
go func() {
|
||||
if err := h.service.CreateDNSRecords(context.Background(), deployment); err != nil {
|
||||
h.logger.Error("Background DNS creation failed",
|
||||
zap.String("deployment", deployment.Name), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Build URLs
|
||||
urls := h.service.BuildDeploymentURLs(deployment)
|
||||
|
||||
@ -139,9 +139,11 @@ func (h *UpdateHandler) updateStatic(ctx context.Context, existing *deployments.
|
||||
|
||||
cid := addResp.Cid
|
||||
|
||||
oldContentCID := existing.ContentCID
|
||||
|
||||
h.logger.Info("New content uploaded",
|
||||
zap.String("deployment", existing.Name),
|
||||
zap.String("old_cid", existing.ContentCID),
|
||||
zap.String("old_cid", oldContentCID),
|
||||
zap.String("new_cid", cid),
|
||||
)
|
||||
|
||||
@ -160,6 +162,13 @@ func (h *UpdateHandler) updateStatic(ctx context.Context, existing *deployments.
|
||||
return nil, fmt.Errorf("failed to update deployment: %w", err)
|
||||
}
|
||||
|
||||
// Unpin old IPFS content (best-effort)
|
||||
if oldContentCID != "" && oldContentCID != cid {
|
||||
if unpinErr := h.staticHandler.ipfsClient.Unpin(ctx, oldContentCID); unpinErr != nil {
|
||||
h.logger.Warn("Failed to unpin old content CID", zap.String("cid", oldContentCID), zap.Error(unpinErr))
|
||||
}
|
||||
}
|
||||
|
||||
// Record in history
|
||||
h.service.recordHistory(ctx, existing, "updated")
|
||||
|
||||
@ -193,9 +202,11 @@ func (h *UpdateHandler) updateDynamic(ctx context.Context, existing *deployments
|
||||
|
||||
cid := addResp.Cid
|
||||
|
||||
oldBuildCID := existing.BuildCID
|
||||
|
||||
h.logger.Info("New build uploaded",
|
||||
zap.String("deployment", existing.Name),
|
||||
zap.String("old_cid", existing.BuildCID),
|
||||
zap.String("old_cid", oldBuildCID),
|
||||
zap.String("new_cid", cid),
|
||||
)
|
||||
|
||||
@ -264,6 +275,13 @@ func (h *UpdateHandler) updateDynamic(ctx context.Context, existing *deployments
|
||||
// Cleanup old
|
||||
removeDirectory(oldPath)
|
||||
|
||||
// Unpin old IPFS build (best-effort)
|
||||
if oldBuildCID != "" && oldBuildCID != cid {
|
||||
if unpinErr := h.nextjsHandler.ipfsClient.Unpin(ctx, oldBuildCID); unpinErr != nil {
|
||||
h.logger.Warn("Failed to unpin old build CID", zap.String("cid", oldBuildCID), zap.Error(unpinErr))
|
||||
}
|
||||
}
|
||||
|
||||
existing.BuildCID = cid
|
||||
existing.Version = newVersion
|
||||
existing.UpdatedAt = now
|
||||
|
||||
@ -41,6 +41,7 @@ type ClusterManager struct {
|
||||
portAllocator *NamespacePortAllocator
|
||||
nodeSelector *ClusterNodeSelector
|
||||
systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners
|
||||
dnsManager *DNSRecordManager
|
||||
logger *zap.Logger
|
||||
baseDomain string
|
||||
baseDataDir string
|
||||
@ -70,6 +71,7 @@ func NewClusterManager(
|
||||
portAllocator := NewNamespacePortAllocator(db, logger)
|
||||
nodeSelector := NewClusterNodeSelector(db, portAllocator, logger)
|
||||
systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger)
|
||||
dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger)
|
||||
|
||||
// Set IPFS defaults
|
||||
ipfsClusterAPIURL := cfg.IPFSClusterAPIURL
|
||||
@ -94,6 +96,7 @@ func NewClusterManager(
|
||||
portAllocator: portAllocator,
|
||||
nodeSelector: nodeSelector,
|
||||
systemdSpawner: systemdSpawner,
|
||||
dnsManager: dnsManager,
|
||||
baseDomain: cfg.BaseDomain,
|
||||
baseDataDir: cfg.BaseDataDir,
|
||||
globalRQLiteDSN: cfg.GlobalRQLiteDSN,
|
||||
@ -138,6 +141,7 @@ func NewClusterManagerWithComponents(
|
||||
portAllocator: portAllocator,
|
||||
nodeSelector: nodeSelector,
|
||||
systemdSpawner: systemdSpawner,
|
||||
dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger),
|
||||
baseDomain: cfg.BaseDomain,
|
||||
baseDataDir: cfg.BaseDataDir,
|
||||
globalRQLiteDSN: cfg.GlobalRQLiteDSN,
|
||||
@ -757,10 +761,8 @@ func (cm *ClusterManager) sendStopRequest(ctx context.Context, nodeIP, action, n
|
||||
}
|
||||
|
||||
// createDNSRecords creates DNS records for the namespace gateway.
|
||||
// Creates A records pointing to the public IPs of nodes running the namespace gateway cluster.
|
||||
// Creates A records (+ wildcards) pointing to the public IPs of nodes running the namespace gateway cluster.
|
||||
func (cm *ClusterManager) createDNSRecords(ctx context.Context, cluster *NamespaceCluster, nodes []NodeCapacity, portBlocks []*PortBlock) error {
|
||||
fqdn := fmt.Sprintf("ns-%s.%s.", cluster.NamespaceName, cm.baseDomain)
|
||||
|
||||
// Collect public IPs from the selected cluster nodes
|
||||
var gatewayIPs []string
|
||||
for _, node := range nodes {
|
||||
@ -777,34 +779,12 @@ func (cm *ClusterManager) createDNSRecords(ctx context.Context, cluster *Namespa
|
||||
return fmt.Errorf("no valid node IPs found for DNS records")
|
||||
}
|
||||
|
||||
cm.logger.Info("Creating DNS records for namespace gateway",
|
||||
zap.String("namespace", cluster.NamespaceName),
|
||||
zap.Strings("ips", gatewayIPs),
|
||||
)
|
||||
|
||||
recordCount := 0
|
||||
for _, ip := range gatewayIPs {
|
||||
query := `
|
||||
INSERT INTO dns_records (fqdn, record_type, value, ttl, namespace, created_by)
|
||||
VALUES (?, 'A', ?, 300, ?, 'system')
|
||||
`
|
||||
_, err := cm.db.Exec(ctx, query, fqdn, ip, cluster.NamespaceName)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to create DNS record",
|
||||
zap.String("fqdn", fqdn),
|
||||
zap.String("ip", ip),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
cm.logger.Info("Created DNS A record for gateway node",
|
||||
zap.String("fqdn", fqdn),
|
||||
zap.String("ip", ip),
|
||||
)
|
||||
recordCount++
|
||||
}
|
||||
if err := cm.dnsManager.CreateNamespaceRecords(ctx, cluster.NamespaceName, gatewayIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cm.logEvent(ctx, cluster.ID, EventDNSCreated, "", fmt.Sprintf("DNS records created for %s (%d gateway node records)", fqdn, recordCount), nil)
|
||||
fqdn := fmt.Sprintf("ns-%s.%s.", cluster.NamespaceName, cm.baseDomain)
|
||||
cm.logEvent(ctx, cluster.ID, EventDNSCreated, "", fmt.Sprintf("DNS records created for %s (%d gateway node records)", fqdn, len(gatewayIPs)*2), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -864,12 +844,10 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in
|
||||
cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID)
|
||||
|
||||
// Delete DNS records
|
||||
query := `DELETE FROM dns_records WHERE namespace = ?`
|
||||
cm.db.Exec(ctx, query, cluster.NamespaceName)
|
||||
cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName)
|
||||
|
||||
// Delete cluster record
|
||||
query = `DELETE FROM namespace_clusters WHERE id = ?`
|
||||
cm.db.Exec(ctx, query, cluster.ID)
|
||||
cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID)
|
||||
|
||||
cm.logEvent(ctx, cluster.ID, EventDeprovisioned, "", "Cluster deprovisioned", nil)
|
||||
|
||||
|
||||
@ -33,6 +33,9 @@ type Config struct {
|
||||
TimerPollInterval time.Duration `yaml:"timer_poll_interval"`
|
||||
DBPollInterval time.Duration `yaml:"db_poll_interval"`
|
||||
|
||||
// WASM execution limits
|
||||
MaxConcurrentExecutions int `yaml:"max_concurrent_executions"` // Max concurrent WASM module instantiations
|
||||
|
||||
// WASM compilation cache
|
||||
ModuleCacheSize int `yaml:"module_cache_size"` // Number of compiled modules to cache
|
||||
EnablePrewarm bool `yaml:"enable_prewarm"` // Pre-compile frequently used functions
|
||||
@ -75,6 +78,9 @@ func DefaultConfig() *Config {
|
||||
TimerPollInterval: time.Second,
|
||||
DBPollInterval: time.Second * 5,
|
||||
|
||||
// WASM execution
|
||||
MaxConcurrentExecutions: 10,
|
||||
|
||||
// WASM cache
|
||||
ModuleCacheSize: 100,
|
||||
EnablePrewarm: true,
|
||||
@ -154,6 +160,9 @@ func (c *Config) ApplyDefaults() {
|
||||
if c.DBPollInterval == 0 {
|
||||
c.DBPollInterval = defaults.DBPollInterval
|
||||
}
|
||||
if c.MaxConcurrentExecutions == 0 {
|
||||
c.MaxConcurrentExecutions = defaults.MaxConcurrentExecutions
|
||||
}
|
||||
if c.ModuleCacheSize == 0 {
|
||||
c.ModuleCacheSize = defaults.ModuleCacheSize
|
||||
}
|
||||
|
||||
@ -116,7 +116,7 @@ func NewEngine(cfg *Config, registry FunctionRegistry, hostServices HostServices
|
||||
hostServices: hostServices,
|
||||
logger: logger,
|
||||
moduleCache: cache.NewModuleCache(cfg.ModuleCacheSize, logger),
|
||||
executor: execution.NewExecutor(runtime, logger),
|
||||
executor: execution.NewExecutor(runtime, logger, cfg.MaxConcurrentExecutions),
|
||||
lifecycle: execution.NewModuleLifecycle(runtime, logger),
|
||||
}
|
||||
|
||||
@ -204,6 +204,12 @@ func (e *Engine) Precompile(ctx context.Context, wasmCID string, wasmBytes []byt
|
||||
return &DeployError{FunctionName: wasmCID, Cause: err}
|
||||
}
|
||||
|
||||
// Enforce memory limits
|
||||
if err := e.checkMemoryLimits(compiled); err != nil {
|
||||
compiled.Close(ctx)
|
||||
return &DeployError{FunctionName: wasmCID, Cause: err}
|
||||
}
|
||||
|
||||
// Cache the compiled module
|
||||
e.moduleCache.Set(wasmCID, compiled)
|
||||
|
||||
@ -233,6 +239,19 @@ func (e *Engine) GetCacheStats() (size int, capacity int) {
|
||||
// Private methods
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// checkMemoryLimits validates that a compiled module's memory declarations
|
||||
// don't exceed the configured maximum. Each WASM memory page is 64KB.
|
||||
func (e *Engine) checkMemoryLimits(compiled wazero.CompiledModule) error {
|
||||
maxPages := uint32(e.config.MaxMemoryLimitMB * 16) // 1 MB = 16 pages (64KB each)
|
||||
for _, mem := range compiled.ExportedMemories() {
|
||||
if max, hasMax := mem.Max(); hasMax && max > maxPages {
|
||||
return fmt.Errorf("module declares %d MB max memory, exceeds limit of %d MB",
|
||||
max/16, e.config.MaxMemoryLimitMB)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCompileModule retrieves a compiled module from cache or compiles it.
|
||||
func (e *Engine) getOrCompileModule(ctx context.Context, wasmCID string) (wazero.CompiledModule, error) {
|
||||
return e.moduleCache.GetOrCompute(wasmCID, func() (wazero.CompiledModule, error) {
|
||||
@ -248,6 +267,12 @@ func (e *Engine) getOrCompileModule(ctx context.Context, wasmCID string) (wazero
|
||||
return nil, ErrCompilationFailed
|
||||
}
|
||||
|
||||
// Enforce memory limits
|
||||
if err := e.checkMemoryLimits(compiled); err != nil {
|
||||
compiled.Close(ctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return compiled, nil
|
||||
})
|
||||
}
|
||||
|
||||
@ -15,13 +15,20 @@ import (
|
||||
type Executor struct {
|
||||
runtime wazero.Runtime
|
||||
logger *zap.Logger
|
||||
sem chan struct{} // concurrency limiter
|
||||
}
|
||||
|
||||
// NewExecutor creates a new Executor.
|
||||
func NewExecutor(runtime wazero.Runtime, logger *zap.Logger) *Executor {
|
||||
// maxConcurrent limits simultaneous module instantiations (0 = unlimited).
|
||||
func NewExecutor(runtime wazero.Runtime, logger *zap.Logger, maxConcurrent int) *Executor {
|
||||
var sem chan struct{}
|
||||
if maxConcurrent > 0 {
|
||||
sem = make(chan struct{}, maxConcurrent)
|
||||
}
|
||||
return &Executor{
|
||||
runtime: runtime,
|
||||
logger: logger,
|
||||
sem: sem,
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,6 +56,16 @@ func (e *Executor) ExecuteModule(ctx context.Context, compiled wazero.CompiledMo
|
||||
WithStderr(stderr).
|
||||
WithArgs(moduleName) // argv[0] is the program name
|
||||
|
||||
// Acquire concurrency slot
|
||||
if e.sem != nil {
|
||||
select {
|
||||
case e.sem <- struct{}{}:
|
||||
defer func() { <-e.sem }()
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate and run the module (WASI _start will be called automatically)
|
||||
instance, err := e.runtime.InstantiateModule(ctx, compiled, moduleConfig)
|
||||
if err != nil {
|
||||
|
||||
51
pkg/serverless/ratelimit.go
Normal file
51
pkg/serverless/ratelimit.go
Normal file
@ -0,0 +1,51 @@
|
||||
package serverless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBucketLimiter implements RateLimiter using a token bucket algorithm.
|
||||
type TokenBucketLimiter struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
max float64
|
||||
refill float64 // tokens per second
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// NewTokenBucketLimiter creates a rate limiter with the given per-minute limit.
|
||||
func NewTokenBucketLimiter(perMinute int) *TokenBucketLimiter {
|
||||
perSecond := float64(perMinute) / 60.0
|
||||
return &TokenBucketLimiter{
|
||||
tokens: float64(perMinute), // start full
|
||||
max: float64(perMinute),
|
||||
refill: perSecond,
|
||||
lastTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed. Returns true if allowed.
|
||||
func (t *TokenBucketLimiter) Allow(_ context.Context, _ string) (bool, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(t.lastTime).Seconds()
|
||||
t.lastTime = now
|
||||
|
||||
// Refill tokens
|
||||
t.tokens += elapsed * t.refill
|
||||
if t.tokens > t.max {
|
||||
t.tokens = t.max
|
||||
}
|
||||
|
||||
// Check if we have a token
|
||||
if t.tokens < 1.0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
t.tokens--
|
||||
return true, nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user