@@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283283end
284284
285285# TensorMap multiplication
286- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
287- tA:: AbstractTensorMap ,
288- tB:: AbstractTensorMap , α= true , β= false )
286+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
287+ tB:: AbstractTensorMap ,
288+ α:: Number , β:: Number ,
289+ backend= TO. DefaultBackend ())
290+ if backend isa TO. DefaultBackend
291+ newbackend = TO. select_backend (mul!, tC, tA, tB)
292+ return mul! (tC, tA, tB, α, β, newbackend)
293+ elseif backend isa TO. NoBackend # error for missing backend
294+ TC = typeof (tC)
295+ TA = typeof (tA)
296+ TB = typeof (tB)
297+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
298+ else # error for unknown backend
299+ TC = typeof (tC)
300+ TA = typeof (tA)
301+ TB = typeof (tB)
302+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
303+ end
304+ end
305+
306+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
307+ B:: AbstractTensorMap )
308+ return SerialScheduler ()
309+ end
310+
311+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
312+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
313+ scheduler:: Union{Nothing,Scheduler} )
314+ if isnothing (scheduler)
315+ return sequential_mul! (tC, tA, tB, α, β)
316+ else
317+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
318+ end
319+ end
320+
321+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
322+ tB:: AbstractTensorMap , α:: Number , β:: Number )
289323 compose (space (tA), space (tB)) == space (tC) ||
290324 throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
291325
@@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325359 return tC
326360end
327361
328- # TODO : consider spawning threads for different blocks, support backends
362+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
363+ α:: Number , β:: Number , scheduler:: Scheduler )
364+ # obtain cached data before multithreading
365+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
366+
367+ tforeach (blocksectors (tC); scheduler) do c
368+ if haskey (bAs, c) # then also bBs should have it
369+ mul! (bCs[c], bAs[c], bBs[c], α, β)
370+ elseif ! isone (β)
371+ scale! (bCs[c], β)
372+ end
373+ end
374+
375+ return tC
376+ end
329377
330378# TensorMap inverse
331379function Base. inv (t:: AbstractTensorMap )
0 commit comments