Skip to content

Commit 2c1a33a

Browse files
authored
Merge pull request #174 from Cthonios/sparse-rework
reworking sparse matrix stuff to enforce constraints without having t…
2 parents 85add38 + 5e76cb3 commit 2c1a33a

File tree

9 files changed

+167
-77
lines changed

9 files changed

+167
-77
lines changed

ext/FiniteElementContainersAdaptExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function Adapt.adapt_structure(to, asm::SparseMatrixAssembler)
1717
adapt(to, asm.mass_storage),
1818
adapt(to, asm.residual_storage),
1919
adapt(to, asm.residual_unknowns),
20-
adapt(to, asm.scalar_quadarature_storage),
20+
adapt(to, asm.scalar_quadrature_storage),
2121
adapt(to, asm.stiffness_storage),
2222
adapt(to, asm.stiffness_action_storage),
2323
adapt(to, asm.stiffness_action_unknowns)

src/FiniteElementContainers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export assemble_matrix_action!
1313
export assemble_scalar!
1414
export assemble_stiffness!
1515
export assemble_vector!
16+
export assemble_vector_neumann_bc!
1617
export constraint_matrix
1718

1819
# BCs
@@ -149,6 +150,7 @@ using StaticArrays
149150
using Tensors
150151
using TimerOutputs
151152

153+
# hooks for extensions
152154
function cpu end
153155
function cuda end
154156
function rocm end

src/assemblers/Assemblers.jl

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@ $(TYPEDSIGNATURES)
88
"""
99
KA.get_backend(asm::AbstractAssembler) = KA.get_backend(asm.dof)
1010

11+
function _adjust_matrix_action_entries_for_condensed!(
12+
Av, constraint_storage, v, ::KA.CPU
13+
# TODO do we need a penalty scale here as well?
14+
)
15+
@assert length(Av) == length(constraint_storage)
16+
@assert length(v) == length(constraint_storage)
17+
# modify Av => (I - G) * Av + Gv
18+
# TODO is this the right thing to do? I think so...
19+
for i in 1:length(constraint_storage)
20+
@inbounds Av[i] = (1. - constraint_storage[i]) * Av[i] + constraint_storage[i] * v[i]
21+
end
22+
return nothing
23+
end
24+
25+
function _adjust_vector_entries_for_condensed!(b, constraint_storage, ::KA.CPU)
26+
@assert length(b) == length(constraint_storage)
27+
# modify b => (I - G) * b + (Gu - g)
28+
# but Gu = g, so we don't need that here
29+
# unless we want to modify this to support weakly
30+
# enforced BCs later
31+
for i in 1:length(constraint_storage)
32+
@inbounds b[i] = (1. - constraint_storage[i]) * b[i]
33+
end
34+
return nothing
35+
end
36+
1137
"""
1238
$(TYPEDSIGNATURES)
1339
Assembly method for an H1Field, e.g. internal force
@@ -139,13 +165,39 @@ end
139165
"""
140166
$(TYPEDSIGNATURES)
141167
"""
142-
function hvp(assembler::AbstractAssembler)
143-
extract_field_unknowns!(
144-
assembler.stiffness_action_unknowns,
145-
assembler.dof,
146-
assembler.stiffness_action_storage
147-
)
148-
return assembler.stiffness_action_unknowns
168+
function hvp(asm::AbstractAssembler)
169+
if _is_condensed(asm.dof)
170+
_adjust_matrix_action_entries_for_condensed!(
171+
asm.stiffness_action_storage, asm.constraint_storage,
172+
KA.get_backend(asm)
173+
)
174+
return asm.stiffness_action_storage.data
175+
else
176+
extract_field_unknowns!(
177+
asm.stiffness_action_unknowns,
178+
asm.dof,
179+
asm.stiffness_action_storage
180+
)
181+
return asm.stiffness_action_unknowns
182+
end
183+
end
184+
185+
# new approach requiring access to the v that makes Hv
186+
function hvp(asm::AbstractAssembler, v)
187+
if _is_condensed(asm.dof)
188+
_adjust_matrix_action_entries_for_condensed!(
189+
asm.stiffness_action_storage, asm.constraint_storage, v,
190+
KA.get_backend(asm)
191+
)
192+
return asm.stiffness_action_storage.data
193+
else
194+
extract_field_unknowns!(
195+
asm.stiffness_action_unknowns,
196+
asm.dof,
197+
asm.stiffness_action_storage
198+
)
199+
return asm.stiffness_action_unknowns
200+
end
149201
end
150202

