mirror of
https://github.com/DeBrosOfficial/network.git
synced 2026-01-30 14:33:03 +00:00
- Introduced a new `network.mdc` file containing comprehensive guidelines for utilizing the network Model Context Protocol (MCP). - Documented available MCP tools for code understanding, skill learning, and recommended workflows to enhance developer efficiency. - Provided detailed instructions on the collaborative skill learning process and user override commands for better interaction with the MCP.
924 lines
23 KiB
Go
924 lines
23 KiB
Go
package rqlite
|
|
|
|
// client.go defines the ORM-like interfaces and a minimal implementation over database/sql.
|
|
// It builds on the rqlite stdlib driver so it behaves like a regular SQL-backed ORM.
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// TableNamer lets a struct provide its table name.
|
|
type TableNamer interface {
|
|
TableName() string
|
|
}
|
|
|
|
// Client is the high-level ORM-like API.
|
|
type Client interface {
|
|
// Query runs an arbitrary SELECT and scans rows into dest (pointer to slice of structs or []map[string]any).
|
|
Query(ctx context.Context, dest any, query string, args ...any) error
|
|
// Exec runs a write statement (INSERT/UPDATE/DELETE).
|
|
Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
|
|
// FindBy/FindOneBy provide simple map-based criteria filtering.
|
|
FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error
|
|
FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error
|
|
|
|
// Save inserts or updates an entity (single-PK).
|
|
Save(ctx context.Context, entity any) error
|
|
// Remove deletes by PK (single-PK).
|
|
Remove(ctx context.Context, entity any) error
|
|
|
|
// Repositories (generic layer). Optional but convenient if you use Go generics.
|
|
Repository(table string) any
|
|
|
|
// Fluent query builder for advanced querying.
|
|
CreateQueryBuilder(table string) *QueryBuilder
|
|
|
|
// Tx executes a function within a transaction.
|
|
Tx(ctx context.Context, fn func(tx Tx) error) error
|
|
}
|
|
|
|
// Tx mirrors Client but executes within a transaction.
|
|
type Tx interface {
|
|
Query(ctx context.Context, dest any, query string, args ...any) error
|
|
Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
CreateQueryBuilder(table string) *QueryBuilder
|
|
|
|
// Optional: scoped Save/Remove inside tx
|
|
Save(ctx context.Context, entity any) error
|
|
Remove(ctx context.Context, entity any) error
|
|
}
|
|
|
|
// Repository provides typed entity operations for a table.
|
|
type Repository[T any] interface {
|
|
Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error
|
|
FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error
|
|
Save(ctx context.Context, entity *T) error
|
|
Remove(ctx context.Context, entity *T) error
|
|
|
|
// Builder helpers
|
|
Q() *QueryBuilder
|
|
}
|
|
|
|
// NewClient wires the ORM client to a *sql.DB (from your RQLiteAdapter).
|
|
func NewClient(db *sql.DB) Client {
|
|
return &client{db: db}
|
|
}
|
|
|
|
// NewClientFromAdapter is convenient if you already created the adapter.
|
|
func NewClientFromAdapter(adapter *RQLiteAdapter) Client {
|
|
return NewClient(adapter.GetSQLDB())
|
|
}
|
|
|
|
// client implements Client over *sql.DB.
|
|
type client struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func (c *client) Query(ctx context.Context, dest any, query string, args ...any) error {
|
|
rows, err := c.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
func (c *client) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return c.db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (c *client) FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error {
|
|
qb := c.CreateQueryBuilder(table)
|
|
for k, v := range criteria {
|
|
qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetMany(ctx, dest)
|
|
}
|
|
|
|
func (c *client) FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error {
|
|
qb := c.CreateQueryBuilder(table)
|
|
for k, v := range criteria {
|
|
qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetOne(ctx, dest)
|
|
}
|
|
|
|
func (c *client) Save(ctx context.Context, entity any) error {
|
|
return saveEntity(ctx, c.db, entity)
|
|
}
|
|
|
|
func (c *client) Remove(ctx context.Context, entity any) error {
|
|
return removeEntity(ctx, c.db, entity)
|
|
}
|
|
|
|
func (c *client) Repository(table string) any {
|
|
// This returns an untyped interface since Go methods cannot have type parameters
|
|
// Users will need to type assert the result to Repository[T]
|
|
return func() any {
|
|
return &repository[any]{c: c, table: table}
|
|
}()
|
|
}
|
|
|
|
func (c *client) CreateQueryBuilder(table string) *QueryBuilder {
|
|
return newQueryBuilder(c.db, table)
|
|
}
|
|
|
|
func (c *client) Tx(ctx context.Context, fn func(tx Tx) error) error {
|
|
sqlTx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
txc := &txClient{tx: sqlTx}
|
|
if err := fn(txc); err != nil {
|
|
_ = sqlTx.Rollback()
|
|
return err
|
|
}
|
|
return sqlTx.Commit()
|
|
}
|
|
|
|
// txClient implements Tx over *sql.Tx.
|
|
type txClient struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (t *txClient) Query(ctx context.Context, dest any, query string, args ...any) error {
|
|
rows, err := t.tx.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
func (t *txClient) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return t.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (t *txClient) CreateQueryBuilder(table string) *QueryBuilder {
|
|
return newQueryBuilder(t.tx, table)
|
|
}
|
|
|
|
func (t *txClient) Save(ctx context.Context, entity any) error {
|
|
return saveEntity(ctx, t.tx, entity)
|
|
}
|
|
|
|
func (t *txClient) Remove(ctx context.Context, entity any) error {
|
|
return removeEntity(ctx, t.tx, entity)
|
|
}
|
|
|
|
// executor is implemented by *sql.DB and *sql.Tx.
|
|
type executor interface {
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
// QueryBuilder implements a fluent SELECT builder with joins, where, etc.
|
|
type QueryBuilder struct {
|
|
exec executor
|
|
table string
|
|
alias string
|
|
selects []string
|
|
|
|
joins []joinClause
|
|
wheres []whereClause
|
|
|
|
groupBys []string
|
|
orderBys []string
|
|
limit *int
|
|
offset *int
|
|
}
|
|
|
|
// joinClause represents INNER/LEFT/etc joins.
|
|
type joinClause struct {
|
|
kind string // "INNER", "LEFT", "JOIN" (default)
|
|
table string
|
|
on string
|
|
}
|
|
|
|
// whereClause holds an expression and args with a conjunction.
|
|
type whereClause struct {
|
|
conj string // "AND" or "OR"
|
|
expr string
|
|
args []any
|
|
}
|
|
|
|
func newQueryBuilder(exec executor, table string) *QueryBuilder {
|
|
return &QueryBuilder{
|
|
exec: exec,
|
|
table: table,
|
|
}
|
|
}
|
|
|
|
func (qb *QueryBuilder) Select(cols ...string) *QueryBuilder {
|
|
qb.selects = append(qb.selects, cols...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Alias(a string) *QueryBuilder {
|
|
qb.alias = a
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Where(expr string, args ...any) *QueryBuilder {
|
|
return qb.AndWhere(expr, args...)
|
|
}
|
|
|
|
func (qb *QueryBuilder) AndWhere(expr string, args ...any) *QueryBuilder {
|
|
qb.wheres = append(qb.wheres, whereClause{conj: "AND", expr: expr, args: args})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) OrWhere(expr string, args ...any) *QueryBuilder {
|
|
qb.wheres = append(qb.wheres, whereClause{conj: "OR", expr: expr, args: args})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) InnerJoin(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "INNER", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) LeftJoin(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "LEFT", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Join(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "JOIN", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) GroupBy(cols ...string) *QueryBuilder {
|
|
qb.groupBys = append(qb.groupBys, cols...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) OrderBy(exprs ...string) *QueryBuilder {
|
|
qb.orderBys = append(qb.orderBys, exprs...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Limit(n int) *QueryBuilder {
|
|
qb.limit = &n
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Offset(n int) *QueryBuilder {
|
|
qb.offset = &n
|
|
return qb
|
|
}
|
|
|
|
// Build returns the SQL string and args for a SELECT.
|
|
func (qb *QueryBuilder) Build() (string, []any) {
|
|
cols := "*"
|
|
if len(qb.selects) > 0 {
|
|
cols = strings.Join(qb.selects, ", ")
|
|
}
|
|
base := fmt.Sprintf("SELECT %s FROM %s", cols, qb.table)
|
|
if qb.alias != "" {
|
|
base += " AS " + qb.alias
|
|
}
|
|
|
|
args := make([]any, 0, 16)
|
|
for _, j := range qb.joins {
|
|
base += fmt.Sprintf(" %s JOIN %s ON %s", j.kind, j.table, j.on)
|
|
}
|
|
|
|
if len(qb.wheres) > 0 {
|
|
base += " WHERE "
|
|
for i, w := range qb.wheres {
|
|
if i > 0 {
|
|
base += " " + w.conj + " "
|
|
}
|
|
base += "(" + w.expr + ")"
|
|
args = append(args, w.args...)
|
|
}
|
|
}
|
|
|
|
if len(qb.groupBys) > 0 {
|
|
base += " GROUP BY " + strings.Join(qb.groupBys, ", ")
|
|
}
|
|
if len(qb.orderBys) > 0 {
|
|
base += " ORDER BY " + strings.Join(qb.orderBys, ", ")
|
|
}
|
|
if qb.limit != nil {
|
|
base += fmt.Sprintf(" LIMIT %d", *qb.limit)
|
|
}
|
|
if qb.offset != nil {
|
|
base += fmt.Sprintf(" OFFSET %d", *qb.offset)
|
|
}
|
|
return base, args
|
|
}
|
|
|
|
// GetMany executes the built query and scans into dest (pointer to slice).
|
|
func (qb *QueryBuilder) GetMany(ctx context.Context, dest any) error {
|
|
sqlStr, args := qb.Build()
|
|
rows, err := qb.exec.QueryContext(ctx, sqlStr, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
// GetOne executes the built query and scans into dest (pointer to struct or map) with LIMIT 1.
|
|
func (qb *QueryBuilder) GetOne(ctx context.Context, dest any) error {
|
|
limit := 1
|
|
if qb.limit == nil {
|
|
qb.limit = &limit
|
|
} else if qb.limit != nil && *qb.limit > 1 {
|
|
qb.limit = &limit
|
|
}
|
|
sqlStr, args := qb.Build()
|
|
rows, err := qb.exec.QueryContext(ctx, sqlStr, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return sql.ErrNoRows
|
|
}
|
|
return scanIntoSingle(rows, dest)
|
|
}
|
|
|
|
// FindOption customizes Find queries.
|
|
type FindOption func(q *QueryBuilder)
|
|
|
|
func WithOrderBy(exprs ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.OrderBy(exprs...) }
|
|
}
|
|
func WithGroupBy(cols ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.GroupBy(cols...) }
|
|
}
|
|
func WithLimit(n int) FindOption {
|
|
return func(q *QueryBuilder) { q.Limit(n) }
|
|
}
|
|
func WithOffset(n int) FindOption {
|
|
return func(q *QueryBuilder) { q.Offset(n) }
|
|
}
|
|
func WithSelect(cols ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.Select(cols...) }
|
|
}
|
|
func WithJoin(kind, table, on string) FindOption {
|
|
return func(q *QueryBuilder) {
|
|
switch strings.ToUpper(kind) {
|
|
case "INNER":
|
|
q.InnerJoin(table, on)
|
|
case "LEFT":
|
|
q.LeftJoin(table, on)
|
|
default:
|
|
q.Join(table, on)
|
|
}
|
|
}
|
|
}
|
|
|
|
// repository is a generic table repository for type T.
|
|
type repository[T any] struct {
|
|
c *client
|
|
table string
|
|
}
|
|
|
|
func (r *repository[T]) Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error {
|
|
qb := r.c.CreateQueryBuilder(r.table)
|
|
for k, v := range criteria {
|
|
qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetMany(ctx, dest)
|
|
}
|
|
|
|
func (r *repository[T]) FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error {
|
|
qb := r.c.CreateQueryBuilder(r.table)
|
|
for k, v := range criteria {
|
|
qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetOne(ctx, dest)
|
|
}
|
|
|
|
func (r *repository[T]) Save(ctx context.Context, entity *T) error {
|
|
return saveEntity(ctx, r.c.db, entity)
|
|
}
|
|
|
|
func (r *repository[T]) Remove(ctx context.Context, entity *T) error {
|
|
return removeEntity(ctx, r.c.db, entity)
|
|
}
|
|
|
|
func (r *repository[T]) Q() *QueryBuilder {
|
|
return r.c.CreateQueryBuilder(r.table)
|
|
}
|
|
|
|
// -----------------------
|
|
// Reflection + scanning
|
|
// -----------------------
|
|
|
|
func scanIntoDest(rows *sql.Rows, dest any) error {
|
|
// dest must be pointer to slice (of struct or map)
|
|
rv := reflect.ValueOf(dest)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("dest must be a non-nil pointer")
|
|
}
|
|
sliceVal := rv.Elem()
|
|
if sliceVal.Kind() != reflect.Slice {
|
|
return errors.New("dest must be pointer to a slice")
|
|
}
|
|
elemType := sliceVal.Type().Elem()
|
|
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for rows.Next() {
|
|
itemPtr := reflect.New(elemType)
|
|
// Support map[string]any and struct
|
|
if elemType.Kind() == reflect.Map {
|
|
m, err := scanRowToMap(rows, cols)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sliceVal.Set(reflect.Append(sliceVal, reflect.ValueOf(m)))
|
|
continue
|
|
}
|
|
|
|
if elemType.Kind() == reflect.Struct {
|
|
if err := scanCurrentRowIntoStruct(rows, cols, itemPtr.Elem()); err != nil {
|
|
return err
|
|
}
|
|
sliceVal.Set(reflect.Append(sliceVal, itemPtr.Elem()))
|
|
continue
|
|
}
|
|
|
|
return fmt.Errorf("unsupported slice element type: %s", elemType.Kind())
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
func scanIntoSingle(rows *sql.Rows, dest any) error {
|
|
rv := reflect.ValueOf(dest)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("dest must be a non-nil pointer")
|
|
}
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch rv.Elem().Kind() {
|
|
case reflect.Map:
|
|
m, err := scanRowToMap(rows, cols)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rv.Elem().Set(reflect.ValueOf(m))
|
|
return nil
|
|
case reflect.Struct:
|
|
return scanCurrentRowIntoStruct(rows, cols, rv.Elem())
|
|
default:
|
|
return fmt.Errorf("unsupported dest kind: %s", rv.Elem().Kind())
|
|
}
|
|
}
|
|
|
|
func scanRowToMap(rows *sql.Rows, cols []string) (map[string]any, error) {
|
|
raw := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range raw {
|
|
ptrs[i] = &raw[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
out := make(map[string]any, len(cols))
|
|
for i, c := range cols {
|
|
out[c] = normalizeSQLValue(raw[i])
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func scanCurrentRowIntoStruct(rows *sql.Rows, cols []string, destStruct reflect.Value) error {
|
|
raw := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range raw {
|
|
ptrs[i] = &raw[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return err
|
|
}
|
|
fieldIndex := buildFieldIndex(destStruct.Type())
|
|
for i, c := range cols {
|
|
if idx, ok := fieldIndex[strings.ToLower(c)]; ok {
|
|
field := destStruct.Field(idx)
|
|
if field.CanSet() {
|
|
if err := setReflectValue(field, raw[i]); err != nil {
|
|
return fmt.Errorf("column %s: %w", c, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeSQLValue(v any) any {
|
|
switch t := v.(type) {
|
|
case []byte:
|
|
return string(t)
|
|
default:
|
|
return v
|
|
}
|
|
}
|
|
|
|
func buildFieldIndex(t reflect.Type) map[string]int {
|
|
m := make(map[string]int)
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if f.IsExported() == false {
|
|
continue
|
|
}
|
|
tag := f.Tag.Get("db")
|
|
col := ""
|
|
if tag != "" {
|
|
col = strings.Split(tag, ",")[0]
|
|
}
|
|
if col == "" {
|
|
col = f.Name
|
|
}
|
|
m[strings.ToLower(col)] = i
|
|
}
|
|
return m
|
|
}
|
|
|
|
func setReflectValue(field reflect.Value, raw any) error {
|
|
if raw == nil {
|
|
// leave zero value
|
|
return nil
|
|
}
|
|
switch field.Kind() {
|
|
case reflect.String:
|
|
switch v := raw.(type) {
|
|
case string:
|
|
field.SetString(v)
|
|
case []byte:
|
|
field.SetString(string(v))
|
|
default:
|
|
field.SetString(fmt.Sprint(v))
|
|
}
|
|
case reflect.Bool:
|
|
switch v := raw.(type) {
|
|
case bool:
|
|
field.SetBool(v)
|
|
case int64:
|
|
field.SetBool(v != 0)
|
|
case []byte:
|
|
s := string(v)
|
|
field.SetBool(s == "1" || strings.EqualFold(s, "true"))
|
|
default:
|
|
field.SetBool(false)
|
|
}
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
switch v := raw.(type) {
|
|
case int64:
|
|
field.SetInt(v)
|
|
case float64:
|
|
// RQLite/JSON returns numbers as float64
|
|
field.SetInt(int64(v))
|
|
case int:
|
|
field.SetInt(int64(v))
|
|
case []byte:
|
|
var n int64
|
|
fmt.Sscan(string(v), &n)
|
|
field.SetInt(n)
|
|
case string:
|
|
var n int64
|
|
fmt.Sscan(v, &n)
|
|
field.SetInt(n)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to int", raw)
|
|
}
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
switch v := raw.(type) {
|
|
case int64:
|
|
if v < 0 {
|
|
v = 0
|
|
}
|
|
field.SetUint(uint64(v))
|
|
case float64:
|
|
// RQLite/JSON returns numbers as float64
|
|
if v < 0 {
|
|
v = 0
|
|
}
|
|
field.SetUint(uint64(v))
|
|
case uint64:
|
|
field.SetUint(v)
|
|
case []byte:
|
|
var n uint64
|
|
fmt.Sscan(string(v), &n)
|
|
field.SetUint(n)
|
|
case string:
|
|
var n uint64
|
|
fmt.Sscan(v, &n)
|
|
field.SetUint(n)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to uint", raw)
|
|
}
|
|
case reflect.Float32, reflect.Float64:
|
|
switch v := raw.(type) {
|
|
case float64:
|
|
field.SetFloat(v)
|
|
case []byte:
|
|
var fv float64
|
|
fmt.Sscan(string(v), &fv)
|
|
field.SetFloat(fv)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to float", raw)
|
|
}
|
|
case reflect.Struct:
|
|
// Support time.Time
|
|
if field.Type() == reflect.TypeOf(time.Time{}) {
|
|
switch v := raw.(type) {
|
|
case time.Time:
|
|
field.Set(reflect.ValueOf(v))
|
|
case string:
|
|
// Try RFC3339
|
|
if tt, err := time.Parse(time.RFC3339, v); err == nil {
|
|
field.Set(reflect.ValueOf(tt))
|
|
}
|
|
case []byte:
|
|
// Try RFC3339
|
|
if tt, err := time.Parse(time.RFC3339, string(v)); err == nil {
|
|
field.Set(reflect.ValueOf(tt))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
// Support sql.NullString
|
|
if field.Type() == reflect.TypeOf(sql.NullString{}) {
|
|
ns := sql.NullString{}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
ns.String = v
|
|
ns.Valid = true
|
|
case []byte:
|
|
ns.String = string(v)
|
|
ns.Valid = true
|
|
}
|
|
field.Set(reflect.ValueOf(ns))
|
|
return nil
|
|
}
|
|
// Support sql.NullInt64
|
|
if field.Type() == reflect.TypeOf(sql.NullInt64{}) {
|
|
ni := sql.NullInt64{}
|
|
switch v := raw.(type) {
|
|
case int64:
|
|
ni.Int64 = v
|
|
ni.Valid = true
|
|
case float64:
|
|
ni.Int64 = int64(v)
|
|
ni.Valid = true
|
|
case int:
|
|
ni.Int64 = int64(v)
|
|
ni.Valid = true
|
|
}
|
|
field.Set(reflect.ValueOf(ni))
|
|
return nil
|
|
}
|
|
// Support sql.NullBool
|
|
if field.Type() == reflect.TypeOf(sql.NullBool{}) {
|
|
nb := sql.NullBool{}
|
|
switch v := raw.(type) {
|
|
case bool:
|
|
nb.Bool = v
|
|
nb.Valid = true
|
|
case int64:
|
|
nb.Bool = v != 0
|
|
nb.Valid = true
|
|
case float64:
|
|
nb.Bool = v != 0
|
|
nb.Valid = true
|
|
}
|
|
field.Set(reflect.ValueOf(nb))
|
|
return nil
|
|
}
|
|
// Support sql.NullFloat64
|
|
if field.Type() == reflect.TypeOf(sql.NullFloat64{}) {
|
|
nf := sql.NullFloat64{}
|
|
switch v := raw.(type) {
|
|
case float64:
|
|
nf.Float64 = v
|
|
nf.Valid = true
|
|
case int64:
|
|
nf.Float64 = float64(v)
|
|
nf.Valid = true
|
|
}
|
|
field.Set(reflect.ValueOf(nf))
|
|
return nil
|
|
}
|
|
fallthrough
|
|
default:
|
|
// Not supported yet
|
|
return fmt.Errorf("unsupported dest field kind: %s", field.Kind())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// -----------------------
|
|
// Save/Remove (basic PK)
|
|
// -----------------------
|
|
|
|
type fieldMeta struct {
|
|
index int
|
|
column string
|
|
isPK bool
|
|
auto bool
|
|
}
|
|
|
|
func collectMeta(t reflect.Type) (fields []fieldMeta, pk fieldMeta, hasPK bool) {
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if !f.IsExported() {
|
|
continue
|
|
}
|
|
tag := f.Tag.Get("db")
|
|
if tag == "-" {
|
|
continue
|
|
}
|
|
opts := strings.Split(tag, ",")
|
|
col := opts[0]
|
|
if col == "" {
|
|
col = f.Name
|
|
}
|
|
meta := fieldMeta{index: i, column: col}
|
|
for _, o := range opts[1:] {
|
|
switch strings.ToLower(strings.TrimSpace(o)) {
|
|
case "pk":
|
|
meta.isPK = true
|
|
case "auto", "autoincrement":
|
|
meta.auto = true
|
|
}
|
|
}
|
|
// If not tagged as pk, fallback to field name "ID"
|
|
if !meta.isPK && f.Name == "ID" {
|
|
meta.isPK = true
|
|
if col == "" {
|
|
meta.column = "id"
|
|
}
|
|
}
|
|
fields = append(fields, meta)
|
|
if meta.isPK {
|
|
pk = meta
|
|
hasPK = true
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func getTableNameFromEntity(v reflect.Value) (string, bool) {
|
|
// If entity implements TableNamer
|
|
if v.CanInterface() {
|
|
if tn, ok := v.Interface().(TableNamer); ok {
|
|
return tn.TableName(), true
|
|
}
|
|
}
|
|
// Fallback: very naive pluralization (append 's')
|
|
typ := v.Type()
|
|
if typ.Kind() == reflect.Pointer {
|
|
typ = typ.Elem()
|
|
}
|
|
if typ.Kind() == reflect.Struct {
|
|
return strings.ToLower(typ.Name()) + "s", true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func saveEntity(ctx context.Context, exec executor, entity any) error {
|
|
rv := reflect.ValueOf(entity)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("entity must be a non-nil pointer to struct")
|
|
}
|
|
ev := rv.Elem()
|
|
if ev.Kind() != reflect.Struct {
|
|
return errors.New("entity must point to a struct")
|
|
}
|
|
|
|
fields, pkMeta, hasPK := collectMeta(ev.Type())
|
|
if !hasPK {
|
|
return errors.New("no primary key field found (tag db:\"...,pk\" or field named ID)")
|
|
}
|
|
table, ok := getTableNameFromEntity(ev)
|
|
if !ok || table == "" {
|
|
return errors.New("unable to resolve table name; implement TableNamer or set up a repository with explicit table")
|
|
}
|
|
|
|
// Build lists
|
|
cols := make([]string, 0, len(fields))
|
|
vals := make([]any, 0, len(fields))
|
|
setParts := make([]string, 0, len(fields))
|
|
|
|
var pkVal any
|
|
var pkIsZero bool
|
|
|
|
for _, fm := range fields {
|
|
f := ev.Field(fm.index)
|
|
if fm.isPK {
|
|
pkVal = f.Interface()
|
|
pkIsZero = isZeroValue(f)
|
|
continue
|
|
}
|
|
cols = append(cols, fm.column)
|
|
vals = append(vals, f.Interface())
|
|
setParts = append(setParts, fmt.Sprintf("%s = ?", fm.column))
|
|
}
|
|
|
|
if pkIsZero {
|
|
// INSERT
|
|
placeholders := strings.Repeat("?,", len(cols))
|
|
if len(placeholders) > 0 {
|
|
placeholders = placeholders[:len(placeholders)-1]
|
|
}
|
|
sqlStr := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(cols, ", "), placeholders)
|
|
res, err := exec.ExecContext(ctx, sqlStr, vals...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Set auto ID if needed
|
|
if pkMeta.auto {
|
|
if id, err := res.LastInsertId(); err == nil {
|
|
ev.Field(pkMeta.index).SetInt(id)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UPDATE ... WHERE pk = ?
|
|
sqlStr := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", table, strings.Join(setParts, ", "), pkMeta.column)
|
|
valsWithPK := append(vals, pkVal)
|
|
_, err := exec.ExecContext(ctx, sqlStr, valsWithPK...)
|
|
return err
|
|
}
|
|
|
|
func removeEntity(ctx context.Context, exec executor, entity any) error {
|
|
rv := reflect.ValueOf(entity)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("entity must be a non-nil pointer to struct")
|
|
}
|
|
ev := rv.Elem()
|
|
if ev.Kind() != reflect.Struct {
|
|
return errors.New("entity must point to a struct")
|
|
}
|
|
_, pkMeta, hasPK := collectMeta(ev.Type())
|
|
if !hasPK {
|
|
return errors.New("no primary key field found")
|
|
}
|
|
table, ok := getTableNameFromEntity(ev)
|
|
if !ok || table == "" {
|
|
return errors.New("unable to resolve table name")
|
|
}
|
|
pkVal := ev.Field(pkMeta.index).Interface()
|
|
sqlStr := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkMeta.column)
|
|
_, err := exec.ExecContext(ctx, sqlStr, pkVal)
|
|
return err
|
|
}
|
|
|
|
func isZeroValue(v reflect.Value) bool {
|
|
switch v.Kind() {
|
|
case reflect.String:
|
|
return v.Len() == 0
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return v.Int() == 0
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return v.Uint() == 0
|
|
case reflect.Bool:
|
|
return v.Bool() == false
|
|
case reflect.Pointer, reflect.Interface:
|
|
return v.IsNil()
|
|
case reflect.Slice, reflect.Map:
|
|
return v.Len() == 0
|
|
case reflect.Struct:
|
|
// Special-case time.Time
|
|
if v.Type() == reflect.TypeOf(time.Time{}) {
|
|
t := v.Interface().(time.Time)
|
|
return t.IsZero()
|
|
}
|
|
zero := reflect.Zero(v.Type())
|
|
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
|
default:
|
|
return false
|
|
}
|
|
}
|