Skip to content

Commit ac6f27c

Browse files
committed
[RF] Avoid compute graph desync in RooFixedProdPdf
The RooFixedProdPdf used an antipattern that should be avoided, and it was actually causing some problems in fits with the new CPU evaluation backend. The problem was that the class held on to some RooAbsArgs for its evaluation in some containers that were separate from RooFit server-client interface (the container was the `std::unique_ptr<RooProdPdf::CacheElem>`). This commit fixes the situation by avoiding this cache element member completely and consistently using the server proxies instead.
1 parent 4731fd5 commit ac6f27c

File tree

3 files changed

+83
-61
lines changed

3 files changed

+83
-61
lines changed

roofit/codegen/src/CodegenImpl.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R
137137

138138
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
139139
{
140-
if (arg.cache()._isRearranged) {
141-
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen));
140+
if (arg.isRearranged()) {
141+
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.rearrangedNum(), *arg.rearrangedDen()));
142142
} else {
143-
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), arg.cache()._partList, arg.cache()._partList.size()));
143+
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), *arg.partList(), arg.partList()->size()));
144144
}
145145
}
146146

roofit/roofitcore/inc/RooProdPdf.h

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ class RooProdPdf : public RooAbsPdf {
177177
void rearrangeProduct(CacheElem&) const;
178178
std::unique_ptr<RooAbsReal> specializeIntegral(RooAbsReal& orig, const char* targetRangeName) const ;
179179
std::unique_ptr<RooAbsReal> specializeRatio(RooFormulaVar& input, const char* targetRangeName) const ;
180-
double calculate(const RooProdPdf::CacheElem& cache, bool verbose=false) const ;
181-
void doEvalImpl(RooAbsArg const* caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &) const;
182-
180+
double calculate(const RooProdPdf::CacheElem &cache, bool verbose = false) const;
183181

184182
friend class RooProdGenContext ;
185183
friend class RooFit::Detail::RooFixedProdPdf ;
@@ -202,15 +200,10 @@ class RooProdPdf : public RooAbsPdf {
202200
bool _selfNorm = true; ///< Is self-normalized
203201
RooArgSet _defNormSet ; ///< Default normalization set
204202

205-
private:
206-
207-
208-
209203
ClassDefOverride(RooProdPdf,6) // PDF representing a product of PDFs
210204
};
211205

212-
namespace RooFit {
213-
namespace Detail {
206+
namespace RooFit::Detail {
214207

215208
/// A RooProdPdf with a fixed normalization set can be replaced by this class.
216209
/// Its purpose is to provide the right client-server interface for the
@@ -227,7 +220,7 @@ class RooFixedProdPdf : public RooAbsPdf {
227220

228221
inline bool canComputeBatchWithCuda() const override { return true; }
229222

230-
inline void doEval(RooFit::EvalContext &ctx) const override { _prodPdf->doEvalImpl(this, *_cache, ctx); }
223+
void doEval(RooFit::EvalContext &ctx) const override;
231224

232225
inline ExtendMode extendMode() const override { return _prodPdf->extendMode(); }
233226
inline double expectedEvents(const RooArgSet * /*nset*/) const override
@@ -260,22 +253,30 @@ class RooFixedProdPdf : public RooAbsPdf {
260253
return _prodPdf->analyticalIntegral(code, rangeName);
261254
}
262255

263-
RooProdPdf::CacheElem const &cache() const { return *_cache; }
256+
bool isRearranged() const { return _isRearranged; }
264257

265-
private:
266-
void initialize();
258+
RooAbsReal const *rearrangedNum() const
259+
{
260+
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[0]) : nullptr;
261+
}
262+
RooAbsReal const *rearrangedDen() const
263+
{
264+
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[1]) : nullptr;
265+
}
266+
267+
RooArgSet const *partList() const { return !_isRearranged ? static_cast<RooArgSet const *>(&_servers) : nullptr; }
267268

268-
inline double evaluate() const override { return _prodPdf->calculate(*_cache); }
269+
private:
270+
double evaluate() const override;
269271

270272
RooArgSet _normSet;
271-
std::unique_ptr<RooProdPdf::CacheElem> _cache;
272273
RooSetProxy _servers;
273274
std::unique_ptr<RooProdPdf> _prodPdf;
275+
bool _isRearranged = false;
274276

275277
ClassDefOverride(RooFit::Detail::RooFixedProdPdf, 0);
276278
};
277279

278-
} // namespace Detail
279-
} // namespace RooFit
280+
} // namespace RooFit::Detail
280281

281282
#endif

