Skip to content

Commit 98af42a

Browse files
committed
More tests, simplify wrapper code
1 parent 6d2e03c commit 98af42a

File tree

5 files changed

+294
-87
lines changed

5 files changed

+294
-87
lines changed

src/abstractsparsearrayinterface.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@ function storedlength end
1414
function storedpairs end
1515
function storedvalues end
1616

17+
# Generic functionality for converting to a
18+
# dense array, trying to preserve information
19+
# about the array (such as which device it is on).
20+
# TODO: Maybe call `densecopy`?
21+
# TODO: Make sure this actually preserves the device,
22+
# maybe use `TypeParameterAccessors.unwrap_array_type`.
23+
# TODO: Turn into an `@interface` function.
24+
function densearray(a::AbstractArray)
25+
# TODO: `set_ndims(unwrap_array_type(a), ndims(a))(a)`
26+
# Maybe define `densetype(a) = set_ndims(unwrap_array_type(a), ndims(a))`.
27+
# Or could use `unspecify_parameters(unwrap_array_type(a))(a)`.
28+
return Array(a)
29+
end
30+
1731
# Minimal interface for `SparseArrayInterface`.
1832
# Fallbacks for dense/non-sparse arrays.
1933
@interface ::AbstractArrayInterface isstored(a::AbstractArray, I::Int...) = true
@@ -32,8 +46,8 @@ end
3246
@interface ::AbstractArrayInterface function setunstoredindex!(
3347
a::AbstractArray, value, I::Int...
3448
)
35-
setindex!(a, value, I...)
36-
return a
49+
# TODO: Make this a `MethodError`?
50+
return error("Not implemented.")
3751
end
3852

3953
# TODO: Use `Base.to_indices`?
@@ -116,10 +130,17 @@ end
116130

117131
@interface ::AbstractSparseArrayInterface storedvalues(a::AbstractArray) = StoredValues(a)
118132

119-
@interface ::AbstractSparseArrayInterface function eachstoredindex(as::AbstractArray...)
133+
@interface ::AbstractSparseArrayInterface function eachstoredindex(
134+
a1::AbstractArray, a2::AbstractArray, a_rest::AbstractArray...
135+
)
120136
# TODO: Make this more customizable, say with a function
121137
# `combine/promote_storedindices(a1, a2)`.
122-
return union(eachstoredindex.(as)...)
138+
return union(eachstoredindex.((a1, a2, a_rest...))...)
139+
end
140+
141+
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::AbstractArray)
142+
# TODO: Use `MethodError`?
143+
return error("Not implemented.")
123144
end
124145

125146
# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing

src/wrappers.jl

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
parentvalue_to_value(a::AbstractArray, value) = value
2+
value_to_parentvalue(a::AbstractArray, value) = value
3+
eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a))
4+
storedparentvalues(a::AbstractArray) = storedvalues(parent(a))
5+
parentindex_to_index(a::AbstractArray, I::CartesianIndex) = error()
6+
function parentindex_to_index(a::AbstractArray, I::Int...)
7+
return Tuple(parentindex_to_index(a, CartesianIndex(I)))
8+
end
9+
index_to_parentindex(a::AbstractArray, I::CartesianIndex) = error()
10+
function index_to_parentindex(a::AbstractArray, I::Int...)
11+
return Tuple(index_to_parentindex(a, CartesianIndex(I)))
12+
end
13+
114
function cartesianindex_reverse(I::CartesianIndex)
215
return CartesianIndex(reverse(Tuple(I)))
316
end
@@ -7,115 +20,117 @@ tuple_oneto(n) = ntuple(identity, n)
720
# https://github.com/jipolanco/StaticPermutations.jl?
821
genperm(v, perm) = map(j -> v[j], perm)
922

