@@ -104,11 +104,10 @@ to_vec(x::AbstractArray) = vec(x)
104104
105105# TODO : This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
106106# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
107- @interface :: AbstractSparseArrayInterface function Base. similar (
108- a:: AbstractArray , T:: Type , size:: Tuple{Vararg{Int}}
109- )
110- # TODO : Define `default_similartype` or something like that?
111- return SparseArrayDOK {T} (size... )
107+ @interface I:: AbstractSparseArrayInterface function Base. similar (
108+ :: AbstractArray , :: Type{T} , ax
109+ ) where {T}
110+ return similar (I, T, ax)
112111end
113112
114113using ArrayLayouts: ArrayLayouts, zero!
@@ -117,13 +116,11 @@ using ArrayLayouts: ArrayLayouts, zero!
117116# and is useful for sparse array logic, since it can be used to empty
118117# the sparse array storage.
119118# We use a single function definition to minimize method ambiguities.
120- @interface interface:: AbstractSparseArrayInterface function ArrayLayouts . zero! (
121- a :: AbstractArray
119+ @interface interface:: AbstractSparseArrayInterface function DerivableInterfaces . zero! (
120+ A :: AbstractArray
122121)
123- # More generally, this codepath could be taking if `zero(eltype(a))`
124- # is defined and the elements are immutable.
125- f = eltype (a) <: Number ? Returns (zero (eltype (a))) : zero!
126- return @interface interface map_stored! (f, a, a)
122+ storedvalues (A) .= zero! (storedvalues (A))
123+ return A
127124end
128125
129126# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
150147 return output
151148end
152149
153- abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
154-
155- @derive (T= AbstractSparseArrayStyle,) begin
156- Base. similar (:: Broadcast.Broadcasted{<:T} , :: Type , :: Tuple )
157- Base. copyto! (:: AbstractArray , :: Broadcast.Broadcasted{<:T} )
158- end
159-
160- struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
161-
162- SparseArrayStyle {M} (:: Val{N} ) where {M,N} = SparseArrayStyle {N} ()
163-
164- DerivableInterfaces. interface (:: Type{<:AbstractSparseArrayStyle} ) = SparseArrayInterface ()
165-
166- @interface :: AbstractSparseArrayInterface function Broadcast. BroadcastStyle (type:: Type )
167- return SparseArrayStyle {ndims(type)} ()
168- end
169-
170150using ArrayLayouts: ArrayLayouts, MatMulMatAdd
171151
172152abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
@@ -190,19 +170,20 @@ using LinearAlgebra: LinearAlgebra, mul!
190170@interface :: AbstractSparseArrayInterface function LinearAlgebra. mul! (
191171 C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number , β:: Number
192172)
193- a_dest .*= β
173+ C .*= β
194174 β′ = one (Bool)
195- for I1 in eachstoredindex (a1)
196- for I2 in eachstoredindex (a2)
197- I_dest = mul_indices (I1, I2)
198- if ! isnothing (I_dest)
199- a_dest[I_dest] = mul! (a_dest[I_dest], a1[I1], a2[I2], α, β′)
200- end
175+ for iA in eachstoredindex (A), iB in eachstoredindex (B)
176+ iC = mul_indices (iA, iB)
177+ if ! isnothing (iC)
178+ C[iC] = mul!! (C[iC], A[iA], B[iB], α, β′)
201179 end
202180 end
203- return a_dest
181+ return C
204182end
205183
184+ mul!! (C, A, B, α, β) = mul! (C, A, B, α, β)
185+ mul!! (C:: Number , A:: Number , B:: Number , α:: Number , β:: Number ) = β * C + α * A * B
186+
206187function ArrayLayouts. materialize! (
207188 m:: MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
208189)
0 commit comments