diff --git a/Project.toml b/Project.toml index a09a9f8..88fff6e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Fabian Gans "] version = "0.4.11" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" [compat] Aqua = "0.8" +ConstructionBase = "1" LRUCache = "1" Mmap = "1" OffsetArrays = "1" diff --git a/src/DiskArrays.jl b/src/DiskArrays.jl index 733108d..5443bd0 100644 --- a/src/DiskArrays.jl +++ b/src/DiskArrays.jl @@ -1,5 +1,7 @@ module DiskArrays +import ConstructionBase + using LRUCache: LRUCache, LRU # Use the README as the module docs diff --git a/src/permute.jl b/src/permute.jl index b092f00..9eae918 100644 --- a/src/permute.jl +++ b/src/permute.jl @@ -3,21 +3,41 @@ A lazily permuted disk array returned by `permutedims(diskarray, permutation)`. """ -struct PermutedDiskArray{T,N,P<:PermutedDimsArray{T,N}} <: AbstractDiskArray{T,N} - a::P +struct PermutedDiskArray{T,N,perm,iperm,A<:AbstractArray{T,N}} <: AbstractDiskArray{T,N} + parent::A end +# We use PermutedDimsArray internals instead of duplicating them, +# and just copy the type parameters it calculates. +PermutedDiskArray(A::AbstractArray, perm::Union{Tuple,AbstractVector}) = + PermutedDiskArray(A, PermutedDimsArray(CartesianIndices(A), perm)) +function PermutedDiskArray( + a::A, ::PermutedDimsArray{<:Any,<:Any,perm,iperm} +) where {A<:AbstractArray{T,N},perm,iperm} where {T,N} + PermutedDiskArray{T,N,perm,iperm,A}(a) +end + +# We need explicit ConstructionBase support as perm and iperm are only in the type. +# We include N so that only arrays of the same dimensionality can be set with this perm and iperm +struct PermutedDiskArrayConstructor{N,perm,iperm} end + +(::PermutedDiskArrayConstructor{N,perm,iperm})(a::A) where A<:AbstractArray{T,N} where {T,N,perm,iperm} = + PermutedDiskArray{T,N,perm,iperm,A}(a) + +ConstructionBase.constructorof(::Type{<:PermutedDiskArray{<:Any,N,perm,iperm}}) where {N,perm,iperm} = + PermutedDiskArrayConstructor{N,perm,iperm}() # Base methods -Base.size(a::PermutedDiskArray) = size(a.a) +Base.parent(a::PermutedDiskArray) = a.parent +Base.size(a::PermutedDiskArray) = genperm(size(parent(a)), _getperm(a)) # DiskArrays interface -haschunks(a::PermutedDiskArray) = haschunks(a.a.parent) +haschunks(a::PermutedDiskArray) = haschunks(parent(a)) function eachchunk(a::PermutedDiskArray) # Get the parent chunks - gridchunks = eachchunk(a.a.parent) - perm = _getperm(a.a) + gridchunks = eachchunk(parent(a)) + perm = _getperm(a) # Return permuted GridChunks return GridChunks(genperm(gridchunks.chunks, perm)...) end @@ -26,33 +46,29 @@ function DiskArrays.readblock!(a::PermutedDiskArray, aout, i::OrdinalRange...) # Permute the indices inew = genperm(i, iperm) # Permute the dest block and read from the true parent - DiskArrays.readblock!(a.a.parent, PermutedDimsArray(aout, iperm), inew...) + DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, iperm), inew...) return nothing end function DiskArrays.writeblock!(a::PermutedDiskArray, v, i::OrdinalRange...) iperm = _getiperm(a) inew = genperm(i, iperm) # Permute the dest block and write from the true parent - DiskArrays.writeblock!(a.a.parent, PermutedDimsArray(v, iperm), inew...) + DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, iperm), inew...) return nothing end -_getperm(a::PermutedDiskArray) = _getperm(a.a) -_getperm(::PermutedDimsArray{<:Any,<:Any,perm}) where {perm} = perm +_getperm(::PermutedDiskArray{<:Any,<:Any,perm}) where {perm} = perm +_getiperm(::PermutedDiskArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm -_getiperm(a::PermutedDiskArray) = _getiperm(a.a) -_getiperm(::PermutedDimsArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm - -# Implementaion macros - -function permutedims_disk(a, perm) - pd = PermutedDimsArray(a, perm) - return PermutedDiskArray{eltype(a),ndims(a),typeof(pd)}(pd) -end +# Implementation macro macro implement_permutedims(t) t = esc(t) quote - Base.permutedims(parent::$t, perm) = permutedims_disk(parent, perm) + Base.permutedims(parent::$t, perm) = PermutedDiskArray(parent, perm) + # It's not correct to return a PermutedDiskArray from the PermutedDimsArray constructor. + # Instead we need a Base julia method that behaves like view for SubArray, such as `lazypermutedims`. + # But until that exists this is better than returning a broken disk array. + Base.PermutedDimsArray(parent::$t, perm) = PermutedDiskArray(parent, perm) end end diff --git a/src/reshape.jl b/src/reshape.jl index fe71d77..58234ee 100644 --- a/src/reshape.jl +++ b/src/reshape.jl @@ -23,6 +23,7 @@ end # Base methods +Base.parent(r::ReshapedDiskArray) = r.parent Base.size(r::ReshapedDiskArray) = r.newsize # DiskArrays interface diff --git a/src/util/testtypes.jl b/src/util/testtypes.jl index 0f3f657..70cfbf5 100644 --- a/src/util/testtypes.jl +++ b/src/util/testtypes.jl @@ -25,7 +25,8 @@ DiskArrays.batchstrategy(a::AccessCountDiskArray) = a.batchstrategy AccessCountDiskArray(a; chunksize=size(a), batchstrategy=DiskArrays.ChunkRead(DiskArrays.NoStepRange(), 0.5)) = AccessCountDiskArray([], [], a, chunksize, batchstrategy) -Base.size(a::AccessCountDiskArray) = size(a.parent) +Base.parent(a::AccessCountDiskArray) = a.parent +Base.size(a::AccessCountDiskArray) = size(parent(a)) # Apply the all in one macro rather than inheriting @@ -38,7 +39,7 @@ function DiskArrays.readblock!(a::AccessCountDiskArray, aout, i::OrdinalRange... end # println("reading from indices ", join(string.(i)," ")) push!(a.getindex_log, i) - return aout .= a.parent[i...] + return aout .= parent(a)[i...] end function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...) ndims(a) == length(i) || error("Number of indices is not correct") @@ -47,31 +48,30 @@ function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...) end # println("Writing to indices ", join(string.(i)," ")) push!(a.setindex_log, i) - return view(a.parent, i...) .= v + return view(parent(a), i...) .= v end getindex_count(a::AccessCountDiskArray) = length(a.getindex_log) setindex_count(a::AccessCountDiskArray) = length(a.setindex_log) getindex_log(a::AccessCountDiskArray) = a.getindex_log setindex_log(a::AccessCountDiskArray) = a.setindex_log -trueparent(a::AccessCountDiskArray) = a.parent - -getindex_count(a::DiskArrays.ReshapedDiskArray) = getindex_count(a.parent) -setindex_count(a::DiskArrays.ReshapedDiskArray) = setindex_count(a.parent) -getindex_log(a::DiskArrays.ReshapedDiskArray) = getindex_log(a.parent) -setindex_log(a::DiskArrays.ReshapedDiskArray) = setindex_log(a.parent) -trueparent(a::DiskArrays.ReshapedDiskArray) = trueparent(a.parent) - -getindex_count(a::DiskArrays.PermutedDiskArray) = getindex_count(a.a.parent) -setindex_count(a::DiskArrays.PermutedDiskArray) = setindex_count(a.a.parent) -getindex_log(a::DiskArrays.PermutedDiskArray) = getindex_log(a.a.parent) -setindex_log(a::DiskArrays.PermutedDiskArray) = setindex_log(a.a.parent) -function trueparent( - a::DiskArrays.PermutedDiskArray{T,N,<:PermutedDimsArray{T,N,perm,iperm}} -) where {T,N,perm,iperm} - return permutedims(trueparent(a.a.parent), perm) +trueparent(a::AccessCountDiskArray) = parent(a) + +getindex_count(a::DiskArrays.AbstractDiskArray) = getindex_count(parent(a)) +setindex_count(a::DiskArrays.AbstractDiskArray) = setindex_count(parent(a)) +getindex_log(a::DiskArrays.AbstractDiskArray) = getindex_log(parent(a)) +setindex_log(a::DiskArrays.AbstractDiskArray) = setindex_log(parent(a)) +function trueparent(a::DiskArrays.AbstractDiskArray) + if parent(a) === a + a + else + trueparent(parent(a)) + end end +trueparent(a::DiskArrays.PermutedDiskArray{T,N,perm,iperm}) where {T,N,perm,iperm} = + permutedims(trueparent(parent(a)), perm) + """ ChunkedDiskArray(A; chunksize) @@ -83,12 +83,13 @@ struct ChunkedDiskArray{T,N,A<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArr end ChunkedDiskArray(a; chunksize=size(a)) = ChunkedDiskArray(a, chunksize) -Base.size(a::ChunkedDiskArray) = size(a.parent) +Base.parent(a::ChunkedDiskArray) = a.parent +Base.size(a::ChunkedDiskArray) = size(parent(a)) DiskArrays.haschunks(::ChunkedDiskArray) = DiskArrays.Chunked() DiskArrays.eachchunk(a::ChunkedDiskArray) = DiskArrays.GridChunks(a, a.chunksize) -DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= a.parent[i...] -DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(a.parent, i...) .= v +DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= parent(a)[i...] +DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(parent(a), i...) .= v """ UnchunkedDiskArray(A) @@ -96,16 +97,17 @@ DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(a A disk array without chunking, that can wrap any other `AbstractArray`. """ struct UnchunkedDiskArray{T,N,P<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArray{T,N} - p::P + parent::P end -Base.size(a::UnchunkedDiskArray) = size(a.p) +Base.parent(a::UnchunkedDiskArray) = a.parent +Base.size(a::UnchunkedDiskArray) = size(parent(a)) DiskArrays.haschunks(::UnchunkedDiskArray) = DiskArrays.Unchunked() function DiskArrays.readblock!(a::UnchunkedDiskArray, aout, i::AbstractUnitRange...) ndims(a) == length(i) || error("Number of indices is not correct") all(r -> isa(r, AbstractUnitRange), i) || error("Not all indices are unit ranges") - return aout .= a.p[i...] + return aout .= parent(a)[i...] end end diff --git a/test/runtests.jl b/test/runtests.jl index 8ecefa2..a15ca9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using DiskArrays.TestTypes using Test using Statistics using Aqua +using ConstructionBase using TraceFuns, Suppressor # Run with any code changes @@ -816,15 +817,20 @@ import Base.PermutedDimsArrays.invperm ip = invperm(p) a = permutedims(AccessCountDiskArray(permutedims(reshape(1:20, 4, 5, 1), ip)), p) test_getindex(a) - a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) + a = PermutedDimsArray(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) test_setindex(a) a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p) test_view(a) - a = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p) - test_reductions(a) + f = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p) + test_reductions(f) a_disk1 = permutedims(AccessCountDiskArray(rand(9, 2, 10); chunksize=(3, 2, 5)), p) test_broadcast(a_disk1) - @test PermutedDiskArray(a_disk1.a) === a_disk1 + + @testset "ConstructionBase works on PermutedDiskArray" begin + v = ones(Int, 10, 2, 2) + av = ConstructionBase.setproperties(a, (; parent=v)) + @test parent(av) === v + end end @testset "Unchunked String arrays" begin