10-
## TODO: Use this and something similar for `Dictionary` to make a faster
11-
## implementation of `storedvalues(::SubArray)`.
12-
## function valuesview(d::Dict, keys)
13-
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
14-
## end
23+
using LinearAlgebra: Adjoint
24+
function parentindex_to_index(a::Adjoint, I::CartesianIndex)
25+
return cartesianindex_reverse(I)
26+
end
27+
function index_to_parentindex(a::Adjoint, I::CartesianIndex)
28+
return cartesianindex_reverse(I)
29+
end
30+
function parentvalue_to_value(a::Adjoint, value)
31+
return adjoint(value)
32+
end
33+
function value_to_parentvalue(a::Adjoint, value)
34+
return adjoint(value)
35+
end
36+
37+
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
38+
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
39+
function index_to_parentindex(a::PermutedDimsArray, I::CartesianIndex)
40+
return CartesianIndex(genperm(I, iperm(a)))
41+
end
42+
function parentindex_to_index(a::PermutedDimsArray, I::CartesianIndex)
43+
return CartesianIndex(genperm(I, perm(a)))
44+
end
45+
46+
using Base: ReshapedArray
47+
function parentindex_to_index(a::ReshapedArray, I::CartesianIndex)
48+
return CartesianIndices(size(a))[LinearIndices(parent(a))[I]]
49+
end
50+
function index_to_parentindex(a::ReshapedArray, I::CartesianIndex)
51+
return CartesianIndices(parent(a))[LinearIndices(size(a))[I]]
52+
end
1553

1654
function eachstoredparentindex(a::SubArray)
1755
return filter(eachstoredindex(parent(a))) do I
1856
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
1957
end
2058
end
21-
@interface ::AbstractSparseArrayInterface function storedvalues(a::SubArray)
59+
function index_to_parentindex(a::SubArray, I::CartesianIndex)
60+
return CartesianIndex(Base.reindex(parentindices(a), Tuple(I)))
61+
end
62+
function parentindex_to_index(a::SubArray, I::CartesianIndex)
63+
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
64+
return !(parentindices(a)[d] isa Real)
65+
end
66+
return CartesianIndex(
67+
map(nonscalardims) do d
68+
return findfirst(==(I[d]), parentindices(a)[d])
69+
end,
70+
)
71+
end
72+
## TODO: Use this and something similar for `Dictionary` to make a faster
73+
## implementation of `storedvalues(::SubArray)`.
74+
## function valuesview(d::Dict, keys)
75+
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
76+
## end
77+
function storedparentvalues(a::SubArray)
2278
# We use `StoredValues` rather than `@view`/`SubArray` so that
2379
# it gets interpreted as a dense array.
2480
return StoredValues(parent(a), collect(eachstoredparentindex(a)))
2581
end
26-
@interface ::AbstractSparseArrayInterface function isstored(a::SubArray, I::Int...)
27-
return isstored(parent(a), Base.reindex(parentindices(a), I)...)
82+
83+
using LinearAlgebra: Transpose
84+
function parentindex_to_index(a::Transpose, I::CartesianIndex)
85+
return cartesianindex_reverse(I)
2886
end
29-
@interface ::AbstractSparseArrayInterface function getstoredindex(a::SubArray, I::Int...)
30-
return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
87+
function index_to_parentindex(a::Transpose, I::CartesianIndex)
88+
return cartesianindex_reverse(I)
3189
end
32-
@interface ::AbstractSparseArrayInterface function getunstoredindex(a::SubArray, I::Int...)
33-
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
90+
function parentvalue_to_value(a::Transpose, value)
91+
return transpose(value)
3492
end
35-
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::SubArray)
36-
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
37-
return !(parentindices(a)[d] isa Real)
38-
end
39-
return collect((
40-
CartesianIndex(
41-
map(nonscalardims) do d
42-
return findfirst(==(I[d]), parentindices(a)[d])
43-
end,
44-
) for I in eachstoredparentindex(a)
45-
))
46-
end
47-
48-
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
49-
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
50-
51-
@interface ::AbstractSparseArrayInterface storedvalues(a::PermutedDimsArray) =
52-
storedvalues(parent(a))
53-
@interface ::AbstractSparseArrayInterface function isstored(a::PermutedDimsArray, I::Int...)
54-
return isstored(parent(a), genperm(I, iperm(a))...)
55-
end
56-
@interface ::AbstractSparseArrayInterface function getstoredindex(
57-
a::PermutedDimsArray, I::Int...
58-
)
59-
return getstoredindex(parent(a), genperm(I, iperm(a))...)
60-
end
61-
@interface ::AbstractSparseArrayInterface function getunstoredindex(
62-
a::PermutedDimsArray, I::Int...
63-
)
64-
return getunstoredindex(parent(a), genperm(I, iperm(a))...)
65-
end
66-
@interface ::AbstractSparseArrayInterface function setstoredindex!(
67-
a::PermutedDimsArray, value, I::Int...
68-
)
69-
# TODO: Should this be `iperm(a)`?
70-
setstoredindex!(parent(a), value, genperm(I, perm(a))...)
71-
return a
72-
end
73-
@interface ::AbstractSparseArrayInterface function setunstoredindex!(
74-
a::PermutedDimsArray, value, I::Int...
75-
)
76-
# TODO: Should this be `iperm(a)`?
77-
setunstoredindex!(parent(a), value, genperm(I, perm(a))...)
78-
return a
79-
end
80-
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::PermutedDimsArray)
81-
# TODO: Make lazy with `Iterators.map`.
82-
return map(collect(eachstoredindex(parent(a)))) do I
83-
return CartesianIndex(genperm(I, perm(a)))
84-
end
93+
function value_to_parentvalue(a::Transpose, value)
94+
return transpose(value)
8595
end
8696

