Skip to content

Commit b07536c

Browse files
authored
Merge pull request #175 from Cthonios/sparse-gpu-rework
reworking sparse implementation on GPU.
2 parents 2c1a33a + 00283ef commit b07536c

File tree

8 files changed

+132
-62
lines changed

8 files changed

+132
-62
lines changed

ext/FiniteElementContainersAMDGPUExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ function AMDGPU.rocSPARSE.ROCSparseMatrixCSC(asm::SparseMatrixAssembler)
2121
#
2222

2323
# n_dofs = FiniteElementContainers.num_unknowns(asm.dof)
24-
n_dofs = length(asm.dof.unknown_dofs)
24+
if FiniteElementContainers._is_condensed(asm.dof)
25+
n_dofs = length(asm.dof)
26+
else
27+
n_dofs = length(asm.dof.unknown_dofs)
28+
end
29+
2530
return AMDGPU.rocSPARSE.ROCSparseMatrixCSC(
2631
asm.pattern.csccolptr,
2732
asm.pattern.cscrowval,

ext/FiniteElementContainersCUDAExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ function CUDA.CUSPARSE.CuSparseMatrixCSC(asm::SparseMatrixAssembler)
1616
# they are doing the right thing
1717
# @assert length(asm.pattern.cscnzval) > 0 "Need to assemble the assembler once with SparseArrays.sparse!(assembler)"
1818
# @assert all(x -> x != zero(eltype(asm.pattern.cscnzval)), asm.pattern.cscnzval) "Need to assemble the assembler once with SparseArrays.sparse!(assembler)"
19-
n_dofs = FiniteElementContainers.num_unknowns(asm.dof)
19+
if FiniteElementContainers._is_condensed(asm.dof)
20+
n_dofs = length(asm.dof)
21+
else
22+
n_dofs = FiniteElementContainers.num_unknowns(asm.dof)
23+
end
24+
2025
return CUDA.CUSPARSE.CuSparseMatrixCSC(
2126
asm.pattern.csccolptr,
2227
asm.pattern.cscrowval,

src/DofManagers.jl

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,12 @@ KA.@kernel function _extract_field_unknowns_kernel!(
8686
end
8787
# COV_EXCL_STOP
8888

89-
# COV_EXCL_START
90-
KA.@kernel function _extract_field_unknowns_kernel!(
91-
Uu::V,
92-
dof::DofManager{true, IT, IDs, Var},
93-
U::AbstractField
94-
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
95-
N = KA.@index(Global)
96-
@inbounds Uu[dof.unknown_dofs[N]] = U[dof.unknown_dofs[N]]
97-
end
98-
# COV_EXCL_STOP
99-
10089
function _extract_field_unknowns!(
10190
Uu::V,
102-
dof::DofManager,
91+
dof::DofManager{false, IT, IDs, Var},
10392
U::AbstractField,
10493
backend::KA.Backend
105-
) where V <: AbstractVector{<:Number}
94+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
10695
kernel! = _extract_field_unknowns_kernel!(backend)
10796
kernel!(Uu, dof, U, ndrange = length(Uu))
10897
return nothing
@@ -118,21 +107,11 @@ function _extract_field_unknowns!(
118107
return nothing
119108
end
120109

121-
function _extract_field_unknowns!(
122-
Uu::V,
123-
dof::DofManager{true, IT, IDs, Var},
124-
U::AbstractField,
125-
::KA.CPU
126-
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
127-
@views Uu[dof.unknown_dofs] .= U[dof.unknown_dofs]
128-
return nothing
129-
end
130-
131110
function extract_field_unknowns!(
132111
Uu::V,
133-
dof::DofManager,
112+
dof::DofManager{false, IT, IDs, Var},
134113
U::AbstractField
135-
) where V <: AbstractVector{<:Number}
114+
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
136115
backend = KA.get_backend(dof)
137116
@assert KA.get_backend(U) == backend
138117
@assert KA.get_backend(Uu) == backend
@@ -168,7 +147,7 @@ KA.@kernel function _update_field_unknowns_kernel!(
168147
Uu::V
169148
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
170149
N = KA.@index(Global)
171-
@inbounds U[dof.unknown_dofs[N]] = Uu[N]
150+
@inbounds U.data[dof.unknown_dofs[N]] = Uu[N]
172151
end
173152
# COV_EXCL_STOP
174153

@@ -179,7 +158,7 @@ KA.@kernel function _update_field_unknowns_kernel!(
179158
Uu::V
180159
) where {V <: AbstractVector{<:Number}, IT, IDs, Var}
181160
N = KA.@index(Global)
182-
@inbounds U[dof.unknown_dofs[N]] = Uu[dof.unknown_dofs[N]]
161+
@inbounds U.data[dof.unknown_dofs[N]] = Uu[dof.unknown_dofs[N]]
183162
end
184163
# COV_EXCL_STOP
185164

@@ -190,7 +169,7 @@ function _update_field_unknowns!(
190169
backend::KA.Backend
191170
) where T <: AbstractVector{<:Number}
192171
kernel! = _update_field_unknowns_kernel!(backend)
193-
kernel!(U, dof, Uu, ndrange = length(Uu))
172+
kernel!(U, dof, Uu, ndrange = length(dof.unknown_dofs))
194173
return nothing
195174
end
196175

@@ -223,7 +202,13 @@ function update_field_unknowns!(
223202
backend = KA.get_backend(dof)
224203
@assert KA.get_backend(U) == backend
225204
@assert KA.get_backend(Uu) == backend
226-
# @assert length(dof.unknown_dofs) == length(Uu)
205+
206+
if _is_condensed(dof)
207+
@assert length(Uu) == length(U)
208+
else
209+
@assert length(Uu) == length(dof.unknown_dofs)
210+
end
211+
227212
_update_field_unknowns!(U, dof, Uu, backend)
228213
return nothing
229214
end

src/assemblers/Assemblers.jl

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ $(TYPEDSIGNATURES)
88
"""
99
KA.get_backend(asm::AbstractAssembler) = KA.get_backend(asm.dof)
1010

11-
function _adjust_matrix_action_entries_for_condensed!(
11+
function _adjust_matrix_action_entries_for_constraints!(
1212
Av, constraint_storage, v, ::KA.CPU
1313
# TODO do we need a penalty scale here as well?
1414
)
@@ -22,7 +22,25 @@ function _adjust_matrix_action_entries_for_condensed!(
2222
return nothing
2323
end
2424

25-
function _adjust_vector_entries_for_condensed!(b, constraint_storage, ::KA.CPU)
25+
KA.@kernel function _adjust_matrix_action_entries_for_constraints_kernel!(
26+
Av, constraint_storage, v
27+
)
28+
I = KA.@index(Global)
29+
# modify Av => (I - G) * Av + Gv
30+
@inbounds Av[I] = (1. - constraint_storage[I]) * Av[I] + constraint_storage[I] * v[I]
31+
end
32+
33+
function _adjust_matrix_action_entries_for_constraints!(
34+
Av, constraint_storage, v, backend::KA.Backend
35+
)
36+
@assert length(Av) == length(constraint_storage)
37+
@assert length(v) == length(constraint_storage)
38+
kernel! = _adjust_matrix_action_entries_for_constraints_kernel!(backend)
39+
kernel!(Av, constraint_storage, v, ndrange = length(Av))
40+
return nothing
41+
end
42+
43+
function _adjust_vector_entries_for_constraints!(b, constraint_storage, ::KA.CPU)
2644
@assert length(b) == length(constraint_storage)
2745
# modify b => (I - G) * b + (Gu - g)
2846
# but Gu = g, so we don't need that here
@@ -34,6 +52,19 @@ function _adjust_vector_entries_for_condensed!(b, constraint_storage, ::KA.CPU)
3452
return nothing
3553
end
3654

55+
KA.@kernel function _adjust_vector_entries_for_constraints_kernel(b, constraint_storage)
56+
I = KA.@index(Global)
57+
# modify b => (I - G) * b + (Gu - g)
58+
@inbounds b[I] = (1. - constraint_storage[I]) * b[I]
59+
end
60+
61+
function _adjust_vector_entries_for_constraints!(b, constraint_storage, backend::KA.Backend)
62+
@assert length(b) == length(constraint_storage)
63+
kernel! = _adjust_vector_entries_for_constraints_kernel(backend)
64+
kernel!(b, constraint_storage, ndrange = length(b))
65+
return nothing
66+
end
67+
3768
"""
3869
$(TYPEDSIGNATURES)
3970
Assembly method for an H1Field, e.g. internal force
@@ -162,30 +193,31 @@ function _quadrature_level_state(state::L2QuadratureField, q::Int, e::Int)
162193
return state_q
163194
end
164195

196+
197+
# function hvp(asm::AbstractAssembler)
198+
# if _is_condensed(asm.dof)
199+
# _adjust_matrix_action_entries_for_constraints!(
200+
# asm.stiffness_action_storage, asm.constraint_storage,
201+
# KA.get_backend(asm)
202+
# )
203+
# return asm.stiffness_action_storage.data
204+
# else
205+
# extract_field_unknowns!(
206+
# asm.stiffness_action_unknowns,
207+
# asm.dof,
208+
# asm.stiffness_action_storage
209+
# )
210+
# return asm.stiffness_action_unknowns
211+
# end
212+
# end
213+
214+
# new approach requiring access to the v that makes Hv
165215
"""
166216
$(TYPEDSIGNATURES)
167217
"""
168-
function hvp(asm::AbstractAssembler)
169-
if _is_condensed(asm.dof)
170-
_adjust_matrix_action_entries_for_condensed!(
171-
asm.stiffness_action_storage, asm.constraint_storage,
172-
KA.get_backend(asm)
173-
)
174-
return asm.stiffness_action_storage.data
175-
else
176-
extract_field_unknowns!(
177-
asm.stiffness_action_unknowns,
178-
asm.dof,
179-
asm.stiffness_action_storage
180-
)
181-
return asm.stiffness_action_unknowns
182-
end
183-
end
184-
185-
# new approach requiring access to the v that makes Hv
186218
function hvp(asm::AbstractAssembler, v)
187219
if _is_condensed(asm.dof)
188-
_adjust_matrix_action_entries_for_condensed!(
220+
_adjust_matrix_action_entries_for_constraints!(
189221
asm.stiffness_action_storage, asm.constraint_storage, v,
190222
KA.get_backend(asm)
191223
)
@@ -213,7 +245,7 @@ assumes assemble_vector! has already been called
213245
"""
214246
function residual(asm::AbstractAssembler)
215247
if _is_condensed(asm.dof)
216-
_adjust_vector_entries_for_condensed!(
248+
_adjust_vector_entries_for_constraints!(
217249
asm.residual_storage, asm.constraint_storage,
218250
KA.get_backend(asm)
219251
)

src/assemblers/SparseMatrixAssembler.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function _assemble_element!(
111111
end
112112

113113
# TODO this only work on CPU right now
114-
function _adjust_matrix_entries_for_condensed!(
114+
function _adjust_matrix_entries_for_constraints!(
115115
A::SparseMatrixCSC, constraint_storage, ::KA.CPU;
116116
penalty_scale = 1.e6
117117
)
@@ -145,6 +145,46 @@ function _adjust_matrix_entries_for_condensed!(
145145
return nothing
146146
end
147147

148+
KA.@kernel function _adjust_matrix_entries_for_constraints_kernel!(
149+
A, constraint_storage, trA;
150+
penalty_scale = 1.e6
151+
)
152+
J = KA.@index(Global)
153+
154+
penalty = penalty_scale * trA / size(A, 2)
155+
156+
# now modify A => (I - G) * A + G
157+
nz = nonzeros(A)
158+
rowval = rowvals(A)
159+
160+
col_start = A.colptr[J]
161+
col_end = A.colptr[J + 1] - 1
162+
for k in col_start:col_end
163+
# for (I - G) * A term
164+
nz[k] = (1. - constraint_storage[J]) * nz[k]
165+
166+
# for + G term
167+
if rowval[k] == J
168+
@inbounds nz[k] = nz[k] + penalty * constraint_storage[J]
169+
end
170+
end
171+
end
172+
173+
function _adjust_matrix_entries_for_constraints(
174+
A, constraint_storage, backend::KA.Backend
175+
)
176+
# first ensure things are the right size
177+
@assert size(A, 1) == size(A, 2)
178+
@assert length(constraint_storage) == size(A, 2)
179+
180+
# get trA ahead of time to save some allocations at kernel level
181+
trA = tr(A)
182+
183+
kernel! = _adjust_matrix_entries_for_constraints_kernel!(backend)
184+
kernel!(A, constraint_storage, trA, ndrange = size(A, 2))
185+
return nothing
186+
end
187+
148188
function constraint_matrix(assembler::SparseMatrixAssembler)
149189
return Diagonal(assembler.constraint_storage)
150190
end
@@ -153,7 +193,7 @@ function _mass(assembler::SparseMatrixAssembler, ::KA.CPU)
153193
M = SparseArrays.sparse!(assembler.pattern, assembler.mass_storage)
154194

155195
if _is_condensed(assembler.dof)
156-
_adjust_matrix_entries_for_condensed!(M, assembler.constraint_storage, KA.get_getbackend(assembler))
196+
_adjust_matrix_entries_for_constraints!(M, assembler.constraint_storage, KA.get_getbackend(assembler))
157197
end
158198

159199
return M
@@ -163,7 +203,7 @@ function _stiffness(assembler::SparseMatrixAssembler, ::KA.CPU)
163203
K = SparseArrays.sparse!(assembler.pattern, assembler.stiffness_storage)
164204

165205
if _is_condensed(assembler.dof)
166-
_adjust_matrix_entries_for_condensed!(K, assembler.constraint_storage, KA.get_backend(assembler))
206+
_adjust_matrix_entries_for_constraints!(K, assembler.constraint_storage, KA.get_backend(assembler))
167207
end
168208

169209
return K

src/assemblers/Vector.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ function assemble_vector!(
2626
backend
2727
)
2828
end
29+
30+
return nothing
2931
end
3032

3133
# CPU Implementation

test/TestAssemblers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ include("poisson/TestPoissonCommon.jl")
3333
K = stiffness(asm)
3434
M = mass(asm)
3535
R = residual(asm)
36-
Mv = hvp(asm)
36+
Mv = hvp(asm, Vu)
3737
end

test/poisson/TestPoissonAMDGPU.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@ bc_func(_, _) = 0.
1717

1818
include("TestPoissonCommon.jl")
1919

20-
function poisson_amdgpu()
20+
function poisson_amdgpu(use_condensed)
2121
# do all setup on CPU
2222
# the mesh for instance is not gpu compatable
2323
mesh = UnstructuredMesh(mesh_file)
2424
V = FunctionSpace(mesh, H1Field, Lagrange)
2525
physics = Poisson()
2626
props = create_properties(physics)
2727
u = ScalarFunction(V, :u)
28-
dof = DofManager(u)
29-
asm = SparseMatrixAssembler(u)
28+
asm = SparseMatrixAssembler(u; use_condensed=use_condensed)
3029

3130
dbcs = DirichletBC[
3231
DirichletBC(:u, :sset_1, bc_func),
@@ -63,7 +62,9 @@ function poisson_amdgpu()
6362
display(solver.timer)
6463
end
6564

66-
@time poisson_amdgpu()
67-
@time poisson_amdgpu()
65+
@time poisson_amdgpu(false)
66+
@time poisson_amdgpu(false)
67+
@time poisson_amdgpu(true)
68+
@time poisson_amdgpu(true)
6869

6970
# @benchmark poisson_cuda()

0 commit comments

Comments
 (0)