Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: test

on:
push:
paths-ignore: [benchmarks/**, docs/**]
paths-ignore: [benchmarks/**, docs/**, '**.md']
branches: [main]
pull_request:
paths-ignore: [benchmarks/**, docs/**]
paths-ignore: [benchmarks/**, docs/**, '**.md']
branches: [main]

jobs:
Expand Down
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/rfberaldo/sqlz.svg)](https://pkg.go.dev/github.com/rfberaldo/sqlz)
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)

**sqlz** is a lightweight, dependency-free Go library that extends the standard [database/sql](https://pkg.go.dev/database/sql) package with named queries, scanning, and batch operations, while having a simple API.
**sqlz** is a lightweight, dependency-free Go library that extends the standard [database/sql](https://pkg.go.dev/database/sql) package, adding support for named queries, struct scanning, and batch operations, while having a clean, minimal API.

It's designed to feel familiar to anyone using [database/sql](https://pkg.go.dev/database/sql), while removing repetitive boilerplate code. It can scan directly into structs, maps, or slices, and run named queries with full UTF-8/multilingual support.

> Documentation: https://rfberaldo.github.io/sqlz/.

## Features

- Named queries for structs and maps.
- Automatic scanning into primitives, structs, maps and slices.
- Automatic expanding "IN" clauses.
- Automatic expanding batch inserts.
- Automatic prepared statement caching.

## Getting started

### Install
Expand Down Expand Up @@ -80,7 +90,4 @@ db.Exec(ctx, "INSERT INTO user (name, email) VALUES (:name, :email)", users)

- It was designed with a simpler API for everyday use, with fewer concepts and less verbose.
- It has full support for UTF-8/multilingual named queries.

### Performance

Take a look at [benchmarks](benchmarks) for more info.
- It's more performant in most cases, take a look at the [benchmarks](benchmarks) for comparison.
103 changes: 86 additions & 17 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,31 @@ import (

"github.com/rfberaldo/sqlz/internal/parser"
"github.com/rfberaldo/sqlz/internal/reflectutil"
"github.com/rfberaldo/sqlz/internal/stmtcache"
)

// querier is satisfied by [sql.DB], [sql.Tx] or [sql.Conn].
type querier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}

// base contains main methods that are shared between [DB] and [Tx].
type base struct {
*config
stmtCache *stmtcache.StmtCache
}

func newBase(cfg *config) *base {
cfg = applyDefaults(cfg)
base := &base{config: cfg}

if cfg.stmtCacheCapacity > 0 {
base.stmtCache = stmtcache.New(cfg.stmtCacheCapacity)
}

return base
}

func (c *base) resolveQuery(query string, args []any) (string, []any, error) {
Expand All @@ -28,24 +42,24 @@ func (c *base) resolveQuery(query string, args []any) (string, []any, error) {
}

if len(args) == 0 {
return query, args, nil
return query, nil, nil
}

switch reflectutil.TypeOfAny(args[0]) {
case reflectutil.Struct, reflectutil.Map,
reflectutil.SliceStruct, reflectutil.SliceMap:
if len(args) > 1 {
return "", nil, fmt.Errorf("sqlz: too many arguments: want 1 got %d", len(args))
}
return processNamed(query, args[0], c.config)
argType := reflectutil.TypeOfAny(args[0])

case reflectutil.Invalid:
if argType == reflectutil.Invalid {
panic(fmt.Sprintf("sqlz: unsupported argument type: %T", args[0]))
}

default:
// must be a native query, just parse for possible "IN" clauses
return parser.ParseInClause(c.bind, query, args)
if argType.IsNamed() {
if len(args) > 1 {
return "", nil, fmt.Errorf("sqlz: too many arguments for named query, want 1 got %d", len(args))
}
return processNamed(query, args[0], c.config)
}

// must be a native query, just parse for possible "IN" clauses
return parser.ParseInClause(c.bind, query, args)
}

func (c *base) query(ctx context.Context, db querier, query string, args ...any) *Scanner {
Expand All @@ -54,11 +68,22 @@ func (c *base) query(ctx context.Context, db querier, query string, args ...any)
return &Scanner{err: err}
}

rows, err := db.QueryContext(ctx, query, args...)
if c.stmtCache == nil || len(args) == 0 {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return &Scanner{err: err}
}
return newScanner(rows, c.config)
}

stmt, err := c.loadOrPrepare(ctx, db, query)
if err != nil {
return &Scanner{err: err}
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return &Scanner{err: err}
}

return newScanner(rows, c.config)
}

Expand All @@ -68,11 +93,22 @@ func (c *base) queryRow(ctx context.Context, db querier, query string, args ...a
return &Scanner{err: err}
}

rows, err := db.QueryContext(ctx, query, args...)
if c.stmtCache == nil || len(args) == 0 {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return &Scanner{err: err}
}
return newRowScanner(rows, c.config)
}

stmt, err := c.loadOrPrepare(ctx, db, query)
if err != nil {
return &Scanner{err: err}
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return &Scanner{err: err}
}

return newRowScanner(rows, c.config)
}

Expand All @@ -82,5 +118,38 @@ func (c *base) exec(ctx context.Context, db querier, query string, args ...any)
return nil, err
}

return db.ExecContext(ctx, query, args...)
if c.stmtCache == nil || len(args) == 0 {
return db.ExecContext(ctx, query, args...)
}

stmt, err := c.loadOrPrepare(ctx, db, query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}

func (c *base) loadOrPrepare(ctx context.Context, db querier, query string) (*sql.Stmt, error) {
if c.stmtCache == nil {
panic("sqlz: stmt cache is not enabled")
}

stmt, ok := c.stmtCache.Get(query)
if !ok {
var err error
stmt, err = db.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("sqlz: preparing stmt: %w", err)
}
c.stmtCache.Put(query, stmt)
}

return stmt.(*sql.Stmt), nil
}

func (c *base) clearStmtCache() {
if c.stmtCache == nil {
return
}
c.stmtCache.Clear()
}
62 changes: 55 additions & 7 deletions base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@ import (

var ctx = context.Background()

func newBase(cfg *config) *base {
cfg.defaults()
return &base{cfg}
}

func TestBase_basic(t *testing.T) {
runConn(t, func(t *testing.T, conn *Conn) {
base := newBase(&config{bind: conn.bind})
query := "SELECT 'Hello World'"

t.Run("select", func(t *testing.T) {
t.Run("query", func(t *testing.T) {
var got []string
err := base.query(ctx, conn.db, query).Scan(&got)
require.NoError(t, err)
Expand All @@ -36,7 +31,7 @@ func TestBase_basic(t *testing.T) {
assert.Equal(t, expect, got)
})

t.Run("get", func(t *testing.T) {
t.Run("queryRow", func(t *testing.T) {
var got string
err := base.queryRow(ctx, conn.db, query).Scan(&got)
require.NoError(t, err)
Expand All @@ -45,6 +40,18 @@ func TestBase_basic(t *testing.T) {
assert.Equal(t, expect, got)
})

t.Run("exec", func(t *testing.T) {
_, err := base.exec(ctx, conn.db, query)
require.NoError(t, err)
})
})
}

func TestBase_basic_no_stmt_cache(t *testing.T) {
runConn(t, func(t *testing.T, conn *Conn) {
base := newBase(&config{bind: conn.bind, stmtCacheCapacity: 0})
query := "SELECT 'Hello World'"

t.Run("query", func(t *testing.T) {
var got []string
err := base.query(ctx, conn.db, query).Scan(&got)
Expand All @@ -62,6 +69,11 @@ func TestBase_basic(t *testing.T) {
expect := "Hello World"
assert.Equal(t, expect, got)
})

t.Run("exec", func(t *testing.T) {
_, err := base.exec(ctx, conn.db, query)
require.NoError(t, err)
})
})
}

Expand Down Expand Up @@ -596,3 +608,39 @@ func TestBase_valuerInterface(t *testing.T) {
assert.ErrorContains(t, err, "not a valid email")
})
}

// BenchmarkBatchInsertStruct-12 210 5568681 ns/op 389638 B/op 3042 allocs/op
func BenchmarkBatchInsertStruct(b *testing.B) {
conn := mysqlConn
base := newBase(&config{bind: conn.bind})
th := newTableHelper(b, conn.db, conn.bind)

_, err := conn.db.Exec(th.fmt(`
CREATE TABLE IF NOT EXISTS %s (
id INT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(255) NOT NULL,
email VARCHAR(255),
password VARCHAR(255),
age INT
)`,
))
require.NoError(b, err)

type user struct {
Username string
Email string
Password string
Age int
}
var args []user
for range 1000 {
args = append(args, user{"john", "john@id.com", "doom", 42})
}
input := th.fmt(`INSERT INTO %s (username, email, password, age)
VALUES (:username, :email, :password, :age)`)

for b.Loop() {
_, err := base.exec(ctx, conn.db, input, args)
require.NoError(b, err)
}
}
34 changes: 30 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,46 @@ import (
"github.com/rfberaldo/sqlz/internal/parser"
)

const (
defaultStructTag = "db"
defaultBind = parser.BindQuestion
defaultStmtCacheCapacity = 16
)

var (
defaultFieldNameTransformer = ToSnakeCase
)

// config contains flags that are used across internal objects.
type config struct {
defaultsApplied bool
bind parser.Bind
structTag string
fieldNameTransformer func(string) string
ignoreMissingFields bool
stmtCacheCapacity int
}

// defaults sets config defaults if not set.
func (cfg *config) defaults() {
cfg.bind = cmp.Or(cfg.bind, parser.BindQuestion)
// applyDefaults returns a cfg with defaults applied, if not set.
func applyDefaults(cfg *config) *config {
if cfg == nil {
cfg = &config{}
}

// make it easy to create custom configs during tests and avoid data racing
if cfg.defaultsApplied {
return cfg
}

cfg.defaultsApplied = true

cfg.bind = cmp.Or(cfg.bind, defaultBind)
cfg.structTag = cmp.Or(cfg.structTag, defaultStructTag)
cfg.stmtCacheCapacity = cmp.Or(cfg.stmtCacheCapacity, defaultStmtCacheCapacity)

if cfg.fieldNameTransformer == nil {
cfg.fieldNameTransformer = ToSnakeCase
cfg.fieldNameTransformer = defaultFieldNameTransformer
}

return cfg
}
1 change: 1 addition & 0 deletions docs/.vitepress/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export default defineConfig({
{ text: 'Querying', link: '/querying' },
{ text: 'Scanning', link: '/scanning' },
{ text: 'Transactions', link: '/transactions' },
{ text: 'Prepared statement caching', link: '/prepared-stmt-caching' },
{ text: 'Custom options', link: '/custom-options' },
{ text: 'Connection pool', link: '/connection-pool' },
],
Expand Down
14 changes: 10 additions & 4 deletions docs/connection-pool.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
Query execution requires a connection, and [sql.DB](https://pkg.go.dev/database/sql#DB) is a pool of connections: whenever you make a query, it grabs a connection, executes it, and returns it to the pool.
There are two ways to control the size of the connection pool:

```go
db, err := sqlz.Connect("sqlite3", ":memory:")
db.Pool().SetMaxOpenConns(n)
db.Pool().SetMaxIdleConns(n)
::: code-group
```go [sqlz.DB]
DB.Pool().SetMaxOpenConns(n int)
DB.Pool().SetMaxIdleConns(n int)
```

```go [sql.DB]
DB.SetMaxOpenConns(n int)
DB.SetMaxIdleConns(n int)
```
:::

By default, the pool creates a new connection whenever needed if all existing connections are in use.
[sql.DB.SetMaxOpenConns](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns) imposes a limit on the number of open connections. Past this limit, new database operations will wait for an existing operation to finish.

Expand Down
Loading