Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class TCPSocket {
/**
* @brief Listen to incoming requests. Should be called after bind.
*/
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
[[nodiscard]] Result Listen(std::int32_t backlog = 512) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
Expand Down
1 change: 1 addition & 0 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "topo.h" // for BootstrapNext, BootstrapPrev
#include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span
Expand Down
79 changes: 33 additions & 46 deletions src/collective/broadcast.cc
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "broadcast.h"

#include <cmath> // for ceil, log2
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move

#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32

#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "topo.h" // for ParentRank

namespace xgboost::collective::cpu_impl {
namespace {
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}

CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto shifted_parent = shifted_rank - (1 << k);
return shifted_parent;
}

// Shift the root node to rank 0
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto shifted_rank = (rank + world - root) % world;
return shifted_rank;
}
// shift back to the original rank
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto orig = (rank + root) % world;
return orig;
}
} // namespace

Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
// Binomial tree broadcast
// * Wiki
Expand All @@ -56,28 +24,47 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
auto rank = comm.Rank();
auto world = comm.World();

// shift root to rank 0
auto shifted_rank = ShiftLeft(rank, world, root);
// Send data to the root to preserve the topology. Alternative is to shift the rank, but
// it requires a all-to-all connection.
//
// Most of the use of broadcasting in XGBoost are short messages, this should be
// fine. Otherwise, we can implement a linear pipeline broadcast.
if (root != 0) {
auto rc = Success() << [&] {
return (rank == 0) ? comm.Chan(root)->RecvAll(data) : Success();
} << [&] {
return (rank == root) ? comm.Chan(0)->SendAll(data) : Success();
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Broadcast failed to send data to root.", std::move(rc));
}
root = 0;
}

std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;

if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
<< [&] { return comm.Chan(parent)->Block(); };
if (rank != 0) { // not root
auto parent = ParentRank(rank, depth);
auto rc = Success() << [&] {
return comm.Chan(parent)->RecvAll(data);
} << [&] {
return comm.Chan(parent)->Block();
};
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
return Fail("Broadcast failed to send data to parent.", std::move(rc));
}
}

for (std::int32_t i = depth; i >= 0; --i) {
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
if (rank % (1 << (i + 1)) == 0 && rank + (1 << i) < world) {
auto peer = rank + (1 << i);
CHECK_NE(peer, root);
auto rc = comm.Chan(peer)->SendAll(data);
if (!rc.OK()) {
return rc;
return Fail("Failed to seed to " + std::to_string(peer), std::move(rc));
}
}
}
Expand Down
133 changes: 89 additions & 44 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#endif // !defined(XGBOOST_USE_NCCL)
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "topo.h" // for BootstrapNext
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
Expand Down Expand Up @@ -58,6 +59,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}

// Connect ring and tree neighbors
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
Expand All @@ -80,10 +82,10 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return prev->NonBlocking(true);
};
if (!rc.OK()) {
return rc;
return Fail("Bootstrap failed to recv from ring prev.", std::move(rc));
}

// exchange host name and port
// Exchange host name and port
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
Expand All @@ -107,7 +109,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st

rc = std::move(rc) << [&] {
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
} << [&] { return block(); };
} << [&] {
return block();
};
if (!rc.OK()) {
return Fail("Failed to get host names from peers.", std::move(rc));
}
Expand All @@ -118,7 +122,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
peers_port.size() * sizeof(ninfo.port)};
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
} << [&] { return block(); };
} << [&] {
return block();
};
if (!rc.OK()) {
return Fail("Failed to get the port from peers.", std::move(rc));
}
Expand All @@ -138,55 +144,94 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st

std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
workers.resize(comm.World());

for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
auto const& peer = peers[r];
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc)
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
<< [&] { return worker->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}

auto rank = comm.Rank();
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
workers[BootstrapNext(comm.Rank(), comm.World())] = next;
if (BootstrapNext(comm.Rank(), comm.World()) == BootstrapPrev(comm.Rank(), comm.World())) {
if (comm.Rank() == 0) {
if (comm.World() == 2) {
workers[BootstrapNext(comm.Rank(), comm.World())] = prev;
} else {
CHECK_EQ(comm.World(), 1);
}
}
workers[r] = std::move(worker);
} else {
workers[BootstrapPrev(comm.Rank(), comm.World())] = prev;
}

