Skip to content

Commit 1098d1c

Browse files
authored
Merge pull request #101 from Cthonios/physics/block-by-block
Physics/block by block
2 parents e7c1a79 + 97719ae commit 1098d1c

File tree

12 files changed

+212
-67
lines changed

12 files changed

+212
-67
lines changed

ext/FiniteElementContainersAdaptExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ function Adapt.adapt_structure(to, field::L2ElementField{T, NF, V, S}) where {T,
106106
return L2ElementField{T, NF, typeof(vals), S}(vals)
107107
end
108108

109+
function Adapt.adapt_structure(to, field::L2QuadratureField{T, NF, NQ, V, S}) where {T, NF, NQ, V, S}
110+
vals = adapt(to, field.vals)
111+
return L2QuadratureField{T, NF, NQ, typeof(vals), S}(vals)
112+
end
113+
109114
function Adapt.adapt_structure(to, field::H1Field{T, NF, V, S}) where {T, NF, V, S}
110115
vals = adapt(to, field.vals)
111116
return H1Field{T, NF, typeof(vals), S}(vals)

src/Parameters.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
abstract type AbstractParameters end
22

3+
# TODO need to break up bcs to different field types
34
struct Parameters{D, N, T, Phys, Props, S, V, H1} <: AbstractParameters
45
# actual parameter fields
56
# TODO add boundary condition stuff and time stepping stuff
@@ -12,7 +13,7 @@ struct Parameters{D, N, T, Phys, Props, S, V, H1} <: AbstractParameters
1213
state_old::S
1314
state_new::S
1415
# scratch fields
15-
h1_dbcs::V
16+
h1_dbcs::V # TODO remove this
1617
h1_field::H1
1718
end
1819

@@ -34,8 +35,6 @@ function Parameters(
3435

3536
# TODO
3637
properties = nothing
37-
state_old = nothing
38-
state_new = nothing
3938

4039
# for mixed spaces we'll need to do this more carefully
4140
if isa(physics, AbstractPhysics)
@@ -47,6 +46,25 @@ function Parameters(
4746
# TODO re-arrange physics tuple to match fspaces when appropriate
4847
end
4948

49+
# state_old = Array{Float64, 3}[]
50+
state_old = L2QuadratureField[]
51+
for (key, val) in pairs(physics)
52+
NS = num_states(val)
53+
NQ = ReferenceFiniteElements.num_quadrature_points(
54+
getfield(values(assembler.dof.H1_vars)[1].fspace.ref_fes, key)
55+
)
56+
NE = size(
57+
getfield(values(assembler.dof.H1_vars)[1].fspace.elem_conns, key),
58+
2
59+
)
60+
syms = tuple(map(x -> Symbol("state_variable_$x"), 1:NS)...)
61+
state_old_temp = L2QuadratureField(zeros(NS, NQ, NE), syms)
62+
push!(state_old, state_old_temp)
63+
end
64+
state_new = copy(state_old)
65+
state_old = NamedTuple{keys(physics)}(tuple(state_old...))
66+
state_new = NamedTuple{keys(physics)}(tuple(state_new...))
67+
5068
if dirichlet_bcs !== nothing
5169
syms = map(x -> Symbol("dirichlet_bc_$x"), 1:length(dirichlet_bcs))
5270
# dbcs = NamedTuple{tuple(syms...)}(tuple(dbcs...))
@@ -82,6 +100,11 @@ function Parameters(
82100

83101
update_dofs!(assembler, p)
84102

103+
# assemble the stiffness at least once for
104+
# making easier to use on GPU
105+
assemble!(assembler, H1Field, p, :stiffness)
106+
K = stiffness(assembler)
107+
85108
return p
86109
end
87110

src/assemblers/Assemblers.jl

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,107 @@ function _assemble_element!(global_val::H1Field, local_val, conn, e, b)
3838
return nothing
3939
end
4040

41-
_assemble_block_method_from_sym(::Val{:mass}) = _assemble_block_mass!
42-
_assemble_block_method_from_sym(::Val{:residual}) = _assemble_block_residual!
43-
_assemble_block_method_from_sym(::Val{:residual_and_stiffness}) = _assemble_block_residual_and_stiffness!
44-
_assemble_block_method_from_sym(::Val{:stiffness}) = _assemble_block_stiffness!
45-
46-
function _check_backends(assembler, U, X, conns)
41+
function _check_backends(assembler, U, X, state_old, state_new, conns)
4742
backend = KA.get_backend(assembler)
4843
# TODO add get_backend method of ref_fe
4944
@assert backend == KA.get_backend(U)
5045
@assert backend == KA.get_backend(X)
5146
@assert backend == KA.get_backend(conns)
47+
@assert backend == KA.get_backend(state_old)
48+
@assert backend == KA.get_backend(state_new)
5249
return backend
5350
end
5451

52+
function assemble!(assembler, ::Type{H1Field}, p, val_sym::Val{:mass})
53+
_zero_storage(assembler, val_sym)
54+
fspace = assembler.dof.H1_vars[1].fspace
55+
for (b, (conns, block_physics, state_old, state_new)) in enumerate(zip(
56+
values(fspace.elem_conns),
57+
values(p.physics),
58+
values(p.state_old), values(p.state_new)
59+
))
60+
ref_fe = values(fspace.ref_fes)[b]
61+
backend = _check_backends(assembler, p.h1_field, fspace.coords, state_old, state_new, conns)
62+
_assemble_block_mass!(
63+
assembler, block_physics, ref_fe,
64+
p.h1_field, fspace.coords, state_old, state_new,
65+
conns, b,
66+
backend
67+
)
68+
end
69+
end
70+
71+
function assemble!(assembler, ::Type{H1Field}, p, val_sym::Val{:residual})
72+
_zero_storage(assembler, val_sym)
73+
fspace = assembler.dof.H1_vars[1].fspace
74+
for (b, (conns, block_physics, state_old, state_new)) in enumerate(zip(
75+
values(fspace.elem_conns),
76+
values(p.physics),
77+
values(p.state_old), values(p.state_new)
78+
))
79+
ref_fe = values(fspace.ref_fes)[b]
80+
backend = _check_backends(assembler, p.h1_field, fspace.coords, state_old, state_new, conns)
81+
_assemble_block_residual!(
82+
assembler, block_physics, ref_fe,
83+
p.h1_field, fspace.coords, state_old, state_new,
84+
conns, b,
85+
backend
86+
)
87+
end
88+
end
89+
90+
function assemble!(assembler, ::Type{H1Field}, p, val_sym::Val{:residual_and_stiffness})
91+
_zero_storage(assembler, val_sym)
92+
fspace = assembler.dof.H1_vars[1].fspace
93+
for (b, (conns, block_physics, state_old, state_new)) in enumerate(zip(
94+
values(fspace.elem_conns),
95+
values(p.physics),
96+
values(p.state_old), values(p.state_new)
97+
))
98+
ref_fe = values(fspace.ref_fes)[b]
99+
backend = _check_backends(assembler, p.h1_field, fspace.coords, state_old, state_new, conns)
100+
_assemble_block_residual_and_stiffness!(
101+
assembler, block_physics, ref_fe,
102+
p.h1_field, fspace.coords, state_old, state_new,
103+
conns, b,
104+
backend
105+
)
106+
end
107+
end
108+
109+
function assemble!(assembler, ::Type{H1Field}, p, val_sym::Val{:stiffness})
110+
_zero_storage(assembler, val_sym)
111+
fspace = assembler.dof.H1_vars[1].fspace
112+
for (b, (conns, block_physics, state_old, state_new)) in enumerate(zip(
113+
values(fspace.elem_conns),
114+
values(p.physics),
115+
values(p.state_old), values(p.state_new)
116+
))
117+
ref_fe = values(fspace.ref_fes)[b]
118+
backend = _check_backends(assembler, p.h1_field, fspace.coords, state_old, state_new, conns)
119+
_assemble_block_stiffness!(
120+
assembler, block_physics, ref_fe,
121+
p.h1_field, fspace.coords, state_old, state_new,
122+
conns, b,
123+
backend
124+
)
125+
end
126+
end
127+
55128
"""
56129
$(TYPEDSIGNATURES)
57130
Top level assembly method for ```H1Field``` that loops over blocks and dispatches
58131
to appropriate kernels based on sym.
59132
60133
TODO need to make sure at setup time that physics and elem_conns have the same
61134
values order. Otherwise, shenanigans.
135+
136+
TODO figure out how to do generated functions
137+
138+
creates one type instability from the Val
62139
"""
63-
function assemble!(assembler, physics, U::H1Field, sym)
64-
val_sym = Val{sym}()
65-
_assemble_block_method! = _assemble_block_method_from_sym(val_sym)
66-
_zero_storage(assembler, val_sym)
67-
fspace = assembler.dof.H1_vars[1].fspace
68-
X = fspace.coords
69-
for (b, (conns, block_physics)) in enumerate(zip(values(fspace.elem_conns), values(physics)))
70-
ref_fe = values(fspace.ref_fes)[b]
71-
backend = _check_backends(assembler, U, X, conns)
72-
_assemble_block_method!(assembler, block_physics, ref_fe, U, X, conns, b, backend)
73-
end
140+
function assemble!(assembler, type::Type{H1Field}, p, sym::Symbol)
141+
assemble!(assembler, type, p, Val{sym}())
74142
end
75143

76144
"""

src/assemblers/CPUGeneral.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ with no threading.
3333
TODO add state variables and physics properties
3434
TODO remove Float64 typing below for eventual unitful use
3535
"""
36-
function _assemble_block_residual!(assembler, physics, ref_fe, U, X, conns, block_id, ::KA.CPU)
36+
function _assemble_block_residual!(
37+
assembler, physics, ref_fe,
38+
U, X, state_old, state_new,
39+
conns, block_id, ::KA.CPU
40+
)
3741
ND = size(U, 1)
3842
NNPE = ReferenceFiniteElements.num_vertices(ref_fe)
3943
NxNDof = NNPE * ND

src/assemblers/GPUGeneral.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
$(TYPEDSIGNATURES)
55
Kernel for residual block assembly
66
"""
7-
KA.@kernel function _assemble_block_residual_kernel!(assembler, physics, ref_fe, U, X, conns, block_id)
7+
KA.@kernel function _assemble_block_residual_kernel!(
8+
assembler, physics, ref_fe,
9+
U, X, state_old, state_new,
10+
conns, block_id
11+
)
812
E = KA.@index(Global)
913

1014
ND = size(U, 1)
@@ -39,13 +43,25 @@ using KernelAbstractions and Atomix for eliminating race conditions
3943
4044
TODO add state variables and physics properties
4145
"""
42-
function _assemble_block_residual!(assembler, physics, ref_fe, U, X, conns, block_id, backend::KA.Backend)
46+
function _assemble_block_residual!(
47+
assembler, physics, ref_fe,
48+
U, X, state_old, state_new,
49+
conns, block_id, backend::KA.Backend
50+
)
4351
kernel! = _assemble_block_residual_kernel!(backend)
44-
kernel!(assembler, physics, ref_fe, U, X, conns, block_id, ndrange=size(conns, 2))
52+
kernel!(
53+
assembler, physics, ref_fe,
54+
U, X, state_old, state_new,
55+
conns, block_id, ndrange=size(conns, 2)
56+
)
4557
return nothing
4658
end
4759

48-
KA.@kernel function _assemble_block_stiffness_kernel!(assembler, physics, ref_fe, U, X, conns, block_id)
60+
KA.@kernel function _assemble_block_stiffness_kernel!(
61+
assembler, physics, ref_fe,
62+
U, X, state_old, state_new,
63+
conns, block_id
64+
)
4965
E = KA.@index(Global)
5066

5167
ND = size(U, 1)
@@ -80,9 +96,17 @@ using KernelAbstractions and Atomix for eliminating race conditions
8096
8197
TODO add state variables and physics properties
8298
"""
83-
function _assemble_block_stiffness!(assembler, physics, ref_fe, U, X, conns, block_id, backend::KA.Backend)
99+
function _assemble_block_stiffness!(
100+
assembler, physics, ref_fe,
101+
U, X, state_old, state_new,
102+
conns, block_id, backend::KA.Backend
103+
)
84104
kernel! = _assemble_block_stiffness_kernel!(backend)
85-
kernel!(assembler, physics, ref_fe, U, X, conns, block_id, ndrange=size(conns, 2))
105+
kernel!(
106+
assembler, physics, ref_fe,
107+
U, X, state_old, state_new,
108+
conns, block_id, ndrange=size(conns, 2)
109+
)
86110
return nothing
87111
end
88112

src/assemblers/SparseMatrixAssembler.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ with no threading.
9898
9999
TODO add state variables and physics properties
100100
"""
101-
function _assemble_block_mass!(assembler, physics, ref_fe, U, X, conns, block_id, ::KA.CPU)
101+
function _assemble_block_mass!(
102+
assembler, physics, ref_fe,
103+
U, X, state_old, state_new,
104+
conns, block_id, ::KA.CPU
105+
)
102106
ND = size(U, 1)
103107
NNPE = ReferenceFiniteElements.num_vertices(ref_fe)
104108
NxNDof = NNPE * ND
@@ -125,7 +129,11 @@ with no threading.
125129
126130
TODO add state variables and physics properties
127131
"""
128-
function _assemble_block_stiffness!(assembler, physics, ref_fe, U, X, conns, block_id, ::KA.CPU)
132+
function _assemble_block_stiffness!(
133+
assembler, physics, ref_fe,
134+
U, X, state_old, state_new,
135+
conns, block_id, ::KA.CPU
136+
)
129137
ND = size(U, 1)
130138
NNPE = ReferenceFiniteElements.num_vertices(ref_fe)
131139
NxNDof = NNPE * ND
@@ -153,7 +161,11 @@ with no threading.
153161
TODO add state variables and physics properties
154162
TODO remove Float64 typing below for eventual unitful use
155163
"""
156-
function _assemble_block_residual_and_stiffness!(assembler, physics, ref_fe, U, X, conns, block_id, ::KA.CPU)
164+
function _assemble_block_residual_and_stiffness!(
165+
assembler, physics, ref_fe,
166+
U, X, state_old, state_new,
167+
conns, block_id, ::KA.CPU
168+
)
157169
ND = size(U, 1)
158170
NNPE = ReferenceFiniteElements.num_vertices(ref_fe)
159171
NxNDof = NNPE * ND

src/fields/L2QuadratureField.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ struct L2QuadratureField{T, NF, NQ, Vals <: AbstractArray{T, 1}, SymIDMap} <: Ab
77
vals::Vals
88
end
99

10+
# currently has shenanigans when vals has zero fields
11+
# will create a field that is 0 x NQ x 0 when it should
12+
# be 0 x NQ x NF. Likely due to the vec(vals) call
1013
function L2QuadratureField(vals::A, syms) where A <: AbstractArray{<:Number, 3}
1114
NF, NQ = size(vals, 1), size(vals, 2)
1215
@assert length(syms) == NF
@@ -21,7 +24,11 @@ end
2124
Base.IndexStyle(::Type{<:L2QuadratureField}) = IndexLinear()
2225

2326
function Base.axes(field::L2QuadratureField{T, NF, NQ, V, SymIDMap}) where {T, NF, NQ, V, SymIDMap}
24-
NE = length(field.vals) ÷ NF ÷ NQ
27+
if NF == 0
28+
NE = length(field.vals) ÷ NQ
29+
else
30+
NE = length(field.vals) ÷ NF ÷ NQ
31+
end
2532
return (Base.OneTo(NF), Base.OneTo(NQ), Base.OneTo(NE))
2633
end
2734

@@ -50,7 +57,14 @@ Base.setindex!(field::L2QuadratureField, v, n::Int) = setindex!(field.vals, v, n
5057
function Base.setindex!(field::L2QuadratureField{T, NF, NQ, A, S}, v, n::Int, q::Int, e::Int) where {T, NF, NQ, A, S}
5158
setindex!(field.vals, v, (e - 1) * NQ + (q - 1) * NF + n)
5259
end
53-
Base.size(field::L2QuadratureField{T, NF, NQ, A, S}) where {T, NF, NQ, A, S} = (NF, NQ, length(field.vals) ÷ NF ÷ NQ)
60+
61+
function Base.size(field::L2QuadratureField{T, NF, NQ, A, S}) where {T, NF, NQ, A, S}
62+
if NF == 0
63+
(NF, NQ, length(field.vals) ÷ NQ)
64+
else
65+
(NF, NQ, length(field.vals) ÷ NF ÷ NQ)
66+
end
67+
end
5468

5569
function Base.view(field::L2QuadratureField, sym::Symbol)
5670
d = _sym_id_map(field, sym)

0 commit comments

Comments
 (0)