diff --git a/CHANGELOG.md b/CHANGELOG.md index 82534f5..664f1e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,20 @@ The format is based on [Keep a Changelog][keepachangelog] and adheres to [Semant ### Added +- Created new rqlite folder +- Created rqlite adapter, client, gateway, migrations and rqlite init + ### Changed +- Updated node.go to support new rqlite architecture +- Updated readme + ### Deprecated ### Removed +- Removed old storage folder + ### Fixed ### Security diff --git a/Makefile b/Makefile index e3a5b18..2343443 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ test-e2e: .PHONY: build clean test run-node run-node2 run-node3 run-example deps tidy fmt vet lint clear-ports -VERSION := 0.44.0-beta +VERSION := 0.50.0-beta COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ) LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)' diff --git a/README.md b/README.md index 7eabb77..a0f038b 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ A robust, decentralized peer-to-peer network built in Go, providing distributed - [CLI Usage](#cli-usage) - [HTTP Gateway](#http-gateway) - [Development](#development) +- [Database Client (Go ORM-like)](#database-client-go-orm-like) - [Troubleshooting](#troubleshooting) - [License](#license) @@ -700,6 +701,242 @@ make clean # Clean build artifacts scripts/test-multinode.sh ``` +--- + +## Database Client (Go ORM-like) + +A lightweight ORM-like client over rqlite using Go’s `database/sql`. It provides: +- Query/Exec for raw SQL +- A fluent QueryBuilder (`Where`, `InnerJoin`, `LeftJoin`, `OrderBy`, `GroupBy`, `Limit`, `Offset`) +- Simple repositories with `Find`/`FindOne` +- `Save`/`Remove` for entities with primary keys +- Transaction support via `Tx` + +### Installation + +- Ensure rqlite is running (the node starts and manages rqlite automatically). +- Import the client: + - Package: `github.com/DeBrosOfficial/network/pkg/rqlite` + +### Quick Start + +```go +package main + +import ( + "context" + "database/sql" + "time" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + _ "github.com/rqlite/gorqlite/stdlib" +) + +type User struct { + ID int64 `db:"id,pk,auto"` + Email string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + CreatedAt time.Time `db:"created_at"` +} + +func (User) TableName() string { return "users" } + +func main() { + ctx := context.Background() + + adapter, _ := rqlite.NewRQLiteAdapter(manager) + client := rqlite.NewClientFromAdapter(adapter) + + // Save (INSERT) + u := &User{Email: "alice@example.com", FirstName: "Alice", LastName: "A"} + _ = client.Save(ctx, u) // auto-sets u.ID if autoincrement is available + + // FindOneBy + var one User + _ = client.FindOneBy(ctx, &one, "users", map[string]any{"email": "alice@example.com"}) + + // QueryBuilder + var users []User + _ = client.CreateQueryBuilder("users"). + Where("email LIKE ?", "%@example.com"). + OrderBy("created_at DESC"). + Limit(10). + GetMany(ctx, &users) +} + +### Entities and Mapping + +- Use struct tags: `db:"column_name"`; the first tag value is the column name. +- Mark primary key: `db:"id,pk"` (and `auto` if autoincrement): `db:"id,pk,auto"`. +- Fallbacks: + - If no `db` tag is provided, the field name is used as the column (case-insensitive). + - If a field is named `ID`, it is treated as the primary key by default. + +```go +type Post struct { + ID int64 `db:"id,pk,auto"` + UserID int64 `db:"user_id"` + Title string `db:"title"` + Body string `db:"body"` + CreatedAt time.Time `db:"created_at"` +} +func (Post) TableName() string { return "posts" } +``` + +### Basic queries + +Raw SQL with scanning into structs or maps: + +```go +var users []User +err := client.Query(ctx, &users, "SELECT id, email, first_name, last_name, created_at FROM users WHERE email LIKE ?", "%@example.com") +if err != nil { + // handle +} + +var rows []map[string]any +_ = client.Query(ctx, &rows, "SELECT id, email FROM users WHERE id IN (?,?)", 1, 2) +``` + +### Query Buider + +Build complex SELECTs with joins, filters, grouping, ordering, and pagination. + +```go +var results []User +qb := client.CreateQueryBuilder("users u"). + InnerJoin("posts p", "p.user_id = u.id"). + Where("u.email LIKE ?", "%@example.com"). + AndWhere("p.created_at >= ?", "2024-01-01T00:00:00Z"). + GroupBy("u.id"). + OrderBy("u.created_at DESC"). + Limit(20). + Offset(0) + +if err := qb.GetMany(ctx, &results); err != nil { + // handle +} + +// Single row (LIMIT 1) +var one User +if err := qb.Limit(1).GetOne(ctx, &one); err != nil { + // handle sql.ErrNoRows, etc. +} +``` + +### FindBy / FindOneBy + +Simple map-based filters: + +```go +var active []User +_ = client.FindBy(ctx, &active, "users", map[string]any{"last_name": "A"}, rqlite.WithOrderBy("created_at DESC"), rqlite.WithLimit(50)) + +var u User +if err := client.FindOneBy(ctx, &u, "users", map[string]any{"email": "alice@example.com"}); err != nil { + // sql.ErrNoRows if not found +} +``` + +### Save / Remove + +`Save` inserts if PK is zero, otherwise updates by PK. +`Remove` deletes by PK. + +```go +// Insert (ID is zero) +u := &User{Email: "bob@example.com", FirstName: "Bob"} +_ = client.Save(ctx, u) // INSERT; sets u.ID if autoincrement + +// Update (ID is non-zero) +u.FirstName = "Bobby" +_ = client.Save(ctx, u) // UPDATE ... WHERE id = ? + +// Remove +_ = client.Remove(ctx, u) // DELETE ... WHERE id = ? + +``` + +### transactions + +Run multiple operations atomically. If your function returns an error, the transaction is rolled back; otherwise it commits. + +```go +err := client.Tx(ctx, func(tx rqlite.Tx) error { + // Read inside the same transaction + var me User + if err := tx.Query(ctx, &me, "SELECT * FROM users WHERE id = ?", 1); err != nil { + return err + } + + // Write inside the same transaction + me.LastName = "Updated" + if err := tx.Save(ctx, &me); err != nil { + return err + } + + // Complex query via builder + var recent []User + if err := tx.CreateQueryBuilder("users"). + OrderBy("created_at DESC"). + Limit(5). + GetMany(ctx, &recent); err != nil { + return err + } + + return nil // commit +}) + +``` + +### Repositories (optional, generic) + +Strongly-typed convenience layer bound to a table: + +```go +repo := client.Repository[User]("users") + +var many []User +_ = repo.Find(ctx, &many, map[string]any{"last_name": "A"}, rqlite.WithOrderBy("created_at DESC"), rqlite.WithLimit(10)) + +var one User +_ = repo.FindOne(ctx, &one, map[string]any{"email": "alice@example.com"}) + +_ = repo.Save(ctx, &one) +_ = repo.Remove(ctx, &one) + +``` + +### Migrations + +Option A: From the node (after rqlite is ready) + +```go +ctx := context.Background() +dirs := []string{ + "network/migrations", // default + "path/to/your/app/migrations", // extra +} + +if err := rqliteManager.ApplyMigrationsDirs(ctx, dirs); err != nil { + logger.Fatal("apply migrations failed", zap.Error(err)) +} +``` + +Option B: Using the adapter sql.DB + +```go +ctx := context.Background() +db := adapter.GetSQLDB() +dirs := []string{"network/migrations", "app/migrations"} + +if err := rqlite.ApplyMigrationsDirs(ctx, db, dirs, logger); err != nil { + logger.Fatal("apply migrations failed", zap.Error(err)) +} +``` + + --- ## Troubleshooting diff --git a/pkg/node/node.go b/pkg/node/node.go index b41feca..78fc0ab 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -23,9 +23,9 @@ import ( "go.uber.org/zap" "github.com/DeBrosOfficial/network/pkg/config" - "github.com/DeBrosOfficial/network/pkg/database" "github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/pubsub" + database "github.com/DeBrosOfficial/network/pkg/rqlite" ) // Node represents a network node with RQLite database diff --git a/pkg/database/adapter.go b/pkg/rqlite/adapter.go similarity index 98% rename from pkg/database/adapter.go rename to pkg/rqlite/adapter.go index 1aa2686..81205bf 100644 --- a/pkg/database/adapter.go +++ b/pkg/rqlite/adapter.go @@ -1,4 +1,4 @@ -package database +package rqlite import ( "database/sql" diff --git a/pkg/rqlite/client.go b/pkg/rqlite/client.go new file mode 100644 index 0000000..70c78e2 --- /dev/null +++ b/pkg/rqlite/client.go @@ -0,0 +1,835 @@ +package rqlite + +// client.go defines the ORM-like interfaces and a minimal implementation over database/sql. +// It builds on the rqlite stdlib driver so it behaves like a regular SQL-backed ORM. + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +// TableNamer lets a struct provide its table name. +type TableNamer interface { + TableName() string +} + +// Client is the high-level ORM-like API. +type Client interface { + // Query runs an arbitrary SELECT and scans rows into dest (pointer to slice of structs or []map[string]any). + Query(ctx context.Context, dest any, query string, args ...any) error + // Exec runs a write statement (INSERT/UPDATE/DELETE). + Exec(ctx context.Context, query string, args ...any) (sql.Result, error) + + // FindBy/FindOneBy provide simple map-based criteria filtering. + FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error + FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error + + // Save inserts or updates an entity (single-PK). + Save(ctx context.Context, entity any) error + // Remove deletes by PK (single-PK). + Remove(ctx context.Context, entity any) error + + // Repositories (generic layer). Optional but convenient if you use Go generics. + Repository(table string) any + + // Fluent query builder for advanced querying. + CreateQueryBuilder(table string) *QueryBuilder + + // Tx executes a function within a transaction. + Tx(ctx context.Context, fn func(tx Tx) error) error +} + +// Tx mirrors Client but executes within a transaction. +type Tx interface { + Query(ctx context.Context, dest any, query string, args ...any) error + Exec(ctx context.Context, query string, args ...any) (sql.Result, error) + CreateQueryBuilder(table string) *QueryBuilder + + // Optional: scoped Save/Remove inside tx + Save(ctx context.Context, entity any) error + Remove(ctx context.Context, entity any) error +} + +// Repository provides typed entity operations for a table. +type Repository[T any] interface { + Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error + FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error + Save(ctx context.Context, entity *T) error + Remove(ctx context.Context, entity *T) error + + // Builder helpers + Q() *QueryBuilder +} + +// NewClient wires the ORM client to a *sql.DB (from your RQLiteAdapter). +func NewClient(db *sql.DB) Client { + return &client{db: db} +} + +// NewClientFromAdapter is convenient if you already created the adapter. +func NewClientFromAdapter(adapter *RQLiteAdapter) Client { + return NewClient(adapter.GetSQLDB()) +} + +// client implements Client over *sql.DB. +type client struct { + db *sql.DB +} + +func (c *client) Query(ctx context.Context, dest any, query string, args ...any) error { + rows, err := c.db.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer rows.Close() + return scanIntoDest(rows, dest) +} + +func (c *client) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { + return c.db.ExecContext(ctx, query, args...) +} + +func (c *client) FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error { + qb := c.CreateQueryBuilder(table) + for k, v := range criteria { + qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v) + } + for _, opt := range opts { + opt(qb) + } + return qb.GetMany(ctx, dest) +} + +func (c *client) FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error { + qb := c.CreateQueryBuilder(table) + for k, v := range criteria { + qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v) + } + for _, opt := range opts { + opt(qb) + } + return qb.GetOne(ctx, dest) +} + +func (c *client) Save(ctx context.Context, entity any) error { + return saveEntity(ctx, c.db, entity) +} + +func (c *client) Remove(ctx context.Context, entity any) error { + return removeEntity(ctx, c.db, entity) +} + +func (c *client) Repository(table string) any { + // This returns an untyped interface since Go methods cannot have type parameters + // Users will need to type assert the result to Repository[T] + return func() any { + return &repository[any]{c: c, table: table} + }() +} + +func (c *client) CreateQueryBuilder(table string) *QueryBuilder { + return newQueryBuilder(c.db, table) +} + +func (c *client) Tx(ctx context.Context, fn func(tx Tx) error) error { + sqlTx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return err + } + txc := &txClient{tx: sqlTx} + if err := fn(txc); err != nil { + _ = sqlTx.Rollback() + return err + } + return sqlTx.Commit() +} + +// txClient implements Tx over *sql.Tx. +type txClient struct { + tx *sql.Tx +} + +func (t *txClient) Query(ctx context.Context, dest any, query string, args ...any) error { + rows, err := t.tx.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer rows.Close() + return scanIntoDest(rows, dest) +} + +func (t *txClient) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *txClient) CreateQueryBuilder(table string) *QueryBuilder { + return newQueryBuilder(t.tx, table) +} + +func (t *txClient) Save(ctx context.Context, entity any) error { + return saveEntity(ctx, t.tx, entity) +} + +func (t *txClient) Remove(ctx context.Context, entity any) error { + return removeEntity(ctx, t.tx, entity) +} + +// executor is implemented by *sql.DB and *sql.Tx. +type executor interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +// QueryBuilder implements a fluent SELECT builder with joins, where, etc. +type QueryBuilder struct { + exec executor + table string + alias string + selects []string + + joins []joinClause + wheres []whereClause + + groupBys []string + orderBys []string + limit *int + offset *int +} + +// joinClause represents INNER/LEFT/etc joins. +type joinClause struct { + kind string // "INNER", "LEFT", "JOIN" (default) + table string + on string +} + +// whereClause holds an expression and args with a conjunction. +type whereClause struct { + conj string // "AND" or "OR" + expr string + args []any +} + +func newQueryBuilder(exec executor, table string) *QueryBuilder { + return &QueryBuilder{ + exec: exec, + table: table, + } +} + +func (qb *QueryBuilder) Select(cols ...string) *QueryBuilder { + qb.selects = append(qb.selects, cols...) + return qb +} + +func (qb *QueryBuilder) Alias(a string) *QueryBuilder { + qb.alias = a + return qb +} + +func (qb *QueryBuilder) Where(expr string, args ...any) *QueryBuilder { + return qb.AndWhere(expr, args...) +} + +func (qb *QueryBuilder) AndWhere(expr string, args ...any) *QueryBuilder { + qb.wheres = append(qb.wheres, whereClause{conj: "AND", expr: expr, args: args}) + return qb +} + +func (qb *QueryBuilder) OrWhere(expr string, args ...any) *QueryBuilder { + qb.wheres = append(qb.wheres, whereClause{conj: "OR", expr: expr, args: args}) + return qb +} + +func (qb *QueryBuilder) InnerJoin(table string, on string) *QueryBuilder { + qb.joins = append(qb.joins, joinClause{kind: "INNER", table: table, on: on}) + return qb +} + +func (qb *QueryBuilder) LeftJoin(table string, on string) *QueryBuilder { + qb.joins = append(qb.joins, joinClause{kind: "LEFT", table: table, on: on}) + return qb +} + +func (qb *QueryBuilder) Join(table string, on string) *QueryBuilder { + qb.joins = append(qb.joins, joinClause{kind: "JOIN", table: table, on: on}) + return qb +} + +func (qb *QueryBuilder) GroupBy(cols ...string) *QueryBuilder { + qb.groupBys = append(qb.groupBys, cols...) + return qb +} + +func (qb *QueryBuilder) OrderBy(exprs ...string) *QueryBuilder { + qb.orderBys = append(qb.orderBys, exprs...) + return qb +} + +func (qb *QueryBuilder) Limit(n int) *QueryBuilder { + qb.limit = &n + return qb +} + +func (qb *QueryBuilder) Offset(n int) *QueryBuilder { + qb.offset = &n + return qb +} + +// Build returns the SQL string and args for a SELECT. +func (qb *QueryBuilder) Build() (string, []any) { + cols := "*" + if len(qb.selects) > 0 { + cols = strings.Join(qb.selects, ", ") + } + base := fmt.Sprintf("SELECT %s FROM %s", cols, qb.table) + if qb.alias != "" { + base += " AS " + qb.alias + } + + args := make([]any, 0, 16) + for _, j := range qb.joins { + base += fmt.Sprintf(" %s JOIN %s ON %s", j.kind, j.table, j.on) + } + + if len(qb.wheres) > 0 { + base += " WHERE " + for i, w := range qb.wheres { + if i > 0 { + base += " " + w.conj + " " + } + base += "(" + w.expr + ")" + args = append(args, w.args...) + } + } + + if len(qb.groupBys) > 0 { + base += " GROUP BY " + strings.Join(qb.groupBys, ", ") + } + if len(qb.orderBys) > 0 { + base += " ORDER BY " + strings.Join(qb.orderBys, ", ") + } + if qb.limit != nil { + base += fmt.Sprintf(" LIMIT %d", *qb.limit) + } + if qb.offset != nil { + base += fmt.Sprintf(" OFFSET %d", *qb.offset) + } + return base, args +} + +// GetMany executes the built query and scans into dest (pointer to slice). +func (qb *QueryBuilder) GetMany(ctx context.Context, dest any) error { + sqlStr, args := qb.Build() + rows, err := qb.exec.QueryContext(ctx, sqlStr, args...) + if err != nil { + return err + } + defer rows.Close() + return scanIntoDest(rows, dest) +} + +// GetOne executes the built query and scans into dest (pointer to struct or map) with LIMIT 1. +func (qb *QueryBuilder) GetOne(ctx context.Context, dest any) error { + limit := 1 + if qb.limit == nil { + qb.limit = &limit + } else if qb.limit != nil && *qb.limit > 1 { + qb.limit = &limit + } + sqlStr, args := qb.Build() + rows, err := qb.exec.QueryContext(ctx, sqlStr, args...) + if err != nil { + return err + } + defer rows.Close() + if !rows.Next() { + return sql.ErrNoRows + } + return scanIntoSingle(rows, dest) +} + +// FindOption customizes Find queries. +type FindOption func(q *QueryBuilder) + +func WithOrderBy(exprs ...string) FindOption { + return func(q *QueryBuilder) { q.OrderBy(exprs...) } +} +func WithGroupBy(cols ...string) FindOption { + return func(q *QueryBuilder) { q.GroupBy(cols...) } +} +func WithLimit(n int) FindOption { + return func(q *QueryBuilder) { q.Limit(n) } +} +func WithOffset(n int) FindOption { + return func(q *QueryBuilder) { q.Offset(n) } +} +func WithSelect(cols ...string) FindOption { + return func(q *QueryBuilder) { q.Select(cols...) } +} +func WithJoin(kind, table, on string) FindOption { + return func(q *QueryBuilder) { + switch strings.ToUpper(kind) { + case "INNER": + q.InnerJoin(table, on) + case "LEFT": + q.LeftJoin(table, on) + default: + q.Join(table, on) + } + } +} + +// repository is a generic table repository for type T. +type repository[T any] struct { + c *client + table string +} + +func (r *repository[T]) Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error { + qb := r.c.CreateQueryBuilder(r.table) + for k, v := range criteria { + qb.AndWhere(fmt.Sprintf("%s = ?", k), v) + } + for _, opt := range opts { + opt(qb) + } + return qb.GetMany(ctx, dest) +} + +func (r *repository[T]) FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error { + qb := r.c.CreateQueryBuilder(r.table) + for k, v := range criteria { + qb.AndWhere(fmt.Sprintf("%s = ?", k), v) + } + for _, opt := range opts { + opt(qb) + } + return qb.GetOne(ctx, dest) +} + +func (r *repository[T]) Save(ctx context.Context, entity *T) error { + return saveEntity(ctx, r.c.db, entity) +} + +func (r *repository[T]) Remove(ctx context.Context, entity *T) error { + return removeEntity(ctx, r.c.db, entity) +} + +func (r *repository[T]) Q() *QueryBuilder { + return r.c.CreateQueryBuilder(r.table) +} + +// ----------------------- +// Reflection + scanning +// ----------------------- + +func scanIntoDest(rows *sql.Rows, dest any) error { + // dest must be pointer to slice (of struct or map) + rv := reflect.ValueOf(dest) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return errors.New("dest must be a non-nil pointer") + } + sliceVal := rv.Elem() + if sliceVal.Kind() != reflect.Slice { + return errors.New("dest must be pointer to a slice") + } + elemType := sliceVal.Type().Elem() + + cols, err := rows.Columns() + if err != nil { + return err + } + + for rows.Next() { + itemPtr := reflect.New(elemType) + // Support map[string]any and struct + if elemType.Kind() == reflect.Map { + m, err := scanRowToMap(rows, cols) + if err != nil { + return err + } + sliceVal.Set(reflect.Append(sliceVal, reflect.ValueOf(m))) + continue + } + + if elemType.Kind() == reflect.Struct { + if err := scanCurrentRowIntoStruct(rows, cols, itemPtr.Elem()); err != nil { + return err + } + sliceVal.Set(reflect.Append(sliceVal, itemPtr.Elem())) + continue + } + + return fmt.Errorf("unsupported slice element type: %s", elemType.Kind()) + } + return rows.Err() +} + +func scanIntoSingle(rows *sql.Rows, dest any) error { + rv := reflect.ValueOf(dest) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return errors.New("dest must be a non-nil pointer") + } + cols, err := rows.Columns() + if err != nil { + return err + } + + switch rv.Elem().Kind() { + case reflect.Map: + m, err := scanRowToMap(rows, cols) + if err != nil { + return err + } + rv.Elem().Set(reflect.ValueOf(m)) + return nil + case reflect.Struct: + return scanCurrentRowIntoStruct(rows, cols, rv.Elem()) + default: + return fmt.Errorf("unsupported dest kind: %s", rv.Elem().Kind()) + } +} + +func scanRowToMap(rows *sql.Rows, cols []string) (map[string]any, error) { + raw := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range raw { + ptrs[i] = &raw[i] + } + if err := rows.Scan(ptrs...); err != nil { + return nil, err + } + out := make(map[string]any, len(cols)) + for i, c := range cols { + out[c] = normalizeSQLValue(raw[i]) + } + return out, nil +} + +func scanCurrentRowIntoStruct(rows *sql.Rows, cols []string, destStruct reflect.Value) error { + raw := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range raw { + ptrs[i] = &raw[i] + } + if err := rows.Scan(ptrs...); err != nil { + return err + } + fieldIndex := buildFieldIndex(destStruct.Type()) + for i, c := range cols { + if idx, ok := fieldIndex[strings.ToLower(c)]; ok { + field := destStruct.Field(idx) + if field.CanSet() { + if err := setReflectValue(field, raw[i]); err != nil { + return fmt.Errorf("column %s: %w", c, err) + } + } + } + } + return nil +} + +func normalizeSQLValue(v any) any { + switch t := v.(type) { + case []byte: + return string(t) + default: + return v + } +} + +func buildFieldIndex(t reflect.Type) map[string]int { + m := make(map[string]int) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.IsExported() == false { + continue + } + tag := f.Tag.Get("db") + col := "" + if tag != "" { + col = strings.Split(tag, ",")[0] + } + if col == "" { + col = f.Name + } + m[strings.ToLower(col)] = i + } + return m +} + +func setReflectValue(field reflect.Value, raw any) error { + if raw == nil { + // leave zero value + return nil + } + switch field.Kind() { + case reflect.String: + switch v := raw.(type) { + case string: + field.SetString(v) + case []byte: + field.SetString(string(v)) + default: + field.SetString(fmt.Sprint(v)) + } + case reflect.Bool: + switch v := raw.(type) { + case bool: + field.SetBool(v) + case int64: + field.SetBool(v != 0) + case []byte: + s := string(v) + field.SetBool(s == "1" || strings.EqualFold(s, "true")) + default: + field.SetBool(false) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v := raw.(type) { + case int64: + field.SetInt(v) + case []byte: + var n int64 + fmt.Sscan(string(v), &n) + field.SetInt(n) + default: + return fmt.Errorf("cannot convert %T to int", raw) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + switch v := raw.(type) { + case int64: + if v < 0 { + v = 0 + } + field.SetUint(uint64(v)) + case []byte: + var n uint64 + fmt.Sscan(string(v), &n) + field.SetUint(n) + default: + return fmt.Errorf("cannot convert %T to uint", raw) + } + case reflect.Float32, reflect.Float64: + switch v := raw.(type) { + case float64: + field.SetFloat(v) + case []byte: + var fv float64 + fmt.Sscan(string(v), &fv) + field.SetFloat(fv) + default: + return fmt.Errorf("cannot convert %T to float", raw) + } + case reflect.Struct: + // Support time.Time; extend as needed. + if field.Type() == reflect.TypeOf(time.Time{}) { + switch v := raw.(type) { + case time.Time: + field.Set(reflect.ValueOf(v)) + case []byte: + // Try RFC3339 + if tt, err := time.Parse(time.RFC3339, string(v)); err == nil { + field.Set(reflect.ValueOf(tt)) + } + } + return nil + } + fallthrough + default: + // Not supported yet + return fmt.Errorf("unsupported dest field kind: %s", field.Kind()) + } + return nil +} + +// ----------------------- +// Save/Remove (basic PK) +// ----------------------- + +type fieldMeta struct { + index int + column string + isPK bool + auto bool +} + +func collectMeta(t reflect.Type) (fields []fieldMeta, pk fieldMeta, hasPK bool) { + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + tag := f.Tag.Get("db") + if tag == "-" { + continue + } + opts := strings.Split(tag, ",") + col := opts[0] + if col == "" { + col = f.Name + } + meta := fieldMeta{index: i, column: col} + for _, o := range opts[1:] { + switch strings.ToLower(strings.TrimSpace(o)) { + case "pk": + meta.isPK = true + case "auto", "autoincrement": + meta.auto = true + } + } + // If not tagged as pk, fallback to field name "ID" + if !meta.isPK && f.Name == "ID" { + meta.isPK = true + if col == "" { + meta.column = "id" + } + } + fields = append(fields, meta) + if meta.isPK { + pk = meta + hasPK = true + } + } + return +} + +func getTableNameFromEntity(v reflect.Value) (string, bool) { + // If entity implements TableNamer + if v.CanInterface() { + if tn, ok := v.Interface().(TableNamer); ok { + return tn.TableName(), true + } + } + // Fallback: very naive pluralization (append 's') + typ := v.Type() + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + if typ.Kind() == reflect.Struct { + return strings.ToLower(typ.Name()) + "s", true + } + return "", false +} + +func saveEntity(ctx context.Context, exec executor, entity any) error { + rv := reflect.ValueOf(entity) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return errors.New("entity must be a non-nil pointer to struct") + } + ev := rv.Elem() + if ev.Kind() != reflect.Struct { + return errors.New("entity must point to a struct") + } + + fields, pkMeta, hasPK := collectMeta(ev.Type()) + if !hasPK { + return errors.New("no primary key field found (tag db:\"...,pk\" or field named ID)") + } + table, ok := getTableNameFromEntity(ev) + if !ok || table == "" { + return errors.New("unable to resolve table name; implement TableNamer or set up a repository with explicit table") + } + + // Build lists + cols := make([]string, 0, len(fields)) + vals := make([]any, 0, len(fields)) + setParts := make([]string, 0, len(fields)) + + var pkVal any + var pkIsZero bool + + for _, fm := range fields { + f := ev.Field(fm.index) + if fm.isPK { + pkVal = f.Interface() + pkIsZero = isZeroValue(f) + continue + } + cols = append(cols, fm.column) + vals = append(vals, f.Interface()) + setParts = append(setParts, fmt.Sprintf("%s = ?", fm.column)) + } + + if pkIsZero { + // INSERT + placeholders := strings.Repeat("?,", len(cols)) + if len(placeholders) > 0 { + placeholders = placeholders[:len(placeholders)-1] + } + sqlStr := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(cols, ", "), placeholders) + res, err := exec.ExecContext(ctx, sqlStr, vals...) + if err != nil { + return err + } + // Set auto ID if needed + if pkMeta.auto { + if id, err := res.LastInsertId(); err == nil { + ev.Field(pkMeta.index).SetInt(id) + } + } + return nil + } + + // UPDATE ... WHERE pk = ? + sqlStr := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", table, strings.Join(setParts, ", "), pkMeta.column) + valsWithPK := append(vals, pkVal) + _, err := exec.ExecContext(ctx, sqlStr, valsWithPK...) + return err +} + +func removeEntity(ctx context.Context, exec executor, entity any) error { + rv := reflect.ValueOf(entity) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return errors.New("entity must be a non-nil pointer to struct") + } + ev := rv.Elem() + if ev.Kind() != reflect.Struct { + return errors.New("entity must point to a struct") + } + _, pkMeta, hasPK := collectMeta(ev.Type()) + if !hasPK { + return errors.New("no primary key field found") + } + table, ok := getTableNameFromEntity(ev) + if !ok || table == "" { + return errors.New("unable to resolve table name") + } + pkVal := ev.Field(pkMeta.index).Interface() + sqlStr := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkMeta.column) + _, err := exec.ExecContext(ctx, sqlStr, pkVal) + return err +} + +func isZeroValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Bool: + return v.Bool() == false + case reflect.Pointer, reflect.Interface: + return v.IsNil() + case reflect.Slice, reflect.Map: + return v.Len() == 0 + case reflect.Struct: + // Special-case time.Time + if v.Type() == reflect.TypeOf(time.Time{}) { + t := v.Interface().(time.Time) + return t.IsZero() + } + zero := reflect.Zero(v.Type()) + return reflect.DeepEqual(v.Interface(), zero.Interface()) + default: + return false + } +} diff --git a/pkg/rqlite/gateway.go b/pkg/rqlite/gateway.go new file mode 100644 index 0000000..e69de29 diff --git a/pkg/rqlite/migrations.go b/pkg/rqlite/migrations.go new file mode 100644 index 0000000..1c43ed9 --- /dev/null +++ b/pkg/rqlite/migrations.go @@ -0,0 +1,436 @@ +package rqlite + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "unicode" + + _ "github.com/rqlite/gorqlite/stdlib" + "go.uber.org/zap" +) + +// ApplyMigrations scans a directory for *.sql files, orders them by numeric prefix, +// and applies any that are not yet recorded in schema_migrations(version). +func ApplyMigrations(ctx context.Context, db *sql.DB, dir string, logger *zap.Logger) error { + if logger == nil { + logger = zap.NewNop() + } + + if err := ensureMigrationsTable(ctx, db); err != nil { + return fmt.Errorf("ensure schema_migrations: %w", err) + } + + files, err := readMigrationFiles(dir) + if err != nil { + return fmt.Errorf("read migration files: %w", err) + } + if len(files) == 0 { + logger.Info("No migrations found", zap.String("dir", dir)) + return nil + } + + applied, err := loadAppliedVersions(ctx, db) + if err != nil { + return fmt.Errorf("load applied versions: %w", err) + } + + for _, mf := range files { + if applied[mf.Version] { + logger.Info("Migration already applied; skipping", zap.Int("version", mf.Version), zap.String("name", mf.Name)) + continue + } + + sqlBytes, err := os.ReadFile(mf.Path) + if err != nil { + return fmt.Errorf("read migration %s: %w", mf.Path, err) + } + + logger.Info("Applying migration", zap.Int("version", mf.Version), zap.String("name", mf.Name)) + if err := applySQL(ctx, db, string(sqlBytes)); err != nil { + return fmt.Errorf("apply migration %d (%s): %w", mf.Version, mf.Name, err) + } + + if _, err := db.ExecContext(ctx, `INSERT OR IGNORE INTO schema_migrations(version) VALUES (?)`, mf.Version); err != nil { + return fmt.Errorf("record migration %d: %w", mf.Version, err) + } + logger.Info("Migration applied", zap.Int("version", mf.Version), zap.String("name", mf.Name)) + } + + return nil +} + +// ApplyMigrationsDirs applies migrations from multiple directories. +// - Gathers *.sql files from each dir +// - Parses numeric prefix as the version +// - Errors if the same version appears in more than one dir (to avoid ambiguity) +// - Sorts globally by version and applies those not yet in schema_migrations +func ApplyMigrationsDirs(ctx context.Context, db *sql.DB, dirs []string, logger *zap.Logger) error { + if logger == nil { + logger = zap.NewNop() + } + if err := ensureMigrationsTable(ctx, db); err != nil { + return fmt.Errorf("ensure schema_migrations: %w", err) + } + + files, err := readMigrationFilesFromDirs(dirs) + if err != nil { + return err + } + if len(files) == 0 { + logger.Info("No migrations found in provided directories", zap.Strings("dirs", dirs)) + return nil + } + + applied, err := loadAppliedVersions(ctx, db) + if err != nil { + return fmt.Errorf("load applied versions: %w", err) + } + + for _, mf := range files { + if applied[mf.Version] { + logger.Info("Migration already applied; skipping", zap.Int("version", mf.Version), zap.String("name", mf.Name), zap.String("path", mf.Path)) + continue + } + sqlBytes, err := os.ReadFile(mf.Path) + if err != nil { + return fmt.Errorf("read migration %s: %w", mf.Path, err) + } + + logger.Info("Applying migration", zap.Int("version", mf.Version), zap.String("name", mf.Name), zap.String("path", mf.Path)) + if err := applySQL(ctx, db, string(sqlBytes)); err != nil { + return fmt.Errorf("apply migration %d (%s): %w", mf.Version, mf.Name, err) + } + + if _, err := db.ExecContext(ctx, `INSERT OR IGNORE INTO schema_migrations(version) VALUES (?)`, mf.Version); err != nil { + return fmt.Errorf("record migration %d: %w", mf.Version, err) + } + logger.Info("Migration applied", zap.Int("version", mf.Version), zap.String("name", mf.Name)) + } + + return nil +} + +// ApplyMigrationsFromManager is a convenience helper bound to RQLiteManager. +func (r *RQLiteManager) ApplyMigrations(ctx context.Context, dir string) error { + db, err := sql.Open("rqlite", fmt.Sprintf("http://localhost:%d", r.config.RQLitePort)) + if err != nil { + return fmt.Errorf("open rqlite db: %w", err) + } + defer db.Close() + + return ApplyMigrations(ctx, db, dir, r.logger) +} + +// ApplyMigrationsDirs is the multi-dir variant on RQLiteManager. +func (r *RQLiteManager) ApplyMigrationsDirs(ctx context.Context, dirs []string) error { + db, err := sql.Open("rqlite", fmt.Sprintf("http://localhost:%d", r.config.RQLitePort)) + if err != nil { + return fmt.Errorf("open rqlite db: %w", err) + } + defer db.Close() + + return ApplyMigrationsDirs(ctx, db, dirs, r.logger) +} + +func ensureMigrationsTable(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + return err +} + +type migrationFile struct { + Version int + Name string + Path string +} + +func readMigrationFiles(dir string) ([]migrationFile, error) { + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return []migrationFile{}, nil + } + return nil, err + } + + var out []migrationFile + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(strings.ToLower(name), ".sql") { + continue + } + ver, ok := parseVersionPrefix(name) + if !ok { + continue + } + out = append(out, migrationFile{ + Version: ver, + Name: name, + Path: filepath.Join(dir, name), + }) + } + sort.Slice(out, func(i, j int) bool { return out[i].Version < out[j].Version }) + return out, nil +} + +func readMigrationFilesFromDirs(dirs []string) ([]migrationFile, error) { + all := make([]migrationFile, 0, 64) + seen := map[int]string{} // version -> path (for duplicate detection) + + for _, d := range dirs { + files, err := readMigrationFiles(d) + if err != nil { + return nil, fmt.Errorf("reading dir %s: %w", d, err) + } + for _, f := range files { + if prev, dup := seen[f.Version]; dup { + return nil, fmt.Errorf("duplicate migration version %d detected in %s and %s; ensure global version uniqueness", f.Version, prev, f.Path) + } + seen[f.Version] = f.Path + all = append(all, f) + } + } + sort.Slice(all, func(i, j int) bool { return all[i].Version < all[j].Version }) + return all, nil +} + +func parseVersionPrefix(name string) (int, bool) { + // Expect formats like "001_initial.sql", "2_add_table.sql", etc. + i := 0 + for i < len(name) && unicode.IsDigit(rune(name[i])) { + i++ + } + if i == 0 { + return 0, false + } + ver, err := strconv.Atoi(name[:i]) + if err != nil { + return 0, false + } + return ver, true +} + +func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) { + rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`) + if err != nil { + // If the table doesn't exist yet (very first run), ensure it and return empty set. + if isNoSuchTable(err) { + if err := ensureMigrationsTable(ctx, db); err != nil { + return nil, err + } + return map[int]bool{}, nil + } + return nil, err + } + defer rows.Close() + + applied := make(map[int]bool) + for rows.Next() { + var v int + if err := rows.Scan(&v); err != nil { + return nil, err + } + applied[v] = true + } + return applied, rows.Err() +} + +func isNoSuchTable(err error) bool { + // rqlite/sqlite error messages vary; keep it permissive + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "no such table") || strings.Contains(msg, "does not exist") +} + +// applySQL tries to run the entire script in one Exec. +// If the driver rejects multi-statement Exec, it falls back to splitting statements and executing sequentially. +func applySQL(ctx context.Context, db *sql.DB, script string) error { + s := strings.TrimSpace(script) + if s == "" { + return nil + } + if _, err := db.ExecContext(ctx, s); err == nil { + return nil + } else { + // Fall back to splitting into statements and executing sequentially (respecting BEGIN/COMMIT if present). + stmts := splitSQLStatements(s) + // If the script already contains explicit BEGIN/COMMIT, we just run as-is. + // Otherwise, we attempt to wrap in a transaction; if BeginTx fails, execute one-by-one. + hasExplicitTxn := containsToken(stmts, "BEGIN") || containsToken(stmts, "BEGIN;") + if !hasExplicitTxn { + if tx, txErr := db.BeginTx(ctx, nil); txErr == nil { + for _, stmt := range stmts { + if stmt == "" { + continue + } + if _, execErr := tx.ExecContext(ctx, stmt); execErr != nil { + _ = tx.Rollback() + return fmt.Errorf("exec stmt failed: %w (stmt: %s)", execErr, snippet(stmt)) + } + } + return tx.Commit() + } + // Fall through to plain sequential exec if BeginTx not supported. + } + + for _, stmt := range stmts { + if stmt == "" { + continue + } + if _, execErr := db.ExecContext(ctx, stmt); execErr != nil { + return fmt.Errorf("exec stmt failed: %w (stmt: %s)", execErr, snippet(stmt)) + } + } + return nil + } +} + +func containsToken(stmts []string, token string) bool { + for _, s := range stmts { + if strings.EqualFold(strings.TrimSpace(s), token) { + return true + } + } + return false +} + +func snippet(s string) string { + s = strings.TrimSpace(s) + if len(s) > 120 { + return s[:120] + "..." + } + return s +} + +// splitSQLStatements splits a SQL script into statements by semicolon, ignoring semicolons +// inside single/double-quoted strings and skipping comments (-- and /* */). +func splitSQLStatements(in string) []string { + var out []string + var b strings.Builder + + inLineComment := false + inBlockComment := false + inSingle := false + inDouble := false + + runes := []rune(in) + for i := 0; i < len(runes); i++ { + ch := runes[i] + next := rune(0) + if i+1 < len(runes) { + next = runes[i+1] + } + + // Handle end of line comment + if inLineComment { + if ch == '\n' { + inLineComment = false + // keep newline normalization but don't include comment + } + continue + } + // Handle end of block comment + if inBlockComment { + if ch == '*' && next == '/' { + inBlockComment = false + i++ + } + continue + } + + // Start of comments? + if !inSingle && !inDouble { + if ch == '-' && next == '-' { + inLineComment = true + i++ + continue + } + if ch == '/' && next == '*' { + inBlockComment = true + i++ + continue + } + } + + // Quotes + if !inDouble && ch == '\'' { + // Toggle single quotes, respecting escaped '' inside. + if inSingle { + // Check for escaped '' (two single quotes) + if next == '\'' { + b.WriteRune(ch) // write one ' + i++ // skip the next ' + continue + } + inSingle = false + } else { + inSingle = true + } + b.WriteRune(ch) + continue + } + if !inSingle && ch == '"' { + if inDouble { + if next == '"' { + b.WriteRune(ch) + i++ + continue + } + inDouble = false + } else { + inDouble = true + } + b.WriteRune(ch) + continue + } + + // Statement boundary + if ch == ';' && !inSingle && !inDouble { + stmt := strings.TrimSpace(b.String()) + if stmt != "" { + out = append(out, stmt) + } + b.Reset() + continue + } + + b.WriteRune(ch) + } + + // Final fragment + if s := strings.TrimSpace(b.String()); s != "" { + out = append(out, s) + } + return out +} + +// Optional helper to load embedded migrations if you later decide to embed. +// Keep for future use; currently unused. +func readDirFS(fsys fs.FS, root string) ([]string, error) { + var files []string + err := fs.WalkDir(fsys, root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + if strings.HasSuffix(strings.ToLower(d.Name()), ".sql") { + files = append(files, path) + } + return nil + }) + return files, err +} diff --git a/pkg/database/rqlite.go b/pkg/rqlite/rqlite.go similarity index 96% rename from pkg/database/rqlite.go rename to pkg/rqlite/rqlite.go index de1866b..740e887 100644 --- a/pkg/database/rqlite.go +++ b/pkg/rqlite/rqlite.go @@ -1,4 +1,4 @@ -package database +package rqlite import ( "context" @@ -174,6 +174,14 @@ func (r *RQLiteManager) Start(ctx context.Context) error { } } + // After waitForLeadership / waitForSQLAvailable succeeds, before returning: + migrationsDir := "network/migrations" + + if err := r.ApplyMigrations(ctx, migrationsDir); err != nil { + r.logger.Error("Migrations failed", zap.Error(err), zap.String("dir", migrationsDir)) + return fmt.Errorf("apply migrations: %w", err) + } + r.logger.Info("RQLite node started successfully") return nil } diff --git a/pkg/storage/client.go b/pkg/storage/client.go deleted file mode 100644 index 2a551d4..0000000 --- a/pkg/storage/client.go +++ /dev/null @@ -1,231 +0,0 @@ -package storage - -import ( - "context" - "fmt" - "io" - "time" - - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" - "go.uber.org/zap" -) - -// Client provides distributed storage client functionality -type Client struct { - host host.Host - logger *zap.Logger - namespace string -} - -// Context utilities for namespace override -type ctxKey string - -// CtxKeyNamespaceOverride is the context key used to override namespace per request -const CtxKeyNamespaceOverride ctxKey = "storage_ns_override" - -// WithNamespace returns a new context that carries a storage namespace override -func WithNamespace(ctx context.Context, ns string) context.Context { - return context.WithValue(ctx, CtxKeyNamespaceOverride, ns) -} - -// NewClient creates a new storage client -func NewClient(h host.Host, namespace string, logger *zap.Logger) *Client { - return &Client{ - host: h, - logger: logger, - namespace: namespace, - } -} - -// Put stores a key-value pair in the distributed storage -func (c *Client) Put(ctx context.Context, key string, value []byte) error { - ns := c.namespace - if v := ctx.Value(CtxKeyNamespaceOverride); v != nil { - if s, ok := v.(string); ok && s != "" { - ns = s - } - } - request := &StorageRequest{ - Type: MessageTypePut, - Key: key, - Value: value, - Namespace: ns, - } - - return c.sendRequest(ctx, request) -} - -// Get retrieves a value by key from the distributed storage -func (c *Client) Get(ctx context.Context, key string) ([]byte, error) { - ns := c.namespace - if v := ctx.Value(CtxKeyNamespaceOverride); v != nil { - if s, ok := v.(string); ok && s != "" { - ns = s - } - } - request := &StorageRequest{ - Type: MessageTypeGet, - Key: key, - Namespace: ns, - } - - response, err := c.sendRequestWithResponse(ctx, request) - if err != nil { - return nil, err - } - - if !response.Success { - return nil, fmt.Errorf(response.Error) - } - - return response.Value, nil -} - -// Delete removes a key from the distributed storage -func (c *Client) Delete(ctx context.Context, key string) error { - ns := c.namespace - if v := ctx.Value(CtxKeyNamespaceOverride); v != nil { - if s, ok := v.(string); ok && s != "" { - ns = s - } - } - request := &StorageRequest{ - Type: MessageTypeDelete, - Key: key, - Namespace: ns, - } - - return c.sendRequest(ctx, request) -} - -// List returns keys with a given prefix -func (c *Client) List(ctx context.Context, prefix string, limit int) ([]string, error) { - ns := c.namespace - if v := ctx.Value(CtxKeyNamespaceOverride); v != nil { - if s, ok := v.(string); ok && s != "" { - ns = s - } - } - request := &StorageRequest{ - Type: MessageTypeList, - Prefix: prefix, - Limit: limit, - Namespace: ns, - } - - response, err := c.sendRequestWithResponse(ctx, request) - if err != nil { - return nil, err - } - - if !response.Success { - return nil, fmt.Errorf(response.Error) - } - - return response.Keys, nil -} - -// Exists checks if a key exists in the distributed storage -func (c *Client) Exists(ctx context.Context, key string) (bool, error) { - ns := c.namespace - if v := ctx.Value(CtxKeyNamespaceOverride); v != nil { - if s, ok := v.(string); ok && s != "" { - ns = s - } - } - request := &StorageRequest{ - Type: MessageTypeExists, - Key: key, - Namespace: ns, - } - - response, err := c.sendRequestWithResponse(ctx, request) - if err != nil { - return false, err - } - - if !response.Success { - return false, fmt.Errorf(response.Error) - } - - return response.Exists, nil -} - -// sendRequest sends a request without expecting a response -func (c *Client) sendRequest(ctx context.Context, request *StorageRequest) error { - _, err := c.sendRequestWithResponse(ctx, request) - return err -} - -// sendRequestWithResponse sends a request and waits for a response -func (c *Client) sendRequestWithResponse(ctx context.Context, request *StorageRequest) (*StorageResponse, error) { - // Get connected peers - peers := c.host.Network().Peers() - if len(peers) == 0 { - return nil, fmt.Errorf("no peers connected") - } - - // Try to send to the first available peer - // In a production system, you might want to implement peer selection logic - for _, peerID := range peers { - response, err := c.sendToPeer(ctx, peerID, request) - if err != nil { - c.logger.Debug("Failed to send to peer", - zap.String("peer", peerID.String()), - zap.Error(err)) - continue - } - return response, nil - } - - return nil, fmt.Errorf("failed to send request to any peer") -} - -// sendToPeer sends a request to a specific peer -func (c *Client) sendToPeer(ctx context.Context, peerID peer.ID, request *StorageRequest) (*StorageResponse, error) { - // Create a new stream to the peer - stream, err := c.host.NewStream(ctx, peerID, protocol.ID(StorageProtocolID)) - if err != nil { - return nil, fmt.Errorf("failed to create stream: %w", err) - } - defer stream.Close() - - // Set deadline for the operation - deadline, ok := ctx.Deadline() - if ok { - stream.SetDeadline(deadline) - } else { - stream.SetDeadline(time.Now().Add(30 * time.Second)) - } - - // Marshal and send request - requestData, err := request.Marshal() - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - if _, err := stream.Write(requestData); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - // Close write side to signal end of request - if err := stream.CloseWrite(); err != nil { - return nil, fmt.Errorf("failed to close write: %w", err) - } - - // Read response - responseData, err := io.ReadAll(stream) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - // Unmarshal response - var response StorageResponse - if err := response.Unmarshal(responseData); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &response, nil -} diff --git a/pkg/storage/kv_ops.go b/pkg/storage/kv_ops.go deleted file mode 100644 index ad46b85..0000000 --- a/pkg/storage/kv_ops.go +++ /dev/null @@ -1,182 +0,0 @@ -package storage - -import ( - "database/sql" - "fmt" - - "go.uber.org/zap" -) - -// processRequest processes a storage request and returns a response -func (s *Service) processRequest(req *StorageRequest) *StorageResponse { - switch req.Type { - case MessageTypePut: - return s.handlePut(req) - case MessageTypeGet: - return s.handleGet(req) - case MessageTypeDelete: - return s.handleDelete(req) - case MessageTypeList: - return s.handleList(req) - case MessageTypeExists: - return s.handleExists(req) - default: - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("unknown message type: %s", req.Type), - } - } -} - -// handlePut stores a key-value pair -func (s *Service) handlePut(req *StorageRequest) *StorageResponse { - s.mu.Lock() - defer s.mu.Unlock() - - // Use REPLACE to handle both insert and update - query := ` - REPLACE INTO kv_storage (namespace, key, value, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ` - - _, err := s.db.Exec(query, req.Namespace, req.Key, req.Value) - if err != nil { - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("failed to store key: %v", err), - } - } - - s.logger.Debug("Stored key", zap.String("key", req.Key), zap.String("namespace", req.Namespace)) - return &StorageResponse{Success: true} -} - -// handleGet retrieves a value by key -func (s *Service) handleGet(req *StorageRequest) *StorageResponse { - s.mu.RLock() - defer s.mu.RUnlock() - - query := `SELECT value FROM kv_storage WHERE namespace = ? AND key = ?` - - var value []byte - err := s.db.QueryRow(query, req.Namespace, req.Key).Scan(&value) - if err != nil { - if err == sql.ErrNoRows { - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("key not found: %s", req.Key), - } - } - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("failed to get key: %v", err), - } - } - - return &StorageResponse{ - Success: true, - Value: value, - } -} - -// handleDelete removes a key -func (s *Service) handleDelete(req *StorageRequest) *StorageResponse { - s.mu.Lock() - defer s.mu.Unlock() - - query := `DELETE FROM kv_storage WHERE namespace = ? AND key = ?` - - result, err := s.db.Exec(query, req.Namespace, req.Key) - if err != nil { - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("failed to delete key: %v", err), - } - } - - rowsAffected, _ := result.RowsAffected() - if rowsAffected == 0 { - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("key not found: %s", req.Key), - } - } - - s.logger.Debug("Deleted key", zap.String("key", req.Key), zap.String("namespace", req.Namespace)) - return &StorageResponse{Success: true} -} - -// handleList lists keys with a prefix -func (s *Service) handleList(req *StorageRequest) *StorageResponse { - s.mu.RLock() - defer s.mu.RUnlock() - - var query string - var args []interface{} - - if req.Prefix == "" { - // List all keys in namespace - query = `SELECT key FROM kv_storage WHERE namespace = ?` - args = []interface{}{req.Namespace} - } else { - // List keys with prefix - query = `SELECT key FROM kv_storage WHERE namespace = ? AND key LIKE ?` - args = []interface{}{req.Namespace, req.Prefix + "%"} - } - - if req.Limit > 0 { - query += ` LIMIT ?` - args = append(args, req.Limit) - } - - rows, err := s.db.Query(query, args...) - if err != nil { - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("failed to query keys: %v", err), - } - } - defer rows.Close() - - var keys []string - for rows.Next() { - var key string - if err := rows.Scan(&key); err != nil { - continue - } - keys = append(keys, key) - } - - return &StorageResponse{ - Success: true, - Keys: keys, - } -} - -// handleExists checks if a key exists -func (s *Service) handleExists(req *StorageRequest) *StorageResponse { - s.mu.RLock() - defer s.mu.RUnlock() - - query := `SELECT 1 FROM kv_storage WHERE namespace = ? AND key = ? LIMIT 1` - - var exists int - err := s.db.QueryRow(query, req.Namespace, req.Key).Scan(&exists) - if err != nil { - if err == sql.ErrNoRows { - return &StorageResponse{ - Success: true, - Exists: false, - } - } - return &StorageResponse{ - Success: false, - Error: fmt.Sprintf("failed to check key existence: %v", err), - } - } - - return &StorageResponse{ - Success: true, - Exists: true, - } -} diff --git a/pkg/storage/logging.go b/pkg/storage/logging.go deleted file mode 100644 index 648af74..0000000 --- a/pkg/storage/logging.go +++ /dev/null @@ -1,16 +0,0 @@ -package storage - -import "go.uber.org/zap" - -// newStorageLogger creates a zap.Logger for storage components. -// Callers can pass quiet=true to reduce log verbosity. -func newStorageLogger(quiet bool) (*zap.Logger, error) { - if quiet { - cfg := zap.NewProductionConfig() - cfg.Level = zap.NewAtomicLevelAt(zap.WarnLevel) - cfg.DisableCaller = true - cfg.DisableStacktrace = true - return cfg.Build() - } - return zap.NewDevelopment() -} diff --git a/pkg/storage/protocol.go b/pkg/storage/protocol.go deleted file mode 100644 index af9c4dd..0000000 --- a/pkg/storage/protocol.go +++ /dev/null @@ -1,60 +0,0 @@ -package storage - -import ( - "encoding/json" -) - -// Storage protocol definitions for distributed storage -const ( - StorageProtocolID = "/network/storage/1.0.0" -) - -// Message types for storage operations -type MessageType string - -const ( - MessageTypePut MessageType = "put" - MessageTypeGet MessageType = "get" - MessageTypeDelete MessageType = "delete" - MessageTypeList MessageType = "list" - MessageTypeExists MessageType = "exists" -) - -// StorageRequest represents a storage operation request -type StorageRequest struct { - Type MessageType `json:"type"` - Key string `json:"key"` - Value []byte `json:"value,omitempty"` - Prefix string `json:"prefix,omitempty"` - Limit int `json:"limit,omitempty"` - Namespace string `json:"namespace"` -} - -// StorageResponse represents a storage operation response -type StorageResponse struct { - Success bool `json:"success"` - Error string `json:"error,omitempty"` - Value []byte `json:"value,omitempty"` - Keys []string `json:"keys,omitempty"` - Exists bool `json:"exists,omitempty"` -} - -// Marshal serializes a request to JSON -func (r *StorageRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// Unmarshal deserializes a request from JSON -func (r *StorageRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -// Marshal serializes a response to JSON -func (r *StorageResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// Unmarshal deserializes a response from JSON -func (r *StorageResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/storage/protocol_test.go b/pkg/storage/protocol_test.go deleted file mode 100644 index 07d3673..0000000 --- a/pkg/storage/protocol_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package storage - -import "testing" - -func TestRequestResponseJSON(t *testing.T) { - req := &StorageRequest{Type: MessageTypePut, Key: "k", Value: []byte("v"), Namespace: "ns"} - b, err := req.Marshal() - if err != nil { t.Fatal(err) } - var out StorageRequest - if err := out.Unmarshal(b); err != nil { t.Fatal(err) } - if out.Type != MessageTypePut || out.Key != "k" || out.Namespace != "ns" { - t.Fatalf("roundtrip mismatch: %+v", out) - } - - resp := &StorageResponse{Success: true, Keys: []string{"a"}, Exists: true} - b, err = resp.Marshal() - if err != nil { t.Fatal(err) } - var outR StorageResponse - if err := outR.Unmarshal(b); err != nil { t.Fatal(err) } - if !outR.Success || !outR.Exists || len(outR.Keys) != 1 { - t.Fatalf("resp mismatch: %+v", outR) - } -} diff --git a/pkg/storage/rqlite_init.go b/pkg/storage/rqlite_init.go deleted file mode 100644 index c339467..0000000 --- a/pkg/storage/rqlite_init.go +++ /dev/null @@ -1,37 +0,0 @@ -package storage - -import ( - "fmt" -) - -// initTables creates the necessary tables for key-value storage -func (s *Service) initTables() error { - // Create storage table with namespace support - createTableSQL := ` - CREATE TABLE IF NOT EXISTS kv_storage ( - namespace TEXT NOT NULL, - key TEXT NOT NULL, - value BLOB NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (namespace, key) - ) - ` - - // Create index for faster queries - createIndexSQL := ` - CREATE INDEX IF NOT EXISTS idx_kv_storage_namespace_key - ON kv_storage(namespace, key) - ` - - if _, err := s.db.Exec(createTableSQL); err != nil { - return fmt.Errorf("failed to create storage table: %w", err) - } - - if _, err := s.db.Exec(createIndexSQL); err != nil { - return fmt.Errorf("failed to create storage index: %w", err) - } - - s.logger.Info("Storage tables initialized successfully") - return nil -} diff --git a/pkg/storage/service.go b/pkg/storage/service.go deleted file mode 100644 index e3062c2..0000000 --- a/pkg/storage/service.go +++ /dev/null @@ -1,32 +0,0 @@ -package storage - -import ( - "database/sql" - "sync" - - "go.uber.org/zap" -) - -// Service provides distributed storage functionality using RQLite -type Service struct { - logger *zap.Logger - db *sql.DB - mu sync.RWMutex -} - -// NewService creates a new storage service backed by RQLite -func NewService(db *sql.DB, logger *zap.Logger) (*Service, error) { - service := &Service{ - logger: logger, - db: db, - } - - return service, nil -} - -// Close closes the storage service -func (s *Service) Close() error { - // The database connection is managed elsewhere - s.logger.Info("Storage service closed") - return nil -} diff --git a/pkg/storage/stream_handler.go b/pkg/storage/stream_handler.go deleted file mode 100644 index 4c38fa1..0000000 --- a/pkg/storage/stream_handler.go +++ /dev/null @@ -1,48 +0,0 @@ -package storage - -import ( - "io" - - "github.com/libp2p/go-libp2p/core/network" - "go.uber.org/zap" -) - -// HandleStorageStream handles incoming storage protocol streams -func (s *Service) HandleStorageStream(stream network.Stream) { - defer stream.Close() - - // Read request - data, err := io.ReadAll(stream) - if err != nil { - s.logger.Error("Failed to read storage request", zap.Error(err)) - return - } - - var request StorageRequest - if err := request.Unmarshal(data); err != nil { - s.logger.Error("Failed to unmarshal storage request", zap.Error(err)) - return - } - - // Process request - response := s.processRequest(&request) - - // Send response - responseData, err := response.Marshal() - if err != nil { - s.logger.Error("Failed to marshal storage response", zap.Error(err)) - return - } - - if _, err := stream.Write(responseData); err != nil { - s.logger.Error("Failed to write storage response", zap.Error(err)) - return - } - - s.logger.Debug("Handled storage request", - zap.String("type", string(request.Type)), - zap.String("key", request.Key), - zap.String("namespace", request.Namespace), - zap.Bool("success", response.Success), - ) -}