Skip to content

Commit 2673552

Browse files
authored
Merge pull request #171 from Cthonios/dofmanager/condensed-bcs-initial-stab
initial stab at condensed bcs (maybe we need a better name) where the…
2 parents ee21851 + c4396f9 commit 2673552

File tree

6 files changed

+249
-82
lines changed

6 files changed

+249
-82
lines changed

ext/FiniteElementContainersAdaptExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,13 @@ function Adapt.adapt_structure(to, bc::FiniteElementContainers.NeumannBCContaine
8686
end
8787

8888
# DofManagers
89-
function Adapt.adapt_structure(to, dof::DofManager)
89+
function Adapt.adapt_structure(to, dof::DofManager{C, IT, IDs, Var}) where {C, IT, IDs, Var}
9090
dirichlet_dofs = adapt(to, dof.dirichlet_dofs)
9191
unknowns = adapt(to, dof.unknown_dofs)
9292
var = adapt(to, dof.var)
93-
return DofManager(dirichlet_dofs, unknowns, var)
93+
return DofManager{
94+
C, IT, typeof(dirichlet_dofs), typeof(var)
95+
}(dirichlet_dofs, unknowns, var)
9496
end
9597

9698
# Fields

src/DofManagers.jl

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ abstract type AbstractDofManager{
44
} end
55

66
struct DofManager{
7+
Condensed, # boolean flag for whether or not to seperate data between unknowns and constrained dofs
8+
# when creating unknowns
79
IT,
810
IDs <: AbstractArray{IT, 1},
911
Var <: AbstractFunction
@@ -15,13 +17,19 @@ end
1517

1618
# method that initializes dof manager
1719
# with all dofs unknown
18-
function DofManager(var::AbstractFunction)
20+
function DofManager(var::AbstractFunction; use_condensed::Bool = false)
1921
dirichlet_dofs = zeros(Int, 0)
2022
unknown_dofs = 1:size(var.fspace.coords, 2) * length(names(var)) |> collect
21-
return DofManager(dirichlet_dofs, unknown_dofs, var)
23+
return DofManager{
24+
use_condensed,
25+
eltype(dirichlet_dofs),
26+
typeof(dirichlet_dofs),
27+
typeof(var)
28+
}(dirichlet_dofs, unknown_dofs, var)
2229
end
2330

2431
_field_type(dof::DofManager) = eval(typeof(dof.var.fspace.coords).name.name)
32+
_is_condensed(dof::DofManager{C, IT, IDs, V}) where {C, IT, IDs, V} = C
2533

2634
Base.length(dof::DofManager) = length(dof.dirichlet_dofs) + length(dof.unknown_dofs)
2735

@@ -57,29 +65,69 @@ function create_field(dof::DofManager)
5765
return _field_type(dof)(field)
5866
end
5967

60-
function create_unknowns(dof::DofManager)
68+
function create_unknowns(dof::DofManager{false, IT, IDs, Var}) where {IT, IDs, Var}
6169
backend = KA.get_backend(dof)
6270
return KA.zeros(backend, Float64, length(dof.unknown_dofs))
6371
end
6472

73+
function create_unknowns(dof::DofManager{true, IT, IDs, Var}) where {IT, IDs, Var}
74+
backend = KA.get_backend(dof)
75+
return KA.zeros(backend, Float64, length(dof))
76+
end
77+
6578
# COV_EXCL_START
66-
KA.@kernel function _extract_field_unknowns_kernel!(Uu::V, dof::DofManager, U::AbstractField) where V <: AbstractVector{<:Number}
79+
KA.@kernel function _extract_field_unknowns_kernel!(
80+
Uu::V,
81+
dof::DofManager{false, IT, IDs, Var},
82+
U::AbstractField
83+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
6784
N = KA.@index(Global)
6885
@inbounds Uu[N] = U[dof.unknown_dofs[N]]
6986
end
7087
# COV_EXCL_STOP
7188

72-
function _extract_field_unknowns!(Uu::V, dof::DofManager, U::AbstractField, backend::KA.Backend) where V <: AbstractVector{<:Number}
89+
# COV_EXCL_START
90+
KA.@kernel function _extract_field_unknowns_kernel!(
91+
Uu::V,
92+
dof::DofManager{true, IT, IDs, Var},
93+
U::AbstractField
94+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
95+
N = KA.@index(Global)
96+
@inbounds Uu[dof.unknown_dofs[N]] = U[dof.unknown_dofs[N]]
97+
end
98+
# COV_EXCL_STOP
99+
100+
function _extract_field_unknowns!(
101+
Uu::V,
102+
dof::DofManager,
103+
U::AbstractField,
104+
backend::KA.Backend
105+
) where V <: AbstractVector{<:Number}
73106
kernel! = _extract_field_unknowns_kernel!(backend)
74107
kernel!(Uu, dof, U, ndrange = length(Uu))
75108
return nothing
76109
end
77110

78-
function _extract_field_unknowns!(Uu::V, dof::DofManager, U::AbstractField, ::KA.CPU) where V <: AbstractVector{<:Number}
111+
function _extract_field_unknowns!(
112+
Uu::V,
113+
dof::DofManager{false, IT, IDs, Var},
114+
U::AbstractField,
115+
::KA.CPU
116+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
79117
@views Uu .= U[dof.unknown_dofs]
80118
return nothing
81119
end
82120

121+
function _extract_field_unknowns!(
122+
Uu::V,
123+
dof::DofManager{true, IT, IDs, Var},
124+
U::AbstractField,
125+
::KA.CPU
126+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
127+
@views Uu[dof.unknown_dofs] .= U[dof.unknown_dofs]
128+
return nothing
129+
end
130+
83131
function extract_field_unknowns!(
84132
Uu::V,
85133
dof::DofManager,
@@ -96,7 +144,7 @@ function function_space(dof::DofManager)
96144
return dof.var.fspace
97145
end
98146

99-
function update_dofs!(dof::DofManager, dirichlet_dofs)
147+
function update_dofs!(dof::DofManager, dirichlet_dofs::V) where V <: AbstractArray{<:Integer, 1}
100148
ND, NI = size(dof)
101149
Base.resize!(dof.dirichlet_dofs, length(dirichlet_dofs))
102150
Base.resize!(dof.unknown_dofs, ND * NI)
@@ -114,24 +162,59 @@ function update_dofs!(dof::DofManager, dirichlet_dofs)
114162
end
115163

116164
# COV_EXCL_START
117-
KA.@kernel function _update_field_unknowns_kernel!(U::AbstractField, dof::DofManager, Uu::V) where V <: AbstractVector{<:Number}
165+
KA.@kernel function _update_field_unknowns_kernel!(
166+
U::AbstractField,
167+
dof::DofManager{false, IT, IDs, Var},
168+
Uu::V
169+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
118170
N = KA.@index(Global)
119171
@inbounds U[dof.unknown_dofs[N]] = Uu[N]
120172
end
121173
# COV_EXCL_STOP
174+
175+
# COV_EXCL_START
176+
KA.@kernel function _update_field_unknowns_kernel!(
177+
U::AbstractField,
178+
dof::DofManager{true, IT, IDs, Var},
179+
Uu::V
180+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
181+
N = KA.@index(Global)
182+
@inbounds U[dof.unknown_dofs[N]] = Uu[dof.unknown_dofs[N]]
183+
end
184+
# COV_EXCL_STOP
122185

123-
function _update_field_unknowns!(U::AbstractField, dof::DofManager, Uu::T, backend::KA.Backend) where T <: AbstractVector{<:Number}
186+
function _update_field_unknowns!(
187+
U::AbstractField,
188+
dof::DofManager,
189+
Uu::T,
190+
backend::KA.Backend
191+
) where T <: AbstractVector{<:Number}
124192
kernel! = _update_field_unknowns_kernel!(backend)
125193
kernel!(U, dof, Uu, ndrange = length(Uu))
126194
return nothing
127195
end
128196

129197
# Need a seperate CPU method since CPU is basically busted in KA
130-
function _update_field_unknowns!(U::AbstractField, dof::DofManager, Uu::T, ::KA.CPU) where T <: AbstractVector{<:Number}
198+
function _update_field_unknowns!(
199+
U::AbstractField,
200+
dof::DofManager{false, IT, IDs, Var},
201+
Uu::T,
202+
::KA.CPU
203+
) where {T <: AbstractVector{<:Number}, IT, IDs, Var}
131204
U[dof.unknown_dofs] .= Uu
132205
return nothing
133206
end
134207

208+
function _update_field_unknowns!(
209+
U::AbstractField,
210+
dof::DofManager{true, IT, IDs, Var},
211+
Uu::T,
212+
::KA.CPU
213+
) where {T <: AbstractVector{<:Number}, IT, IDs, Var}
214+
@views U[dof.unknown_dofs] .= Uu[dof.unknown_dofs]
215+
return nothing
216+
end
217+
135218
function update_field_unknowns!(
136219
U::AbstractField,
137220
dof::DofManager,
@@ -140,6 +223,7 @@ function update_field_unknowns!(
140223
backend = KA.get_backend(dof)
141224
@assert KA.get_backend(U) == backend
142225
@assert KA.get_backend(Uu) == backend
226+
# @assert length(dof.unknown_dofs) == length(Uu)
143227
_update_field_unknowns!(U, dof, Uu, backend)
144228
return nothing
145229
end

src/assemblers/SparseMatrixAssembler.jl

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,8 @@ function SparseMatrixAssembler(dof::DofManager)
7979
)
8080
end
8181

82-
# function SparseMatrixAssembler(::Type{<:H1Field}, vars...)
83-
# dof = DofManager(vars...)
84-
# return SparseMatrixAssembler(dof, H1Field)
85-
# end
86-
87-
function SparseMatrixAssembler(var::AbstractFunction)
88-
dof = DofManager(var)
82+
function SparseMatrixAssembler(var::AbstractFunction; use_condensed::Bool = false)
83+
dof = DofManager(var; use_condensed=use_condensed)
8984
return SparseMatrixAssembler(dof)
9085
end
9186

@@ -119,15 +114,27 @@ end
119114
$(TYPEDSIGNATURES)
120115
TODO add symbol to interface
121116
"""
122-
function SparseArrays.sparse!(assembler::SparseMatrixAssembler, ::Val{:stiffness})
117+
function SparseArrays.sparse!(
118+
assembler::SparseMatrixAssembler, ::Val{:stiffness};
119+
unknowns_only=true
120+
)
123121
pattern = assembler.pattern
124122
storage = assembler.stiffness_storage
125-
return @views SparseArrays.sparse!(
126-
pattern.Is, pattern.Js, storage[assembler.pattern.unknown_dofs],
127-
length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
128-
pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
129-
pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
130-
)
123+
if unknowns_only
124+
return @views SparseArrays.sparse!(
125+
pattern.Is, pattern.Js, storage[assembler.pattern.unknown_dofs],
126+
length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
127+
pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
128+
pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
129+
)
130+
# else
131+
# return SparseArrays.sparse!(
132+
# pattern.Is, pattern.Js, storage[assembler.pattern.unknown_dofs],
133+
# length(pattern.klasttouch), length(pattern.klasttouch), +, pattern.klasttouch,
134+
# pattern.csrrowptr, pattern.csrcolval, pattern.csrnzval,
135+
# pattern.csccolptr, pattern.cscrowval, pattern.cscnzval
136+
# )
137+
end
131138
end
132139

133140
"""
@@ -170,15 +177,9 @@ end
170177
# the residual and stiffness appropriately without having to reshape, Is, Js, etc.
171178
# when we want to change BCs which is slow
172179

173-
function update_dofs!(assembler::SparseMatrixAssembler, dirichlet_bcs; use_condensed=false)
174-
# vars = assembler.dof.H1_vars
175-
var = assembler.dof.var
180+
function update_dofs!(assembler::SparseMatrixAssembler, dirichlet_bcs)
181+
use_condensed = _is_condensed(assembler.dof)
176182

177-
# if length(vars) != 1
178-
# @assert false "multiple fspace not supported yet"
179-
# end
180-
181-
# dirichlet_dofs = dirichlet_bcs.bookkeeping.dofs
182183
if length(dirichlet_bcs) > 0
183184
dirichlet_dofs = mapreduce(x -> x.bookkeeping.dofs, vcat, dirichlet_bcs)
184185
dirichlet_dofs = unique(sort(dirichlet_dofs))
@@ -203,8 +204,8 @@ end
203204
# when we want to change BCs which is slow
204205

205206
function _update_dofs_condensed!(assembler::SparseMatrixAssembler)
206-
assembler.constraint_storage[assembler.dof.unknown_dofs] .= 1.
207-
assembler.constraint_storage[assembler.dof.dirichlet_bcs] .= 0.
207+
assembler.constraint_storage[assembler.dof.unknown_dofs] .= 0.
208+
assembler.constraint_storage[assembler.dof.dirichlet_dofs] .= 1.
208209
return nothing
209210
end
210211

@@ -214,55 +215,10 @@ end
214215
function _update_dofs!(assembler::SparseMatrixAssembler, dirichlet_dofs::T) where T <: AbstractArray{<:Integer, 1}
215216

216217
# resize the resiual unkowns
217-
# n_total_H1_dofs = num_nodes(assembler.dof) * num_dofs_per_node(assembler.dof)
218-
# n_unknown_dofs = length(assembler.dof.unknown_dofs)
219-
# resize!(assembler.residual_unknowns, length(assembler.dof.H1_unknown_dofs))
220-
# resize!(assembler.stiffness_action_unknowns, length(assembler.dof.H1_unknown_dofs))
221218
resize!(assembler.residual_unknowns, length(assembler.dof.unknown_dofs))
222219
resize!(assembler.stiffness_action_unknowns, length(assembler.dof.unknown_dofs))
223220

224-
# n_total_dofs = length(assembler.dof) - length(dirichlet_dofs)
225-
n_total_dofs = length(assembler.dof) - length(dirichlet_dofs)
226-
# n_total_dofs = n_unknown_dofs - length(dirichlet_dofs)
227-
228-
# TODO change to a good sizehint!
229-
resize!(assembler.pattern.Is, 0)
230-
resize!(assembler.pattern.Js, 0)
231-
resize!(assembler.pattern.unknown_dofs, 0)
232-
233-
# ND, NN = num_dofs_per_node(assembler.dof), num_nodes(assembler.dof)
234-
ND, NN = size(assembler.dof)
235-
# ids = reshape(1:length(assembler.dof), ND, NN)
236-
ids = reshape(1:length(assembler.dof), ND, NN)
237-
238-
# TODO
239-
# vars = assembler.dof.H1_vars
240-
# fspace = vars[1].fspace
241-
fspace = function_space(assembler.dof)
242-
243-
n = 1
244-
for conns in values(fspace.elem_conns)
245-
dof_conns = @views reshape(ids[:, conns], ND * size(conns, 1), size(conns, 2))
246-
247-
for e in 1:size(conns, 2)
248-
conn = @views dof_conns[:, e]
249-
for temp in Iterators.product(conn, conn)
250-
if insorted(temp[1], dirichlet_dofs) || insorted(temp[2], dirichlet_dofs)
251-
# really do nothing here
252-
else
253-
push!(assembler.pattern.Is, temp[1] - count(x -> x < temp[1], dirichlet_dofs))
254-
push!(assembler.pattern.Js, temp[2] - count(x -> x < temp[2], dirichlet_dofs))
255-
push!(assembler.pattern.unknown_dofs, n)
256-
end
257-
n += 1
258-
end
259-
end
260-
end
261-
262-
resize!(assembler.pattern.klasttouch, n_total_dofs)
263-
resize!(assembler.pattern.csrrowptr, n_total_dofs + 1)
264-
resize!(assembler.pattern.csrcolval, length(assembler.pattern.Is))
265-
resize!(assembler.pattern.csrnzval, length(assembler.pattern.Is))
221+
_update_dofs!(assembler.pattern, assembler.dof, dirichlet_dofs)
266222

267223
return nothing
268224
end

src/assemblers/SparsityPattern.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,51 @@ function SparsityPattern(dof::DofManager)
109109
end
110110

111111
num_entries(s::SparsityPattern) = length(s.Is)
112+
113+
# NOTE this methods assumes that dof is up to date
114+
# NOTE this method also only resizes unknown_dofs
115+
# in the pattern object, that means that things
116+
# like Is, Js, etc. need to be viewed into or sliced
117+
function _update_dofs!(pattern::SparsityPattern, dof, dirichlet_dofs)
118+
n_total_dofs = length(dof) - length(dirichlet_dofs)
119+
120+
# remove me
121+
resize!(pattern.Is, 0)
122+
resize!(pattern.Js, 0)
123+
# end remove me
124+
resize!(pattern.unknown_dofs, 0)
125+
ND, NN = size(dof)
126+
ids = reshape(1:length(dof), ND, NN)
127+
fspace = function_space(dof)
128+
129+
n = 1
130+
for conns in values(fspace.elem_conns)
131+
dof_conns = @views reshape(ids[:, conns], ND * size(conns, 1), size(conns, 2))
132+
133+
for e in 1:size(conns, 2)
134+
conn = @views dof_conns[:, e]
135+
for temp in Iterators.product(conn, conn)
136+
if insorted(temp[1], dirichlet_dofs) || insorted(temp[2], dirichlet_dofs)
137+
# really do nothing here
138+
else
139+
# remove me
140+
push!(pattern.Is, temp[1] - count(x -> x < temp[1], dirichlet_dofs))
141+
push!(pattern.Js, temp[2] - count(x -> x < temp[2], dirichlet_dofs))
142+
# end remove me
143+
push!(pattern.unknown_dofs, n)
144+
end
145+
n += 1
146+
end
147+
end
148+
end
149+
150+
resize!(pattern.klasttouch, n_total_dofs)
151+
resize!(pattern.csrrowptr, n_total_dofs + 1)
152+
# TODO Not sure about below 2 sizes
153+
# resize!(assembler.pattern.csrcolval, length(assembler.pattern.Is))
154+
# resize!(assembler.pattern.csrnzval, length(assembler.pattern.Is))
155+
resize!(pattern.csrcolval, length(pattern.Is))
156+
resize!(pattern.csrnzval, length(pattern.Is))
157+
158+
return nothing
159+
end

0 commit comments

Comments
 (0)