Skip to content

Commit 21600f7

Browse files
authored
Merge pull request #103 from Cthonios/fields/abstract-methods
consolidating a lot of field methods and fixing a bug in GPU assemble…
2 parents a710974 + e4c460c commit 21600f7

File tree

6 files changed

+103
-317
lines changed

6 files changed

+103
-317
lines changed

src/assemblers/GPUGeneral.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function _assemble_block_residual!(
5151
kernel! = _assemble_block_residual_kernel!(backend)
5252
kernel!(
5353
assembler, physics, ref_fe,
54-
U, X, state_old, state_new,
54+
U, X, state_old, state_new, props,
5555
conns, block_id, ndrange=size(conns, 2)
5656
)
5757
return nothing
@@ -104,7 +104,7 @@ function _assemble_block_stiffness!(
104104
kernel! = _assemble_block_stiffness_kernel!(backend)
105105
kernel!(
106106
assembler, physics, ref_fe,
107-
U, X, state_old, state_new,
107+
U, X, state_old, state_new, props,
108108
conns, block_id, ndrange=size(conns, 2)
109109
)
110110
return nothing

src/fields/Fields.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,85 @@ $(TYPEDSIGNATURES)
3232
"""
3333
_sym_id_map(::AbstractField{T, N, NF, Vals, SymIDMap}, sym::Symbol) where {T, N, NF, Vals, SymIDMap} = getproperty(SymIDMap, sym)
3434

35+
# minimal abstractarray interface methods below
36+
37+
function Base.axes(field::AbstractField{T, 2, NF, V, S}) where {T, NF, V, S}
38+
NN = length(field) ÷ NF
39+
return (Base.OneTo(NF), Base.OneTo(NN))
40+
end
41+
42+
function Base.getindex(field::AbstractField, n::Int)
43+
return getindex(field.vals, n)
44+
end
45+
46+
function Base.getindex(field::AbstractField{T, 2, NF, V, S}, sym::Symbol) where {T, NF, V, S}
47+
d = _sym_id_map(field, sym)
48+
return field[d, :]
49+
end
50+
51+
function Base.getindex(field::AbstractField{T, 3, NF, V, S}, sym::Symbol) where {T, NF, V, S}
52+
d = _sym_id_map(field, sym)
53+
return field[d, :, :]
54+
end
55+
56+
function Base.getindex(field::AbstractField{T, 2, NF, V, S}, sym::Symbol, ::Colon) where {T, NF, V, S}
57+
d = _sym_id_map(field, sym)
58+
return field[d, :]
59+
end
60+
61+
function Base.getindex(field::AbstractField{T, 3, NF, V, S}, sym::Symbol, ::Colon, ::Colon) where {T, NF, V, S}
62+
d = _sym_id_map(field, sym)
63+
return field[d, :, :]
64+
end
65+
66+
function Base.getindex(field::AbstractField{T, 2, NF, V, S}, sym::Symbol, n::Int) where {T, NF, V, S}
67+
d = _sym_id_map(field, sym)
68+
return field[d, n]
69+
end
70+
71+
function Base.getindex(field::AbstractField{T, 3, NF, V, S}, sym::Symbol, m::Int, n::Int) where {T, NF, V, S}
72+
d = _sym_id_map(field, sym)
73+
return field[d, m, n]
74+
end
75+
76+
function Base.IndexStyle(::Type{<:AbstractField})
77+
return IndexLinear()
78+
end
79+
80+
function Base.setindex!(field::AbstractField{T, N, NF, V, S}, v::T, n::Int) where {T, N, NF, V, S}
81+
setindex!(field.vals, v, n)
82+
return nothing
83+
end
84+
85+
function Base.similar(field::AbstractField)
86+
vals = similar(field.vals)
87+
return typeof(field)(vals)
88+
end
89+
90+
function Base.size(field::AbstractField{T, 2, NF, V, SymIDMap}) where {T, NF, V <: DenseArray, SymIDMap}
91+
NN = length(field.vals) ÷ NF
92+
return (NF, NN)
93+
end
94+
95+
function Base.view(field::AbstractField{T, 2, NF, V, S}, sym::Symbol) where {T, NF, V, S}
96+
d = _sym_id_map(field, sym)
97+
return view(field, d, :)
98+
end
99+
100+
function Base.view(field::AbstractField{T, 2, NF, V, S}, sym::Symbol, ::Colon) where {T, NF, V, S}
101+
d = _sym_id_map(field, sym)
102+
return view(field, d, :)
103+
end
104+
105+
function Base.view(field::AbstractField{T, 3, NF, V, S}, sym::Symbol) where {T, NF, V, S}
106+
d = _sym_id_map(field, sym)
107+
return view(field, d, :, :)
108+
end
109+
110+
function Base.view(field::AbstractField{T, 3, NF, V, S}, sym::Symbol, ::Colon, ::Colon) where {T, NF, V, S}
111+
d = _sym_id_map(field, sym)
112+
return view(field, d, :, :)
113+
end
35114

36115
# actual implementations
37116
# include("ElementField.jl")

src/fields/H1Field.jl

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,9 @@ struct H1Field{T, NF, Vals <: AbstractArray{T, 1}, SymIDMap} <: AbstractField{T,
77
vals::Vals
88
end
99

10-
# constructors
11-
1210
"""
1311
$(TYPEDSIGNATURES)
14-
```H1Field{NF, NN}(vals::V) where {NF, NN, V <: AbstractArray{<:Number, 1}}```
15-
"""
16-
function H1Field{NF, NN}(vals::V, syms) where {NF, NN, V <: AbstractArray{<:Number, 1}}
17-
@assert length(vals) == NF * NN
18-
@assert length(syms) == NF
19-
nt = NamedTuple{syms}(1:length(syms))
20-
H1Field{eltype(vals), NF, typeof(vals), nt}(vals)
21-
end
22-
2312
"""
24-
$(TYPEDSIGNATURES)
25-
```H1Field{NF, NN}(vals::M) where {NF, NN, M <: AbstractArray{<:Number, 2}}```
26-
"""
27-
function H1Field{NF, NN}(vals::M, syms) where {NF, NN, M <: AbstractArray{<:Number, 2}}
28-
@assert size(vals) == (NF, NN)
29-
@assert length(syms) == NF
30-
new_vals = vec(vals)
31-
nt = NamedTuple{syms}(1:length(syms))
32-
H1Field{eltype(new_vals), NF, typeof(new_vals), nt}(new_vals)
33-
end
34-
35-
function H1Field{Tup, T}(::UndefInitializer, syms) where {Tup, T}
36-
NF, NN = Tup
37-
@assert length(syms) == NF
38-
nt = NamedTuple{syms}(1:length(syms))
39-
vals = Vector{T}(undef, NF * NN)
40-
return H1Field{T, NF, typeof(vals), nt}(vals)
41-
end
42-
4313
function H1Field(vals::M, syms) where M <: AbstractMatrix
4414
NF = size(vals, 1)
4515
@assert length(syms) == NF
@@ -48,59 +18,21 @@ function H1Field(vals::M, syms) where M <: AbstractMatrix
4818
return H1Field{eltype(vals), NF, typeof(vals), nt}(vals)
4919
end
5020

51-
"""
52-
$(TYPEDSIGNATURES)
53-
"""
54-
H1Field{Tup}(vals, syms) where Tup = H1Field{Tup[1], Tup[2]}(vals, syms)
55-
5621
# general base methods
57-
"""
58-
$(TYPEDSIGNATURES)
59-
"""
60-
function Base.similar(field::H1Field{T, NF, Vals, SymIDMap}) where {T, NF, Vals, SymIDMap}
61-
vals = similar(field.vals)
62-
return H1Field{T, NF, Vals, SymIDMap}(vals)
63-
end
6422

6523
function Base.zero(::Type{H1Field{T, NF, Vals, SymIDMap}}, n_nodes) where {T, NF, Vals, SymIDMap}
6624
vals = zeros(T, NF * n_nodes)
6725
return H1Field{T, NF, typeof(vals), SymIDMap}(vals)
6826
end
6927

7028
# abstract array interface
71-
Base.IndexStyle(::Type{<:H1Field}) = IndexLinear()
72-
73-
function Base.axes(field::H1Field{T, NF, V, SymIDMap}) where {T, NF, V <: DenseArray, SymIDMap}
74-
NN = length(field) ÷ NF
75-
return (Base.OneTo(NF), Base.OneTo(NN))
76-
end
77-
78-
Base.getindex(field::H1Field, n::Int) = getindex(field.vals, n)
7929

8030
function Base.getindex(field::H1Field, d::Int, n::Int)
8131
@assert d > 0 && d <= num_fields(field)
8232
@assert n > 0 && n <= num_nodes(field)
8333
getindex(field.vals, (n - 1) * num_fields(field) + d)
8434
end
8535

86-
function Base.getindex(field::H1Field, sym::Symbol, n::Int)
87-
d = _sym_id_map(field, sym)
88-
return getindex(field, d, n)
89-
end
90-
91-
function Base.getindex(field::H1Field, sym::Symbol)
92-
d = _sym_id_map(field, sym)
93-
return field[d, :]
94-
end
95-
96-
function Base.getindex(field::H1Field, sym::Symbol, ::Colon)
97-
d = _sym_id_map(field, sym)
98-
return field[d, :]
99-
end
100-
101-
102-
Base.setindex!(field::H1Field, v, n::Int) = setindex!(field.vals, v, n)
103-
10436
function Base.setindex!(field::H1Field{T, NF, V, SymIDMap}, v, d::Int, n::Int) where {T, NF, V <: DenseArray, SymIDMap}
10537
@assert d > 0 && d <= num_fields(field)
10638
@assert n > 0 && n <= num_nodes(field)
@@ -113,22 +45,6 @@ end
11345
# # d = _sym_id_map(field)
11446
# end
11547

116-
function Base.size(field::H1Field{T, NF, V, SymIDMap}) where {T, NF, V <: DenseArray, SymIDMap}
117-
NN = length(field.vals) ÷ NF
118-
return (NF, NN)
119-
end
120-
121-
# TODO
122-
function Base.view(field::H1Field, sym::Symbol)
123-
d = _sym_id_map(field, sym)
124-
return view(field, d, :)
125-
end
126-
127-
function Base.view(field::H1Field, sym::Symbol, ::Colon)
128-
d = _sym_id_map(field, sym)
129-
return view(field, d, :)
130-
end
131-
13248
# additional methods
13349
"""
13450
$(TYPEDSIGNATURES)

src/fields/L2ElementField.jl

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,9 @@ struct L2ElementField{T, NN, Vals <: AbstractArray{T, 1}, SymIDMap} <: AbstractF
77
vals::Vals
88
end
99

10-
# constructors
1110
"""
1211
$(TYPEDSIGNATURES)
1312
"""
14-
function L2ElementField{NF, NN}(vals::V, syms) where {NF, NN, V <: AbstractArray{<:Number, 1}}
15-
@assert length(vals) == NF * NN
16-
@assert length(syms) == NF
17-
nt = NamedTuple{syms}(1:length(syms))
18-
L2ElementField{eltype(vals), NF, typeof(vals), nt}(vals)
19-
end
20-
21-
"""
22-
$(TYPEDSIGNATURES)
23-
"""
24-
function L2ElementField{NN, NE}(vals::M, syms) where {NN, NE, M <: AbstractArray{<:Number, 2}}
25-
@assert size(vals) == (NN, NE)
26-
@assert length(syms) == NN
27-
new_vals = vec(vals)
28-
nt = NamedTuple{syms}(1:length(syms))
29-
L2ElementField{eltype(new_vals), NN, typeof(new_vals), nt}(new_vals)
30-
end
31-
32-
function L2ElementField{Tup, T}(::UndefInitializer, syms) where {Tup, T}
33-
NN, NE = Tup
34-
vals = Vector{T}(undef, NN * NE)
35-
nt = NamedTuple{syms}(1:length(syms))
36-
return L2ElementField{T, NN, typeof(vals), nt}(vals)
37-
end
38-
39-
"""
40-
$(TYPEDSIGNATURES)
41-
"""
42-
L2ElementField{Tup}(vals, syms) where Tup = L2ElementField{Tup[1], Tup[2]}(vals, syms)
43-
4413
function L2ElementField(vals::M, syms) where M <: AbstractMatrix
4514
NN = size(vals, 1)
4615
vals = vec(vals)
@@ -49,73 +18,26 @@ function L2ElementField(vals::M, syms) where M <: AbstractMatrix
4918
end
5019

5120
# general base methods
52-
"""
53-
$(TYPEDSIGNATURES)
54-
"""
55-
function Base.similar(field::L2ElementField{T, NN, Vals, SymIDMap}) where {T, NN, Vals, SymIDMap}
56-
vals = similar(field.vals)
57-
return L2ElementField{T, NN, Vals, SymIDMap}(vals)
58-
end
5921

6022
function Base.zero(::Type{L2ElementField{T, NN, Vals, SymIDMap}}, n_elements) where {T, NN, Vals, SymIDMap}
6123
vals = zeros(T, NN * n_elements)
6224
return L2ElementField{T, NN, typeof(vals), SymIDMap}(vals)
6325
end
6426

6527
# abstract array interface
66-
Base.IndexStyle(::Type{<:L2ElementField}) = IndexLinear()
67-
68-
function Base.axes(field::L2ElementField{T, NN, V, SymIDMap}) where {T, NN, V <: DenseArray, SymIDMap}
69-
NE = length(field) ÷ NN
70-
return (Base.OneTo(NN), Base.OneTo(NE))
71-
end
72-
73-
Base.getindex(field::L2ElementField, n::Int) = getindex(field.vals, n)
7428

7529
function Base.getindex(field::L2ElementField, d::Int, n::Int)
7630
@assert d > 0 && d <= num_fields(field)
7731
@assert n > 0 && n <= num_elements(field)
7832
getindex(field.vals, (n - 1) * num_fields(field) + d)
7933
end
8034

81-
function Base.getindex(field::L2ElementField, sym::Symbol, n::Int)
82-
d = _sym_id_map(field, sym)
83-
return getindex(field, d, n)
84-
end
85-
86-
function Base.getindex(field::L2ElementField, sym::Symbol)
87-
d = _sym_id_map(field, sym)
88-
return field[d, :]
89-
end
90-
91-
function Base.getindex(field::L2ElementField, sym::Symbol, ::Colon)
92-
d = _sym_id_map(field, sym)
93-
return field[d, :]
94-
end
95-
96-
Base.setindex!(field::L2ElementField, v, n::Int) = setindex!(field.vals, v, n)
97-
9835
function Base.setindex!(field::L2ElementField{T, NN, V, SymIDMap}, v, d::Int, n::Int) where {T, NN, V <: DenseArray, SymIDMap}
9936
@assert d > 0 && d <= num_fields(field)
10037
@assert n > 0 && n <= num_elements(field)
10138
setindex!(field.vals, v, (n - 1) * num_fields(field) + d)
10239
end
10340

104-
function Base.size(field::L2ElementField{T, NN, V, SymIDMap}) where {T, NN, V <: DenseArray, SymIDMap}
105-
NE = length(field.vals) ÷ NN
106-
return (NN, NE)
107-
end
108-
109-
function Base.view(field::L2ElementField, sym::Symbol)
110-
d = _sym_id_map(field, sym)
111-
return view(field, d, :)
112-
end
113-
114-
function Base.view(field::L2ElementField, sym::Symbol, ::Colon)
115-
d = _sym_id_map(field, sym)
116-
return view(field, d, :)
117-
end
118-
11941
"""
12042
$(TYPEDSIGNATURES)
12143
"""

0 commit comments

Comments
 (0)