Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
name = "FunctionImplementations"
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.3.0"
version = "0.3.1"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[extensions]
FunctionImplementationsBlockArraysExt = "BlockArrays"
FunctionImplementationsFillArraysExt = "FillArrays"
FunctionImplementationsLinearAlgebraExt = "LinearAlgebra"

[compat]
BlockArrays = "1.4"
FillArrays = "1.15"
LinearAlgebra = "1.10"
julia = "1.10"

Expand Down
2 changes: 1 addition & 1 deletion docs/src/reference.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Reference

```@autodocs
Modules = [FunctionImplementations]
Modules = [FunctionImplementations, FunctionImplementations.Concatenate]
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module FunctionImplementationsBlockArraysExt

using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
using FunctionImplementations.Concatenate: Concatenate

function Concatenate.cat_axis(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
return blockedrange([blocklengths(a1); blocklengths(a2)])
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module FunctionImplementationsFillArraysExt

using FillArrays: RectDiagonal
using FunctionImplementations: FunctionImplementations

function FunctionImplementations.permuteddims(a::RectDiagonal, perm)
(ndims(a) == length(perm) && isperm(perm)) ||
throw(ArgumentError("no valid permutation of dimensions"))
return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a)))
end

end
2 changes: 2 additions & 0 deletions src/FunctionImplementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ module FunctionImplementations
include("implementation.jl")
include("style.jl")
include("permuteddims.jl")
include("zero.jl")
include("concatenate.jl")

end
225 changes: 225 additions & 0 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
module Concatenate

Alternative implementation for `Base.cat` through `Concatenate.cat(!)`.

This is mostly a copy of the Base implementation, with the main difference being
that the destination is chosen based on all inputs instead of just the first.

Additionally, we have an intermediate representation in terms of a Concatenated object,
reminiscent of how Broadcast works.

The various entry points for specializing behavior are:

* Destination selection can be achieved through:

```julia
Base.similar(concat::Concatenated{Style}, ::Type{T}, axes) where {Style}
```

* Custom implementations:

```julia
Base.copy(concat::Concatenated{Style}) # custom implementation of cat
Base.copyto!(dest, concat::Concatenated{Style}) # custom implementation of cat! based on style
Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest)
```
"""
module Concatenate

export concatenate
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Concatenated, cat, cat!, concatenated"))

using Base: promote_eltypeof
import Base.Broadcast as BC
using ..FunctionImplementations: zero!

unval(::Val{x}) where {x} = x

function _Concatenated end

"""
Concatenated{Style, Dims, Args <: Tuple}

Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
hooks to customize the implementation.
"""
struct Concatenated{Style, Dims, Args <: Tuple}
style::Style
dims::Val{Dims}
args::Args
global @inline function _Concatenated(
style::Style, dims::Val{Dims}, args::Args
) where {Style, Dims, Args <: Tuple}
return new{Style, Dims, Args}(style, dims, args)
end
end

function Concatenated(
style::Union{BC.AbstractArrayStyle, Nothing}, dims::Val, args::Tuple
)
return _Concatenated(style, dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(cat_style(dims, args...), dims, args)
end
function Concatenated{Style}(
dims::Val, args::Tuple
) where {Style <: Union{BC.AbstractArrayStyle, Nothing}}
return Concatenated(Style(), dims, args)
end

dims(::Concatenated{<:Any, D}) where {D} = D
style(concat::Concatenated) = getfield(concat, :style)

concatenated(dims, args...) = concatenated(Val(dims), args...)
concatenated(dims::Val, args...) = Concatenated(dims, args)

function Base.convert(
::Type{Concatenated{NewStyle}}, concat::Concatenated{<:Any, Dims, Args}
) where {NewStyle, Dims, Args}
return Concatenated{NewStyle}(
concat.dims, concat.args
)::Concatenated{NewStyle, Dims, Args}
end

# allocating the destination container
# ------------------------------------
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
function Base.similar(concat::Concatenated, ax)
return similar(concat, eltype(concat), ax)
end

function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
# Convert to a broadcasted to leverage its similar implementation.
bc = BC.Broadcasted(style(concat), identity, concat.args, ax)
return similar(bc, T)
end

function cat_axis(
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return cat_axis(cat_axis(a1, a2), a_rest...)
end
function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange)
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
return Base.OneTo(length(a1) + length(a2))
end

function cat_ndims(dims, as::AbstractArray...)
return max(maximum(dims), maximum(ndims, as))
end
function cat_ndims(dims::Val, as::AbstractArray...)
return cat_ndims(unval(dims), as...)
end

function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
return ntuple(cat_ndims(dims, a, as...)) do dim
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
end
end
function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)
end

function cat_style(dims, as::AbstractArray...)
N = cat_ndims(dims, as...)
return typeof(BC.combine_styles(as...))(Val(N))
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...)

# Main logic
# ----------
"""
concatenate(dims, args...)

Concatenate the supplied `args` along dimensions `dims`.

See also [`cat`](@ref) and [`cat!`](@ref).
"""
concatenate(dims, args...) = Base.materialize(concatenated(dims, args...))

"""
Concatenate.cat(args...; dims)

Concatenate the supplied `args` along dimensions `dims`.

See also [`concatenate`](@ref) and [`cat!`](@ref).
"""
cat(args...; dims) = concatenate(dims, args...)
Base.materialize(concat::Concatenated) = copy(concat)

"""
Concatenate.cat!(dest, args...; dims)

Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`.
"""
function cat!(dest, args...; dims)
Base.materialize!(dest, concatenated(dims, args...))
return dest
end
Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)

Base.copy(concat::Concatenated) = copyto!(similar(concat), concat)

# The following is largely copied from the Base implementation of `Base.cat`, see:
# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)

cat_size(A) = (1,)
cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1
cat_size(A::AbstractArray, d) = size(A, d)

cat_indices(A, d) = Base.OneTo(1)
cat_indices(A::AbstractArray, d) = axes(A, d)

function __cat!(A, shape, catdims, X...)
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
end
function __cat_offset!(A, shape, catdims, offsets, x, X...)
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
return __cat_offset!(A, shape, catdims, newoffsets, X...)
end
__cat_offset!(A, shape, catdims, offsets) = A
function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
end
_copy_or_fill!(A, inds, x)
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
end
return newoffsets
end

dims2cat(dims::Val) = dims2cat(unval(dims))
function dims2cat(dims)
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
end
return ntuple(in(dims), maximum(dims))
end

# default falls back to replacing style with Nothing
# this permits specializing on typeof(dest) without ambiguities
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
return copyto!(dest, convert(Concatenated{Nothing}, concat))
end

function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
catdims = dims2cat(dims(concat))
shape = size(concat)
count(!iszero, catdims)::Int > 1 && zero!(dest)
return __cat!(dest, shape, catdims, concat.args...)
end

end
10 changes: 10 additions & 0 deletions src/zero.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
zero!(a::AbstractArray)

In-place version of `zero(a)`, sets all entries of `a` to zero.
"""
zero!(a::AbstractArray) = style(a)(zero!)(a)
function (::Implementation{typeof(zero!)})(a::AbstractArray)
fill!(a, zero(eltype(a)))
return a
end
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Expand All @@ -10,8 +14,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
FunctionImplementations = {path = ".."}

