diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e859409..cbb5ad3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -2,10 +2,10 @@ name: test on: push: - paths-ignore: [benchmarks/**, docs/**] + paths-ignore: [benchmarks/**, docs/**, '**.md'] branches: [main] pull_request: - paths-ignore: [benchmarks/**, docs/**] + paths-ignore: [benchmarks/**, docs/**, '**.md'] branches: [main] jobs: diff --git a/README.md b/README.md index 459fa70..e144b27 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,20 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/rfberaldo/sqlz.svg)](https://pkg.go.dev/github.com/rfberaldo/sqlz) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) -**sqlz** is a lightweight, dependency-free Go library that extends the standard [database/sql](https://pkg.go.dev/database/sql) package with named queries, scanning, and batch operations, while having a simple API. +**sqlz** is a lightweight, dependency-free Go library that extends the standard [database/sql](https://pkg.go.dev/database/sql) package, adding support for named queries, struct scanning, and batch operations, while having a clean, minimal API. + +It's designed to feel familiar to anyone using [database/sql](https://pkg.go.dev/database/sql), while removing repetitive boilerplate code. It can scan directly into structs, maps, or slices, and run named queries with full UTF-8/multilingual support. > Documentation: https://rfberaldo.github.io/sqlz/. +## Features + +- Named queries for structs and maps. +- Automatic scanning into primitives, structs, maps and slices. +- Automatic expanding "IN" clauses. +- Automatic expanding batch inserts. +- Automatic prepared statement caching. + ## Getting started ### Install @@ -80,7 +90,4 @@ db.Exec(ctx, "INSERT INTO user (name, email) VALUES (:name, :email)", users) - It was designed with a simpler API for everyday use, with fewer concepts and less verbose. - It has full support for UTF-8/multilingual named queries. - -### Performance - -Take a look at [benchmarks](benchmarks) for more info. +- It's more performant in most cases, take a look at the [benchmarks](benchmarks) for comparison. diff --git a/base.go b/base.go index f6630a3..2021790 100644 --- a/base.go +++ b/base.go @@ -8,17 +8,31 @@ import ( "github.com/rfberaldo/sqlz/internal/parser" "github.com/rfberaldo/sqlz/internal/reflectutil" + "github.com/rfberaldo/sqlz/internal/stmtcache" ) // querier is satisfied by [sql.DB], [sql.Tx] or [sql.Conn]. type querier interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // base contains main methods that are shared between [DB] and [Tx]. type base struct { *config + stmtCache *stmtcache.StmtCache +} + +func newBase(cfg *config) *base { + cfg = applyDefaults(cfg) + base := &base{config: cfg} + + if cfg.stmtCacheCapacity > 0 { + base.stmtCache = stmtcache.New(cfg.stmtCacheCapacity) + } + + return base } func (c *base) resolveQuery(query string, args []any) (string, []any, error) { @@ -28,24 +42,24 @@ func (c *base) resolveQuery(query string, args []any) (string, []any, error) { } if len(args) == 0 { - return query, args, nil + return query, nil, nil } - switch reflectutil.TypeOfAny(args[0]) { - case reflectutil.Struct, reflectutil.Map, - reflectutil.SliceStruct, reflectutil.SliceMap: - if len(args) > 1 { - return "", nil, fmt.Errorf("sqlz: too many arguments: want 1 got %d", len(args)) - } - return processNamed(query, args[0], c.config) + argType := reflectutil.TypeOfAny(args[0]) - case reflectutil.Invalid: + if argType == reflectutil.Invalid { panic(fmt.Sprintf("sqlz: unsupported argument type: %T", args[0])) + } - default: - // must be a native query, just parse for possible "IN" clauses - return parser.ParseInClause(c.bind, query, args) + if argType.IsNamed() { + if len(args) > 1 { + return "", nil, fmt.Errorf("sqlz: too many arguments for named query, want 1 got %d", len(args)) + } + return processNamed(query, args[0], c.config) } + + // must be a native query, just parse for possible "IN" clauses + return parser.ParseInClause(c.bind, query, args) } func (c *base) query(ctx context.Context, db querier, query string, args ...any) *Scanner { @@ -54,11 +68,22 @@ func (c *base) query(ctx context.Context, db querier, query string, args ...any) return &Scanner{err: err} } - rows, err := db.QueryContext(ctx, query, args...) + if c.stmtCache == nil || len(args) == 0 { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return &Scanner{err: err} + } + return newScanner(rows, c.config) + } + + stmt, err := c.loadOrPrepare(ctx, db, query) + if err != nil { + return &Scanner{err: err} + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return &Scanner{err: err} } - return newScanner(rows, c.config) } @@ -68,11 +93,22 @@ func (c *base) queryRow(ctx context.Context, db querier, query string, args ...a return &Scanner{err: err} } - rows, err := db.QueryContext(ctx, query, args...) + if c.stmtCache == nil || len(args) == 0 { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return &Scanner{err: err} + } + return newRowScanner(rows, c.config) + } + + stmt, err := c.loadOrPrepare(ctx, db, query) + if err != nil { + return &Scanner{err: err} + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return &Scanner{err: err} } - return newRowScanner(rows, c.config) } @@ -82,5 +118,38 @@ func (c *base) exec(ctx context.Context, db querier, query string, args ...any) return nil, err } - return db.ExecContext(ctx, query, args...) + if c.stmtCache == nil || len(args) == 0 { + return db.ExecContext(ctx, query, args...) + } + + stmt, err := c.loadOrPrepare(ctx, db, query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args...) +} + +func (c *base) loadOrPrepare(ctx context.Context, db querier, query string) (*sql.Stmt, error) { + if c.stmtCache == nil { + panic("sqlz: stmt cache is not enabled") + } + + stmt, ok := c.stmtCache.Get(query) + if !ok { + var err error + stmt, err = db.PrepareContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("sqlz: preparing stmt: %w", err) + } + c.stmtCache.Put(query, stmt) + } + + return stmt.(*sql.Stmt), nil +} + +func (c *base) clearStmtCache() { + if c.stmtCache == nil { + return + } + c.stmtCache.Clear() } diff --git a/base_test.go b/base_test.go index f83fd7b..147f9c3 100644 --- a/base_test.go +++ b/base_test.go @@ -17,17 +17,12 @@ import ( var ctx = context.Background() -func newBase(cfg *config) *base { - cfg.defaults() - return &base{cfg} -} - func TestBase_basic(t *testing.T) { runConn(t, func(t *testing.T, conn *Conn) { base := newBase(&config{bind: conn.bind}) query := "SELECT 'Hello World'" - t.Run("select", func(t *testing.T) { + t.Run("query", func(t *testing.T) { var got []string err := base.query(ctx, conn.db, query).Scan(&got) require.NoError(t, err) @@ -36,7 +31,7 @@ func TestBase_basic(t *testing.T) { assert.Equal(t, expect, got) }) - t.Run("get", func(t *testing.T) { + t.Run("queryRow", func(t *testing.T) { var got string err := base.queryRow(ctx, conn.db, query).Scan(&got) require.NoError(t, err) @@ -45,6 +40,18 @@ func TestBase_basic(t *testing.T) { assert.Equal(t, expect, got) }) + t.Run("exec", func(t *testing.T) { + _, err := base.exec(ctx, conn.db, query) + require.NoError(t, err) + }) + }) +} + +func TestBase_basic_no_stmt_cache(t *testing.T) { + runConn(t, func(t *testing.T, conn *Conn) { + base := newBase(&config{bind: conn.bind, stmtCacheCapacity: 0}) + query := "SELECT 'Hello World'" + t.Run("query", func(t *testing.T) { var got []string err := base.query(ctx, conn.db, query).Scan(&got) @@ -62,6 +69,11 @@ func TestBase_basic(t *testing.T) { expect := "Hello World" assert.Equal(t, expect, got) }) + + t.Run("exec", func(t *testing.T) { + _, err := base.exec(ctx, conn.db, query) + require.NoError(t, err) + }) }) } @@ -596,3 +608,39 @@ func TestBase_valuerInterface(t *testing.T) { assert.ErrorContains(t, err, "not a valid email") }) } + +// BenchmarkBatchInsertStruct-12 210 5568681 ns/op 389638 B/op 3042 allocs/op +func BenchmarkBatchInsertStruct(b *testing.B) { + conn := mysqlConn + base := newBase(&config{bind: conn.bind}) + th := newTableHelper(b, conn.db, conn.bind) + + _, err := conn.db.Exec(th.fmt(` + CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY AUTO_INCREMENT, + username VARCHAR(255) NOT NULL, + email VARCHAR(255), + password VARCHAR(255), + age INT + )`, + )) + require.NoError(b, err) + + type user struct { + Username string + Email string + Password string + Age int + } + var args []user + for range 1000 { + args = append(args, user{"john", "john@id.com", "doom", 42}) + } + input := th.fmt(`INSERT INTO %s (username, email, password, age) + VALUES (:username, :email, :password, :age)`) + + for b.Loop() { + _, err := base.exec(ctx, conn.db, input, args) + require.NoError(b, err) + } +} diff --git a/config.go b/config.go index 9ce1c61..9b16b24 100644 --- a/config.go +++ b/config.go @@ -6,20 +6,46 @@ import ( "github.com/rfberaldo/sqlz/internal/parser" ) +const ( + defaultStructTag = "db" + defaultBind = parser.BindQuestion + defaultStmtCacheCapacity = 16 +) + +var ( + defaultFieldNameTransformer = ToSnakeCase +) + // config contains flags that are used across internal objects. type config struct { + defaultsApplied bool bind parser.Bind structTag string fieldNameTransformer func(string) string ignoreMissingFields bool + stmtCacheCapacity int } -// defaults sets config defaults if not set. -func (cfg *config) defaults() { - cfg.bind = cmp.Or(cfg.bind, parser.BindQuestion) +// applyDefaults returns a cfg with defaults applied, if not set. +func applyDefaults(cfg *config) *config { + if cfg == nil { + cfg = &config{} + } + + // make it easy to create custom configs during tests and avoid data racing + if cfg.defaultsApplied { + return cfg + } + + cfg.defaultsApplied = true + + cfg.bind = cmp.Or(cfg.bind, defaultBind) cfg.structTag = cmp.Or(cfg.structTag, defaultStructTag) + cfg.stmtCacheCapacity = cmp.Or(cfg.stmtCacheCapacity, defaultStmtCacheCapacity) if cfg.fieldNameTransformer == nil { - cfg.fieldNameTransformer = ToSnakeCase + cfg.fieldNameTransformer = defaultFieldNameTransformer } + + return cfg } diff --git a/docs/.vitepress/config.ts b/docs/.vitepress/config.ts index e184e29..e8176d7 100644 --- a/docs/.vitepress/config.ts +++ b/docs/.vitepress/config.ts @@ -29,6 +29,7 @@ export default defineConfig({ { text: 'Querying', link: '/querying' }, { text: 'Scanning', link: '/scanning' }, { text: 'Transactions', link: '/transactions' }, + { text: 'Prepared statement caching', link: '/prepared-stmt-caching' }, { text: 'Custom options', link: '/custom-options' }, { text: 'Connection pool', link: '/connection-pool' }, ], diff --git a/docs/connection-pool.md b/docs/connection-pool.md index 648609b..2d17381 100644 --- a/docs/connection-pool.md +++ b/docs/connection-pool.md @@ -3,12 +3,18 @@ Query execution requires a connection, and [sql.DB](https://pkg.go.dev/database/sql#DB) is a pool of connections: whenever you make a query, it grabs a connection, executes it, and returns it to the pool. There are two ways to control the size of the connection pool: -```go -db, err := sqlz.Connect("sqlite3", ":memory:") -db.Pool().SetMaxOpenConns(n) -db.Pool().SetMaxIdleConns(n) +::: code-group +```go [sqlz.DB] +DB.Pool().SetMaxOpenConns(n int) +DB.Pool().SetMaxIdleConns(n int) ``` +```go [sql.DB] +DB.SetMaxOpenConns(n int) +DB.SetMaxIdleConns(n int) +``` +::: + By default, the pool creates a new connection whenever needed if all existing connections are in use. [sql.DB.SetMaxOpenConns](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns) imposes a limit on the number of open connections. Past this limit, new database operations will wait for an existing operation to finish. diff --git a/docs/custom-options.md b/docs/custom-options.md index 6228269..844e1a4 100644 --- a/docs/custom-options.md +++ b/docs/custom-options.md @@ -1,7 +1,8 @@ # Custom options To set custom options, use the [Options](https://pkg.go.dev/github.com/rfberaldo/sqlz#Options) object with the `New()` constructor. -Any option can be left blank for defaults: + +Shown values are defaults: ```go pool, err := sql.Open("sqlite3", ":memory:") @@ -16,5 +17,10 @@ db := sqlz.New("sqlite3", pool, &sqlz.Options{ // IgnoreMissingFields causes the scanner to ignore missing struct fields // rather than returning an error. IgnoreMissingFields: false, + + // StatementCacheCapacity sets the maximum number of cached statements, + // if it's zero, prepared statement caching is completely disabled. + // Note that each statement may be prepared on each connection in the pool. + StatementCacheCapacity 16, }) ``` diff --git a/docs/index.md b/docs/index.md index 168e423..3fdb5d2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ title: Guide Introduction # Introduction -**sqlz** is a lightweight Go library that builds on top of the standard [database/sql](https://pkg.go.dev/database/sql) package, adding first-class support for named queries, struct scanning, and batch operations, while having a clean, minimal API and zero external dependencies. +**sqlz** is a lightweight, dependency-free Go library that extends the standard [database/sql](https://pkg.go.dev/database/sql) package, adding support for named queries, struct scanning, and batch operations, while having a clean, minimal API. It's designed to feel familiar to anyone using [database/sql](https://pkg.go.dev/database/sql), while removing repetitive boilerplate code. It can scan directly into structs, maps, or slices, and run named queries with full UTF-8/multilingual support. @@ -17,10 +17,10 @@ It also doesn't know anything about relationships between objects. ## Features - Named queries for structs and maps. -- Auto-scanning into primitives, structs, maps and slices. -- Auto-expanding "IN" clauses. -- Auto-expanding batch inserts. -- Performant. +- Automatic scanning into primitives, structs, maps and slices. +- Automatic expanding "IN" clauses. +- Automatic expanding batch inserts. +- Automatic prepared statement caching. ## About this documentation @@ -30,4 +30,4 @@ It also doesn't know anything about relationships between objects. ## Similar projects -**sqlz** was inspired by [sqlx](https://github.com/jmoiron/sqlx/) and [scanny](https://github.com/georgysavva/scany/). +**sqlz** was inspired by [sqlx](https://github.com/jmoiron/sqlx/) and [scany](https://github.com/georgysavva/scany/). diff --git a/docs/prepared-stmt-caching.md b/docs/prepared-stmt-caching.md new file mode 100644 index 0000000..dd2394c --- /dev/null +++ b/docs/prepared-stmt-caching.md @@ -0,0 +1,26 @@ +# Prepared statement caching + +By default, **sqlz** automatically caches prepared statements. +Under the hood, it uses an [LRU caching policy](https://en.wikipedia.org/wiki/Cache_replacement_policies#LRU), meaning it will always keep the most frequently used queries prepared. +Default capacity is 16, but it can be [customized](/custom-options). + +Setting `StatementCacheCapacity: 0` completely disables this feature. + +Finding the sweet spot for the caching capacity will depend on your application. +When increasing the capacity, database memory usage will also increase, while CPU usage will decrease. + +Some databases limit the number of prepared statements; [MySQL](https://dev.mysql.com/doc/refman/8.4/en/server-system-variables.html#sysvar_max_prepared_stmt_count) for instance has a default limit of 16382, while PostgreSQL has no fixed limit. + +> [!IMPORTANT] +> Note that internally, each prepared statement is bound to a connection, but [database/sql](https://pkg.go.dev/database/sql) will prepare it on other connections automatically when needed. +> This means that, effectively, the cache capacity is per active connection, which is why the default capacity is conservative. + +Limiting the [connection pool](/connection-pool) may have a large impact: statements will eventually be prepared across all active connections, making memory usage predictable. + +For example, given a maximum of 16 connections and 16 cache capacity, the **maximum number** of cached statements would be 256. + +Transactions have their own cache, and are cleared on `Commit()` or `Rollback()`. + +> [!WARNING] +> Note that while having this feature active, database schema changes also require the cache to reset. +> You can just restart the application, or call `DB.ClearStmtCache()` to clear the cache. diff --git a/internal/reflectutil/reflectutil.go b/internal/reflectutil/reflectutil.go index 22d623b..c69d06c 100644 --- a/internal/reflectutil/reflectutil.go +++ b/internal/reflectutil/reflectutil.go @@ -7,7 +7,7 @@ import ( // Type is similar to [reflect.Kind], but adds support for type of slices. // [reflect.Func], [reflect.Chan], [reflect.Array] and [reflect.UnsafePointer] are considered Invalid. // Nil is considered Primitive. -type Type uint8 +type Type uint const ( Invalid Type = 0 @@ -21,12 +21,16 @@ const ( ) func (t Type) IsSlice() bool { - return t&Slice != 0 + return (t & Slice) != 0 } -// IsPrimitive reports whether t is [Primitive] or [SlicePrimitive]. func (t Type) IsPrimitive() bool { - return t&Primitive != 0 + return (t & Primitive) != 0 +} + +// IsNamed reports whether t contains [Struct] or [Map]. +func (t Type) IsNamed() bool { + return (t & (Struct | Map)) != 0 } // TypeOfAny recursively returns the Type of arg, nil is considered Primitive. diff --git a/internal/stmtcache/lrucache.go b/internal/stmtcache/lrucache.go new file mode 100644 index 0000000..e295cb9 --- /dev/null +++ b/internal/stmtcache/lrucache.go @@ -0,0 +1,66 @@ +package stmtcache + +import ( + "container/list" + "sync" +) + +type lruCache[K comparable, V any] struct { + cap int + mutex sync.Mutex + m map[K]*list.Element + l *list.List + onEvict func(K, V) +} + +func newLRUCache[K comparable, V any](cap int, onEvict func(K, V)) *lruCache[K, V] { + return &lruCache[K, V]{ + cap, sync.Mutex{}, make(map[K]*list.Element), list.New(), onEvict, + } +} + +type entry[K comparable, V any] struct { + key K + val V +} + +func (c *lruCache[K, V]) get(key K) (val V, ok bool) { + defer c.mutex.Unlock() + c.mutex.Lock() + + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(entry[K, V]).val, true + } + + return val, false +} + +func (c *lruCache[K, V]) put(key K, val V) (evicted bool) { + defer c.mutex.Unlock() + c.mutex.Lock() + + if el, ok := c.m[key]; ok { + el.Value = entry[K, V]{key, val} + c.l.MoveToFront(el) + return + } + + if c.l.Len() >= c.cap { + evicted = true + c.evict() + } + + el := c.l.PushFront(entry[K, V]{key, val}) + c.m[key] = el + + return evicted +} + +func (c *lruCache[K, V]) evict() { + el := c.l.Remove(c.l.Back()).(entry[K, V]) + delete(c.m, el.key) + if c.onEvict != nil { + c.onEvict(el.key, el.val) + } +} diff --git a/internal/stmtcache/lrucache_test.go b/internal/stmtcache/lrucache_test.go new file mode 100644 index 0000000..e289600 --- /dev/null +++ b/internal/stmtcache/lrucache_test.go @@ -0,0 +1,86 @@ +package stmtcache + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUCache(t *testing.T) { + const cap = 2 + c := newLRUCache[string, string](cap, nil) + + t.Run("put and get value", func(t *testing.T) { + evicted := c.put("foo", "fooval") + assert.False(t, evicted) + v, ok := c.get("foo") + require.True(t, ok) + assert.Equal(t, "fooval", v) + }) + + t.Run("updating existing key moves it to front", func(t *testing.T) { + evicted := c.put("foo", "fooval2") + assert.False(t, evicted) + v, ok := c.get("foo") + require.True(t, ok) + assert.Equal(t, "fooval2", v) + assert.Equal(t, "fooval2", c.l.Front().Value.(entry[string, string]).val) + }) + + t.Run("evict when full", func(t *testing.T) { + evicted := c.put("bar", "barval") + assert.False(t, evicted) + + evicted = c.put("baz", "bazval") + assert.True(t, evicted) + + _, ok := c.get("foo") + assert.False(t, ok) + + v, ok := c.get("bar") + assert.True(t, ok) + assert.Equal(t, "barval", v) + + v, ok = c.get("baz") + assert.True(t, ok) + assert.Equal(t, "bazval", v) + + assert.Equal(t, cap, c.l.Len()) + assert.Equal(t, cap, len(c.m)) + }) +} + +func TestLRUCache_concurrency(t *testing.T) { + c := newLRUCache[string, int](50, nil) + var wg sync.WaitGroup + + // multiple writers + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + for j := range 100 { + key := string(rune(j)) + c.put(key, j) + } + }() + } + + // multiple readers + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + for j := range 100 { + key := string(rune(j)) + c.get(key) + } + }() + } + + wg.Wait() + assert.Equal(t, c.l.Len(), c.cap) + assert.Equal(t, len(c.m), c.cap) +} diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go new file mode 100644 index 0000000..b08a54c --- /dev/null +++ b/internal/stmtcache/stmtcache.go @@ -0,0 +1,69 @@ +package stmtcache + +import ( + "container/list" + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" +) + +// stmt is satisfied by [sql.Stmt]. +type stmt interface { + Close() error + ExecContext(ctx context.Context, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) +} + +type StmtCache struct { + *lruCache[string, stmt] +} + +// New returns a new [StmtCache] with n maximum capacity, panics if capacity <= 0. +func New(cap int) *StmtCache { + if cap <= 0 { + panic("sqlz/stmtcache: capacity must be > 0") + } + + return &StmtCache{ + newLRUCache(cap, func(key string, stmt stmt) { + _ = stmt.Close() + }), + } +} + +func (c *StmtCache) Get(key string) (stmt, bool) { + return c.get(hashKey(key)) +} + +// Put adds a new entry to cache, returns whether an item was evicted, +// panics if key is blank. +func (c *StmtCache) Put(key string, stmt stmt) (evicted bool) { + if key == "" { + panic("sqlz/stmtcache: key must not be blank") + } + + return c.put(hashKey(key), stmt) +} + +// Clear removes all entries from the cache, closing all prepared statements. +func (c *StmtCache) Clear() { + for el := c.l.Front(); el != nil; el = el.Next() { + stmt := el.Value.(entry[string, stmt]).val + _ = stmt.Close() + } + c.l.Init() + c.m = make(map[string]*list.Element) +} + +// Len returns the number of cached statements. +func (c *StmtCache) Len() int { + return c.l.Len() +} + +// hashKey hashes s using SHA256, it's deterministic, and it's a consistent +// way to store a query as a key. +func hashKey(s string) string { + digest := sha256.Sum256([]byte(s)) + return hex.EncodeToString(digest[0:24]) +} diff --git a/internal/stmtcache/stmtcache_test.go b/internal/stmtcache/stmtcache_test.go new file mode 100644 index 0000000..c922eb8 --- /dev/null +++ b/internal/stmtcache/stmtcache_test.go @@ -0,0 +1,127 @@ +package stmtcache + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockStmt struct { + closeCalled bool +} + +func (m *mockStmt) Close() error { + m.closeCalled = true + return nil +} + +func (m *mockStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + return nil, nil +} + +func (m *mockStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + return nil, nil +} + +func TestStmtCache(t *testing.T) { + t.Run("panic if cap <= 0", func(t *testing.T) { + assert.Panics(t, func() { New(0) }) + }) + + const cap = 2 + c := New(cap) + + fooStmt := &mockStmt{} + barStmt := &mockStmt{} + bazStmt := &mockStmt{} + + t.Run("put and get value", func(t *testing.T) { + evicted := c.Put("foo", nil) + assert.False(t, evicted) + v, ok := c.Get("foo") + require.True(t, ok) + assert.Equal(t, nil, v) + assert.Equal(t, 1, c.Len()) + }) + + t.Run("updating existing key moves it to front", func(t *testing.T) { + evicted := c.Put("foo", fooStmt) + assert.False(t, evicted) + v, ok := c.Get("foo") + require.True(t, ok) + assert.Equal(t, fooStmt, v) + assert.Equal(t, 1, c.Len()) + }) + + t.Run("evict when full", func(t *testing.T) { + evicted := c.Put("bar", barStmt) + assert.False(t, evicted) + + assert.False(t, fooStmt.closeCalled) + evicted = c.Put("baz", bazStmt) + assert.True(t, evicted) + assert.True(t, fooStmt.closeCalled) + + _, ok := c.Get("foo") + assert.False(t, ok) + + v, ok := c.Get("bar") + assert.True(t, ok) + assert.Equal(t, barStmt, v) + + v, ok = c.Get("baz") + assert.True(t, ok) + assert.Equal(t, bazStmt, v) + + assert.Equal(t, cap, c.Len()) + }) + + t.Run("clear", func(t *testing.T) { + assert.False(t, barStmt.closeCalled) + assert.False(t, bazStmt.closeCalled) + c.Clear() + assert.True(t, barStmt.closeCalled) + assert.True(t, bazStmt.closeCalled) + assert.Equal(t, 0, c.Len()) + }) + + t.Run("blank key should panic", func(t *testing.T) { + assert.Panics(t, func() { + c.Put("", nil) + }) + }) +} + +func TestHashKey(t *testing.T) { + tests := []struct { + name string + input string + expect string + }{ + { + name: "small string", + input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + expect: "a58dd8680234c1f8cc2ef2b325a43733605a7f16f288e072", + }, + { + name: "medium string", + input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Proin sed dapibus sapien. Donec nec ipsum a lorem aliquet blandit. Nullam quis tempus velit. In id massa blandit, sollicitudin dui non, fermentum nulla. Sed sed eros ac elit aliquet malesuada quis nec ligula.", + expect: "bd78ae92057058526d9f6a8cf2b3d6e6911196f15c030d9f", + }, + { + name: "large string", + input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Proin sed dapibus sapien. Donec nec ipsum a lorem aliquet blandit. Mauris metus nibh, commodo ut elit sed, eleifend sollicitudin tellus. Nullam quis tempus velit. In id massa blandit, sollicitudin dui non, fermentum nulla. Sed sed eros ac elit aliquet malesuada quis nec ligula. Etiam nunc ex, accumsan a bibendum pellentesque, maximus et lectus. Ut nisl massa, rutrum id bibendum fringilla, suscipit a nunc. Vivamus fringilla mi eget leo condimentum convallis.", + expect: "aecb66379c0bdac5883dfa5ea01fa7e2bfd5d753b6f31724", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := hashKey(tc.input) + assert.Equal(t, tc.expect, got) + }) + } +} diff --git a/named.go b/named.go index 4f8f5a3..9054af8 100644 --- a/named.go +++ b/named.go @@ -20,12 +20,7 @@ type namedQuery struct { } func processNamed(query string, arg any, cfg *config) (string, []any, error) { - if cfg == nil { - cfg = &config{} - } - cfg.defaults() - - n := &namedQuery{config: cfg} + n := &namedQuery{config: applyDefaults(cfg)} if err := n.process(query, arg); err != nil { return "", nil, err @@ -51,9 +46,8 @@ func (n *namedQuery) process(query string, arg any) error { return fmt.Errorf("sqlz/named: unsupported argument type: %T", arg) } -func (n *namedQuery) processOne(query string, argValue reflect.Value, kind reflect.Kind) error { +func (n *namedQuery) processOne(query string, argValue reflect.Value, kind reflect.Kind) (err error) { query, idents := parser.Parse(n.bind, query) - var err error switch kind { case reflect.Map: @@ -145,8 +139,6 @@ func (n *namedQuery) bindMapArgs(idents []string, argValue reflect.Value) error return nil } -type binderFunc = func(idents []string, argValue reflect.Value) error - func (n *namedQuery) processSlice(query string, sliceValue reflect.Value) error { if sliceValue.Len() == 0 { return fmt.Errorf("sqlz/named: slice is zero length: %s", sliceValue.Type()) @@ -165,20 +157,22 @@ func (n *namedQuery) processSlice(query string, sliceValue reflect.Value) error } } -func (n *namedQuery) bindSliceArgs(query string, sliceValue reflect.Value, binder binderFunc) error { +func (n *namedQuery) bindSliceArgs( + query string, + sliceValue reflect.Value, + fn func(idents []string, argValue reflect.Value) error, +) (err error) { idents := parser.ParseIdents(n.bind, query) if n.args == nil { n.args = make([]any, 0, len(idents)*sliceValue.Len()) } for i := range sliceValue.Len() { - if err := binder(idents, sliceValue.Index(i)); err != nil { + if err := fn(idents, sliceValue.Index(i)); err != nil { return err } } - var err error - // if bind is '?', parse query before expanding if n.bind == parser.BindQuestion { n.query = parser.ParseQuery(n.bind, query) diff --git a/named_test.go b/named_test.go index 8118eea..4b6dcd4 100644 --- a/named_test.go +++ b/named_test.go @@ -321,7 +321,6 @@ func TestProcessNamed(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &config{structTag: tt.structTag} - cfg.defaults() cfg.bind = parser.BindAt query, args, err := processNamed(tt.inputQuery, tt.inputArg, cfg) diff --git a/scanner.go b/scanner.go index 653f0c4..470baf8 100644 --- a/scanner.go +++ b/scanner.go @@ -37,21 +37,18 @@ type Scanner struct { } func newScanner(rows rows, cfg *config) *Scanner { - if cfg == nil { - cfg = &config{} - } - cfg.defaults() - return &Scanner{ - config: cfg, + config: applyDefaults(cfg), rows: rows, } } func newRowScanner(rows rows, cfg *config) *Scanner { - scanner := newScanner(rows, cfg) - scanner.queryRow = true - return scanner + return &Scanner{ + config: applyDefaults(cfg), + rows: rows, + queryRow: true, + } } func (s *Scanner) resolveColumns() (err error) { diff --git a/sqlz.go b/sqlz.go index b7dca76..930a85f 100644 --- a/sqlz.go +++ b/sqlz.go @@ -28,8 +28,7 @@ type Options struct { // Default is "db". StructTag string - // FieldNameTransformer transforms a struct field name, - // it is only used when the struct tag is not found. + // FieldNameTransformer transforms a struct field name when the struct tag is not found. // Default is [ToSnakeCase]. FieldNameTransformer func(string) string @@ -37,6 +36,12 @@ type Options struct { // rather than returning an error. // Default is false. IgnoreMissingFields bool + + // StatementCacheCapacity sets the maximum number of cached statements, + // if it's zero, prepared statement caching is completely disabled. + // Note that each statement may be prepared on each connection in the pool. + // Default is 16. + StatementCacheCapacity int } // New returns a [DB] instance using an existing [sql.DB]. @@ -47,6 +52,10 @@ type Options struct { // pool, err := sql.Open("sqlite3", ":memory:") // db := sqlz.New("sqlite3", pool, nil) func New(driverName string, db *sql.DB, opts *Options) *DB { + if opts != nil && opts.StatementCacheCapacity == 0 { + opts.StatementCacheCapacity = -1 + } + if opts == nil { opts = &Options{} } @@ -56,15 +65,13 @@ func New(driverName string, db *sql.DB, opts *Options) *DB { panic(fmt.Sprintf("sqlz: unable to find bind for '%s', set with Options.Bind", driverName)) } - cfg := &config{ + return &DB{db, newBase(&config{ bind: bind, structTag: opts.StructTag, fieldNameTransformer: opts.FieldNameTransformer, ignoreMissingFields: opts.IgnoreMissingFields, - } - cfg.defaults() - - return &DB{db, &base{cfg}} + stmtCacheCapacity: opts.StatementCacheCapacity, + })} } // Connect opens a database specified by its database driver name and a @@ -111,6 +118,13 @@ type DB struct { // Pool return the underlying [sql.DB]. func (db *DB) Pool() *sql.DB { return db.pool } +// ClearStmtCache clears the prepared statement cache. +// This is useful when the database schema has changed and cached statements +// may no longer be valid. +func (db *DB) ClearStmtCache() { + db.base.clearStmtCache() +} + // Begin starts a transaction. The default isolation level is dependent on // the driver. // @@ -138,7 +152,7 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { return nil, err } - return &Tx{tx, db.base}, nil + return &Tx{tx, newBase(db.base.config)}, nil } // Query executes a query that can return multiple rows. Any errors are deferred @@ -195,13 +209,19 @@ func (tx *Tx) Conn() *sql.Tx { return tx.conn } // Commit commits the transaction. // // If Commit fails, then all queries on the Tx should be discarded as invalid. -func (tx *Tx) Commit() error { return tx.conn.Commit() } +func (tx *Tx) Commit() error { + tx.base.clearStmtCache() + return tx.conn.Commit() +} // Rollback aborts the transaction. // // Even if Rollback fails, the transaction will no longer be valid, // nor will it have been committed to the database. -func (tx *Tx) Rollback() error { return tx.conn.Rollback() } +func (tx *Tx) Rollback() error { + tx.base.clearStmtCache() + return tx.conn.Rollback() +} // Query executes a query that can return multiple rows. Any errors are deferred // until [Scanner.Err] or [Scanner.Scan] is called. diff --git a/sqlz_test.go b/sqlz_test.go index f997db6..292839d 100644 --- a/sqlz_test.go +++ b/sqlz_test.go @@ -21,6 +21,11 @@ func TestNew(t *testing.T) { assert.IsType(t, &DB{}, db) } +func TestNew_no_stmt_cache(t *testing.T) { + db := New("sqlite3", &sql.DB{}, &Options{StatementCacheCapacity: 0}) + assert.Nil(t, db.base.stmtCache) +} + func TestNew_panic(t *testing.T) { defer func() { assert.Contains(t, recover(), "unable to find bind") @@ -405,3 +410,17 @@ func TestDB_Pool(t *testing.T) { assert.IsType(t, &sql.Tx{}, tx.Conn()) }) } + +func TestDB_ClearStmtCache(t *testing.T) { + runConn(t, func(t *testing.T, conn *Conn) { + db := New(conn.driverName, conn.db, nil) + query := rebind(conn.bind, "SELECT 'Hello World' WHERE 1 = ?") + + err := db.QueryRow(ctx, query, 1).Scan(new(string)) + require.NoError(t, err) + assert.Equal(t, db.base.stmtCache.Len(), 1) + + db.ClearStmtCache() + assert.Equal(t, db.base.stmtCache.Len(), 0) + }) +} diff --git a/util.go b/util.go index 5265351..1ddc7ba 100644 --- a/util.go +++ b/util.go @@ -13,8 +13,6 @@ import ( "github.com/rfberaldo/sqlz/internal/parser" ) -const defaultStructTag = "db" - var ( // scannerType is [reflect.Type] of [sql.Scanner] scannerType = reflect.TypeFor[sql.Scanner]()