From e8b194ff011a4d958c1861b527ec5e02cacd5c16 Mon Sep 17 00:00:00 2001 From: Peter Stace Date: Fri, 7 Nov 2025 13:22:59 +1100 Subject: [PATCH 1/5] Make rtree package generic over record type The `rtree` package previously hardcoded record IDs as integers, requiring users to maintain separate mappings between their records and integer IDs. This added unnecessary indirection and made the API less ergonomic. By making `RTree` generic over the record type T, users can now store their actual records directly in the tree (or use integers if they prefer to keep the existing behaviour). This simplifies usage patterns and eliminates the need for separate index mappings in consuming code. Note: This is a breaking change. The easiest way for users to upgrade is to simply replace all occurrences of `*rtree.RTree` with `*rtree.RTree[int]` in their codebase. --- .golangci.yaml | 4 +- CHANGELOG.md | 14 ++++++- geom/alg_distance.go | 20 +++++----- geom/alg_intersection.go | 2 +- geom/alg_intersects.go | 10 ++--- geom/alg_point_in_ring.go | 2 +- geom/dcel_ghosts.go | 6 +-- geom/dcel_re_noding.go | 4 +- geom/rtree.go | 20 +++++----- geom/type_line_string.go | 6 +-- geom/type_multi_line_string.go | 10 ++--- geom/type_multi_polygon.go | 6 +-- geom/type_polygon.go | 4 +- rtree/box.go | 2 +- rtree/bulk.go | 32 ++++++++-------- rtree/golden_internal_test.go | 4 +- rtree/nearest.go | 34 ++++++++--------- rtree/nearest_internal_test.go | 10 ++--- rtree/perf_internal_test.go | 6 +-- rtree/quick_partition_internal_test.go | 4 +- rtree/rtree.go | 51 ++++++++++++-------------- rtree/rtree_internal_test.go | 32 ++++++++-------- 22 files changed, 145 insertions(+), 138 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index b42ba769..25c0fd79 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -35,8 +35,6 @@ linters-settings: errcheck: exclude-functions: - io.Copy(os.Stdout) - - (*github.com/peterstace/simplefeatures/rtree.RTree).RangeSearch - - (*github.com/peterstace/simplefeatures/rtree.RTree).PrioritySearch # NOTE: every linter supported by golangci-lint is either explicitly included # or excluded. @@ -79,7 +77,6 @@ linters: - importas - ineffassign - intrange - - ireturn - loggercheck - makezero - mirror @@ -143,6 +140,7 @@ linters: - gomnd - inamedparam - interfacebloat + - ireturn - lll - maintidx - nestif diff --git a/CHANGELOG.md b/CHANGELOG.md index 6706428c..2cbc2956 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,20 @@ This includes function parameters, return types, struct fields, and type assertions. +- **Breaking change:** The `rtree` package types and functions are now generic + over the record type. The `RTree` type is now `RTree[T]`, `BulkItem` is now + `BulkItem[T]`, and `BulkLoad` is now `BulkLoad[T]`. The `RecordID int` field + in `BulkItem` has been renamed to `Record T`. This allows users to store + their records directly in the tree rather than maintaining separate mappings + between integer IDs and records. Users can upgrade by adding type parameters + to their rtree usage (e.g., `RTree[int]` to maintain existing behavior with + integer IDs, or use a custom type like `RTree[MyRecord]` to store records + directly). The `RecordID` field in `BulkItem` should be renamed to `Record`, + and callback function signatures should change from `func(recordID int)` to + `func(record T)` where `T` is the type parameter. + - **Breaking change:** The minimum required Go version is now 1.18 (previously - 1.17). This is required to support the `any` keyword. + 1.17). This is required to support the `any` keyword and generics. ## v0.55.0 diff --git a/geom/alg_distance.go b/geom/alg_distance.go index 95148352..5b3ea83a 100644 --- a/geom/alg_distance.go +++ b/geom/alg_distance.go @@ -73,7 +73,7 @@ func Distance(g1, g2 Geometry) (float64, bool) { } for _, xy := range xys1 { xyEnv := xy.uncheckedEnvelope() - tr.PrioritySearch(xy.box(), func(recordID int) error { + _ = tr.PrioritySearch(xy.box(), func(recordID int) error { return searchBody( xyEnv, recordID, @@ -84,7 +84,7 @@ func Distance(g1, g2 Geometry) (float64, bool) { } for _, ln := range lns1 { lnEnv := ln.uncheckedEnvelope() - tr.PrioritySearch(ln.box(), func(recordID int) error { + _ = tr.PrioritySearch(ln.box(), func(recordID int) error { return searchBody( lnEnv, recordID, @@ -132,18 +132,18 @@ func extractXYsAndLines(g Geometry) ([]XY, []line) { // uses positive record IDs to refer to the XYs, and negative recordIDs to // refer to the lines. Because +0 and -0 are the same, indexing is 1-based and // recordID 0 is not used. -func loadTree(xys []XY, lns []line) *rtree.RTree { - items := make([]rtree.BulkItem, len(xys)+len(lns)) +func loadTree(xys []XY, lns []line) *rtree.RTree[int] { + items := make([]rtree.BulkItem[int], len(xys)+len(lns)) for i, xy := range xys { - items[i] = rtree.BulkItem{ - Box: xy.box(), - RecordID: i + 1, + items[i] = rtree.BulkItem[int]{ + Box: xy.box(), + Record: i + 1, } } for i, ln := range lns { - items[i+len(xys)] = rtree.BulkItem{ - Box: ln.box(), - RecordID: -(i + 1), + items[i+len(xys)] = rtree.BulkItem[int]{ + Box: ln.box(), + Record: -(i + 1), } } return rtree.BulkLoad(items) diff --git a/geom/alg_intersection.go b/geom/alg_intersection.go index 23762c1a..89e56e5a 100644 --- a/geom/alg_intersection.go +++ b/geom/alg_intersection.go @@ -10,7 +10,7 @@ func intersectionOfIndexedLines( var pts []Point seen := make(map[XY]bool) for i := range lines1.lines { - lines2.tree.RangeSearch(lines1.lines[i].box(), func(j int) error { + _ = lines2.tree.RangeSearch(lines1.lines[i].box(), func(j int) error { inter := lines1.lines[i].intersectLine(lines2.lines[j]) if inter.empty { return nil diff --git a/geom/alg_intersects.go b/geom/alg_intersects.go index b541d781..94d80f4d 100644 --- a/geom/alg_intersects.go +++ b/geom/alg_intersects.go @@ -195,11 +195,11 @@ func hasIntersectionBetweenLines( lines1, lines2 = lines2, lines1 } - bulk := make([]rtree.BulkItem, len(lines1)) + bulk := make([]rtree.BulkItem[int], len(lines1)) for i, ln := range lines1 { - bulk[i] = rtree.BulkItem{ - Box: ln.box(), - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: ln.box(), + Record: i, } } tree := rtree.BulkLoad(bulk) @@ -209,7 +209,7 @@ func hasIntersectionBetweenLines( var env Envelope for _, lnA := range lines2 { - tree.RangeSearch(lnA.box(), func(i int) error { + _ = tree.RangeSearch(lnA.box(), func(i int) error { lnB := lines1[i] inter := lnA.intersectLine(lnB) if inter.empty { diff --git a/geom/alg_point_in_ring.go b/geom/alg_point_in_ring.go index e8b68dc4..3f03ac68 100644 --- a/geom/alg_point_in_ring.go +++ b/geom/alg_point_in_ring.go @@ -61,7 +61,7 @@ func relatePointToPolygon(pt XY, polyBoundary indexedLines) side { } var onBound bool var count int - polyBoundary.tree.RangeSearch(box, func(i int) error { + _ = polyBoundary.tree.RangeSearch(box, func(i int) error { ln := polyBoundary.lines[i] crossing, onLine := hasCrossing(pt, ln) if onLine { diff --git a/geom/dcel_ghosts.go b/geom/dcel_ghosts.go index d428d9d7..e6ae07f9 100644 --- a/geom/dcel_ghosts.go +++ b/geom/dcel_ghosts.go @@ -30,9 +30,9 @@ func spanningTree(xys []XY) MultiLineString { // Load points into r-tree. xys = sortAndUniquifyXYs(xys) - items := make([]rtree.BulkItem, len(xys)) + items := make([]rtree.BulkItem[int], len(xys)) for i, xy := range xys { - items[i] = rtree.BulkItem{Box: xy.box(), RecordID: i} + items[i] = rtree.BulkItem[int]{Box: xy.box(), Record: i} } tree := rtree.BulkLoad(items) @@ -49,7 +49,7 @@ func spanningTree(xys []XY) MultiLineString { // of being the closest to another point. continue } - tree.PrioritySearch(xyi.box(), func(j int) error { + _ = tree.PrioritySearch(xyi.box(), func(j int) error { // We don't want to include a new edge in the spanning tree if it // would cause a cycle (i.e. the two endpoints are already in the // same tree). This is checked via dset. diff --git a/geom/dcel_re_noding.go b/geom/dcel_re_noding.go index 71a103a1..943ac8bf 100644 --- a/geom/dcel_re_noding.go +++ b/geom/dcel_re_noding.go @@ -48,7 +48,7 @@ func reNodeGeometries(g1, g2 Geometry, mls MultiLineString) (Geometry, Geometry, // Create new nodes for point/line intersections. ptIndex := newIndexedPoints(nodes.list()) appendCutsForPointXLine := func(ln line, cuts []XY) []XY { - ptIndex.tree.RangeSearch(ln.box(), func(i int) error { + _ = ptIndex.tree.RangeSearch(ln.box(), func(i int) error { xy := ptIndex.points[i] if !ln.hasEndpoint(xy) && distBetweenXYAndLine(xy, ln) < ulp*0x200 { cuts = append(cuts, xy) @@ -64,7 +64,7 @@ func reNodeGeometries(g1, g2 Geometry, mls MultiLineString) (Geometry, Geometry, // Create new nodes for line/line intersections. lnIndex := newIndexedLines(appendLines(nil, all())) appendCutsLineXLine := func(ln line, cuts []XY) []XY { - lnIndex.tree.RangeSearch(ln.box(), func(i int) error { + _ = lnIndex.tree.RangeSearch(ln.box(), func(i int) error { other := lnIndex.lines[i] // TODO: This is a hacky approach (re-orders inputs, rather than diff --git a/geom/rtree.go b/geom/rtree.go index eb4f9d60..fbdad7f8 100644 --- a/geom/rtree.go +++ b/geom/rtree.go @@ -7,15 +7,15 @@ import "github.com/peterstace/simplefeatures/rtree" // the indices of the lines slice. type indexedLines struct { lines []line - tree *rtree.RTree + tree *rtree.RTree[int] } func newIndexedLines(lines []line) indexedLines { - bulk := make([]rtree.BulkItem, len(lines)) + bulk := make([]rtree.BulkItem[int], len(lines)) for i, ln := range lines { - bulk[i] = rtree.BulkItem{ - Box: ln.box(), - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: ln.box(), + Record: i, } } return indexedLines{lines, rtree.BulkLoad(bulk)} @@ -26,15 +26,15 @@ func newIndexedLines(lines []line) indexedLines { // the indices of the points slice. type indexedPoints struct { points []XY - tree *rtree.RTree + tree *rtree.RTree[int] } func newIndexedPoints(points []XY) indexedPoints { - bulk := make([]rtree.BulkItem, len(points)) + bulk := make([]rtree.BulkItem[int], len(points)) for i, pt := range points { - bulk[i] = rtree.BulkItem{ - Box: rtree.Box{MinX: pt.X, MaxX: pt.X, MinY: pt.Y, MaxY: pt.Y}, - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: rtree.Box{MinX: pt.X, MaxX: pt.X, MinY: pt.Y, MaxY: pt.Y}, + Record: i, } } return indexedPoints{points, rtree.BulkLoad(bulk)} diff --git a/geom/type_line_string.go b/geom/type_line_string.go index 72814448..f4ac322e 100644 --- a/geom/type_line_string.go +++ b/geom/type_line_string.go @@ -116,13 +116,13 @@ func (s LineString) IsSimple() bool { } n := s.seq.Length() - items := make([]rtree.BulkItem, 0, n-1) + items := make([]rtree.BulkItem[int], 0, n-1) for i := 0; i < n; i++ { ln, ok := getLine(s.seq, i) if !ok { continue } - items = append(items, rtree.BulkItem{Box: ln.box(), RecordID: i}) + items = append(items, rtree.BulkItem[int]{Box: ln.box(), Record: i}) } tree := rtree.BulkLoad(items) @@ -142,7 +142,7 @@ func (s LineString) IsSimple() bool { } simple := true // assume simple until proven otherwise - tree.RangeSearch(ln.box(), func(j int) error { + _ = tree.RangeSearch(ln.box(), func(j int) error { // Skip finding the original line (i == j) or cases where we have // already checked that pair (i > j). if i >= j { diff --git a/geom/type_multi_line_string.go b/geom/type_multi_line_string.go index 89caa230..2ec952d5 100644 --- a/geom/type_multi_line_string.go +++ b/geom/type_multi_line_string.go @@ -125,7 +125,7 @@ func (m MultiLineString) IsSimple() bool { for _, ls := range m.lines { numItems += maxInt(0, ls.Coordinates().Length()-1) } - items := make([]rtree.BulkItem, 0, numItems) + items := make([]rtree.BulkItem[int], 0, numItems) for i, ls := range m.lines { seq := ls.Coordinates() seqLen := seq.Length() @@ -134,9 +134,9 @@ func (m MultiLineString) IsSimple() bool { if !ok { continue } - items = append(items, rtree.BulkItem{ - Box: ln.box(), - RecordID: toRecordID(i, j), + items = append(items, rtree.BulkItem[int]{ + Box: ln.box(), + Record: toRecordID(i, j), }) } } @@ -151,7 +151,7 @@ func (m MultiLineString) IsSimple() bool { continue } isSimple := true // assume simple until proven otherwise - tree.RangeSearch(ln.box(), func(recordID int) error { + _ = tree.RangeSearch(ln.box(), func(recordID int) error { // Ignore the intersection if it's for the same LineString that we're currently looking up. lineStringIdx, seqIdx := fromRecordID(recordID) if lineStringIdx == i { diff --git a/geom/type_multi_polygon.go b/geom/type_multi_polygon.go index 235770c4..9b821584 100644 --- a/geom/type_multi_polygon.go +++ b/geom/type_multi_polygon.go @@ -62,11 +62,11 @@ func (m MultiPolygon) checkMultiPolygonConstraints() error { // Construct RTree of Polygons. boxes := make([]rtree.Box, len(m.polys)) - items := make([]rtree.BulkItem, 0, len(m.polys)) + items := make([]rtree.BulkItem[int], 0, len(m.polys)) for i, p := range m.polys { if box, ok := p.Envelope().AsBox(); ok { boxes[i] = box - item := rtree.BulkItem{Box: boxes[i], RecordID: i} + item := rtree.BulkItem[int]{Box: boxes[i], Record: i} items = append(items, item) } } @@ -142,7 +142,7 @@ func validatePolyNotInsidePoly(p1, p2 indexedLines) error { for j := range p2.lines { // Find intersection points. var pts []XY - p1.tree.RangeSearch(p2.lines[j].box(), func(i int) error { + _ = p1.tree.RangeSearch(p2.lines[j].box(), func(i int) error { inter := p1.lines[i].intersectLine(p2.lines[j]) if inter.empty { return nil diff --git a/geom/type_polygon.go b/geom/type_polygon.go index 8352d5a9..75ebc36e 100644 --- a/geom/type_polygon.go +++ b/geom/type_polygon.go @@ -65,7 +65,7 @@ func (p Polygon) Validate() error { // Construct RTree of rings. boxes := make([]rtree.Box, len(p.rings)) - items := make([]rtree.BulkItem, len(p.rings)) + items := make([]rtree.BulkItem[int], len(p.rings)) for i, r := range p.rings { box, ok := r.Envelope().AsBox() if !ok { @@ -74,7 +74,7 @@ func (p Polygon) Validate() error { panic("unexpected empty ring") } boxes[i] = box - items[i] = rtree.BulkItem{Box: boxes[i], RecordID: i} + items[i] = rtree.BulkItem[int]{Box: boxes[i], Record: i} } tree := rtree.BulkLoad(items) diff --git a/rtree/box.go b/rtree/box.go index 704be2d3..e6f3461f 100644 --- a/rtree/box.go +++ b/rtree/box.go @@ -6,7 +6,7 @@ type Box struct { } // calculateBound calculates the smallest bounding box that fits a node. -func calculateBound(n *node) Box { +func calculateBound[T any](n *node[T]) Box { box := n.entries[0].box for i := 1; i < n.numEntries; i++ { box = combine(box, n.entries[i].box) diff --git a/rtree/bulk.go b/rtree/bulk.go index 18482716..425d4b3c 100644 --- a/rtree/bulk.go +++ b/rtree/bulk.go @@ -1,23 +1,23 @@ package rtree // BulkItem is an item that can be inserted for bulk loading. -type BulkItem struct { - Box Box - RecordID int +type BulkItem[T any] struct { + Box Box + Record T } // BulkLoad bulk loads multiple items into a new R-Tree. The bulk load // operation is optimised for creating R-Trees with minimal node overlap. This // allows for fast searching. -func BulkLoad(items []BulkItem) *RTree { +func BulkLoad[T any](items []BulkItem[T]) *RTree[T] { if len(items) == 0 { - return &RTree{} + return &RTree[T]{} } root := bulkInsert(items) - return &RTree{root, len(items)} + return &RTree[T]{root, len(items)} } -func bulkInsert(items []BulkItem) *node { +func bulkInsert[T any](items []BulkItem[T]) *node[T] { if len(items) == 0 { panic("should not have recursed into bulkInsert without any items") } @@ -27,11 +27,11 @@ func bulkInsert(items []BulkItem) *node { // 4 or fewer items can fit into a single node. if len(items) <= 4 { - n := &node{numEntries: len(items)} + n := &node[T]{numEntries: len(items)} for i, item := range items { - n.entries[i] = entry{ - box: item.Box, - recordID: item.RecordID, + n.entries[i] = entry[T]{ + box: item.Box, + record: item.Record, } } return n @@ -52,8 +52,8 @@ func bulkInsert(items []BulkItem) *node { return bulkNode(firstQuarter, secondQuarter, thirdQuarter, fourthQuarter) } -func bulkNode(parts ...[]BulkItem) *node { - root := &node{numEntries: len(parts)} +func bulkNode[T any](parts ...[]BulkItem[T]) *node[T] { + root := &node[T]{numEntries: len(parts)} for i, part := range parts { child := bulkInsert(part) root.entries[i].child = child @@ -62,7 +62,7 @@ func bulkNode(parts ...[]BulkItem) *node { return root } -func splitBulkItems2Ways(items []BulkItem) ([]BulkItem, []BulkItem) { +func splitBulkItems2Ways[T any](items []BulkItem[T]) ([]BulkItem[T], []BulkItem[T]) { horizontal := itemsAreHorizontal(items) split := len(items) / 2 quickPartition(items, split, horizontal) @@ -72,7 +72,7 @@ func splitBulkItems2Ways(items []BulkItem) ([]BulkItem, []BulkItem) { // quickPartition performs a partial in-place sort on the items slice. The // partial sort is such that items 0 through k-1 are less than or equal to item // k, and items k+1 through n-1 are greater than or equal to item k. -func quickPartition(items []BulkItem, k int, horizontal bool) { +func quickPartition[T any](items []BulkItem[T], k int, horizontal bool) { // Use a custom linear congruential random number generator. This is used // because we don't need high quality random numbers. Using a regular // rand.Rand generator causes a significant bottleneck due to the reliance @@ -150,7 +150,7 @@ func quickPartition(items []BulkItem, k int, horizontal bool) { } } -func itemsAreHorizontal(items []BulkItem) bool { +func itemsAreHorizontal[T any](items []BulkItem[T]) bool { box := items[0].Box for _, item := range items[1:] { box = combine(box, item.Box) diff --git a/rtree/golden_internal_test.go b/rtree/golden_internal_test.go index 74bec92e..9ffcb90c 100644 --- a/rtree/golden_internal_test.go +++ b/rtree/golden_internal_test.go @@ -135,12 +135,12 @@ func TestBulkLoadGolden(t *testing.T) { } } -func checksum(n *node) uint64 { +func checksum(n *node[int]) uint64 { var entries []string for i := 0; i < n.numEntries; i++ { var entry string if n.entries[i].child == nil { - entry = strconv.Itoa(n.entries[i].recordID) + entry = strconv.Itoa(n.entries[i].record) } else { entry = strconv.FormatUint(checksum(n.entries[i].child), 10) } diff --git a/rtree/nearest.go b/rtree/nearest.go index 39b27038..b4bf661c 100644 --- a/rtree/nearest.go +++ b/rtree/nearest.go @@ -9,13 +9,13 @@ import ( // as measured by the Euclidean metric. Note that there may be multiple records // that are equidistant from the input box, in which case one is chosen // arbitrarily. If the RTree is empty, then false is returned. -func (t *RTree) Nearest(box Box) (recordID int, found bool) { - t.PrioritySearch(box, func(rid int) error { - recordID = rid +func (t *RTree[T]) Nearest(box Box) (record T, found bool) { + _ = t.PrioritySearch(box, func(rec T) error { + record = rec found = true return Stop }) - return recordID, found + return record, found } // PrioritySearch iterates over the records in the RTree in priority order of @@ -25,13 +25,13 @@ func (t *RTree) Nearest(box Box) (recordID int, found bool) { // error returned from the callback is returned by PrioritySearch, except for // the case where the special Stop sentinel error is returned (in which case // nil will be returned from PrioritySearch). Stop may be wrapped. -func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error { +func (t *RTree[T]) PrioritySearch(box Box, callback func(record T) error) error { if t.root == nil { return nil } - queue := entriesQueue{origin: box} - equeueNode := func(n *node) { + queue := entriesQueue[T]{origin: box} + equeueNode := func(n *node[T]) { for i := 0; i < n.numEntries; i++ { heap.Push(&queue, &n.entries[i]) } @@ -39,9 +39,9 @@ func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error equeueNode(t.root) for len(queue.entries) > 0 { - nearest := heap.Pop(&queue).(*entry) + nearest := heap.Pop(&queue).(*entry[T]) if nearest.child == nil { - if err := callback(nearest.recordID); err != nil { + if err := callback(nearest.record); err != nil { if errors.Is(err, Stop) { return nil } @@ -54,30 +54,30 @@ func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error return nil } -type entriesQueue struct { - entries []*entry +type entriesQueue[T any] struct { + entries []*entry[T] origin Box } -func (q *entriesQueue) Len() int { +func (q *entriesQueue[T]) Len() int { return len(q.entries) } -func (q *entriesQueue) Less(i int, j int) bool { +func (q *entriesQueue[T]) Less(i int, j int) bool { d1 := squaredEuclideanDistance(q.entries[i].box, q.origin) d2 := squaredEuclideanDistance(q.entries[j].box, q.origin) return d1 < d2 } -func (q *entriesQueue) Swap(i int, j int) { +func (q *entriesQueue[T]) Swap(i int, j int) { q.entries[i], q.entries[j] = q.entries[j], q.entries[i] } -func (q *entriesQueue) Push(x any) { - q.entries = append(q.entries, x.(*entry)) +func (q *entriesQueue[T]) Push(x interface{}) { + q.entries = append(q.entries, x.(*entry[T])) } -func (q *entriesQueue) Pop() any { +func (q *entriesQueue[T]) Pop() interface{} { e := q.entries[len(q.entries)-1] q.entries = q.entries[:len(q.entries)-1] return e diff --git a/rtree/nearest_internal_test.go b/rtree/nearest_internal_test.go index e518d143..e95aeac0 100644 --- a/rtree/nearest_internal_test.go +++ b/rtree/nearest_internal_test.go @@ -21,7 +21,7 @@ func TestNearest(t *testing.T) { } } -func checkNearest(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkNearest(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { originBB := randomBox(rnd, 0.9, 0.1) @@ -47,13 +47,13 @@ func checkNearest(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { } } -func checkPrioritySearch(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkPrioritySearch(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { var got []int originBB := randomBox(rnd, 0.9, 0.1) t.Logf("origin: %v", originBB) - rt.PrioritySearch(originBB, func(recordID int) error { + _ = rt.PrioritySearch(originBB, func(recordID int) error { got = append(got, recordID) return nil }) @@ -79,10 +79,10 @@ func TestPrioritySearchEarlyStop(t *testing.T) { boxes[i] = randomBox(rnd, 0.9, 0.1) } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } rt := BulkLoad(inserts) origin := randomBox(rnd, 0.9, 0.1) diff --git a/rtree/perf_internal_test.go b/rtree/perf_internal_test.go index e9f6bf36..664c5b0e 100644 --- a/rtree/perf_internal_test.go +++ b/rtree/perf_internal_test.go @@ -13,10 +13,10 @@ func BenchmarkBulk(b *testing.B) { for i := range boxes { boxes[i] = randomBox(rnd, 0.9, 0.1) } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } b.Run(fmt.Sprintf("n=%d", pop), func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -33,7 +33,7 @@ func BenchmarkRangeSearch(b *testing.B) { tree, _ := testBulkLoad(rnd, pop) b.ResetTimer() for i := 0; i < b.N; i++ { - tree.RangeSearch(Box{0.5, 0.5, 0.5, 0.5}, func(int) error { return nil }) + _ = tree.RangeSearch(Box{0.5, 0.5, 0.5, 0.5}, func(int) error { return nil }) } }) } diff --git a/rtree/quick_partition_internal_test.go b/rtree/quick_partition_internal_test.go index 625065a8..50116ea7 100644 --- a/rtree/quick_partition_internal_test.go +++ b/rtree/quick_partition_internal_test.go @@ -45,10 +45,10 @@ func TestQuickPartition(t *testing.T) { t.Run(strconv.Itoa(i), func(t *testing.T) { for k := range tc { t.Run(fmt.Sprintf("k=%d", k), func(t *testing.T) { - items := make([]BulkItem, 0, len(tc)) + items := make([]BulkItem[int], 0, len(tc)) for _, num := range tc { f := float64(num) - items = append(items, BulkItem{ + items = append(items, BulkItem[int]{ Box{f, f, f, f}, len(items), }) diff --git a/rtree/rtree.go b/rtree/rtree.go index d1fc92a9..48af0247 100644 --- a/rtree/rtree.go +++ b/rtree/rtree.go @@ -9,29 +9,26 @@ const ( maxEntries = 4 ) -// node is a node in an R-Tree, holding user record IDs and/or links to deeper +// node is a node in an R-Tree, holding user records and/or links to deeper // nodes in the tree. -type node struct { - entries [maxEntries]entry +type node[T any] struct { + entries [maxEntries]entry[T] numEntries int } // entry is an entry contained inside a node. An entry can either hold a user -// record ID, or point to a deeper node in the tree (but not both). Because 0 -// is a valid record ID, the child pointer should be used to distinguish -// between the two types of entries. -type entry struct { - box Box - child *node - recordID int +// record, or point to a deeper node in the tree (but not both). The child +// pointer should be used to distinguish between the two types of entries. +type entry[T any] struct { + box Box + child *node[T] + record T } -// RTree is an in-memory R-Tree data structure. It holds record ID and bounding -// box pairs (the actual records aren't stored in the tree; the user is -// responsible for storing their own records). Its zero value is an empty -// R-Tree. -type RTree struct { - root *node +// RTree is an in-memory R-Tree data structure. It holds records of type T +// along with their bounding boxes. Its zero value is an empty R-Tree. +type RTree[T any] struct { + root *node[T] count int } @@ -40,24 +37,24 @@ type RTree struct { var Stop = errors.New("stop") //nolint:stylecheck,revive // RangeSearch looks for any items in the tree that overlap with the given -// bounding box. The callback is called with the record ID for each found item. -// If an error is returned from the callback then the search is terminated -// early. Any error returned from the callback is returned by RangeSearch, -// except for the case where the special Stop sentinel error is returned (in -// which case nil will be returned from RangeSearch). Stop may be wrapped. -func (t *RTree) RangeSearch(box Box, callback func(recordID int) error) error { +// bounding box. The callback is called with each found item's record. If an +// error is returned from the callback then the search is terminated early. +// Any error returned from the callback is returned by RangeSearch, except for +// the case where the special Stop sentinel error is returned (in which case +// nil will be returned from RangeSearch). Stop may be wrapped. +func (t *RTree[T]) RangeSearch(box Box, callback func(record T) error) error { if t.root == nil { return nil } - var recurse func(*node) error - recurse = func(n *node) error { + var recurse func(*node[T]) error + recurse = func(n *node[T]) error { for i := 0; i < n.numEntries; i++ { entry := n.entries[i] if !overlap(entry.box, box) { continue } if entry.child == nil { - if err := callback(entry.recordID); errors.Is(err, Stop) { + if err := callback(entry.record); errors.Is(err, Stop) { return nil } else if err != nil { return err @@ -75,7 +72,7 @@ func (t *RTree) RangeSearch(box Box, callback func(recordID int) error) error { // Extent gives the Box that most closely bounds the RTree. If the RTree is // empty, then false is returned. -func (t *RTree) Extent() (Box, bool) { +func (t *RTree[T]) Extent() (Box, bool) { if t.root == nil || t.root.numEntries == 0 { return Box{}, false } @@ -83,6 +80,6 @@ func (t *RTree) Extent() (Box, bool) { } // Count gives the number of entries in the RTree. -func (t *RTree) Count() int { +func (t *RTree[T]) Count() int { return t.count } diff --git a/rtree/rtree_internal_test.go b/rtree/rtree_internal_test.go index 5b39b1bb..1afb2f92 100644 --- a/rtree/rtree_internal_test.go +++ b/rtree/rtree_internal_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func testBulkLoad(rnd *rand.Rand, pop int) (*RTree, []Box) { +func testBulkLoad(rnd *rand.Rand, pop int) (*RTree[int], []Box) { boxes := make([]Box, pop) seenX := make(map[float64]bool) seenY := make(map[float64]bool) @@ -27,10 +27,10 @@ func testBulkLoad(rnd *rand.Rand, pop int) (*RTree, []Box) { } boxes[i] = box } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } return BulkLoad(inserts), boxes } @@ -57,12 +57,12 @@ func TestRandom(t *testing.T) { } } -func checkSearch(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkSearch(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { searchBB := randomBox(rnd, 0.5, 0.5) var got []int - rt.RangeSearch(searchBB, func(idx int) error { + _ = rt.RangeSearch(searchBB, func(idx int) error { got = append(got, idx) return nil }) @@ -99,16 +99,16 @@ func randomBox(rnd *rand.Rand, maxStart, maxWidth float64) Box { return box } -func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { +func checkInvariants(t *testing.T, rt *RTree[int], boxes []Box) { t.Helper() - var recurse func(*node, string) - recurse = func(current *node, indent string) { + var recurse func(*node[int], string) + recurse = func(current *node[int], indent string) { t.Logf("%sNode addr=%p numEntries=%d", indent, current, current.numEntries) indent += "\t" for i := 0; i < current.numEntries; i++ { e := current.entries[i] if e.child == nil { - t.Logf("%sEntry[%d] recordID=%d box=%v", indent, i, e.recordID, e.box) + t.Logf("%sEntry[%d] recordID=%d box=%v", indent, i, e.record, e.box) } else { t.Logf("%sEntry[%d] box=%v", indent, i, e.box) recurse(e.child, indent+"\t") @@ -134,20 +134,20 @@ func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { minLeafLevel := math.MaxInt maxLeafLevel := math.MinInt - var check func(n *node, level int) - check = func(current *node, level int) { + var check func(n *node[int], level int) + check = func(current *node[int], level int) { for i := 0; i < current.numEntries; i++ { e := current.entries[i] if e.child == nil { minLeafLevel = minInt(minLeafLevel, level) maxLeafLevel = maxInt(maxLeafLevel, level) - if _, ok := unfound[e.recordID]; !ok { + if _, ok := unfound[e.record]; !ok { t.Fatal("record ID found in tree but wasn't in unfound map") } - delete(unfound, e.recordID) + delete(unfound, e.record) } else { - if e.recordID != 0 { - t.Fatal("non-leaf has recordID") + if e.record != 0 { + t.Fatal("non-leaf has record") } box := e.child.entries[0].box for j := 1; j < e.child.numEntries; j++ { @@ -161,7 +161,7 @@ func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { } for i := current.numEntries; i < len(current.entries); i++ { e := current.entries[i] - if e != (entry{}) { + if e != (entry[int]{}) { t.Fatal("entry past numEntries is not the zero value") } } From 5fb24077eb5503befc45e04a6dec2e0460b73778 Mon Sep 17 00:00:00 2001 From: Peter Stace Date: Fri, 21 Nov 2025 10:24:59 +1100 Subject: [PATCH 2/5] Fix logical merge conflict in CHANGELOG.md --- CHANGELOG.md | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 716ec35c..c0e5ca0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## Unreleased + +- **Breaking change:** The `rtree` package types and functions are now generic + over the record type. The `RTree` type is now `RTree[T]`, `BulkItem` is now + `BulkItem[T]`, and `BulkLoad` is now `BulkLoad[T]`. The `RecordID int` field + in `BulkItem` has been renamed to `Record T`. This allows users to store + their records directly in the tree rather than maintaining separate mappings + between integer IDs and records. Users can upgrade by adding type parameters + to their rtree usage (e.g., `RTree[int]` to maintain existing behavior with + integer IDs, or use a custom type like `RTree[MyRecord]` to store records + directly). The `RecordID` field in `BulkItem` should be renamed to `Record`, + and callback function signatures should change from `func(recordID int)` to + `func(record T)` where `T` is the type parameter. + ## v0.56.0 2025-11-21 @@ -14,18 +28,6 @@ This includes function parameters, return types, struct fields, and type assertions. -- **Breaking change:** The `rtree` package types and functions are now generic - over the record type. The `RTree` type is now `RTree[T]`, `BulkItem` is now - `BulkItem[T]`, and `BulkLoad` is now `BulkLoad[T]`. The `RecordID int` field - in `BulkItem` has been renamed to `Record T`. This allows users to store - their records directly in the tree rather than maintaining separate mappings - between integer IDs and records. Users can upgrade by adding type parameters - to their rtree usage (e.g., `RTree[int]` to maintain existing behavior with - integer IDs, or use a custom type like `RTree[MyRecord]` to store records - directly). The `RecordID` field in `BulkItem` should be renamed to `Record`, - and callback function signatures should change from `func(recordID int)` to - `func(record T)` where `T` is the type parameter. - - **Breaking change:** The minimum required Go version is now 1.18 (previously 1.17). This is required to support the `any` keyword and generics. From 65541f2e261cb6284edc1aab8e3acd6d2b895005 Mon Sep 17 00:00:00 2001 From: Peter Stace Date: Fri, 21 Nov 2025 10:49:46 +1100 Subject: [PATCH 3/5] Use any instead of interface{} --- geom/type_null_geometry_test.go | 2 +- rtree/nearest.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/geom/type_null_geometry_test.go b/geom/type_null_geometry_test.go index a6d444f7..ca91fe6d 100644 --- a/geom/type_null_geometry_test.go +++ b/geom/type_null_geometry_test.go @@ -13,7 +13,7 @@ func TestNullGeometryScan(t *testing.T) { for _, tc := range []struct { description string - value interface{} + value any wantValid bool wantWKT string }{ diff --git a/rtree/nearest.go b/rtree/nearest.go index b4bf661c..e4dc8e9b 100644 --- a/rtree/nearest.go +++ b/rtree/nearest.go @@ -73,11 +73,11 @@ func (q *entriesQueue[T]) Swap(i int, j int) { q.entries[i], q.entries[j] = q.entries[j], q.entries[i] } -func (q *entriesQueue[T]) Push(x interface{}) { +func (q *entriesQueue[T]) Push(x any) { q.entries = append(q.entries, x.(*entry[T])) } -func (q *entriesQueue[T]) Pop() interface{} { +func (q *entriesQueue[T]) Pop() any { e := q.entries[len(q.entries)-1] q.entries = q.entries[:len(q.entries)-1] return e From e4801d6c826e461d3f089a6236c5eeefbfeb7254 Mon Sep 17 00:00:00 2001 From: Peter Stace Date: Fri, 21 Nov 2025 14:08:14 +1100 Subject: [PATCH 4/5] Use generic RTree (WIP) --- geom/alg_distance.go | 103 ++++++++++++++------------------- geom/rtree.go | 24 ++++++++ geom/type_multi_line_string.go | 29 ++++------ 3 files changed, 78 insertions(+), 78 deletions(-) diff --git a/geom/alg_distance.go b/geom/alg_distance.go index 5b3ea83a..6add7e93 100644 --- a/geom/alg_distance.go +++ b/geom/alg_distance.go @@ -37,60 +37,42 @@ func Distance(g1, g2 Geometry) (float64, bool) { lns1, lns2 = lns2, lns1 } - tr := loadTree(xys2, lns2) + xyTree := loadXYTree(xys2) + lnTree := loadLineTree(lns2) minDist := math.Inf(+1) - searchBody := func( - env Envelope, - recordID int, - xyDist func(int) float64, - lnDist func(int) float64, - ) error { - // Convert recordID back to array indexes. - xyIdx := recordID - 1 - lnIdx := -recordID - 1 - - // Abort the search if we're gone further away compared to our best - // distance so far. - var recordEnv Envelope - if recordID > 0 { - recordEnv = xys2[xyIdx].uncheckedEnvelope() - } else { - recordEnv = lns2[lnIdx].uncheckedEnvelope() - } - if d, ok := recordEnv.Distance(env); ok && d > minDist { - return rtree.Stop - } - - // See if the current item in the tree is better than our current best - // distance. - if recordID > 0 { - minDist = fastMin(minDist, xyDist(xyIdx)) - } else { - minDist = fastMin(minDist, lnDist(lnIdx)) - } - return nil - } - for _, xy := range xys1 { - xyEnv := xy.uncheckedEnvelope() - _ = tr.PrioritySearch(xy.box(), func(recordID int) error { - return searchBody( - xyEnv, - recordID, - func(i int) float64 { return distBetweenXYs(xy, xys2[i]) }, - func(i int) float64 { return distBetweenXYAndLine(xy, lns2[i]) }, - ) + for _, xy1 := range xys1 { + xy1Env := xy1.uncheckedEnvelope() + _ = xyTree.PrioritySearch(xy1.box(), func(xy2 XY) error { + if d, ok := xy2.uncheckedEnvelope().Distance(xy1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYs(xy1, xy2)) + return nil + }) + _ = lnTree.PrioritySearch(xy1.box(), func(ln2 line) error { + if d, ok := ln2.uncheckedEnvelope().Distance(xy1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYAndLine(xy1, ln2)) + return nil }) } - for _, ln := range lns1 { - lnEnv := ln.uncheckedEnvelope() - _ = tr.PrioritySearch(ln.box(), func(recordID int) error { - return searchBody( - lnEnv, - recordID, - func(i int) float64 { return distBetweenXYAndLine(xys2[i], ln) }, - func(i int) float64 { return distBetweenLineAndLine(lns2[i], ln) }, - ) + for _, ln1 := range lns1 { + ln1Env := ln1.uncheckedEnvelope() + _ = xyTree.PrioritySearch(ln1.box(), func(xy2 XY) error { + if d, ok := xy2.uncheckedEnvelope().Distance(ln1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYAndLine(xy2, ln1)) + return nil + }) + _ = lnTree.PrioritySearch(ln1.box(), func(ln2 line) error { + if d, ok := ln2.uncheckedEnvelope().Distance(ln1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenLineAndLine(ln1, ln2)) + return nil }) } @@ -128,22 +110,23 @@ func extractXYsAndLines(g Geometry) ([]XY, []line) { } } -// loadTree creates a new RTree that indexes both the XYs and the lines. It -// uses positive record IDs to refer to the XYs, and negative recordIDs to -// refer to the lines. Because +0 and -0 are the same, indexing is 1-based and -// recordID 0 is not used. -func loadTree(xys []XY, lns []line) *rtree.RTree[int] { - items := make([]rtree.BulkItem[int], len(xys)+len(lns)) +func loadXYTree(xys []XY) *rtree.RTree[XY] { + items := make([]rtree.BulkItem[XY], len(xys)) for i, xy := range xys { - items[i] = rtree.BulkItem[int]{ + items[i] = rtree.BulkItem[XY]{ Box: xy.box(), - Record: i + 1, + Record: xy, } } + return rtree.BulkLoad(items) +} + +func loadLineTree(lns []line) *rtree.RTree[line] { + items := make([]rtree.BulkItem[line], len(lns)) for i, ln := range lns { - items[i+len(xys)] = rtree.BulkItem[int]{ + items[i] = rtree.BulkItem[line]{ Box: ln.box(), - Record: -(i + 1), + Record: ln, } } return rtree.BulkLoad(items) diff --git a/geom/rtree.go b/geom/rtree.go index fbdad7f8..568c23fd 100644 --- a/geom/rtree.go +++ b/geom/rtree.go @@ -2,6 +2,30 @@ package geom import "github.com/peterstace/simplefeatures/rtree" +// TODO: Use this instead of indexedLines/Points where possible. +func newLineRTree(lines []line) *rtree.RTree[line] { + items := make([]rtree.BulkItem[line], len(lines)) + for i, ln := range lines { + items[i] = rtree.BulkItem[line]{ + Box: ln.box(), + Record: ln, + } + } + return rtree.BulkLoad(items) +} + +// TODO: Use this instead of indexedLines/Points where possible. +func newPointRTree(points []XY) *rtree.RTree[XY] { + items := make([]rtree.BulkItem[XY], len(points)) + for i, pt := range points { + items[i] = rtree.BulkItem[XY]{ + Box: pt.box(), + Record: pt, + } + } + return rtree.BulkLoad(items) +} + // indexedLines is a simple container to hold a list of lines, and a r-tree // structure indexing those lines. The record IDs in the rtree correspond to // the indices of the lines slice. diff --git a/geom/type_multi_line_string.go b/geom/type_multi_line_string.go index 2ec952d5..3b3ded7b 100644 --- a/geom/type_multi_line_string.go +++ b/geom/type_multi_line_string.go @@ -110,22 +110,16 @@ func (m MultiLineString) IsSimple() bool { } } - // Map between record ID in the rtree and a particular line segment: - toRecordID := func(lineStringIdx, seqIdx int) int { - return int(uint64(lineStringIdx)<<32 | uint64(seqIdx)) - } - fromRecordID := func(recordID int) (lineStringIdx, seqIdx int) { - lineStringIdx = int(uint64(recordID) >> 32) - seqIdx = int((uint64(recordID) << 32) >> 32) - return - } - // Create an RTree containing all line segments. + type record struct { + lineStringIdx int + seqIdx int + } var numItems int for _, ls := range m.lines { numItems += maxInt(0, ls.Coordinates().Length()-1) } - items := make([]rtree.BulkItem[int], 0, numItems) + items := make([]rtree.BulkItem[record], 0, numItems) for i, ls := range m.lines { seq := ls.Coordinates() seqLen := seq.Length() @@ -134,9 +128,9 @@ func (m MultiLineString) IsSimple() bool { if !ok { continue } - items = append(items, rtree.BulkItem[int]{ + items = append(items, rtree.BulkItem[record]{ Box: ln.box(), - Record: toRecordID(i, j), + Record: record{lineStringIdx: i, seqIdx: j}, }) } } @@ -151,15 +145,14 @@ func (m MultiLineString) IsSimple() bool { continue } isSimple := true // assume simple until proven otherwise - _ = tree.RangeSearch(ln.box(), func(recordID int) error { + _ = tree.RangeSearch(ln.box(), func(rec record) error { // Ignore the intersection if it's for the same LineString that we're currently looking up. - lineStringIdx, seqIdx := fromRecordID(recordID) - if lineStringIdx == i { + if rec.lineStringIdx == i { return nil } - otherLS := m.lines[lineStringIdx] - other, ok := getLine(otherLS.Coordinates(), seqIdx) + otherLS := m.lines[rec.lineStringIdx] + other, ok := getLine(otherLS.Coordinates(), rec.seqIdx) if !ok { // Shouldn't even happen, since we were able to insert this // entry into the RTree. From e55fa7c380204d73ad89a1010a406c96d70bc56d Mon Sep 17 00:00:00 2001 From: Peter Stace Date: Fri, 21 Nov 2025 14:12:43 +1100 Subject: [PATCH 5/5] Add nolint directives --- geom/rtree.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/geom/rtree.go b/geom/rtree.go index 568c23fd..5080afd9 100644 --- a/geom/rtree.go +++ b/geom/rtree.go @@ -3,7 +3,7 @@ package geom import "github.com/peterstace/simplefeatures/rtree" // TODO: Use this instead of indexedLines/Points where possible. -func newLineRTree(lines []line) *rtree.RTree[line] { +func newLineRTree(lines []line) *rtree.RTree[line] { //nolint:unused items := make([]rtree.BulkItem[line], len(lines)) for i, ln := range lines { items[i] = rtree.BulkItem[line]{ @@ -15,7 +15,7 @@ func newLineRTree(lines []line) *rtree.RTree[line] { } // TODO: Use this instead of indexedLines/Points where possible. -func newPointRTree(points []XY) *rtree.RTree[XY] { +func newPointRTree(points []XY) *rtree.RTree[XY] { //nolint:unused items := make([]rtree.BulkItem[XY], len(points)) for i, pt := range points { items[i] = rtree.BulkItem[XY]{