[compat]
Adapt = "4"
Aqua = "0.8"
BlockArrays = "1.4"
FillArrays = "1.15"
FunctionImplementations = "0.3"
JLArrays = "0.3"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Suppressor = "0.2"
Expand Down
17 changes: 17 additions & 0 deletions test/test_blockarraysext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using BlockArrays: BlockArray, blockedrange, blockisequal
using FunctionImplementations.Concatenate: concatenate
using Test: @test, @testset

@testset "BlockArraysExt" begin
a = BlockArray(randn(4, 4), [2, 2], [2, 2])
b = BlockArray(randn(4, 4), [2, 2], [2, 2])

concat = concatenate(1, a, b)
@test axes(concat) == (Base.OneTo(8), Base.OneTo(4))
@test blockisequal(axes(concat, 1), blockedrange([2, 2, 2, 2]))
@test blockisequal(axes(concat, 2), blockedrange([2, 2]))
@test size(concat) == (8, 4)
@test eltype(concat) ≡ Float64
@test copy(concat) == cat(a, b; dims = 1)
@test copy(concat) isa BlockArray{Float64, 2}
end
38 changes: 38 additions & 0 deletions test/test_concatenate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using Adapt: adapt
using FunctionImplementations.Concatenate: concatenated
using JLArrays: JLArray
using Test: @test, @testset

@testset "Concatenated" for arrayt in (Array, JLArray)
dev = adapt(arrayt)
a = dev(randn(Float32, 2, 2))
b = dev(randn(Float64, 2, 2))

concat = concatenated((1, 2), a, b)
@test axes(concat) == Base.OneTo.((4, 4))
@test size(concat) == (4, 4)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims = (1, 2))
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2}

concat = concatenated(1, a, b)
@test axes(concat) == Base.OneTo.((4, 2))
@test size(concat) == (4, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims = 1)
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2}

concat = concatenated(3, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 2))
@test size(concat) == (2, 2, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims = 3)
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 3}

concat = concatenated(4, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
@test size(concat) == (2, 2, 1, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims = 4)
@test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 4}
end
Loading
Loading