-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Optimization of PushRowPage for high number of cpu cores #11182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 21 commits
1780f5b
922af51
77923f2
ab31368
836c768
0879f24
49e309c
d30e657
add81f6
e00bf4a
2c0e5dd
7982415
a0d5bd6
9280b5f
1519084
8130d94
c09dc9e
ab2462a
f2b33b8
79d9cb6
7aee6e1
6604939
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -748,37 +748,79 @@ std::vector<bst_idx_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_ | |
return entries_per_columns; | ||
} | ||
|
||
struct WLBalance { | ||
explicit WLBalance(size_t n_columns) : is_column_splited(n_columns) {} | ||
|
||
struct ThreadWorkLoad { | ||
std::vector<size_t> columns; | ||
size_t split_idx = 0; | ||
size_t n_splits = 1; | ||
|
||
ThreadWorkLoad() : columns() {} | ||
}; | ||
|
||
std::vector<ThreadWorkLoad> baskets; | ||
std::vector<bool> is_column_splited; | ||
bool has_splitted = false; | ||
}; | ||
|
||
|
||
template <typename Batch, typename IsValid> | ||
std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, | ||
size_t const nthreads, IsValid&& is_valid) { | ||
/* Some sparse datasets have their mass concentrating on small number of features. To | ||
* avoid waiting for a few threads running forever, we here distribute different number | ||
* of columns to different threads according to number of entries. | ||
WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, | ||
size_t const nthreads, IsValid&& is_valid) { | ||
/* Some datasets have long columns. It is beneficial to split such columns between threads and | ||
* than collect the result if number of threads is high enourth. In this case, each thread being | ||
razdoburdin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
* involved in processing of splitted columns works only with a single column. | ||
* | ||
* Columns that are too small for splitting are distributed between threads. In this case each thread | ||
* can process multiple columns. The range of columns indexes for all the rthreads in this case don't | ||
* overlap with each other. | ||
*/ | ||
WLBalance wl_balance(n_columns); | ||
if (nnz == 0) return wl_balance; | ||
auto& wl_baskets = wl_balance.baskets; | ||
|
||
size_t const total_entries = nnz; | ||
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads); | ||
|
||
// Need to calculate the size for each batch. | ||
std::vector<bst_idx_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); | ||
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0); | ||
size_t count{0}; | ||
size_t current_thread{1}; | ||
|
||
for (auto col : entries_per_columns) { | ||
cols_ptr.at(current_thread)++; // add one column to thread | ||
count += col; | ||
CHECK_LE(count, total_entries); | ||
if (count > entries_per_thread) { | ||
current_thread++; | ||
count = 0; | ||
cols_ptr.at(current_thread) = cols_ptr[current_thread - 1]; | ||
size_t count = 0; | ||
for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) { | ||
size_t n_entries = entries_per_columns[column_idx]; | ||
|
||
if (n_entries > 0) { | ||
size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The goal is to assign for processing of the columns amount of threads being proportional to number of entries in the column. |
||
constexpr size_t kMinBlockSize = (1u << 16); | ||
if ((n_splits > 1) && (kMinBlockSize * n_splits < n_entries)) { | ||
// Split column between threads | ||
wl_balance.has_splitted = true; | ||
wl_balance.is_column_splited[column_idx] = true; | ||
for (size_t split_idx = 0; split_idx < n_splits; split_idx++) { | ||
wl_baskets.emplace_back(); | ||
|
||
auto& wl = wl_baskets.back(); | ||
wl.columns.push_back(column_idx); | ||
wl.split_idx = split_idx; | ||
wl.n_splits = n_splits; | ||
} | ||
} else { | ||
if (wl_baskets.empty() || count > entries_per_thread) { | ||
wl_baskets.emplace_back(); | ||
count = 0; | ||
} | ||
count += n_entries; | ||
|
||
auto& wl = wl_baskets.back(); | ||
wl.columns.push_back(column_idx); | ||
wl_balance.is_column_splited[column_idx] = false; | ||
} | ||
} | ||
} | ||
// Idle threads. | ||
for (; current_thread < cols_ptr.size() - 1; ++current_thread) { | ||
cols_ptr[current_thread + 1] = cols_ptr[current_thread]; | ||
} | ||
return cols_ptr; | ||
|
||
CHECK_LE(wl_baskets.size(), nthreads); | ||
return wl_balance; | ||
} | ||
|
||
/*! | ||
|
@@ -840,46 +882,126 @@ class SketchContainerImpl { | |
template <typename Batch, typename IsValid> | ||
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, | ||
size_t n_features, bool is_dense, IsValid is_valid) { | ||
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); | ||
auto threads_wl = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); | ||
if (threads_wl.baskets.empty()) return; | ||
|
||
std::vector<std::set<float>> categories_buff; | ||
std::vector<WQSketch> sketches_buff; | ||
std::vector<int> buff_was_used; | ||
|
||
if (threads_wl.has_splitted) { | ||
sketches_buff.resize(threads_wl.baskets.size()); | ||
categories_buff.resize(threads_wl.baskets.size()); | ||
buff_was_used.resize(threads_wl.baskets.size(), 0); | ||
} | ||
|
||
dmlc::OMPException exc; | ||
#pragma omp parallel num_threads(n_threads_) | ||
#pragma omp parallel num_threads(threads_wl.baskets.size()) | ||
{ | ||
exc.Run([&]() { | ||
auto tid = static_cast<uint32_t>(omp_get_thread_num()); | ||
auto const begin = thread_columns_ptr[tid]; | ||
auto const end = thread_columns_ptr[tid + 1]; | ||
const auto& wl = threads_wl.baskets[tid]; | ||
if (wl.n_splits > 1) { | ||
// We process only a single column in this case | ||
size_t column = wl.columns.front(); | ||
|
||
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[column]); | ||
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor); | ||
sketches_buff[tid].Init(columns_size_[column], eps); | ||
|
||
// do not iterate if no columns are assigned to the thread | ||
if (begin < end && end <= n_features) { | ||
size_t split_size = DivRoundUp(batch.Size(), wl.n_splits); | ||
size_t begin = wl.split_idx * split_size; | ||
size_t end = std::min(begin + split_size, batch.Size()); | ||
|
||
for (size_t ridx = begin; ridx < end; ++ridx) { | ||
auto const &line = batch.GetLine(ridx); | ||
auto w = weights[ridx + base_rowid]; | ||
if (is_dense) { | ||
auto const &elem = line.GetElement(column); | ||
/* elem.column_idx == column */ | ||
if (is_valid(elem)) { | ||
buff_was_used[tid] = 1; | ||
PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); | ||
} | ||
} else { | ||
size_t n_columns_with_high_idx = n_features - column; | ||
size_t col_begin = line.Size() < n_columns_with_high_idx ? 0 | ||
: line.Size() - n_columns_with_high_idx; | ||
size_t col_end = std::min(column + 1, line.Size()); | ||
for (size_t i = col_begin; i < col_end; ++i) { | ||
auto const &elem = line.GetElement(i); | ||
if (is_valid(elem) && (elem.column_idx == column)) { | ||
buff_was_used[tid] = 1; | ||
PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); | ||
} | ||
} | ||
} | ||
} | ||
} else { | ||
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { | ||
auto const &line = batch.GetLine(ridx); | ||
auto w = weights[ridx + base_rowid]; | ||
if (is_dense) { | ||
for (size_t ii = begin; ii < end; ii++) { | ||
auto elem = line.GetElement(ii); | ||
if (is_valid(elem)) { | ||
if (IsCat(feature_types_, ii)) { | ||
categories_[ii].emplace(elem.value); | ||
} else { | ||
sketches_[ii].Push(elem.value, w); | ||
for (size_t ii = wl.columns.front(); ii <= wl.columns.back(); ++ii) { | ||
if (!threads_wl.is_column_splited[ii]) { | ||
auto const &elem = line.GetElement(ii); | ||
/* elem.column_idx == ii */ | ||
if (is_valid(elem)) { | ||
PushElement(elem, &categories_[ii], &sketches_[ii], w); | ||
} | ||
} | ||
} | ||
} else { | ||
for (size_t i = 0; i < line.Size(); ++i) { | ||
// number of columns with idx >= wl.columns.front() | ||
size_t n_columns_with_high_idx = n_features - wl.columns.front(); | ||
size_t col_begin = line.Size() < n_columns_with_high_idx | ||
? 0 : line.Size() - n_columns_with_high_idx; | ||
size_t col_end = std::min(wl.columns.back() + 1, line.Size()); | ||
for (size_t i = col_begin; i < col_end; ++i) { | ||
auto const &elem = line.GetElement(i); | ||
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { | ||
if (IsCat(feature_types_, elem.column_idx)) { | ||
categories_[elem.column_idx].emplace(elem.value); | ||
} else { | ||
sketches_[elem.column_idx].Push(elem.value, w); | ||
if (is_valid(elem)) { | ||
if (!threads_wl.is_column_splited[elem.column_idx] && | ||
(elem.column_idx >= wl.columns.front()) && | ||
(elem.column_idx <= wl.columns.back())) { | ||
PushElement(elem, &categories_[elem.column_idx], | ||
&sketches_[elem.column_idx], w); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#pragma omp barrier | ||
if (wl.n_splits > 1 && wl.split_idx == 0) { | ||
/* The thread being responsible for the first block in split | ||
* collect info from the other ones. | ||
*/ | ||
size_t column_idx = wl.columns.front(); | ||
|
||
typename WQSketch::SummaryContainer main_summary; | ||
main_summary.Reserve(sketches_[column_idx].limit_size); | ||
typename WQSketch::SummaryContainer split_summary; | ||
split_summary.Reserve(2 * sketches_[column_idx].limit_size); | ||
typename WQSketch::SummaryContainer comb_summary; | ||
comb_summary.Reserve(3 * sketches_[column_idx].limit_size); | ||
Comment on lines
+982
to
+987
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The size of split_summary size guaranties absence of intermediate pruning while calculating summary. While the size of comb_summary is the sum of sizes of main_summary and split_summary |
||
|
||
for (size_t th = tid + 0; th < tid + wl.n_splits; ++th) { | ||
CHECK_LT(th, threads_wl.baskets.size()); | ||
// Make shure some work was done by thread | ||
if (buff_was_used[th] > 0) { | ||
if (IsCat(feature_types_, column_idx)) { | ||
categories_[column_idx].merge(categories_buff[th]); | ||
} else { | ||
sketches_buff[th].GetSummary(&split_summary); | ||
|
||
comb_summary.SetCombine(main_summary, split_summary); | ||
main_summary.SetPrune(comb_summary, sketches_[column_idx].limit_size); | ||
} | ||
} | ||
} | ||
sketches_[column_idx].PushSummary(main_summary); | ||
} | ||
}); | ||
} | ||
exc.Rethrow(); | ||
|
@@ -893,6 +1015,18 @@ class SketchContainerImpl { | |
private: | ||
// Merge all categories from other workers. | ||
void AllreduceCategories(Context const* ctx, MetaInfo const& info); | ||
|
||
template <class ElemType> | ||
void PushElement(const ElemType& elem, | ||
std::set<float>* categorie, | ||
WQSketch* sketch, | ||
float w) { | ||
if (IsCat(feature_types_, elem.column_idx)) { | ||
categorie->emplace(elem.value); | ||
} else { | ||
sketch->Push(elem.value, w); | ||
} | ||
} | ||
}; | ||
|
||
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_spitted
/has_been_xxx
?