Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class Predictor {
gbm::GBTreeModel const& model, bst_tree_t tree_end = 0,
std::vector<float> const* tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;
unsigned condition_feature = 0, HostDeviceVector<bst_feature_t>* feature_reprs = nullptr) const = 0;

virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
gbm::GBTreeModel const& model,
Expand Down
10 changes: 5 additions & 5 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ class CPUPredictor : public Predictor {
DataView batch, const MetaInfo &info, const gbm::GBTreeModel &model,
const std::vector<bst_float> *tree_weights, std::vector<std::vector<float>> *mean_values,
std::vector<RegTree::FVec> *feat_vecs, std::vector<bst_float> *contribs,
bst_tree_t ntree_limit, bool approximate, int condition, unsigned condition_feature) const {
bst_tree_t ntree_limit, bool approximate, int condition, unsigned condition_feature, std::vector<bst_feature_t> *feature_reprs = nullptr) const {
const int num_feature = model.learner_model_param->num_feature;
const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
Expand Down Expand Up @@ -778,7 +778,7 @@ class CPUPredictor : public Predictor {
}
if (!approximate) {
CalculateContributions(*model.trees[j], feats, tree_mean_values,
&this_tree_contribs[0], condition, condition_feature);
&this_tree_contribs[0], condition, condition_feature, feature_reprs == nullptr ? nullptr : feature_reprs->data());
} else {
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
Expand Down Expand Up @@ -950,7 +950,7 @@ class CPUPredictor : public Predictor {
void PredictContribution(DMatrix *p_fmat, HostDeviceVector<float> *out_contribs,
const gbm::GBTreeModel &model, bst_tree_t ntree_limit,
std::vector<bst_float> const *tree_weights, bool approximate,
int condition, unsigned condition_feature) const override {
int condition, unsigned condition_feature, HostDeviceVector<bst_feature_t> *feature_reprs = nullptr) const override {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "Predict contribution" << MTNotImplemented();
CHECK(!p_fmat->Info().IsColumnSplit())
Expand Down Expand Up @@ -982,13 +982,13 @@ class CPUPredictor : public Predictor {
for (const auto &batch : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, {})) {
PredictContributionKernel(GHistIndexMatrixView{batch, std::forward<Enc>(acc), ft}, info,
model, tree_weights, &mean_values, &feat_vecs, &contribs,
ntree_limit, approximate, condition, condition_feature);
ntree_limit, approximate, condition, condition_feature, feature_reprs == nullptr ? nullptr : &feature_reprs->HostVector());
}
} else {
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
PredictContributionKernel(SparsePageView{&batch, std::forward<Enc>(acc)}, info, model,
tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit,
approximate, condition, condition_feature);
approximate, condition, condition_feature, feature_reprs == nullptr ? nullptr : &feature_reprs->HostVector());
}
}
};
Expand Down
23 changes: 13 additions & 10 deletions src/predictor/cpu_treeshap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ float UnwoundPathSum(const PathElement* unique_path, std::uint32_t unique_depth,
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
* \param feature_reprs mapping of features to group representatives, for groupSHAP calculation
*/
void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_node_t node_index,
std::uint32_t unique_depth, PathElement* parent_unique_path,
float parent_zero_fraction, float parent_one_fraction, int parent_feature_index,
int condition, std::uint32_t condition_feature, float condition_fraction) {
int condition, std::uint32_t condition_feature, float condition_fraction,
std::uint32_t* feature_reprs) {
const auto node = tree[node_index];

// stop if we have no weight coming down to us
Expand All @@ -125,6 +127,7 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no
parent_feature_index);
}
const std::uint32_t split_index = node.SplitIndex();
const std::uint32_t split_index_repr = feature_reprs == nullptr ? split_index : feature_reprs[split_index];

// leaf node
if (node.IsLeaf()) {
Expand Down Expand Up @@ -153,7 +156,7 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no
// if so we undo that split so we can redo it for this node
std::uint32_t path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (static_cast<std::uint32_t>(unique_path[path_index].feature_index) == split_index) break;
if (static_cast<std::uint32_t>(unique_path[path_index].feature_index) == split_index_repr) break;
}
if (path_index != unique_depth + 1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
Expand All @@ -165,28 +168,28 @@ void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_no
// divide up the condition_fraction among the recursive calls
float hot_condition_fraction = condition_fraction;
float cold_condition_fraction = condition_fraction;
if (condition > 0 && split_index == condition_feature) {
if (condition > 0 && split_index_repr == condition_feature) {
cold_condition_fraction = 0;
unique_depth -= 1;
} else if (condition < 0 && split_index == condition_feature) {
} else if (condition < 0 && split_index_repr == condition_feature) {
hot_condition_fraction *= hot_zero_fraction;
cold_condition_fraction *= cold_zero_fraction;
unique_depth -= 1;
}

TreeShap(tree, feat, phi, hot_index, unique_depth + 1, unique_path,
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index,
condition, condition_feature, hot_condition_fraction);
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index_repr,
condition, condition_feature, hot_condition_fraction, feature_reprs);

TreeShap(tree, feat, phi, cold_index, unique_depth + 1, unique_path,
cold_zero_fraction * incoming_zero_fraction, 0, split_index, condition,
condition_feature, cold_condition_fraction);
cold_zero_fraction * incoming_zero_fraction, 0, split_index_repr, condition,
condition_feature, cold_condition_fraction, feature_reprs);
}
}

void CalculateContributions(RegTree const& tree, const RegTree::FVec& feat,
std::vector<float>* mean_values, float* out_contribs, int condition,
std::uint32_t condition_feature) {
std::uint32_t condition_feature, std::uint32_t* feature_reprs) {
// find the expected value of the tree's predictions
if (condition == 0) {
float node_value = (*mean_values)[0];
Expand All @@ -198,6 +201,6 @@ void CalculateContributions(RegTree const& tree, const RegTree::FVec& feat,
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);

TreeShap(tree, feat, out_contribs, 0, 0, unique_path_data.data(), 1, 1, -1, condition,
condition_feature, 1);
condition_feature, 1, feature_reprs);
}
} // namespace xgboost
3 changes: 2 additions & 1 deletion src/predictor/cpu_treeshap.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ namespace xgboost {
* \param out_contribs output vector to hold the contributions
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
* \param feature_reprs mapping of features to group representatives, for groupSHAP calculation
*/
void CalculateContributions(RegTree const &tree, const RegTree::FVec &feat,
std::vector<float> *mean_values, bst_float *out_contribs, int condition,
unsigned condition_feature);
unsigned condition_feature, std::uint32_t* feature_reprs);
} // namespace xgboost
#endif // XGBOOST_PREDICTOR_CPU_TREESHAP_H_