Skip to content

Commit 6d2e03c

Browse files
committed
Make sparse array minimal interface into interface functions
1 parent 6e1a946 commit 6d2e03c

File tree

3 files changed

+97
-39
lines changed

3 files changed

+97
-39
lines changed

src/abstractsparsearrayinterface.jl

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
1+
using Derive: Derive, @derive, @interface, AbstractArrayInterface
2+
13
# This is to bring `ArrayLayouts.zero!` into the namespace
24
# since it is considered part of the sparse array interface.
35
using ArrayLayouts: zero!
46

7+
function eachstoredindex end
8+
function getstoredindex end
9+
function getunstoredindex end
10+
function isstored end
11+
function setstoredindex! end
12+
function setunstoredindex! end
13+
function storedlength end
14+
function storedpairs end
15+
function storedvalues end
16+
517
# Minimal interface for `SparseArrayInterface`.
6-
isstored(a::AbstractArray, I::Int...) = true
7-
eachstoredindex(a::AbstractArray) = eachindex(a)
8-
getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...)
9-
function setstoredindex!(a::AbstractArray, value, I::Int...)
18+
# Fallbacks for dense/non-sparse arrays.
19+
@interface ::AbstractArrayInterface isstored(a::AbstractArray, I::Int...) = true
20+
@interface ::AbstractArrayInterface eachstoredindex(a::AbstractArray) = eachindex(a)
21+
@interface ::AbstractArrayInterface getstoredindex(a::AbstractArray, I::Int...) =
22+
getindex(a, I...)
23+
@interface ::AbstractArrayInterface function setstoredindex!(
24+
a::AbstractArray, value, I::Int...
25+
)
1026
setindex!(a, value, I...)
1127
return a
1228
end
1329
# TODO: Should this error by default if the value at the index
1430
# is stored? It could be disabled with something analogous
1531
# to `checkbounds`, like `checkstored`/`checkunstored`.
16-
function setunstoredindex!(a::AbstractArray, value, I::Int...)
32+
@interface ::AbstractArrayInterface function setunstoredindex!(
33+
a::AbstractArray, value, I::Int...
34+
)
1735
setindex!(a, value, I...)
1836
return a
1937
end
@@ -36,11 +54,41 @@ end
3654
# Interface defaults.
3755
# TODO: Have a fallback that handles element types
3856
# that don't define `zero(::Type)`.
39-
getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
57+
@interface ::AbstractArrayInterface getunstoredindex(a::AbstractArray, I::Int...) =
58+
zero(eltype(a))
4059

4160
# Derived interface.
42-
storedlength(a::AbstractArray) = length(storedvalues(a))
43-
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
61+
@interface ::AbstractArrayInterface storedlength(a::AbstractArray) = length(storedvalues(a))
62+
@interface ::AbstractArrayInterface storedpairs(a::AbstractArray) =
63+
map(I -> I => getstoredindex(a, I), eachstoredindex(a))
64+
65+
@interface ::AbstractArrayInterface function eachstoredindex(as::AbstractArray...)
66+
return eachindex(as...)
67+
end
68+
69+
@interface ::AbstractArrayInterface storedvalues(a::AbstractArray) = a
70+
71+
# Automatically derive the interface for all `AbstractArray` subtypes.
72+
# TODO: Define `SparseArrayInterfaceOps` derivable trait and rewrite this
73+
# as `@derive AbstractArray SparseArrayInterfaceOps`.
74+
@derive (T=AbstractArray,) begin
75+
SparseArraysBase.eachstoredindex(::T)
76+
SparseArraysBase.eachstoredindex(::T...)
77+
SparseArraysBase.getstoredindex(::T, ::Int...)
78+
SparseArraysBase.getunstoredindex(::T, ::Int...)
79+
SparseArraysBase.isstored(::T, ::Int...)
80+
SparseArraysBase.setstoredindex!(::T, ::Any, ::Int...)
81+
SparseArraysBase.setunstoredindex!(::T, ::Any, ::Int...)
82+
SparseArraysBase.storedlength(::T)
83+
SparseArraysBase.storedpairs(::T)
84+
SparseArraysBase.storedvalues(::T)
85+
end
86+
87+
# TODO: Add `ndims` type parameter, like `Base.Broadcast.AbstractArrayStyle`.
88+
# TODO: This isn't used to define interface functions right now.
89+
# Currently, `@interface` expects an instance, probably it should take a
90+
# type instead so fallback functions can use abstract types.
91+
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end
4492