87-
for (type, func) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
97+
# TODO: Turn these into `AbstractWrappedSparseArrayInterface` functions?
98+
for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose)
8899
@eval begin
89-
using LinearAlgebra: $type
90-
@interface ::AbstractSparseArrayInterface storedvalues(a::$type) =
91-
storedvalues(parent(a))
92-
@interface ::AbstractSparseArrayInterface function isstored(a::$type, i::Int, j::Int)
93-
return isstored(parent(a), j, i)
100+
@interface ::AbstractSparseArrayInterface storedvalues(a::$type) = storedparentvalues(a)
101+
@interface ::AbstractSparseArrayInterface function isstored(a::$type, I::Int...)
102+
return isstored(parent(a), index_to_parentindex(a, I...)...)
94103
end
95104
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::$type)
96105
# TODO: Make lazy with `Iterators.map`.
97-
return map(cartesianindex_reverse, collect(eachstoredindex(parent(a))))
106+
return map(collect(eachstoredparentindex(a))) do I
107+
return parentindex_to_index(a, I)
108+
end
98109
end
99-
@interface ::AbstractSparseArrayInterface function getstoredindex(
100-
a::$type, i::Int, j::Int
101-
)
102-
return $func(getstoredindex(parent(a), j, i))
110+
@interface ::AbstractSparseArrayInterface function getstoredindex(a::$type, I::Int...)
111+
return parentvalue_to_value(
112+
a, getstoredindex(parent(a), index_to_parentindex(a, I...)...)
113+
)
103114
end
104-
@interface ::AbstractSparseArrayInterface function getunstoredindex(
105-
a::$type, i::Int, j::Int
106-
)
107-
return $func(getunstoredindex(parent(a), j, i))
115+
@interface ::AbstractSparseArrayInterface function getunstoredindex(a::$type, I::Int...)
116+
return parentvalue_to_value(
117+
a, getunstoredindex(parent(a), index_to_parentindex(a, I...)...)
118+
)
108119
end
109120
@interface ::AbstractSparseArrayInterface function setstoredindex!(
110-
a::$type, value, i::Int, j::Int
121+
a::$type, value, I::Int...
111122
)
112-
setstoredindex!(parent(a), $func(value), j, i)
123+
setstoredindex!(
124+
parent(a), value_to_parentvalue(a, value), index_to_parentindex(a, I...)...
125+
)
113126
return a
114127
end
115128
@interface ::AbstractSparseArrayInterface function setunstoredindex!(
116-
a::$type, value, i::Int, j::Int
129+
a::$type, value, I::Int...
117130
)
118-
setunstoredindex!(parent(a), $func(value), j, i)
131+
setunstoredindex!(
132+
parent(a), value_to_parentvalue(a, value), index_to_parentindex(a, I...)...
133+
)
119134
return a
120135
end
121136
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
45
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
56
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
67
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"

test/basics/test_basics.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ arrayts = (Array, JLArray)
5757
for I in ((1, 2), (CartesianIndex(1, 2),))
5858
b = copy(a)
5959
value = randn(elt)
60-
@allowscalar setunstoredindex!(b, value, I...)
61-
@allowscalar b[I...] == value
60+
@test_throws ErrorException setunstoredindex!(b, value, I...)
6261
end
6362
end

0 commit comments

Comments
 (0)