Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,14 @@ end
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N

#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties.
function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M}
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties.
function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, M}) where {T, N, M}
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

Expand Down
70 changes: 70 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,73 @@ end
@test compare(argmin, AT, -rand(Int, 10))
end
end

@testsuite "indexing combinatorial" (AT, eltypes) -> begin
@testset "Reshaped SubArray dispatch" for T in eltypes
@testset "3D slice assignment" begin
A = AT(ones(T, 4, 4, 4))
@views V = A[:, :, 1:2]
@allowscalar begin
@test_nowarn V .= zero(T)
@test all(Array(V) .== zero(T))
end
end

@testset "Logical mask view (dim = 3) — GPU safe" begin
A = AT(ones(T, 4, 4, 4))
idx = findall(Bool[true, false, true, false])
@views V = A[:, :, idx]
@allowscalar begin
@test_nowarn V .+= T(2)
@test all(Array(V) .== T(3))
end
end

@testset "Nested Reshape" begin
A = AT(ones(T, 4, 4, 4))
V = view(A, 1:2, 1:2, 1:2)
R1 = reshape(V, 4, 2)
R2 = reshape(R1, :)
@allowscalar begin
@test_nowarn R2 .+= one(T)
@test all(Array(R2) .== T(2))
end
end
end

@testset "Permuted and Reinterpreted Views" for T in eltypes
@testset "Reshaped PermutedDims" begin
A = AT(ones(T, 4, 4))
P = PermutedDimsArray(A, (2, 1))
R = reshape(P, :)
@allowscalar begin
@test_nowarn R[1:2] .= zero(T)
# Check the full assigned range.
@test all(Array(R)[1:2] .== zero(T))
end
end

@testset "Reshaped Reinterpreted" begin
T_base = real(T)
if T <: Complex
A = AT(ones(T, 4, 4))
IT = Complex{Int16}
R = reshape(reinterpret(IT, A), :)
@allowscalar begin
@test_nowarn R[1:2] .= zero(IT)
@test all(Array(R)[1:2] .== zero(IT))
end
end
end
end

@testset "Data parity with compare() — GPU safe" for T in eltypes
idx = 2:4
@test compare(AT, rand(T, 8, 8, 8)) do A
# compare() handles CPU/GPU execution no @allowscalar needed here
V = view(A, :, idx, :)
V .+= one(T)
A
end
end
end
Loading