4593
to_vec(x) = vec(collect(x))
4694
to_vec(x::AbstractArray) = vec(x)
@@ -66,22 +114,14 @@ function Base.setindex!(a::StoredValues, value, I::Int)
66114
return setstoredindex!(a.array, value, a.storedindices[I])
67115
end
68116

69-
storedvalues(a::AbstractArray) = StoredValues(a)
117+
@interface ::AbstractSparseArrayInterface storedvalues(a::AbstractArray) = StoredValues(a)
70118

71-
function eachstoredindex(a1, a2, a_rest...)
119+
@interface ::AbstractSparseArrayInterface function eachstoredindex(as::AbstractArray...)
72120
# TODO: Make this more customizable, say with a function
73121
# `combine/promote_storedindices(a1, a2)`.
74-
return union(eachstoredindex.((a1, a2, a_rest...))...)
122+
return union(eachstoredindex.(as)...)
75123
end
76124

77-
using Derive: Derive, @derive, @interface, AbstractArrayInterface
78-
79-
# TODO: Add `ndims` type parameter.
80-
# TODO: This isn't used to define interface functions right now.
81-
# Currently, `@interface` expects an instance, probably it should take a
82-
# type instead so fallback functions can use abstract types.
83-
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end
84-
85125
# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing
86126
# indices and linear indices.
87127
@interface ::AbstractSparseArrayInterface function Base.getindex(

src/wrappers.jl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@ function eachstoredparentindex(a::SubArray)
1818
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
1919
end
2020
end
21-
function storedvalues(a::SubArray)
21+
@interface ::AbstractSparseArrayInterface function storedvalues(a::SubArray)
2222
# We use `StoredValues` rather than `@view`/`SubArray` so that
2323
# it gets interpreted as a dense array.
2424
return StoredValues(parent(a), collect(eachstoredparentindex(a)))
2525
end
26-
function isstored(a::SubArray, I::Int...)
26+
@interface ::AbstractSparseArrayInterface function isstored(a::SubArray, I::Int...)
2727
return isstored(parent(a), Base.reindex(parentindices(a), I)...)
2828
end
29-
function getstoredindex(a::SubArray, I::Int...)
29+
@interface ::AbstractSparseArrayInterface function getstoredindex(a::SubArray, I::Int...)
3030
return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
3131
end
32-
function getunstoredindex(a::SubArray, I::Int...)
32+
@interface ::AbstractSparseArrayInterface function getunstoredindex(a::SubArray, I::Int...)
3333
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
3434
end
35-
function eachstoredindex(a::SubArray)
35+
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::SubArray)
3636
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
3737
return !(parentindices(a)[d] isa Real)
3838
end
@@ -48,27 +48,36 @@ end
4848
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
4949
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
5050

