Skip to content

Commit cfc2e58

Browse files
authored
Improve type-inference in blockbandwidths for BlockBandedMatrix (#210)
1 parent 7467a90 commit cfc2e58

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/broadcast.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,13 @@ import BandedMatrices: _isweakzero
9696
function blockbandwidths(bc::Broadcasted)
9797
(a,b) = size(bc)
9898
bnds = (a-1,b-1)
99-
_isweakzero(bc.f, bc.args...) && return min.(bnds, max.(_broadcast_blockbandwidths.(Ref(bnds), bc.args, Ref(axes(bc)))...))
99+
if _isweakzero(bc.f, bc.args...)
100+
ax = axes(bc)
101+
t = map(bc.args) do x
102+
_broadcast_blockbandwidths(bnds, x, ax)
103+
end
104+
return min.(bnds, max.(t...))
105+
end
100106
bnds
101107
end
102108

test/test_broadcasting.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,14 @@ import Base: oneto
289289
@test C == A + A
290290
end
291291
end
292+
293+
@testset "blockbandwidths" begin
294+
B = BlockArray(ones(6,6), 1:3, 1:3)
295+
BB = BlockBandedMatrix(B, (1,1))
296+
bc = Broadcast.broadcasted(+, BB, BB)
297+
bbw = @inferred blockbandwidths(bc)
298+
@test bbw == blockbandwidths(BB)
299+
end
292300
end
293301

294302
end # module

0 commit comments

Comments
 (0)