Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@
*.e
*.swp
*.vtu

*.g.nem
*.g.pex
*.g.*.*
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteElementContainers"
uuid = "d08262e4-672f-4e7f-a976-f2cea5767631"
version = "0.9.1"
version = "0.9.2"
authors = ["Craig M. Hamel <[email protected]> and contributors"]

[deps]
Expand Down
48 changes: 48 additions & 0 deletions examples/mpi-example/example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import FiniteElementContainers:
communication_graph,
cpu,
decompose_mesh,
distribute,
global_colorings,
DistributedDevice,
getdata,
num_dofs_per_rank,
ParVector,
rank_devices,
scatter_ghosts!
using Exodus
using KernelAbstractions
using MPI

backend = CPU()
num_dofs = 1
num_ranks = 4
ranks = 1:4 |> collect
mesh_file = Base.source_dir() * "/square.g"

decompose_mesh(mesh_file, num_ranks)
global_dofs_to_colors = global_colorings(mesh_file, num_dofs, num_ranks)

comm = MPI.COMM_WORLD
# comm = nothing

ranks = distribute(ranks, comm)
# n_dofs_per_rank = distribute(num_dofs_per_rank(global_dofs_to_colors), comm)

# shards = shard_indices(global_dofs_to_colors, )
comm_graphs = map(ranks) do rank
mesh_file = Base.source_dir() * "/square.g"
mesh_file = mesh_file * ".$num_ranks" * ".$(lpad(rank - 1, Exodus.exodus_pad(num_ranks |> Int32), '0'))"
comm_graph = communication_graph(mesh_file, global_dofs_to_colors, num_dofs, rank)
end

par_vec = ParVector(comm_graphs)
parts = map(par_vec.parts, ranks) do part, rank
part .= rank
part
end
par_vec = ParVector(parts)
par_vec_2 = scatter_ghosts!(par_vec)
MPI.Barrier(comm)
# par_vec_2 = par_vec_2.parts
print("part on $(MPI.Comm_rank(comm) + 1) is $par_vec_2\n")
Binary file added examples/mpi-example/square.g
Binary file not shown.
2 changes: 1 addition & 1 deletion ext/FiniteElementContainersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function CUDA.CUSPARSE.CuSparseMatrixCSC(asm::SparseMatrixAssembler)
if FiniteElementContainers._is_condensed(asm.dof)
n_dofs = length(asm.dof)
else
n_dofs = FiniteElementContainers.num_unknowns(asm.dof)
n_dofs = length(asm.dof.unknown_dofs)
end

