Skip to content

Commit 9b21f20

Browse files
authored
Merge pull request #177 from Cthonios/sparse-more-testing
a bunch of re-work for poisson regression tests.
2 parents 7959823 + 00af7aa commit 9b21f20

13 files changed

+183
-431
lines changed

src/DofManagers.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,9 @@ function update_field_unknowns!(
212212
_update_field_unknowns!(U, dof, Uu, backend)
213213
return nothing
214214
end
215+
216+
function update_field_unknowns!(
217+
U::F, dof::DofManager, Uu::F
218+
) where F <: AbstractField
219+
update_field_unknowns!(U, dof, Uu.data)
220+
end

src/assemblers/Assemblers.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ end
215215
# end
216216
# end
217217

218+
"""
219+
$(TYPEDSIGNATURES)
220+
"""
221+
function hessian(asm::AbstractAssembler)
222+
return _hessian(asm, KA.get_backend(asm))
223+
end
224+
218225
# new approach requiring access to the v that makes Hv
219226
"""
220227
$(TYPEDSIGNATURES)

src/assemblers/SparseMatrixAssembler.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,21 @@ function _adjust_matrix_entries_for_constraints(
187187
return nothing
188188
end
189189

190+
function _hessian(assembler::SparseMatrixAssembler, ::KA.CPU)
191+
H = SparseArrays.sparse!(assembler.pattern, assembler.hessian_storage)
192+
193+
if _is_condensed(assembler.dof)
194+
_adjust_matrix_entries_for_constraints!(H, assembler.constraint_storage, KA.get_backend(assembler))
195+
end
196+
197+
return H
198+
end
199+
190200
function _mass(assembler::SparseMatrixAssembler, ::KA.CPU)
191201
M = SparseArrays.sparse!(assembler.pattern, assembler.mass_storage)
192202

193203
if _is_condensed(assembler.dof)
194-
_adjust_matrix_entries_for_constraints!(M, assembler.constraint_storage, KA.get_getbackend(assembler))
204+
_adjust_matrix_entries_for_constraints!(M, assembler.constraint_storage, KA.get_backend(assembler))
195205
end
196206

197207
return M

src/bcs/BoundaryConditions.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,33 @@ function BCBookKeeping(
9191
)
9292
end
9393

94+
function BCBookKeeping(mesh, sset_name::Symbol)
95+
# get sset specific fields
96+
elements = getproperty(mesh.sideset_elems, sset_name)
97+
nodes = getproperty(mesh.sideset_nodes, sset_name)
98+
sides = getproperty(mesh.sideset_sides, sset_name)
99+
side_nodes = getproperty(mesh.sideset_side_nodes, sset_name)
100+
101+
blocks = Vector{Int64}(undef, 0)
102+
103+
# gather the blocks that are present in this sideset
104+
# TODO this isn't quite right
105+
for (n, val) in enumerate(values(mesh.element_id_maps))
106+
# note these are the local elem id to the block, e.g. starting from 1.
107+
indices_in_sset = indexin(val, elements)
108+
filter!(x -> x !== nothing, indices_in_sset)
109+
110+
if length(indices_in_sset) > 0
111+
append!(blocks, repeat([n], length(indices_in_sset)))
112+
end
113+
end
114+
115+
dofs = Vector{Int64}(undef, 0)
116+
return BCBookKeeping(
117+
blocks, dofs, elements, nodes, sides, side_nodes
118+
)
119+
end
120+
94121
"""
95122
$(TYPEDSIGNATURES)
96123
"""

src/bcs/DirichletBCs.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ function _update_field_dirichlet_bcs!(U, bc::DirichletBCContainer, ::KA.CPU)
140140
return nothing
141141
end
142142

143+
function _update_field_dirichlet_bcs!(U, V, bc::DirichletBCContainer, ::KA.CPU)
144+
for (dof, val, val_dot) in zip(bc.bookkeeping.dofs, bc.vals, bc.vals_dot)
145+
U[dof] = val
146+
V[dof] = val_dot
147+
end
148+
return nothing
149+
end
150+
143151
function _update_field_dirichlet_bcs!(U, V, A, bc::DirichletBCContainer, ::KA.CPU)
144152
for (dof, val, val_dot, val_dot_dot) in zip(bc.bookkeeping.dofs, bc.vals, bc.vals_dot, bc.vals_dot_dot)
145153
U[dof] = val
@@ -188,6 +196,12 @@ function update_field_dirichlet_bcs!(U, bcs::NamedTuple)
188196
return nothing
189197
end
190198

199+
function update_field_dirichlet_bcs!(U, V, bcs::NamedTuple)
200+
for bc in values(bcs)
201+
_update_field_dirichlet_bcs!(U, V, bc, KA.get_backend(bc))
202+
end
203+
end
204+
191205
function update_field_dirichlet_bcs!(U, V, A, bcs::NamedTuple)
192206
for bc in values(bcs)
193207
_update_field_dirichlet_bcs!(U, V, A, bc, KA.get_backend(bc))

src/fields/L2QuadratureField.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ end
3737
# setindex!(field.vals, v, (e - 1) * NQ + (q - 1) * NF + n)
3838
# end
3939

40+
function Base.setindex!(field::L2QuadratureField{T, D, NF, NQ}, v, n::Int, q::Int, e::Int) where {T, D, NF, NQ}
41+
setindex!(field.data, v, (e - 1) * NQ + (q - 1) * NF + n)
42+
return nothing
43+
end
44+
4045
function Base.size(field::L2QuadratureField{T, D, NF, NQ}) where {T, D, NF, NQ}
4146
if NF == 0
4247
(NF, NQ, length(field.data) ÷ NQ)

test/poisson/TestPoisson.jl

Lines changed: 89 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ include("TestPoissonCommon.jl")
1818

1919
# read mesh and relevant quantities
2020

21-
function test_poisson_direct(use_condensed)
21+
function test_poisson_dirichlet(
22+
dev, use_condensed,
23+
nsolver, lsolver
24+
)
2225
mesh = UnstructuredMesh(mesh_file)
2326
V = FunctionSpace(mesh, H1Field, Lagrange)
2427
physics = Poisson()
@@ -34,29 +37,41 @@ function test_poisson_direct(use_condensed)
3437
DirichletBC(:u, :sset_4, bc_func),
3538
]
3639

37-
# direct solver test
3840
# setup the parameters
39-
@show p = create_parameters(mesh, asm, physics, props; dirichlet_bcs=dbcs)
41+
p = create_parameters(mesh, asm, physics, props; dirichlet_bcs=dbcs)
42+
43+
if dev != cpu
44+
p = p |> dev
45+
asm = asm |> dev
46+
end
4047

4148
# setup solver and integrator
42-
solver = NewtonSolver(DirectLinearSolver(asm))
49+
solver = nsolver(lsolver(asm))
4350
integrator = QuasiStaticIntegrator(solver)
4451
evolve!(integrator, p)
4552

53+
if dev != cpu
54+
p = p |> cpu
55+
end
56+
57+
U = p.h1_field
58+
4659
pp = PostProcessor(mesh, output_file, u)
4760
write_times(pp, 1, 0.0)
48-
write_field(pp, 1, ("u",), p.h1_field)
61+
write_field(pp, 1, ("u",), U)
4962
close(pp)
5063

5164
if !Sys.iswindows()
5265
@test exodiff(output_file, gold_file)
5366
end
5467
rm(output_file; force=true)
5568
display(solver.timer)
56-
5769
end
5870

59-
function test_poisson_direct_neumman(use_condensed)
71+
function test_poisson_neumann(
72+
dev, use_condensed,
73+
nsolver, lsolver
74+
)
6075
mesh = UnstructuredMesh(mesh_file)
6176
V = FunctionSpace(mesh, H1Field, Lagrange)
6277
physics = Poisson()
@@ -77,123 +92,86 @@ function test_poisson_direct_neumman(use_condensed)
7792

7893
# direct solver test
7994
# setup the parameters
80-
@show p = create_parameters(
95+
p = create_parameters(
8196
mesh, asm, physics, props;
8297
dirichlet_bcs=dbcs,
8398
neumann_bcs=nbcs
8499
)
85100

101+
if dev != cpu
102+
p = p |> dev
103+
asm = asm |> dev
104+
end
105+
86106
# setup solver and integrator
87-
solver = NewtonSolver(DirectLinearSolver(asm))
107+
solver = nsolver(lsolver(asm))
88108
integrator = QuasiStaticIntegrator(solver)
89109
evolve!(integrator, p)
90110

91-
pp = PostProcessor(mesh, output_file, u)
92-
write_times(pp, 1, 0.0)
93-
write_field(pp, 1, ("u",), p.h1_field)
94-
close(pp)
95-
96-
if !Sys.iswindows()
97-
@test exodiff(output_file, gold_file)
111+
if dev != cpu
112+
p = p |> cpu
98113
end
99-
rm(output_file; force=true)
100-
display(solver.timer)
101-
102-
end
103-
104-
function test_poisson_iterative(use_condensed)
105-
mesh = UnstructuredMesh(mesh_file)
106-
V = FunctionSpace(mesh, H1Field, Lagrange)
107-
physics = Poisson()
108-
props = create_properties(physics)
109-
u = ScalarFunction(V, :u)
110-
asm = SparseMatrixAssembler(u; use_condensed)
111-
112-
# setup and update bcs
113-
dbcs = DirichletBC[
114-
DirichletBC(:u, :sset_1, bc_func),
115-
DirichletBC(:u, :sset_2, bc_func),
116-
DirichletBC(:u, :sset_3, bc_func),
117-
DirichletBC(:u, :sset_4, bc_func),
118-
]
119-
120-
# iterative solver test
121-
# setup the parameters
122-
p = create_parameters(mesh, asm, physics, props; dirichlet_bcs=dbcs)
123-
124-
# setup solver and integrator
125-
solver = NewtonSolver(IterativeLinearSolver(asm, :CgSolver))
126-
integrator = QuasiStaticIntegrator(solver)
127-
@time evolve!(integrator, p)
128114

129-
display(solver.timer)
115+
# TODO make a neumann gold file
116+
# U = p.h1_field
130117

131-
pp = PostProcessor(mesh, output_file, u)
132-
write_times(pp, 1, 0.0)
133-
write_field(pp, 1, ("u",), p.h1_field)
134-
close(pp)
118+
# pp = PostProcessor(mesh, output_file, u)
119+
# write_times(pp, 1, 0.0)
120+
# write_field(pp, 1, ("u",), U)
121+
# close(pp)
135122

136-
if !Sys.iswindows()
137-
@test exodiff(output_file, gold_file)
138-
end
139-
rm(output_file; force=true)
123+
# if !Sys.iswindows()
124+
# @test exodiff(output_file, gold_file)
125+
# end
126+
# rm(output_file; force=true)
140127
display(solver.timer)
141128
end
142129

143-
@time test_poisson_direct(false)
144-
@time test_poisson_direct(false)
145-
@time test_poisson_direct(true)
146-
@time test_poisson_direct(true)
147-
# @time test_poisson_direct_neumman()
148-
# @time test_poisson_direct_neumman()
149-
@time test_poisson_iterative(false)
150-
@time test_poisson_iterative(false)
151-
@time test_poisson_iterative(true)
152-
@time test_poisson_iterative(true)
153-
154-
# # condensed test
155-
# mesh = UnstructuredMesh(mesh_file)
156-
# V = FunctionSpace(mesh, H1Field, Lagrange)
157-
# physics = Poisson()
158-
# u = ScalarFunction(V, :u)
159-
# asm = SparseMatrixAssembler(H1Field, u)
160-
# # pp = PostProcessor(mesh, output_file, u)
161-
162-
# # setup and update bcs
163-
# dbcs = DirichletBC[
164-
# DirichletBC(asm.dof, :u, :sset_1, bc_func),
165-
# DirichletBC(asm.dof, :u, :sset_2, bc_func),
166-
# DirichletBC(asm.dof, :u, :sset_3, bc_func),
167-
# DirichletBC(asm.dof, :u, :sset_4, bc_func),
168-
# ]
169-
# update_dofs!(asm, dbcs; use_condensed=true)
170-
# Uu = create_unknowns(asm)
171-
# Ubc = create_bcs(asm, H1Field)
172-
# U = create_field(asm, H1Field)
173-
# update_field!(U, asm, Uu, Ubc)
174-
# update_field_bcs!(U, asm.dof, dbcs, 0.)
175-
# assemble!(asm, physics, U, :residual_and_stiffness)
176-
# K = stiffness(asm)
177-
# G = constraint_matrix(asm)
178-
# # @time H = (G + I) * K
179-
# K[asm.dof.H1_bc_dofs, asm.dof.H1_bc_dofs] .= 1.
180-
# # R = G * residual(asm)
181-
# # R = G * asm.residual_storage.vals
182-
# R = asm.residual_storage
183-
# R[asm.dof.H1_bc_dofs] .= 0.
184-
# ΔUu = -K \ R.vals
185-
# U.vals .= U.vals .+ ΔUu
186-
# assemble!(asm, physics, U, :residual_and_stiffness)
187-
# K = stiffness(asm)
188-
# G = constraint_matrix(asm)
189-
# # @time H = (G + I) * K
190-
# K[asm.dof.H1_bc_dofs, asm.dof.H1_bc_dofs] .= 1.
191-
# # R = G * residual(asm)
192-
# # R = G * asm.residual_storage.vals
193-
# R = asm.residual_storage
194-
# R[asm.dof.H1_bc_dofs] .= 0.
195-
# ΔUu = -K \ R.vals
196-
# U.vals .= U.vals .+ ΔUu
197-
# U
198-
# # @time H = G * K + G * I
199-
# # @time H = G * K
130+
# function test_poisson_iterative(dev, use_condensed)
131+
# mesh = UnstructuredMesh(mesh_file)
132+
# V = FunctionSpace(mesh, H1Field, Lagrange)
133+
# physics = Poisson()
134+
# props = create_properties(physics)
135+
# u = ScalarFunction(V, :u)
136+
# asm = SparseMatrixAssembler(u; use_condensed)
137+
138+
# # setup and update bcs
139+
# dbcs = DirichletBC[
140+
# DirichletBC(:u, :sset_1, bc_func),
141+
# DirichletBC(:u, :sset_2, bc_func),
142+
# DirichletBC(:u, :sset_3, bc_func),
143+
# DirichletBC(:u, :sset_4, bc_func),
144+
# ]
145+
146+
# # iterative solver test
147+
# # setup the parameters
148+
# p = create_parameters(mesh, asm, physics, props; dirichlet_bcs=dbcs)
149+
150+
# if dev != cpu
151+
# p = p |> dev
152+
# asm = asm |> dev
153+
# end
154+
155+
# # setup solver and integrator
156+
# solver = NewtonSolver(IterativeLinearSolver(asm, :CgSolver))
157+
# integrator = QuasiStaticIntegrator(solver)
158+
# @time evolve!(integrator, p)
159+
160+
# display(solver.timer)
161+
162+
# if dev != cpu
163+
# p = p |> cpu
164+
# end
165+
166+
# U = p.h1_field
167+
# pp = PostProcessor(mesh, output_file, u)
168+
# write_times(pp, 1, 0.0)
169+
# write_field(pp, 1, ("u",), U)
170+
# close(pp)
171+
172+
# if !Sys.iswindows()
173+
# @test exodiff(output_file, gold_file)
174+
# end
175+
# rm(output_file; force=true)
176+
# display(solver.timer)
177+
# end

0 commit comments

Comments
 (0)