mirror of
https://github.com/DeBrosOfficial/network.git
synced 2025-10-06 13:49:07 +00:00

- Created new rqlite folder - Created rqlite adapter, client, gateway, migrations and rqlite init - Updated node.go to support new rqlite architecture - Updated readme - Updated version
836 lines
21 KiB
Go
836 lines
21 KiB
Go
package rqlite
|
|
|
|
// client.go defines the ORM-like interfaces and a minimal implementation over database/sql.
|
|
// It builds on the rqlite stdlib driver so it behaves like a regular SQL-backed ORM.
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// TableNamer lets a struct provide its table name.
|
|
type TableNamer interface {
|
|
TableName() string
|
|
}
|
|
|
|
// Client is the high-level ORM-like API.
|
|
type Client interface {
|
|
// Query runs an arbitrary SELECT and scans rows into dest (pointer to slice of structs or []map[string]any).
|
|
Query(ctx context.Context, dest any, query string, args ...any) error
|
|
// Exec runs a write statement (INSERT/UPDATE/DELETE).
|
|
Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
|
|
// FindBy/FindOneBy provide simple map-based criteria filtering.
|
|
FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error
|
|
FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error
|
|
|
|
// Save inserts or updates an entity (single-PK).
|
|
Save(ctx context.Context, entity any) error
|
|
// Remove deletes by PK (single-PK).
|
|
Remove(ctx context.Context, entity any) error
|
|
|
|
// Repositories (generic layer). Optional but convenient if you use Go generics.
|
|
Repository(table string) any
|
|
|
|
// Fluent query builder for advanced querying.
|
|
CreateQueryBuilder(table string) *QueryBuilder
|
|
|
|
// Tx executes a function within a transaction.
|
|
Tx(ctx context.Context, fn func(tx Tx) error) error
|
|
}
|
|
|
|
// Tx mirrors Client but executes within a transaction.
|
|
type Tx interface {
|
|
Query(ctx context.Context, dest any, query string, args ...any) error
|
|
Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
CreateQueryBuilder(table string) *QueryBuilder
|
|
|
|
// Optional: scoped Save/Remove inside tx
|
|
Save(ctx context.Context, entity any) error
|
|
Remove(ctx context.Context, entity any) error
|
|
}
|
|
|
|
// Repository provides typed entity operations for a table.
|
|
type Repository[T any] interface {
|
|
Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error
|
|
FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error
|
|
Save(ctx context.Context, entity *T) error
|
|
Remove(ctx context.Context, entity *T) error
|
|
|
|
// Builder helpers
|
|
Q() *QueryBuilder
|
|
}
|
|
|
|
// NewClient wires the ORM client to a *sql.DB (from your RQLiteAdapter).
|
|
func NewClient(db *sql.DB) Client {
|
|
return &client{db: db}
|
|
}
|
|
|
|
// NewClientFromAdapter is convenient if you already created the adapter.
|
|
func NewClientFromAdapter(adapter *RQLiteAdapter) Client {
|
|
return NewClient(adapter.GetSQLDB())
|
|
}
|
|
|
|
// client implements Client over *sql.DB.
|
|
type client struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func (c *client) Query(ctx context.Context, dest any, query string, args ...any) error {
|
|
rows, err := c.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
func (c *client) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return c.db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (c *client) FindBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error {
|
|
qb := c.CreateQueryBuilder(table)
|
|
for k, v := range criteria {
|
|
qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetMany(ctx, dest)
|
|
}
|
|
|
|
func (c *client) FindOneBy(ctx context.Context, dest any, table string, criteria map[string]any, opts ...FindOption) error {
|
|
qb := c.CreateQueryBuilder(table)
|
|
for k, v := range criteria {
|
|
qb = qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetOne(ctx, dest)
|
|
}
|
|
|
|
func (c *client) Save(ctx context.Context, entity any) error {
|
|
return saveEntity(ctx, c.db, entity)
|
|
}
|
|
|
|
func (c *client) Remove(ctx context.Context, entity any) error {
|
|
return removeEntity(ctx, c.db, entity)
|
|
}
|
|
|
|
func (c *client) Repository(table string) any {
|
|
// This returns an untyped interface since Go methods cannot have type parameters
|
|
// Users will need to type assert the result to Repository[T]
|
|
return func() any {
|
|
return &repository[any]{c: c, table: table}
|
|
}()
|
|
}
|
|
|
|
func (c *client) CreateQueryBuilder(table string) *QueryBuilder {
|
|
return newQueryBuilder(c.db, table)
|
|
}
|
|
|
|
func (c *client) Tx(ctx context.Context, fn func(tx Tx) error) error {
|
|
sqlTx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
txc := &txClient{tx: sqlTx}
|
|
if err := fn(txc); err != nil {
|
|
_ = sqlTx.Rollback()
|
|
return err
|
|
}
|
|
return sqlTx.Commit()
|
|
}
|
|
|
|
// txClient implements Tx over *sql.Tx.
|
|
type txClient struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (t *txClient) Query(ctx context.Context, dest any, query string, args ...any) error {
|
|
rows, err := t.tx.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
func (t *txClient) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return t.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (t *txClient) CreateQueryBuilder(table string) *QueryBuilder {
|
|
return newQueryBuilder(t.tx, table)
|
|
}
|
|
|
|
func (t *txClient) Save(ctx context.Context, entity any) error {
|
|
return saveEntity(ctx, t.tx, entity)
|
|
}
|
|
|
|
func (t *txClient) Remove(ctx context.Context, entity any) error {
|
|
return removeEntity(ctx, t.tx, entity)
|
|
}
|
|
|
|
// executor is implemented by *sql.DB and *sql.Tx.
|
|
type executor interface {
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
// QueryBuilder implements a fluent SELECT builder with joins, where, etc.
|
|
type QueryBuilder struct {
|
|
exec executor
|
|
table string
|
|
alias string
|
|
selects []string
|
|
|
|
joins []joinClause
|
|
wheres []whereClause
|
|
|
|
groupBys []string
|
|
orderBys []string
|
|
limit *int
|
|
offset *int
|
|
}
|
|
|
|
// joinClause represents INNER/LEFT/etc joins.
|
|
type joinClause struct {
|
|
kind string // "INNER", "LEFT", "JOIN" (default)
|
|
table string
|
|
on string
|
|
}
|
|
|
|
// whereClause holds an expression and args with a conjunction.
|
|
type whereClause struct {
|
|
conj string // "AND" or "OR"
|
|
expr string
|
|
args []any
|
|
}
|
|
|
|
func newQueryBuilder(exec executor, table string) *QueryBuilder {
|
|
return &QueryBuilder{
|
|
exec: exec,
|
|
table: table,
|
|
}
|
|
}
|
|
|
|
func (qb *QueryBuilder) Select(cols ...string) *QueryBuilder {
|
|
qb.selects = append(qb.selects, cols...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Alias(a string) *QueryBuilder {
|
|
qb.alias = a
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Where(expr string, args ...any) *QueryBuilder {
|
|
return qb.AndWhere(expr, args...)
|
|
}
|
|
|
|
func (qb *QueryBuilder) AndWhere(expr string, args ...any) *QueryBuilder {
|
|
qb.wheres = append(qb.wheres, whereClause{conj: "AND", expr: expr, args: args})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) OrWhere(expr string, args ...any) *QueryBuilder {
|
|
qb.wheres = append(qb.wheres, whereClause{conj: "OR", expr: expr, args: args})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) InnerJoin(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "INNER", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) LeftJoin(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "LEFT", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Join(table string, on string) *QueryBuilder {
|
|
qb.joins = append(qb.joins, joinClause{kind: "JOIN", table: table, on: on})
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) GroupBy(cols ...string) *QueryBuilder {
|
|
qb.groupBys = append(qb.groupBys, cols...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) OrderBy(exprs ...string) *QueryBuilder {
|
|
qb.orderBys = append(qb.orderBys, exprs...)
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Limit(n int) *QueryBuilder {
|
|
qb.limit = &n
|
|
return qb
|
|
}
|
|
|
|
func (qb *QueryBuilder) Offset(n int) *QueryBuilder {
|
|
qb.offset = &n
|
|
return qb
|
|
}
|
|
|
|
// Build returns the SQL string and args for a SELECT.
|
|
func (qb *QueryBuilder) Build() (string, []any) {
|
|
cols := "*"
|
|
if len(qb.selects) > 0 {
|
|
cols = strings.Join(qb.selects, ", ")
|
|
}
|
|
base := fmt.Sprintf("SELECT %s FROM %s", cols, qb.table)
|
|
if qb.alias != "" {
|
|
base += " AS " + qb.alias
|
|
}
|
|
|
|
args := make([]any, 0, 16)
|
|
for _, j := range qb.joins {
|
|
base += fmt.Sprintf(" %s JOIN %s ON %s", j.kind, j.table, j.on)
|
|
}
|
|
|
|
if len(qb.wheres) > 0 {
|
|
base += " WHERE "
|
|
for i, w := range qb.wheres {
|
|
if i > 0 {
|
|
base += " " + w.conj + " "
|
|
}
|
|
base += "(" + w.expr + ")"
|
|
args = append(args, w.args...)
|
|
}
|
|
}
|
|
|
|
if len(qb.groupBys) > 0 {
|
|
base += " GROUP BY " + strings.Join(qb.groupBys, ", ")
|
|
}
|
|
if len(qb.orderBys) > 0 {
|
|
base += " ORDER BY " + strings.Join(qb.orderBys, ", ")
|
|
}
|
|
if qb.limit != nil {
|
|
base += fmt.Sprintf(" LIMIT %d", *qb.limit)
|
|
}
|
|
if qb.offset != nil {
|
|
base += fmt.Sprintf(" OFFSET %d", *qb.offset)
|
|
}
|
|
return base, args
|
|
}
|
|
|
|
// GetMany executes the built query and scans into dest (pointer to slice).
|
|
func (qb *QueryBuilder) GetMany(ctx context.Context, dest any) error {
|
|
sqlStr, args := qb.Build()
|
|
rows, err := qb.exec.QueryContext(ctx, sqlStr, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
return scanIntoDest(rows, dest)
|
|
}
|
|
|
|
// GetOne executes the built query and scans into dest (pointer to struct or map) with LIMIT 1.
|
|
func (qb *QueryBuilder) GetOne(ctx context.Context, dest any) error {
|
|
limit := 1
|
|
if qb.limit == nil {
|
|
qb.limit = &limit
|
|
} else if qb.limit != nil && *qb.limit > 1 {
|
|
qb.limit = &limit
|
|
}
|
|
sqlStr, args := qb.Build()
|
|
rows, err := qb.exec.QueryContext(ctx, sqlStr, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return sql.ErrNoRows
|
|
}
|
|
return scanIntoSingle(rows, dest)
|
|
}
|
|
|
|
// FindOption customizes Find queries.
|
|
type FindOption func(q *QueryBuilder)
|
|
|
|
func WithOrderBy(exprs ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.OrderBy(exprs...) }
|
|
}
|
|
func WithGroupBy(cols ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.GroupBy(cols...) }
|
|
}
|
|
func WithLimit(n int) FindOption {
|
|
return func(q *QueryBuilder) { q.Limit(n) }
|
|
}
|
|
func WithOffset(n int) FindOption {
|
|
return func(q *QueryBuilder) { q.Offset(n) }
|
|
}
|
|
func WithSelect(cols ...string) FindOption {
|
|
return func(q *QueryBuilder) { q.Select(cols...) }
|
|
}
|
|
func WithJoin(kind, table, on string) FindOption {
|
|
return func(q *QueryBuilder) {
|
|
switch strings.ToUpper(kind) {
|
|
case "INNER":
|
|
q.InnerJoin(table, on)
|
|
case "LEFT":
|
|
q.LeftJoin(table, on)
|
|
default:
|
|
q.Join(table, on)
|
|
}
|
|
}
|
|
}
|
|
|
|
// repository is a generic table repository for type T.
|
|
type repository[T any] struct {
|
|
c *client
|
|
table string
|
|
}
|
|
|
|
func (r *repository[T]) Find(ctx context.Context, dest *[]T, criteria map[string]any, opts ...FindOption) error {
|
|
qb := r.c.CreateQueryBuilder(r.table)
|
|
for k, v := range criteria {
|
|
qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetMany(ctx, dest)
|
|
}
|
|
|
|
func (r *repository[T]) FindOne(ctx context.Context, dest *T, criteria map[string]any, opts ...FindOption) error {
|
|
qb := r.c.CreateQueryBuilder(r.table)
|
|
for k, v := range criteria {
|
|
qb.AndWhere(fmt.Sprintf("%s = ?", k), v)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(qb)
|
|
}
|
|
return qb.GetOne(ctx, dest)
|
|
}
|
|
|
|
func (r *repository[T]) Save(ctx context.Context, entity *T) error {
|
|
return saveEntity(ctx, r.c.db, entity)
|
|
}
|
|
|
|
func (r *repository[T]) Remove(ctx context.Context, entity *T) error {
|
|
return removeEntity(ctx, r.c.db, entity)
|
|
}
|
|
|
|
func (r *repository[T]) Q() *QueryBuilder {
|
|
return r.c.CreateQueryBuilder(r.table)
|
|
}
|
|
|
|
// -----------------------
|
|
// Reflection + scanning
|
|
// -----------------------
|
|
|
|
func scanIntoDest(rows *sql.Rows, dest any) error {
|
|
// dest must be pointer to slice (of struct or map)
|
|
rv := reflect.ValueOf(dest)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("dest must be a non-nil pointer")
|
|
}
|
|
sliceVal := rv.Elem()
|
|
if sliceVal.Kind() != reflect.Slice {
|
|
return errors.New("dest must be pointer to a slice")
|
|
}
|
|
elemType := sliceVal.Type().Elem()
|
|
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for rows.Next() {
|
|
itemPtr := reflect.New(elemType)
|
|
// Support map[string]any and struct
|
|
if elemType.Kind() == reflect.Map {
|
|
m, err := scanRowToMap(rows, cols)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sliceVal.Set(reflect.Append(sliceVal, reflect.ValueOf(m)))
|
|
continue
|
|
}
|
|
|
|
if elemType.Kind() == reflect.Struct {
|
|
if err := scanCurrentRowIntoStruct(rows, cols, itemPtr.Elem()); err != nil {
|
|
return err
|
|
}
|
|
sliceVal.Set(reflect.Append(sliceVal, itemPtr.Elem()))
|
|
continue
|
|
}
|
|
|
|
return fmt.Errorf("unsupported slice element type: %s", elemType.Kind())
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
func scanIntoSingle(rows *sql.Rows, dest any) error {
|
|
rv := reflect.ValueOf(dest)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("dest must be a non-nil pointer")
|
|
}
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch rv.Elem().Kind() {
|
|
case reflect.Map:
|
|
m, err := scanRowToMap(rows, cols)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rv.Elem().Set(reflect.ValueOf(m))
|
|
return nil
|
|
case reflect.Struct:
|
|
return scanCurrentRowIntoStruct(rows, cols, rv.Elem())
|
|
default:
|
|
return fmt.Errorf("unsupported dest kind: %s", rv.Elem().Kind())
|
|
}
|
|
}
|
|
|
|
func scanRowToMap(rows *sql.Rows, cols []string) (map[string]any, error) {
|
|
raw := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range raw {
|
|
ptrs[i] = &raw[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
out := make(map[string]any, len(cols))
|
|
for i, c := range cols {
|
|
out[c] = normalizeSQLValue(raw[i])
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func scanCurrentRowIntoStruct(rows *sql.Rows, cols []string, destStruct reflect.Value) error {
|
|
raw := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range raw {
|
|
ptrs[i] = &raw[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return err
|
|
}
|
|
fieldIndex := buildFieldIndex(destStruct.Type())
|
|
for i, c := range cols {
|
|
if idx, ok := fieldIndex[strings.ToLower(c)]; ok {
|
|
field := destStruct.Field(idx)
|
|
if field.CanSet() {
|
|
if err := setReflectValue(field, raw[i]); err != nil {
|
|
return fmt.Errorf("column %s: %w", c, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeSQLValue(v any) any {
|
|
switch t := v.(type) {
|
|
case []byte:
|
|
return string(t)
|
|
default:
|
|
return v
|
|
}
|
|
}
|
|
|
|
func buildFieldIndex(t reflect.Type) map[string]int {
|
|
m := make(map[string]int)
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if f.IsExported() == false {
|
|
continue
|
|
}
|
|
tag := f.Tag.Get("db")
|
|
col := ""
|
|
if tag != "" {
|
|
col = strings.Split(tag, ",")[0]
|
|
}
|
|
if col == "" {
|
|
col = f.Name
|
|
}
|
|
m[strings.ToLower(col)] = i
|
|
}
|
|
return m
|
|
}
|
|
|
|
func setReflectValue(field reflect.Value, raw any) error {
|
|
if raw == nil {
|
|
// leave zero value
|
|
return nil
|
|
}
|
|
switch field.Kind() {
|
|
case reflect.String:
|
|
switch v := raw.(type) {
|
|
case string:
|
|
field.SetString(v)
|
|
case []byte:
|
|
field.SetString(string(v))
|
|
default:
|
|
field.SetString(fmt.Sprint(v))
|
|
}
|
|
case reflect.Bool:
|
|
switch v := raw.(type) {
|
|
case bool:
|
|
field.SetBool(v)
|
|
case int64:
|
|
field.SetBool(v != 0)
|
|
case []byte:
|
|
s := string(v)
|
|
field.SetBool(s == "1" || strings.EqualFold(s, "true"))
|
|
default:
|
|
field.SetBool(false)
|
|
}
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
switch v := raw.(type) {
|
|
case int64:
|
|
field.SetInt(v)
|
|
case []byte:
|
|
var n int64
|
|
fmt.Sscan(string(v), &n)
|
|
field.SetInt(n)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to int", raw)
|
|
}
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
switch v := raw.(type) {
|
|
case int64:
|
|
if v < 0 {
|
|
v = 0
|
|
}
|
|
field.SetUint(uint64(v))
|
|
case []byte:
|
|
var n uint64
|
|
fmt.Sscan(string(v), &n)
|
|
field.SetUint(n)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to uint", raw)
|
|
}
|
|
case reflect.Float32, reflect.Float64:
|
|
switch v := raw.(type) {
|
|
case float64:
|
|
field.SetFloat(v)
|
|
case []byte:
|
|
var fv float64
|
|
fmt.Sscan(string(v), &fv)
|
|
field.SetFloat(fv)
|
|
default:
|
|
return fmt.Errorf("cannot convert %T to float", raw)
|
|
}
|
|
case reflect.Struct:
|
|
// Support time.Time; extend as needed.
|
|
if field.Type() == reflect.TypeOf(time.Time{}) {
|
|
switch v := raw.(type) {
|
|
case time.Time:
|
|
field.Set(reflect.ValueOf(v))
|
|
case []byte:
|
|
// Try RFC3339
|
|
if tt, err := time.Parse(time.RFC3339, string(v)); err == nil {
|
|
field.Set(reflect.ValueOf(tt))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
fallthrough
|
|
default:
|
|
// Not supported yet
|
|
return fmt.Errorf("unsupported dest field kind: %s", field.Kind())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// -----------------------
|
|
// Save/Remove (basic PK)
|
|
// -----------------------
|
|
|
|
type fieldMeta struct {
|
|
index int
|
|
column string
|
|
isPK bool
|
|
auto bool
|
|
}
|
|
|
|
func collectMeta(t reflect.Type) (fields []fieldMeta, pk fieldMeta, hasPK bool) {
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if !f.IsExported() {
|
|
continue
|
|
}
|
|
tag := f.Tag.Get("db")
|
|
if tag == "-" {
|
|
continue
|
|
}
|
|
opts := strings.Split(tag, ",")
|
|
col := opts[0]
|
|
if col == "" {
|
|
col = f.Name
|
|
}
|
|
meta := fieldMeta{index: i, column: col}
|
|
for _, o := range opts[1:] {
|
|
switch strings.ToLower(strings.TrimSpace(o)) {
|
|
case "pk":
|
|
meta.isPK = true
|
|
case "auto", "autoincrement":
|
|
meta.auto = true
|
|
}
|
|
}
|
|
// If not tagged as pk, fallback to field name "ID"
|
|
if !meta.isPK && f.Name == "ID" {
|
|
meta.isPK = true
|
|
if col == "" {
|
|
meta.column = "id"
|
|
}
|
|
}
|
|
fields = append(fields, meta)
|
|
if meta.isPK {
|
|
pk = meta
|
|
hasPK = true
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func getTableNameFromEntity(v reflect.Value) (string, bool) {
|
|
// If entity implements TableNamer
|
|
if v.CanInterface() {
|
|
if tn, ok := v.Interface().(TableNamer); ok {
|
|
return tn.TableName(), true
|
|
}
|
|
}
|
|
// Fallback: very naive pluralization (append 's')
|
|
typ := v.Type()
|
|
if typ.Kind() == reflect.Pointer {
|
|
typ = typ.Elem()
|
|
}
|
|
if typ.Kind() == reflect.Struct {
|
|
return strings.ToLower(typ.Name()) + "s", true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func saveEntity(ctx context.Context, exec executor, entity any) error {
|
|
rv := reflect.ValueOf(entity)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("entity must be a non-nil pointer to struct")
|
|
}
|
|
ev := rv.Elem()
|
|
if ev.Kind() != reflect.Struct {
|
|
return errors.New("entity must point to a struct")
|
|
}
|
|
|
|
fields, pkMeta, hasPK := collectMeta(ev.Type())
|
|
if !hasPK {
|
|
return errors.New("no primary key field found (tag db:\"...,pk\" or field named ID)")
|
|
}
|
|
table, ok := getTableNameFromEntity(ev)
|
|
if !ok || table == "" {
|
|
return errors.New("unable to resolve table name; implement TableNamer or set up a repository with explicit table")
|
|
}
|
|
|
|
// Build lists
|
|
cols := make([]string, 0, len(fields))
|
|
vals := make([]any, 0, len(fields))
|
|
setParts := make([]string, 0, len(fields))
|
|
|
|
var pkVal any
|
|
var pkIsZero bool
|
|
|
|
for _, fm := range fields {
|
|
f := ev.Field(fm.index)
|
|
if fm.isPK {
|
|
pkVal = f.Interface()
|
|
pkIsZero = isZeroValue(f)
|
|
continue
|
|
}
|
|
cols = append(cols, fm.column)
|
|
vals = append(vals, f.Interface())
|
|
setParts = append(setParts, fmt.Sprintf("%s = ?", fm.column))
|
|
}
|
|
|
|
if pkIsZero {
|
|
// INSERT
|
|
placeholders := strings.Repeat("?,", len(cols))
|
|
if len(placeholders) > 0 {
|
|
placeholders = placeholders[:len(placeholders)-1]
|
|
}
|
|
sqlStr := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, strings.Join(cols, ", "), placeholders)
|
|
res, err := exec.ExecContext(ctx, sqlStr, vals...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Set auto ID if needed
|
|
if pkMeta.auto {
|
|
if id, err := res.LastInsertId(); err == nil {
|
|
ev.Field(pkMeta.index).SetInt(id)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UPDATE ... WHERE pk = ?
|
|
sqlStr := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", table, strings.Join(setParts, ", "), pkMeta.column)
|
|
valsWithPK := append(vals, pkVal)
|
|
_, err := exec.ExecContext(ctx, sqlStr, valsWithPK...)
|
|
return err
|
|
}
|
|
|
|
func removeEntity(ctx context.Context, exec executor, entity any) error {
|
|
rv := reflect.ValueOf(entity)
|
|
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
|
return errors.New("entity must be a non-nil pointer to struct")
|
|
}
|
|
ev := rv.Elem()
|
|
if ev.Kind() != reflect.Struct {
|
|
return errors.New("entity must point to a struct")
|
|
}
|
|
_, pkMeta, hasPK := collectMeta(ev.Type())
|
|
if !hasPK {
|
|
return errors.New("no primary key field found")
|
|
}
|
|
table, ok := getTableNameFromEntity(ev)
|
|
if !ok || table == "" {
|
|
return errors.New("unable to resolve table name")
|
|
}
|
|
pkVal := ev.Field(pkMeta.index).Interface()
|
|
sqlStr := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkMeta.column)
|
|
_, err := exec.ExecContext(ctx, sqlStr, pkVal)
|
|
return err
|
|
}
|
|
|
|
func isZeroValue(v reflect.Value) bool {
|
|
switch v.Kind() {
|
|
case reflect.String:
|
|
return v.Len() == 0
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return v.Int() == 0
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return v.Uint() == 0
|
|
case reflect.Bool:
|
|
return v.Bool() == false
|
|
case reflect.Pointer, reflect.Interface:
|
|
return v.IsNil()
|
|
case reflect.Slice, reflect.Map:
|
|
return v.Len() == 0
|
|
case reflect.Struct:
|
|
// Special-case time.Time
|
|
if v.Type() == reflect.TypeOf(time.Time{}) {
|
|
t := v.Interface().(time.Time)
|
|
return t.IsZero()
|
|
}
|
|
zero := reflect.Zero(v.Type())
|
|
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
|
default:
|
|
return false
|
|
}
|
|
}
|