From 9a772282dfc1634cf362716b125c21ca69025fd8 Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Tue, 1 Apr 2025 23:41:35 +0200 Subject: [PATCH 1/4] add PermutedDimsArray method that returns PermutedDiskArray --- src/permute.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/permute.jl b/src/permute.jl index b092f00..502bc01 100644 --- a/src/permute.jl +++ b/src/permute.jl @@ -43,7 +43,7 @@ _getperm(::PermutedDimsArray{<:Any,<:Any,perm}) where {perm} = perm _getiperm(a::PermutedDiskArray) = _getiperm(a.a) _getiperm(::PermutedDimsArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm -# Implementaion macros +# Implementation macros function permutedims_disk(a, perm) pd = PermutedDimsArray(a, perm) @@ -54,5 +54,9 @@ macro implement_permutedims(t) t = esc(t) quote Base.permutedims(parent::$t, perm) = permutedims_disk(parent, perm) + # It's not correct to return a PermutedDiskArray from the PermutedDimsArray constructor. + # Instead we need a Base julia method that behaves like view for SubArray, such as `lazypermutedims`. + # But until that exists this is better than returning a broken disk array. + Base.PermutedDimsArray(parent::$t, perm) = permutedims_disk(parent, perm) end end From 17d266f5057a1e0575e4e628a2e709a0f73fdf92 Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Wed, 2 Apr 2025 12:37:44 +0200 Subject: [PATCH 2/4] bugfix recursion --- src/permute.jl | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/permute.jl b/src/permute.jl index 502bc01..3fc6fa6 100644 --- a/src/permute.jl +++ b/src/permute.jl @@ -3,21 +3,30 @@ A lazily permuted disk array returned by `permutedims(diskarray, permutation)`. """ -struct PermutedDiskArray{T,N,P<:PermutedDimsArray{T,N}} <: AbstractDiskArray{T,N} +struct PermutedDiskArray{T,N,perm,iperm,A<:AbstractArray{T,N}} <: AbstractDiskArray{T,N} a::P end +PermutedDiskArray(A::AbstractArray, perm::Union{Tuple,AbstractVector}) = + PermutedDiskArray(A, PermutedDimsArray(CartesianIndices(A), perm)) +# We use PermutedDimsArray internals instead of duplicating them, +# and just copy the type parameters +function PermutedDiskArray( + a::A, perm::PermutedDimsArray{<:Any,<:Any,perm,iperm} +) where {A<:AbstractArray{T,N},perm,iperm} where {T,N} = + PermutedDiskArray{T,N,perm,iperm,A}(a) +end # Base methods -Base.size(a::PermutedDiskArray) = size(a.a) +Base.size(a::PermutedDiskArray) = genperm(size(parent(a)), _getperm(a)) # DiskArrays interface -haschunks(a::PermutedDiskArray) = haschunks(a.a.parent) +haschunks(a::PermutedDiskArray) = haschunks(parent(a)) function eachchunk(a::PermutedDiskArray) # Get the parent chunks - gridchunks = eachchunk(a.a.parent) - perm = _getperm(a.a) + gridchunks = eachchunk(parent(a)) + perm = _getperm(a) # Return permuted GridChunks return GridChunks(genperm(gridchunks.chunks, perm)...) end @@ -37,26 +46,18 @@ function DiskArrays.writeblock!(a::PermutedDiskArray, v, i::OrdinalRange...) return nothing end -_getperm(a::PermutedDiskArray) = _getperm(a.a) -_getperm(::PermutedDimsArray{<:Any,<:Any,perm}) where {perm} = perm - -_getiperm(a::PermutedDiskArray) = _getiperm(a.a) -_getiperm(::PermutedDimsArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm +_getperm(::PermutedDiskArray{<:Any,<:Any,perm}) where {perm} = perm +_getiperm(::PermutedDiskArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm -# Implementation macros - -function permutedims_disk(a, perm) - pd = PermutedDimsArray(a, perm) - return PermutedDiskArray{eltype(a),ndims(a),typeof(pd)}(pd) -end +# Implementation macro macro implement_permutedims(t) t = esc(t) quote - Base.permutedims(parent::$t, perm) = permutedims_disk(parent, perm) + Base.permutedims(parent::$t, perm) = PermutedDiskArray(parent, perm) # It's not correct to return a PermutedDiskArray from the PermutedDimsArray constructor. # Instead we need a Base julia method that behaves like view for SubArray, such as `lazypermutedims`. # But until that exists this is better than returning a broken disk array. - Base.PermutedDimsArray(parent::$t, perm) = permutedims_disk(parent, perm) + Base.PermutedDimsArray(parent::$t, perm) = PermutedDiskArray(parent, perm) end end From 36305866c5eeb0007a92a80033c48c411bff6a8f Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Wed, 2 Apr 2025 12:39:23 +0200 Subject: [PATCH 3/4] A type param --- src/permute.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/permute.jl b/src/permute.jl index 3fc6fa6..2e06376 100644 --- a/src/permute.jl +++ b/src/permute.jl @@ -4,7 +4,7 @@ A lazily permuted disk array returned by `permutedims(diskarray, permutation)`. """ struct PermutedDiskArray{T,N,perm,iperm,A<:AbstractArray{T,N}} <: AbstractDiskArray{T,N} - a::P + a::A end PermutedDiskArray(A::AbstractArray, perm::Union{Tuple,AbstractVector}) = PermutedDiskArray(A, PermutedDimsArray(CartesianIndices(A), perm)) From 4841be1b7434c46f13d3c87a42d3a3e6b92edc87 Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Wed, 2 Apr 2025 13:55:35 +0200 Subject: [PATCH 4/4] bugfix parent and ConstructionBase --- Project.toml | 2 ++ src/DiskArrays.jl | 2 ++ src/permute.jl | 25 +++++++++++++++------ src/reshape.jl | 1 + src/util/testtypes.jl | 52 ++++++++++++++++++++++--------------------- test/runtests.jl | 14 ++++++++---- 6 files changed, 60 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index e16fa4d..05f46b0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Fabian Gans "] version = "0.4.11" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" [compat] Aqua = "0.8" +ConstructionBase = "1" LRUCache = "1" Mmap = "1" OffsetArrays = "1" diff --git a/src/DiskArrays.jl b/src/DiskArrays.jl index 733108d..5443bd0 100644 --- a/src/DiskArrays.jl +++ b/src/DiskArrays.jl @@ -1,5 +1,7 @@ module DiskArrays +import ConstructionBase + using LRUCache: LRUCache, LRU # Use the README as the module docs diff --git a/src/permute.jl b/src/permute.jl index 2e06376..9eae918 100644 --- a/src/permute.jl +++ b/src/permute.jl @@ -4,20 +4,31 @@ A lazily permuted disk array returned by `permutedims(diskarray, permutation)`. """ struct PermutedDiskArray{T,N,perm,iperm,A<:AbstractArray{T,N}} <: AbstractDiskArray{T,N} - a::A + parent::A end +# We use PermutedDimsArray internals instead of duplicating them, +# and just copy the type parameters it calculates. PermutedDiskArray(A::AbstractArray, perm::Union{Tuple,AbstractVector}) = PermutedDiskArray(A, PermutedDimsArray(CartesianIndices(A), perm)) -# We use PermutedDimsArray internals instead of duplicating them, -# and just copy the type parameters function PermutedDiskArray( - a::A, perm::PermutedDimsArray{<:Any,<:Any,perm,iperm} -) where {A<:AbstractArray{T,N},perm,iperm} where {T,N} = + a::A, ::PermutedDimsArray{<:Any,<:Any,perm,iperm} +) where {A<:AbstractArray{T,N},perm,iperm} where {T,N} PermutedDiskArray{T,N,perm,iperm,A}(a) end +# We need explicit ConstructionBase support as perm and iperm are only in the type. +# We include N so that only arrays of the same dimensionality can be set with this perm and iperm +struct PermutedDiskArrayConstructor{N,perm,iperm} end + +(::PermutedDiskArrayConstructor{N,perm,iperm})(a::A) where A<:AbstractArray{T,N} where {T,N,perm,iperm} = + PermutedDiskArray{T,N,perm,iperm,A}(a) + +ConstructionBase.constructorof(::Type{<:PermutedDiskArray{<:Any,N,perm,iperm}}) where {N,perm,iperm} = + PermutedDiskArrayConstructor{N,perm,iperm}() + # Base methods +Base.parent(a::PermutedDiskArray) = a.parent Base.size(a::PermutedDiskArray) = genperm(size(parent(a)), _getperm(a)) # DiskArrays interface @@ -35,14 +46,14 @@ function DiskArrays.readblock!(a::PermutedDiskArray, aout, i::OrdinalRange...) # Permute the indices inew = genperm(i, iperm) # Permute the dest block and read from the true parent - DiskArrays.readblock!(a.a.parent, PermutedDimsArray(aout, iperm), inew...) + DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, iperm), inew...) return nothing end function DiskArrays.writeblock!(a::PermutedDiskArray, v, i::OrdinalRange...) iperm = _getiperm(a) inew = genperm(i, iperm) # Permute the dest block and write from the true parent - DiskArrays.writeblock!(a.a.parent, PermutedDimsArray(v, iperm), inew...) + DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, iperm), inew...) return nothing end diff --git a/src/reshape.jl b/src/reshape.jl index fe71d77..58234ee 100644 --- a/src/reshape.jl +++ b/src/reshape.jl @@ -23,6 +23,7 @@ end # Base methods +Base.parent(r::ReshapedDiskArray) = r.parent Base.size(r::ReshapedDiskArray) = r.newsize # DiskArrays interface diff --git a/src/util/testtypes.jl b/src/util/testtypes.jl index 0f3f657..70cfbf5 100644 --- a/src/util/testtypes.jl +++ b/src/util/testtypes.jl @@ -25,7 +25,8 @@ DiskArrays.batchstrategy(a::AccessCountDiskArray) = a.batchstrategy AccessCountDiskArray(a; chunksize=size(a), batchstrategy=DiskArrays.ChunkRead(DiskArrays.NoStepRange(), 0.5)) = AccessCountDiskArray([], [], a, chunksize, batchstrategy) -Base.size(a::AccessCountDiskArray) = size(a.parent) +Base.parent(a::AccessCountDiskArray) = a.parent +Base.size(a::AccessCountDiskArray) = size(parent(a)) # Apply the all in one macro rather than inheriting @@ -38,7 +39,7 @@ function DiskArrays.readblock!(a::AccessCountDiskArray, aout, i::OrdinalRange... end # println("reading from indices ", join(string.(i)," ")) push!(a.getindex_log, i) - return aout .= a.parent[i...] + return aout .= parent(a)[i...] end function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...) ndims(a) == length(i) || error("Number of indices is not correct") @@ -47,31 +48,30 @@ function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...) end # println("Writing to indices ", join(string.(i)," ")) push!(a.setindex_log, i) - return view(a.parent, i...) .= v + return view(parent(a), i...) .= v end getindex_count(a::AccessCountDiskArray) = length(a.getindex_log) setindex_count(a::AccessCountDiskArray) = length(a.setindex_log) getindex_log(a::AccessCountDiskArray) = a.getindex_log setindex_log(a::AccessCountDiskArray) = a.setindex_log -trueparent(a::AccessCountDiskArray) = a.parent - -getindex_count(a::DiskArrays.ReshapedDiskArray) = getindex_count(a.parent) -setindex_count(a::DiskArrays.ReshapedDiskArray) = setindex_count(a.parent) -getindex_log(a::DiskArrays.ReshapedDiskArray) = getindex_log(a.parent) -setindex_log(a::DiskArrays.ReshapedDiskArray) = setindex_log(a.parent) -trueparent(a::DiskArrays.ReshapedDiskArray) = trueparent(a.parent) - -getindex_count(a::DiskArrays.PermutedDiskArray) = getindex_count(a.a.parent) -setindex_count(a::DiskArrays.PermutedDiskArray) = setindex_count(a.a.parent) -getindex_log(a::DiskArrays.PermutedDiskArray) = getindex_log(a.a.parent) -setindex_log(a::DiskArrays.PermutedDiskArray) = setindex_log(a.a.parent) -function trueparent( - a::DiskArrays.PermutedDiskArray{T,N,<:PermutedDimsArray{T,N,perm,iperm}} -) where {T,N,perm,iperm} - return permutedims(trueparent(a.a.parent), perm) +trueparent(a::AccessCountDiskArray) = parent(a) + +getindex_count(a::DiskArrays.AbstractDiskArray) = getindex_count(parent(a)) +setindex_count(a::DiskArrays.AbstractDiskArray) = setindex_count(parent(a)) +getindex_log(a::DiskArrays.AbstractDiskArray) = getindex_log(parent(a)) +setindex_log(a::DiskArrays.AbstractDiskArray) = setindex_log(parent(a)) +function trueparent(a::DiskArrays.AbstractDiskArray) + if parent(a) === a + a + else + trueparent(parent(a)) + end end +trueparent(a::DiskArrays.PermutedDiskArray{T,N,perm,iperm}) where {T,N,perm,iperm} = + permutedims(trueparent(parent(a)), perm) + """ ChunkedDiskArray(A; chunksize) @@ -83,12 +83,13 @@ struct ChunkedDiskArray{T,N,A<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArr end ChunkedDiskArray(a; chunksize=size(a)) = ChunkedDiskArray(a, chunksize) -Base.size(a::ChunkedDiskArray) = size(a.parent) +Base.parent(a::ChunkedDiskArray) = a.parent +Base.size(a::ChunkedDiskArray) = size(parent(a)) DiskArrays.haschunks(::ChunkedDiskArray) = DiskArrays.Chunked() DiskArrays.eachchunk(a::ChunkedDiskArray) = DiskArrays.GridChunks(a, a.chunksize) -DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= a.parent[i...] -DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(a.parent, i...) .= v +DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= parent(a)[i...] +DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(parent(a), i...) .= v """ UnchunkedDiskArray(A) @@ -96,16 +97,17 @@ DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(a A disk array without chunking, that can wrap any other `AbstractArray`. """ struct UnchunkedDiskArray{T,N,P<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArray{T,N} - p::P + parent::P end -Base.size(a::UnchunkedDiskArray) = size(a.p) +Base.parent(a::UnchunkedDiskArray) = a.parent +Base.size(a::UnchunkedDiskArray) = size(parent(a)) DiskArrays.haschunks(::UnchunkedDiskArray) = DiskArrays.Unchunked() function DiskArrays.readblock!(a::UnchunkedDiskArray, aout, i::AbstractUnitRange...) ndims(a) == length(i) || error("Number of indices is not correct") all(r -> isa(r, AbstractUnitRange), i) || error("Not all indices are unit ranges") - return aout .= a.p[i...] + return aout .= parent(a)[i...] end end diff --git a/test/runtests.jl b/test/runtests.jl index 689ec3a..736ffa2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using DiskArrays.TestTypes using Test using Statistics using Aqua +using ConstructionBase # Run with any code changes # using JET @@ -815,15 +816,20 @@ import Base.PermutedDimsArrays.invperm ip = invperm(p) a = permutedims(AccessCountDiskArray(permutedims(reshape(1:20, 4, 5, 1), ip)), p) test_getindex(a) - a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) + a = PermutedDimsArray(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) test_setindex(a) a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) test_view(a) - a = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p) - test_reductions(a) + f = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p) + test_reductions(f) a_disk1 = permutedims(AccessCountDiskArray(rand(9, 2, 10); chunksize=(3, 2, 5)), p) test_broadcast(a_disk1) - @test PermutedDiskArray(a_disk1.a) === a_disk1 + + @testset "ConstructionBase works on PermutedDiskArray" begin + v = ones(Int, 10, 2, 2) + av = ConstructionBase.setproperties(a, (; parent=v)) + @test parent(av) === v + end end @testset "Unchunked String arrays" begin