for (std::int32_t r = 0; r < comm.Rank(); ++r) {
auto peer = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
/**
* Construct tree.
*/
// All workers connect to rank 0 so that we can always use rank 0 as broadcast root.
if (comm.Rank() == 0) {
for (std::int32_t i = 0; i < comm.World() - 3; ++i) {
auto worker = std::make_shared<TCPSocket>();
SockAddress addr;
return listener->Accept(peer.get(), &addr);
} << [&] {
return peer->RecvTimeout(timeout);
};
if (!rc.OK()) {
return rc;
rc = listener->Accept(worker.get(), &addr);
if (!rc.OK()) {
return Fail("Failed to accept for rank 0.", std::move(rc));
}
std::int32_t r{-1};
std::size_t n_bytes{0};
rc = worker->RecvAll(&r, sizeof(r), &n_bytes);
if (!rc.OK()) {
return Fail("Failed to recv rank.", std::move(rc));
}
if (n_bytes != sizeof(r)) {
return Fail("Failed to recv rank due to size.", std::move(rc));
}
workers[r] = worker;
}
std::int32_t rank{-1};
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
} else {
if (!workers[0]) {
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
return Connect(peers[0].host, peers[0].port, retry, timeout, worker.get());
} << [&] {
auto rank = comm.Rank();
std::size_t n_bytes = 0;
auto rc = worker->SendAll(&rank, sizeof(rank), &n_bytes);
if (n_bytes != sizeof(rank)) {
return Fail("Failed to send rank due to size.", std::move(rc));
}
return rc;
};
if (!rc.OK()) {
return Fail("Failed to connect to root.", std::move(rc));
}
workers[0] = worker;
}
}
// Binomial tree connect
std::int32_t const kDepth = std::ceil(std::log2(static_cast<double>(comm.World()))) - 1;
if (comm.Rank() != 0) {
auto prank = ParentRank(comm.Rank(), kDepth);
if (!workers[prank]) { // Skip if it's part of the ring.
auto parent = std::make_shared<TCPSocket>();
SockAddress addr;
rc = listener->Accept(parent.get(), &addr);
if (!rc.OK()) {
return Fail("Failed to recv connection from tree parent.", std::move(rc));
}
workers[prank] = parent;
}
workers[rank] = std::move(peer);
}

for (std::int32_t r = 0; r < comm.World(); ++r) {
if (r == comm.Rank()) {
continue;
for (std::int32_t i = kDepth; i >= 0; --i) {
if (comm.Rank() % (1 << (i + 1)) == 0 && comm.Rank() + (1 << i) < comm.World()) {
auto peer = comm.Rank() + (1 << i);
if (workers[peer]) { // skip if it's part of the ring.
continue;
}
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
return Connect(peers[peer].host, peers[peer].port, retry, timeout, worker.get());
} << [&] {
return worker->RecvTimeout(timeout);
};
if (!rc.OK()) {
return Fail("Failed to connect to tree neighbor", std::move(rc));
}
workers[peer] = worker;
}
CHECK(workers[r]);
}

return Success();
Expand Down
17 changes: 3 additions & 14 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,10 @@

namespace xgboost::collective {

inline constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min
inline constexpr std::int32_t DefaultRetry() { return 3; }
constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min
constexpr std::int32_t DefaultRetry() { return 3; }

// indexing into the ring
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
auto nrank = (r + world + 1) % world;
return nrank;
}

inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
auto nrank = (r + world - 1) % world;
return nrank;
}

inline StringView DefaultNcclName() { return "libnccl.so.2"; }
constexpr StringView DefaultNcclName() { return "libnccl.so.2"; }

class Channel;
class Coll;
Expand Down
26 changes: 26 additions & 0 deletions src/collective/topo.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "topo.h"

#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
namespace xgboost::collective {
std::int32_t ParentRank(std::int32_t rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}

CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto parent = rank - (1 << k);
return parent;
}
} // namespace xgboost::collective
Loading