roofit/roofitcore/src/RooProdPdf.cxx

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,6 @@ double RooProdPdf::calculate(const RooProdPdf::CacheElem& cache, bool /*verbose*
409409
}
410410
}
411411

412-
////////////////////////////////////////////////////////////////////////////////
413-
/// Evaluate product of PDFs in batch mode.
414-
void RooProdPdf::doEvalImpl(RooAbsArg const *caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &ctx) const
415-
{
416-
if (cache._isRearranged) {
417-
auto numerator = ctx.at(cache._rearrangedNum.get());
418-
auto denominator = ctx.at(cache._rearrangedDen.get());
419-
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
420-
} else {
421-
std::vector<std::span<const double>> factors;
422-
factors.reserve(cache._partList.size());
423-
for (const RooAbsArg *i : cache._partList) {
424-
auto span = ctx.at(i);
425-
factors.push_back(span);
426-
}
427-
std::array<double, 1> special{static_cast<double>(factors.size())};
428-
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
429-
}
430-
}
431-
432412
namespace {
433413

434414
template<class T>
@@ -2185,44 +2165,85 @@ RooProdPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileC
21852165
return fixedProdPdf;
21862166
}
21872167

2188-
namespace RooFit {
2189-
namespace Detail {
2168+
namespace RooFit::Detail {
21902169

21912170
RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet)
21922171
: RooAbsPdf(prodPdf->GetName(), prodPdf->GetTitle()),
21932172
_normSet{normSet},
21942173
_servers("!servers", "List of servers", this),
21952174
_prodPdf{std::move(prodPdf)}
21962175
{
2197-
initialize();
2176+
auto cache = _prodPdf->createCacheElem(&_normSet, nullptr);
2177+
_isRearranged = cache->_isRearranged;
2178+
2179+
// The actual servers for a given normalization set depend on whether the
2180+
// cache is rearranged or not. See RooProdPdf::calculate to see
2181+
// which args in the cache are used directly.
2182+
if (_isRearranged) {
2183+
_servers.add(*cache->_rearrangedNum);
2184+
_servers.add(*cache->_rearrangedDen);
2185+
addOwnedComponents(std::move(cache->_rearrangedNum));
2186+
addOwnedComponents(std::move(cache->_rearrangedDen));
2187+
return;
2188+
}
2189+
// We don't want to carry the full cache object around, so we let it go out
2190+
// of scope and transfer the ownership of the args that we actually need.
2191+
cache->_ownedList.releaseOwnership();
2192+
std::vector<std::unique_ptr<RooAbsArg>> owned;
2193+
for (RooAbsArg *arg : cache->_ownedList) {
2194+
owned.emplace_back(arg);
2195+
}
2196+
for (RooAbsArg *arg : cache->_partList) {
2197+
_servers.add(*arg);
2198+
auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); });
2199+
if (found != owned.end()) {
2200+
addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)]));
2201+
}
2202+
}
21982203
}
21992204

22002205
RooFixedProdPdf::RooFixedProdPdf(const RooFixedProdPdf &other, const char *name)
22012206
: RooAbsPdf(other, name),
22022207
_normSet{other._normSet},
2203-
_servers("!servers", "List of servers", this),
2204-
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())}
2208+
//_servers("!servers", "List of servers", this),
2209+
_servers("!servers", this, other._servers),
2210+
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())},
2211+
_isRearranged{other._isRearranged}
2212+
{
2213+
}
2214+
2215+
////////////////////////////////////////////////////////////////////////////////
2216+
/// Evaluate product of PDFs in batch mode.
2217+
2218+
void RooFixedProdPdf::doEval(RooFit::EvalContext &ctx) const
22052219
{
2206-
initialize();
2220+
if (_isRearranged) {
2221+
auto numerator = ctx.at(rearrangedNum());
2222+
auto denominator = ctx.at(rearrangedDen());
2223+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
2224+
return;
2225+
}
2226+
std::vector<std::span<const double>> factors;
2227+
factors.reserve(partList()->size());
2228+
for (const RooAbsArg *arg : *partList()) {
2229+
auto span = ctx.at(arg);
2230+
factors.push_back(span);
2231+
}
2232+
std::array<double, 1> special{static_cast<double>(factors.size())};
2233+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
22072234
}
22082235

2209-
void RooFixedProdPdf::initialize()
2236+
double RooFixedProdPdf::evaluate() const
22102237
{
2211-
_cache = _prodPdf->createCacheElem(&_normSet, nullptr);
2212-
auto &cache = *_cache;
2238+
if (_isRearranged) {
2239+
return rearrangedNum()->getVal() / rearrangedDen()->getVal();
2240+
}
2241+
double value = 1.0;
22132242

2214-
// The actual servers for a given normalization set depend on whether the
2215-
// cache is rearranged or not. See RooProdPdf::calculateBatch to see
2216-
// which args in the cache are used directly.
2217-
if (cache._isRearranged) {
2218-
_servers.add(*cache._rearrangedNum);
2219-
_servers.add(*cache._rearrangedDen);
2220-
} else {
2221-
for (std::size_t i = 0; i < cache._partList.size(); ++i) {
2222-
_servers.add(cache._partList[i]);
2223-
}
2243+
for (auto *arg : static_range_cast<RooAbsReal *>(*partList())) {
2244+
value *= arg->getVal();
22242245
}
2246+
return value;
22252247
}
22262248

2227-
} // namespace Detail
2228-
} // namespace RooFit
2249+
} // namespace RooFit::Detail

0 commit comments

Comments
 (0)