151203
"""
@@ -157,14 +209,23 @@ end
157209

158210
"""
159211
$(TYPEDSIGNATURES)
212+
assumes assemble_vector! has already been called
160213
"""
161214
function residual(asm::AbstractAssembler)
162-
extract_field_unknowns!(
163-
asm.residual_unknowns,
164-
asm.dof,
165-
asm.residual_storage
166-
)
167-
return asm.residual_unknowns
215+
if _is_condensed(asm.dof)
216+
_adjust_vector_entries_for_condensed!(
217+
asm.residual_storage, asm.constraint_storage,
218+
KA.get_backend(asm)
219+
)
220+
return asm.residual_storage.data
221+
else
222+
extract_field_unknowns!(
223+
asm.residual_unknowns,
224+
asm.dof,
225+
asm.residual_storage
226+
)
227+
return asm.residual_unknowns
228+
end
168229
end
169230

170231
"""

src/assemblers/QuadratureQuantity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function assemble_scalar!(
55
assembler, func::F, Uu, p
66
) where F <: Function
77
assemble_quadrature_quantity!(
8-
assembler.scalar_quadarature_storage, assembler.dof,
8+
assembler.scalar_quadrature_storage, assembler.dof,
99
func, Uu, p
1010
)
1111
end

src/assemblers/SparseMatrixAssembler.jl

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct SparseMatrixAssembler{
1919
mass_storage::Storage1
2020
residual_storage::Storage2
2121
residual_unknowns::Storage1
22-
scalar_quadarature_storage::Storage3 # useful for energy like calculations
22+
scalar_quadrature_storage::Storage3 # useful for energy like calculations
2323
stiffness_storage::Storage1
2424
stiffness_action_storage::Storage2
2525
stiffness_action_unknowns::Storage1
@@ -110,48 +110,63 @@ function _assemble_element!(
110110
return nothing
111111
end
112112

113-
"""
114-
$(TYPEDSIGNATURES)
115-
TODO add symbol to interface
116-
"""
117-
function SparseArrays.sparse!(
118-
assembler::SparseMatrixAssembler, ::Val{:stiffness}
113+
# TODO this only work on CPU right now
114+
function _adjust_matrix_entries_for_condensed!(
115+
A::SparseMatrixCSC, constraint_storage, ::KA.CPU;
116+
penalty_scale = 1.e6
119117
)
120-
pattern = assembler.pattern
121-
storage = assembler.stiffness_storage
122-
return @views SparseArrays.sparse!(
123-
pattern.Is, pattern.Js, storage[assembler.pattern.unknown_dofs],
124-
length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
125-
pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
126-
pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
127-
)
128-
end
118+
# first ensure things are the right size
119+
@assert size(A, 1) == size(A, 2)
120+
@assert length(constraint_storage) == size(A, 2)
121+
122+
# hacky for now
123+
# need a penalty otherwise we get into trouble with
124+
# iterative linear solvers even for a simple poisson problem
125+
# TODO perhaps this should be optional somehow
126+
penalty = penalty_scale * tr(A) / size(A, 2)
127+
128+
# now modify A => (I - G) * A + G
129+
nz = nonzeros(A)
130+
rowval = rowvals(A)
131+
for j in 1:size(A, 2)
132+
col_start = A.colptr[j]
133+
col_end = A.colptr[j + 1] - 1
134+
for k in col_start:col_end
135+
# for (I - G) * A term
136+
nz[k] = (1. - constraint_storage[j]) * nz[k]
137+
138+
# for + G term
139+
if rowval[k] == j
140+
@inbounds nz[k] = nz[k] + penalty * constraint_storage[j]
141+
end
142+
end
143+
end
129144

130-
"""
131-
$(TYPEDSIGNATURES)
132-
TODO add symbol to interface
133-
"""
134-
function SparseArrays.sparse!(assembler::SparseMatrixAssembler, ::Val{:mass})
135-
pattern = assembler.pattern
136-
storage = assembler.mass_storage
137-
return @views SparseArrays.sparse!(
138-
pattern.Is, pattern.Js, storage[assembler.pattern.unknown_dofs],
139-
length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
140-
pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
141-
pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
142-
)
145+
return nothing
143146
end
144147

145148
function constraint_matrix(assembler::SparseMatrixAssembler)
146149
return Diagonal(assembler.constraint_storage)
147150
end
148151

149152
function _mass(assembler::SparseMatrixAssembler, ::KA.CPU)
150-
return SparseArrays.sparse!(assembler, Val{:mass}())
153+
M = SparseArrays.sparse!(assembler.pattern, assembler.mass_storage)
154+
155+
if _is_condensed(assembler.dof)
156+
_adjust_matrix_entries_for_condensed!(M, assembler.constraint_storage, KA.get_getbackend(assembler))
157+
end
158+
159+
return M
151160
end
152161

153162
function _stiffness(assembler::SparseMatrixAssembler, ::KA.CPU)
154-
return SparseArrays.sparse!(assembler, Val{:stiffness}())
163+
K = SparseArrays.sparse!(assembler.pattern, assembler.stiffness_storage)
164+
165+
if _is_condensed(assembler.dof)
166+
_adjust_matrix_entries_for_condensed!(K, assembler.constraint_storage, KA.get_backend(assembler))
167+
end
168+
169+
return K
155170
end
156171

157172
# TODO probably only works for H1 fields

src/assemblers/SparsityPattern.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ function SparsityPattern(dof::DofManager)
108108
)
109109
end
110110

111+
function SparseArrays.sparse!(pattern::SparsityPattern, storage)
112+
return @views SparseArrays.sparse!(
113+
pattern.Is, pattern.Js, storage[pattern.unknown_dofs],
114+
length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
115+
pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
116+
pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
117+
)
118+
end
119+
111120
num_entries(s::SparsityPattern) = length(s.Is)
112121

113122
# NOTE this methods assumes that dof is up to date

test/poisson/TestPoisson.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ include("TestPoissonCommon.jl")
1818

1919
# read mesh and relevant quantities
2020

21-
function test_poisson_direct()
21+
function test_poisson_direct(use_condensed)
2222
mesh = UnstructuredMesh(mesh_file)
2323
V = FunctionSpace(mesh, H1Field, Lagrange)
2424
physics = Poisson()
2525
props = create_properties(physics)
2626
u = ScalarFunction(V, :u)
27-
asm = SparseMatrixAssembler(u)
27+
asm = SparseMatrixAssembler(u; use_condensed=use_condensed)
2828

2929
# setup and update bcs
3030
dbcs = DirichletBC[
@@ -56,13 +56,13 @@ function test_poisson_direct()
5656

5757
end
5858

59-
function test_poisson_direct_neumman()
59+
function test_poisson_direct_neumman(use_condensed)
6060
mesh = UnstructuredMesh(mesh_file)
6161
V = FunctionSpace(mesh, H1Field, Lagrange)
6262
physics = Poisson()
6363
props = create_properties(physics)
6464
u = ScalarFunction(V, :u)
65-
asm = SparseMatrixAssembler(u)
65+
asm = SparseMatrixAssembler(u; use_condensed=use_condensed)
6666

6767
# setup and update bcs
6868
dbcs = DirichletBC[
@@ -101,13 +101,13 @@ function test_poisson_direct_neumman()
101101

102102
end
103103

104-
function test_poisson_iterative()
104+
function test_poisson_iterative(use_condensed)
105105
mesh = UnstructuredMesh(mesh_file)
106106
V = FunctionSpace(mesh, H1Field, Lagrange)
107107
physics = Poisson()
108108
props = create_properties(physics)
109109
u = ScalarFunction(V, :u)
110-
asm = SparseMatrixAssembler(u)
110+
asm = SparseMatrixAssembler(u; use_condensed)
111111

112112
# setup and update bcs
113113
dbcs = DirichletBC[
@@ -140,12 +140,16 @@ function test_poisson_iterative()
140140
display(solver.timer)
141141
end
142142

143-
@time test_poisson_direct()
144-
@time test_poisson_direct()
143+
@time test_poisson_direct(false)
144+
@time test_poisson_direct(false)
145+
@time test_poisson_direct(true)
146+
@time test_poisson_direct(true)
145147
# @time test_poisson_direct_neumman()
146148
# @time test_poisson_direct_neumman()
147-
@time test_poisson_iterative()
148-
@time test_poisson_iterative()
149+
@time test_poisson_iterative(false)
150+
@time test_poisson_iterative(false)
151+
@time test_poisson_iterative(true)
152+
@time test_poisson_iterative(true)
149153

150154
# # condensed test
151155
# mesh = UnstructuredMesh(mesh_file)

test/poisson/TestPoissonCondensedBCs.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,27 @@ function test_poisson_condensed_bcs()
3737
mesh, asm, physics, props;
3838
dirichlet_bcs=dbcs
3939
)
40-
Uu = create_unknowns(asm)
40+
# Uu = create_unknowns(asm)
4141

42-
FiniteElementContainers.update_bc_values!(p)
43-
for bc in values(p.dirichlet_bcs)
44-
FiniteElementContainers._update_field_dirichlet_bcs!(Uu, bc, KA.get_backend(bc))
45-
end
42+
# FiniteElementContainers.update_bc_values!(p)
43+
# for bc in values(p.dirichlet_bcs)
44+
# FiniteElementContainers._update_field_dirichlet_bcs!(Uu, bc, KA.get_backend(bc))
45+
# end
4646

47-
for n in 1:3
48-
assemble_stiffness!(asm, stiffness, Uu, p)
49-
assemble_vector!(asm, residual, Uu, p)
50-
G = constraint_matrix(asm)
51-
K = stiffness(asm)
52-
R = residual(asm)
47+
# for n in 1:3
48+
# assemble_stiffness!(asm, stiffness, Uu, p)
49+
# assemble_vector!(asm, residual, Uu, p)
50+
# K = stiffness(asm)
51+
# R = residual(asm)
5352

54-
R_s = (I - G) * R
55-
K_s = (I - G) * K + G
53+
# dUu = -K \ R
54+
# Uu = Uu + dUu
55+
# end
56+
# Uu
5657

57-
dUu = -K_s \ R_s
58-
Uu = Uu + dUu
59-
end
60-
Uu
58+
solver = NewtonSolver(DirectLinearSolver(asm))
59+
integrator = QuasiStaticIntegrator(solver)
60+
evolve!(integrator, p)
6161

6262
pp = PostProcessor(mesh, output_file, u)
6363
write_times(pp, 1, 0.0)
@@ -68,9 +68,8 @@ function test_poisson_condensed_bcs()
6868
@test exodiff(output_file, gold_file)
6969
end
7070
rm(output_file; force=true)
71-
# display(solver.timer)
72-
73-
Uu, p
71+
display(solver.timer)
7472
end
7573

7674
test_poisson_condensed_bcs()
75+
test_poisson_condensed_bcs()

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ include("TestPhysics.jl")
3232
@testset ExtendedTestSet "Poisson problem" begin
3333
include("poisson/TestPoisson.jl")
3434
include("poisson/TestPoissonNeumann.jl")
35-
include("poisson/TestPoissonCondensedBCs.jl")
35+
# include("poisson/TestPoissonCondensedBCs.jl")
3636
if AMDGPU.functional()
3737
include("poisson/TestPoissonAMDGPU.jl")
3838
include("poisson/TestPoissonNeumannAMDGPU.jl")

0 commit comments

Comments
 (0)