return CUDA.CUSPARSE.CuSparseMatrixCSC(
Expand Down
35 changes: 35 additions & 0 deletions src/parallel/Communication.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# The assumptions of this struct are as follows
# the dege points from the current comm
# to rank.
# This struct can be used to build
# send/receive data exchangerss

struct CommunicationGraphEdge{
TV <: AbstractVector,
IV <: AbstractVector{<:Integer}
}
data_recv::TV
data_send::TV
indices::IV
is_owned_recv::IV
is_owned_send::IV
rank::Int32
end

# TODO cpu only method currently
function setdatasend!(edge::CommunicationGraphEdge, part)
for n in axes(edge.indices, 1)
index = edge.indices[n]
edge.data_send[n] = part[index]
end
end

struct CommunicationGraph{
TV <: AbstractVector,
IV <: AbstractVector{<:Integer}
}
edges::Vector{CommunicationGraphEdge{TV, IV}}
global_to_color::IV
n_local::Int
n_owned::Int
end
11 changes: 11 additions & 0 deletions src/parallel/DistributedDevice.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
abstract type AbstractDevice end

struct DistributedDevice <: AbstractDevice
backend::KA.Backend
num_ranks::Int32
rank::Int32
# maybe other settings later, e.g.
# threads on a local rank
# gpu settings
# etc.
end
146 changes: 146 additions & 0 deletions src/parallel/Interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# TODO eventually don't specialize to MPI
# right now this is just for gathering to the element level
function communication_graph(
file_name::String,
global_dofs_to_colors,
n_dofs, rank
)
# open mesh and read id maps
mesh = UnstructuredMesh(file_name)
node_id_map = mesh.node_id_map
cmaps = node_cmaps(mesh.mesh_obj, rank)

# getting local dofs
n_nodes = length(node_id_map)
local_dofs = 1:n_dofs * n_nodes
local_dofs = reshape(local_dofs, n_dofs, n_nodes)

# find global dofs owned by this rank
n_nodes_total = length(global_dofs_to_colors) ÷ n_dofs
global_dofs = 1:length(global_dofs_to_colors)
global_dofs = reshape(global_dofs, n_dofs, n_nodes_total)
global_dofs_in_rank = findall(x -> x == rank, global_dofs_to_colors)

# setup graph edges
graph_edges = Vector{
CommunicationGraphEdge{Vector{Float64}, Vector{Int64}}
}(undef, length(cmaps))

for (n, cmap) in enumerate(cmaps)
# check to make sure proc is unique
procs = unique(cmap.proc_ids)
@assert length(procs) == 1
dest = procs[1]
local_cmap_nodes = cmap.node_ids
global_cmap_nodes = node_id_map[cmap.node_ids]

local_cmap_dofs = local_dofs[:, local_cmap_nodes] |> vec
global_cmap_dofs = global_dofs[:, global_cmap_nodes] |> vec

# figure out which part of comm maps are owned by this rank
is_owned_send = Vector{Int}(undef, length(local_cmap_dofs))
for (n, dof) in enumerate(global_cmap_dofs)
if dof in global_dofs_in_rank
is_owned_send[n] = 1
else
is_owned_send[n] = 0
end
end

# TODO make vector type generic, but floats are likely the main thing
data_send = zeros(Float64, length(local_cmap_dofs))
data_recv = similar(data_send)
# graph_edges[n] = CommunicationGraphEdge(data_recv, data_send, local_cmap_dofs, is_owned, dest)
is_owned_recv = similar(is_owned_send)

# TODO move to interface so its not MPI specific
comm = MPI.COMM_WORLD
recv_req = MPI.Irecv!(is_owned_recv, comm; source=dest - 1)
send_req = MPI.Isend(is_owned_send, comm; dest=dest - 1)
MPI.Waitall([recv_req, send_req])

graph_edges[n] = CommunicationGraphEdge(
data_recv, data_send,
local_cmap_dofs,
is_owned_recv, is_owned_send,
dest
)
end

return CommunicationGraph(graph_edges, global_dofs_to_colors, length(node_id_map), length(global_dofs_in_rank))
end

# serial/debug case
function distribute(data)
return data
end

# serial/debug case
function distribute(data, ::Nothing)
return data
end

function distribute(data, comm::MPI.Comm)
rank = MPI.Comm_rank(comm) + 1
return MPIValue(comm, data[rank])
end

function num_dofs_per_rank(global_dofs_to_colors)
num_ranks = length(unique(global_dofs_to_colors))
num_dofs = Vector{Int}(undef, num_ranks)
for n in axes(num_dofs, 1)
num_dofs[n] = length(filter(x -> x == n, global_dofs_to_colors))
end
return num_dofs
end

function rank_devices(
num_ranks, ::Val{:debug},
backend::KA.Backend
)
if !MPI.Initialized()
# @assert false "You're trying to use MPI from debug mode"
@warn "Running in debug mode. Make sure you're not using mpiexecjl"
end
map(x -> DistributedDevice(backend, num_ranks, x), 1:num_ranks)
end

function rank_devices(
num_ranks, ::Val{:mpi},
backend::KA.Backend
)
# put init here so users don't need to be aware
if !MPI.Initialized()
MPI.Init()
end

# MPI indexing stuff
comm = MPI.COMM_WORLD
@assert MPI.Comm_size(comm) == num_ranks
rank = MPI.Comm_rank(comm) + 1
device = DistributedDevice(backend, num_ranks, rank)

# return MPIVector(comm, device)
return MPIValue(device)
end

function rank_devices(
num_ranks, par_type::Symbol;
backend = KA.CPU()
)
val = Val{par_type}()
return rank_devices(num_ranks, val, backend)
end

# function shard_indices(file_name, global_dofs_to_colors, rank)
# indices = findall(x -> x == rank, global_dofs_to_colors)
# mesh = UnstructuredMesh(file_name)
# node_id_map = mesh.node_id_map
# cmaps = node_cmaps(mesh.mesh_obj, rank)

# if length(indices) == length(node_id_map)
# # special case that this rank owns all its dofs
# else

# end
# end
47 changes: 47 additions & 0 deletions src/parallel/MPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
abstract type AbstractMPIData{T} <: AbstractArray{T, 1} end

Base.IndexStyle(::Type{<:AbstractMPIData}) = Base.IndexLinear()
Base.getindex(a::AbstractMPIData, n) = a.data[1]
Base.size(a::AbstractMPIData) = size(a.data)

struct MPIValue{T} <: AbstractMPIData{T}
comm::MPI.Comm
data::T
end

function MPIValue(data)
comm = MPI.COMM_WORLD
return MPIValue(comm, data)
end

function Base.map(f, v::MPIValue)
new_data = f(v.data)
return MPIValue(v.comm, new_data)
end

function Base.map(f, v1::MPIValue, v2::MPIValue)
new_data = f(v1.data, v2.data)
return MPIValue(v1.comm, new_data)
end

function Base.show(io::IO, v::MPIValue)
rank = MPI.Comm_rank(v.comm) + 1
println(io, "MPIValue on rank $rank")
println(io, v.data)
end

function getdata(v::MPIValue)
return v.data
end

function setdata(v::MPIValue{T}, data::T) where T
if ismutable(data)
if isa(data, AbstractArray)
copyto!(v.data, data)
else
@assert false "Unsupported type setting in setdata in MPIValue"
end
else
@assert false "Can't set immutable data in MPIValue"
end
end
Loading
Loading