From c94ea3d9aeb4b115e03d4ec99dd17de14dbd1c48 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Tue, 15 Apr 2025 15:08:28 +0200 Subject: [PATCH 1/2] Fix adding custom collation sequences --- Sources/SQLite/SQLiteDatabase.swift | 75 +++++++------ Tests/SQLiteTests/SQLiteDatabaseTests.swift | 116 ++------------------ 2 files changed, 53 insertions(+), 138 deletions(-) diff --git a/Sources/SQLite/SQLiteDatabase.swift b/Sources/SQLite/SQLiteDatabase.swift index 69eb535..38b3f8b 100644 --- a/Sources/SQLite/SQLiteDatabase.swift +++ b/Sources/SQLite/SQLiteDatabase.swift @@ -9,6 +9,17 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { public static let suspendNotification = GRDB.Database.suspendNotification public static let resumeNotification = GRDB.Database.resumeNotification + public static let unicodeCompare = + GRDB.DatabaseCollation.unicodeCompare.name + public static let caseInsensitiveCompare = + GRDB.DatabaseCollation.caseInsensitiveCompare.name + public static let localizedCaseInsensitiveCompare = + GRDB.DatabaseCollation.localizedCaseInsensitiveCompare.name + public static let localizedCompare = + GRDB.DatabaseCollation.localizedCompare.name + public static let localizedStandardCompare = + GRDB.DatabaseCollation.localizedStandardCompare.name + public let path: String public let sqliteVersion: String @@ -33,7 +44,10 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { public static func makeShared( path: String, - busyTimeout: TimeInterval = 5 + busyTimeout: TimeInterval = 5, + collationSequences: [ + String: @Sendable (String, String) -> ComparisonResult + ] = [:] ) throws -> SQLiteDatabase { guard path != ":memory:" else { throw SQLiteError.SQLITE_IOERR @@ -58,7 +72,8 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { do { database = try SQLiteDatabase( path: url.path, - busyTimeout: busyTimeout + busyTimeout: busyTimeout, + collationSequences: collationSequences ) } catch { databaseError = error @@ -79,8 +94,18 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { return db } - public init(path: String = ":memory:", busyTimeout: TimeInterval = 5) throws { - database = try Self.open(at: path, busyTimeout: busyTimeout) + public init( + path: String = ":memory:", + busyTimeout: TimeInterval = 5, + collationSequences: [ + String: @Sendable (String, String) -> ComparisonResult, + ] = [:] + ) throws { + database = try Self.open( + at: path, + busyTimeout: busyTimeout, + collationSequences: collationSequences + ) self.path = path let sqliteVersion = try Self.getSQLiteVersion(database) self.sqliteVersion = sqliteVersion.description @@ -587,33 +612,6 @@ public extension SQLiteDatabase { } } -// MARK: - Collating sequences - -public extension SQLiteDatabase { - func addCollation( - named name: String, - comparator: @escaping @Sendable (String, String) -> ComparisonResult - ) throws { - let collation = DatabaseCollation( - name, - function: comparator - ) - try database - .writer - .barrierWriteWithoutTransaction { $0.add(collation: collation) } - } - - func removeCollation(named name: String) throws { - let collation = DatabaseCollation( - name, - function: { _, _ in .orderedSame } - ) - try database - .writer - .barrierWriteWithoutTransaction { $0.remove(collation: collation) } - } -} - // MARK: - Pragmas public extension SQLiteDatabase { @@ -748,7 +746,10 @@ extension SQLiteDatabase { private extension SQLiteDatabase { class func open( at path: String, - busyTimeout: TimeInterval + busyTimeout: TimeInterval, + collationSequences: [ + String: @Sendable (String, String) -> ComparisonResult + ] ) throws -> Database { let isInMemory: Bool = { let p = path.lowercased() @@ -763,6 +764,16 @@ private extension SQLiteDatabase { ProcessInfo.processInfo.processorCount, 6 ) + if !collationSequences.isEmpty { + config.prepareDatabase { db in + for (name, comparator) in collationSequences { + db.add(collation: DatabaseCollation( + name, + function: comparator + )) + } + } + } guard !isInMemory else { do { diff --git a/Tests/SQLiteTests/SQLiteDatabaseTests.swift b/Tests/SQLiteTests/SQLiteDatabaseTests.swift index 1137442..9ffdfac 100644 --- a/Tests/SQLiteTests/SQLiteDatabaseTests.swift +++ b/Tests/SQLiteTests/SQLiteDatabaseTests.swift @@ -94,112 +94,10 @@ final class SQLiteDatabaseTests: XCTestCase { } } - func testAddAndRemoveCollation() throws { - struct Entity: Hashable, SQLiteTransformable { - let id: String - let string: String? - - init(_ id: Int, _ string: String? = nil) { - self.id = String(id) - self.string = string - } - - init(row: SQLiteRow) throws { - id = try row.value(for: "id") - string = row.optionalValue(for: "string") - } - - var asArguments: SQLiteArguments { - [ - "id": .text(id), - "string": string.map { .text($0) } ?? .null, - ] - } - } - - let apple = Entity(1, "Apple") - let banana = Entity(2, "banana") - let zebra = Entity(3, "Zebra") - let null1 = Entity(4) - let null2 = Entity(5) - - try database.inTransaction { db in - try db.write(_createTableWithIDAsStringAndNullableString) - try [apple, banana, zebra, null1, null2] - .forEach { entity in - try db.write( - _insertIDAndString, - arguments: entity.asArguments - ) - } - } - - let selectDefaultSorted: SQL = """ - SELECT * FROM test ORDER BY string; - """ - - let selectCustomCaseSensitiveSorted: SQL = """ - SELECT * FROM test ORDER BY string COLLATE CUSTOM; - """ - - let selectCustomCaseInsensitiveSorted: SQL = """ - SELECT * FROM test ORDER BY string COLLATE CUSTOM_NOCASE; - """ - - let defaultSorted: [Entity] = try database.read(selectDefaultSorted) - XCTAssertEqual( - defaultSorted, - [null1, null2, apple, zebra, banana] - ) - - XCTAssertThrowsError( - try database.read(selectCustomCaseSensitiveSorted) - ) { error in - guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else { - XCTFail("Should have thrown SQLITE_ERROR") - return - } - } - - try database.addCollation(named: "CUSTOM") { $0.compare($1) } - let customSorted: [Entity] = try database.read(selectCustomCaseSensitiveSorted) - XCTAssertEqual( - customSorted, - [null1, null2, apple, zebra, banana] - ) - - try database.addCollation( - named: "CUSTOM_NOCASE" - ) { $0.caseInsensitiveCompare($1) } - - let customNoCaseSorted: [Entity] = try database - .read(selectCustomCaseInsensitiveSorted) - XCTAssertEqual( - customNoCaseSorted, - [null1, null2, apple, banana, zebra] - ) - - try database.removeCollation(named: "CUSTOM_NOCASE") - XCTAssertThrowsError( - try database.read(selectCustomCaseInsensitiveSorted) - ) { error in - guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else { - XCTFail("Should have thrown SQLITE_ERROR") - return - } - } - let customSortedAfterRemovingNoCase: [Entity] = try database - .read(selectCustomCaseSensitiveSorted) - XCTAssertEqual( - customSortedAfterRemovingNoCase, - [null1, null2, apple, zebra, banana] - ) - } - func testCustomLocalizedCollation() throws { - try database.addCollation(named: "LOCALIZED") { lhs, rhs in - lhs.localizedStandardCompare(rhs) - } + database = try SQLiteDatabase(collationSequences: [ + "CUSTOM_LOCALIZED": { $0.localizedStandardCompare($1) }, + ]) // NOTE: ([toInsert], [binary sort], [localized sort]) let cases: [([String], [String], [String])] = [ @@ -252,10 +150,16 @@ final class SQLiteDatabaseTests: XCTestCase { XCTAssertEqual(binarySorted, binarySort) let localizedSorted: [String] = try database - .read("SELECT * FROM test ORDER BY string COLLATE LOCALIZED;") + .read("SELECT * FROM test ORDER BY string COLLATE CUSTOM_LOCALIZED;") .compactMap { $0["string"]?.stringValue } XCTAssertEqual(localizedSorted, localizedSort) + let grdbName = SQLiteDatabase.localizedStandardCompare + let grdbStandardSorted: [String] = try database + .read("SELECT * FROM test ORDER BY string COLLATE \(grdbName);") + .compactMap { $0["string"]?.stringValue } + XCTAssertEqual(grdbStandardSorted, localizedSort) + try database.write("DROP TABLE test;") } } From 28cda0842d06632a1adb7fbea50b8009ed7cebc7 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Tue, 15 Apr 2025 15:14:44 +0200 Subject: [PATCH 2/2] Fix compilation --- Sources/SQLite/SQLiteDatabase.swift | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/Sources/SQLite/SQLiteDatabase.swift b/Sources/SQLite/SQLiteDatabase.swift index 38b3f8b..3d8c4dd 100644 --- a/Sources/SQLite/SQLiteDatabase.swift +++ b/Sources/SQLite/SQLiteDatabase.swift @@ -45,9 +45,7 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { public static func makeShared( path: String, busyTimeout: TimeInterval = 5, - collationSequences: [ - String: @Sendable (String, String) -> ComparisonResult - ] = [:] + collationSequences: [String: @Sendable (String, String) -> ComparisonResult] = [:] ) throws -> SQLiteDatabase { guard path != ":memory:" else { throw SQLiteError.SQLITE_IOERR @@ -97,9 +95,7 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable { public init( path: String = ":memory:", busyTimeout: TimeInterval = 5, - collationSequences: [ - String: @Sendable (String, String) -> ComparisonResult, - ] = [:] + collationSequences: [String: @Sendable (String, String) -> ComparisonResult] = [:] ) throws { database = try Self.open( at: path, @@ -747,9 +743,7 @@ private extension SQLiteDatabase { class func open( at path: String, busyTimeout: TimeInterval, - collationSequences: [ - String: @Sendable (String, String) -> ComparisonResult - ] + collationSequences: [String: @Sendable (String, String) -> ComparisonResult] ) throws -> Database { let isInMemory: Bool = { let p = path.lowercased()