51-
storedvalues(a::PermutedDimsArray) = storedvalues(parent(a))
52-
function isstored(a::PermutedDimsArray, I::Int...)
51+
@interface ::AbstractSparseArrayInterface storedvalues(a::PermutedDimsArray) =
52+
storedvalues(parent(a))
53+
@interface ::AbstractSparseArrayInterface function isstored(a::PermutedDimsArray, I::Int...)
5354
return isstored(parent(a), genperm(I, iperm(a))...)
5455
end
55-
function getstoredindex(a::PermutedDimsArray, I::Int...)
56+
@interface ::AbstractSparseArrayInterface function getstoredindex(
57+
a::PermutedDimsArray, I::Int...
58+
)
5659
return getstoredindex(parent(a), genperm(I, iperm(a))...)
5760
end
58-
function getunstoredindex(a::PermutedDimsArray, I::Int...)
61+
@interface ::AbstractSparseArrayInterface function getunstoredindex(
62+
a::PermutedDimsArray, I::Int...
63+
)
5964
return getunstoredindex(parent(a), genperm(I, iperm(a))...)
6065
end
61-
function setstoredindex!(a::PermutedDimsArray, value, I::Int...)
66+
@interface ::AbstractSparseArrayInterface function setstoredindex!(
67+
a::PermutedDimsArray, value, I::Int...
68+
)
6269
# TODO: Should this be `iperm(a)`?
6370
setstoredindex!(parent(a), value, genperm(I, perm(a))...)
6471
return a
6572
end
66-
function setunstoredindex!(a::PermutedDimsArray, value, I::Int...)
73+
@interface ::AbstractSparseArrayInterface function setunstoredindex!(
74+
a::PermutedDimsArray, value, I::Int...
75+
)
6776
# TODO: Should this be `iperm(a)`?
6877
setunstoredindex!(parent(a), value, genperm(I, perm(a))...)
6978
return a
7079
end
71-
function eachstoredindex(a::PermutedDimsArray)
80+
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::PermutedDimsArray)
7281
# TODO: Make lazy with `Iterators.map`.
7382
return map(collect(eachstoredindex(parent(a)))) do I
7483
return CartesianIndex(genperm(I, perm(a)))
@@ -78,25 +87,34 @@ end
7887
for (type, func) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
7988
@eval begin
8089
using LinearAlgebra: $type
81-
storedvalues(a::$type) = storedvalues(parent(a))
82-
function isstored(a::$type, i::Int, j::Int)
90+
@interface ::AbstractSparseArrayInterface storedvalues(a::$type) =
91+
storedvalues(parent(a))
92+
@interface ::AbstractSparseArrayInterface function isstored(a::$type, i::Int, j::Int)
8393
return isstored(parent(a), j, i)
8494
end
85-
function eachstoredindex(a::$type)
95+
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::$type)
8696
# TODO: Make lazy with `Iterators.map`.
8797
return map(cartesianindex_reverse, collect(eachstoredindex(parent(a))))
8898
end
89-
function getstoredindex(a::$type, i::Int, j::Int)
99+
@interface ::AbstractSparseArrayInterface function getstoredindex(
100+
a::$type, i::Int, j::Int
101+
)
90102
return $func(getstoredindex(parent(a), j, i))
91103
end
92-
function getunstoredindex(a::$type, i::Int, j::Int)
104+
@interface ::AbstractSparseArrayInterface function getunstoredindex(
105+
a::$type, i::Int, j::Int
106+
)
93107
return $func(getunstoredindex(parent(a), j, i))
94108
end
95-
function setstoredindex!(a::$type, value, i::Int, j::Int)
109+
@interface ::AbstractSparseArrayInterface function setstoredindex!(
110+
a::$type, value, i::Int, j::Int
111+
)
96112
setstoredindex!(parent(a), $func(value), j, i)
97113
return a
98114
end
99-
function setunstoredindex!(a::$type, value, i::Int, j::Int)
115+
@interface ::AbstractSparseArrayInterface function setunstoredindex!(
116+
a::$type, value, i::Int, j::Int
117+
)
100118
setunstoredindex!(parent(a), $func(value), j, i)
101119
return a
102120
end

test/basics/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ arrayts = (Array, JLArray)
3232
# TODO: We should be specializing these for dense/strided arrays,
3333
# probably we can have a trait for that. It could be based
3434
# on the `ArrayLayouts.MemoryLayout`.
35-
@allowscalar @test storedvalues(a) == vec(a)
35+
@allowscalar @test storedvalues(a) == a
3636
@allowscalar @test storedpairs(a) == collect(pairs(vec(a)))
3737
@allowscalar for I in eachindex(a)
3838
@test getstoredindex(a, I) == a[I]

0 commit comments

Comments
 (0)