Skip to content

Commit ffbdfaf

Browse files
committed
[RF] Avoid compute graph desync in RooFixedProdPdf
The RooFixedProdPdf used an antipattern that should be avoided, and it was actually causing some problems in fits with the new CPU evaluation backend. The problem was that the class held on to some RooAbsArgs for its evaluation in some containers that were separate from RooFit server-client interface (the container was the `std::unique_ptr<RooProdPdf::CacheElem>`). This commit fixes the situation by avoiding this cache element member completely and consistently using the server proxies instead.
1 parent 4731fd5 commit ffbdfaf

File tree

3 files changed

+82
-59
lines changed

3 files changed

+82
-59
lines changed

roofit/codegen/src/CodegenImpl.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R
137137

138138
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
139139
{
140-
if (arg.cache()._isRearranged) {
141-
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen));
140+
if (arg.isRearranged()) {
141+
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.rearrangedNum(), *arg.rearrangedDen()));
142142
} else {
143-
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), arg.cache()._partList, arg.cache()._partList.size()));
143+
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), *arg.partList(), arg.partList()->size()));
144144
}
145145
}
146146

roofit/roofitcore/inc/RooProdPdf.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ class RooProdPdf : public RooAbsPdf {
178178
std::unique_ptr<RooAbsReal> specializeIntegral(RooAbsReal& orig, const char* targetRangeName) const ;
179179
std::unique_ptr<RooAbsReal> specializeRatio(RooFormulaVar& input, const char* targetRangeName) const ;
180180
double calculate(const RooProdPdf::CacheElem& cache, bool verbose=false) const ;
181-
void doEvalImpl(RooAbsArg const* caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &) const;
182181

183182

184183
friend class RooProdGenContext ;
@@ -202,15 +201,10 @@ class RooProdPdf : public RooAbsPdf {
202201
bool _selfNorm = true; ///< Is self-normalized
203202
RooArgSet _defNormSet ; ///< Default normalization set
204203

205-
private:
206-
207-
208-
209204
ClassDefOverride(RooProdPdf,6) // PDF representing a product of PDFs
210205
};
211206

212-
namespace RooFit {
213-
namespace Detail {
207+
namespace RooFit::Detail {
214208

215209
/// A RooProdPdf with a fixed normalization set can be replaced by this class.
216210
/// Its purpose is to provide the right client-server interface for the
@@ -227,7 +221,7 @@ class RooFixedProdPdf : public RooAbsPdf {
227221

228222
inline bool canComputeBatchWithCuda() const override { return true; }
229223

230-
inline void doEval(RooFit::EvalContext &ctx) const override { _prodPdf->doEvalImpl(this, *_cache, ctx); }
224+
void doEval(RooFit::EvalContext &ctx) const override;
231225

232226
inline ExtendMode extendMode() const override { return _prodPdf->extendMode(); }
233227
inline double expectedEvents(const RooArgSet * /*nset*/) const override
@@ -260,22 +254,30 @@ class RooFixedProdPdf : public RooAbsPdf {
260254
return _prodPdf->analyticalIntegral(code, rangeName);
261255
}
262256

263-
RooProdPdf::CacheElem const &cache() const { return *_cache; }
257+
bool isRearranged() const { return _isRearranged; }
264258

265-
private:
266-
void initialize();
259+
RooAbsReal const *rearrangedNum() const
260+
{
261+
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[0]) : nullptr;
262+
}
263+
RooAbsReal const *rearrangedDen() const
264+
{
265+
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[1]) : nullptr;
266+
}
267+
268+
RooArgSet const *partList() const { return !_isRearranged ? static_cast<RooArgSet const *>(&_servers) : nullptr; }
267269

268-
inline double evaluate() const override { return _prodPdf->calculate(*_cache); }
270+
private:
271+
double evaluate() const override;
269272

270273
RooArgSet _normSet;
271-
std::unique_ptr<RooProdPdf::CacheElem> _cache;
272274
RooSetProxy _servers;
273275
std::unique_ptr<RooProdPdf> _prodPdf;
276+
bool _isRearranged = false;
274277

275278
ClassDefOverride(RooFit::Detail::RooFixedProdPdf, 0);
276279
};
277280

278-
} // namespace Detail
279-
} // namespace RooFit
281+
} // namespace RooFit::Detail
280282

281283
#endif

roofit/roofitcore/src/RooProdPdf.cxx

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,6 @@ double RooProdPdf::calculate(const RooProdPdf::CacheElem& cache, bool /*verbose*
409409
}
410410
}
411411

412-
////////////////////////////////////////////////////////////////////////////////
413-
/// Evaluate product of PDFs in batch mode.
414-
void RooProdPdf::doEvalImpl(RooAbsArg const *caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &ctx) const
415-
{
416-
if (cache._isRearranged) {
417-
auto numerator = ctx.at(cache._rearrangedNum.get());
418-
auto denominator = ctx.at(cache._rearrangedDen.get());
419-
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
420-
} else {
421-
std::vector<std::span<const double>> factors;
422-
factors.reserve(cache._partList.size());
423-
for (const RooAbsArg *i : cache._partList) {
424-
auto span = ctx.at(i);
425-
factors.push_back(span);
426-
}
427-
std::array<double, 1> special{static_cast<double>(factors.size())};
428-
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
429-
}
430-
}
431-
432412
namespace {
433413

434414
template<class T>
@@ -2185,44 +2165,85 @@ RooProdPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileC
21852165
return fixedProdPdf;
21862166
}
21872167

2188-
namespace RooFit {
2189-
namespace Detail {
2168+
namespace RooFit::Detail {
21902169

21912170
RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet)
21922171
: RooAbsPdf(prodPdf->GetName(), prodPdf->GetTitle()),
21932172
_normSet{normSet},
21942173
_servers("!servers", "List of servers", this),
21952174
_prodPdf{std::move(prodPdf)}
21962175
{
2197-
initialize();
2176+
auto cache = _prodPdf->createCacheElem(&_normSet, nullptr);
2177+
_isRearranged = cache->_isRearranged;
2178+
2179+
// The actual servers for a given normalization set depend on whether the
2180+
// cache is rearranged or not. See RooProdPdf::calculate to see
2181+
// which args in the cache are used directly.
2182+
if (_isRearranged) {
2183+
_servers.add(*cache->_rearrangedNum);
2184+
_servers.add(*cache->_rearrangedDen);
2185+
addOwnedComponents(std::move(cache->_rearrangedNum));
2186+
addOwnedComponents(std::move(cache->_rearrangedDen));
2187+
return;
2188+
}
2189+
// We don't want to carry the full cache object around, so we let it go out
2190+
// of scope and transfer the ownership of the args that we actually need.
2191+
cache->_ownedList.releaseOwnership();
2192+
std::vector<std::unique_ptr<RooAbsArg>> owned;
2193+
for (RooAbsArg *arg : cache->_ownedList) {
2194+
owned.emplace_back(arg);
2195+
}
2196+
for (RooAbsArg *arg : cache->_partList) {
2197+
_servers.add(*arg);
2198+
auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); });
2199+
if (found != owned.end()) {
2200+
addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)]));
2201+
}
2202+
}
21982203
}
21992204

22002205
RooFixedProdPdf::RooFixedProdPdf(const RooFixedProdPdf &other, const char *name)
22012206
: RooAbsPdf(other, name),
22022207
_normSet{other._normSet},
2203-
_servers("!servers", "List of servers", this),
2204-
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())}
2208+
//_servers("!servers", "List of servers", this),
2209+
_servers("!servers", this, other._servers),
2210+
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())},
2211+
_isRearranged{other._isRearranged}
2212+
{
2213+
}
2214+
2215+
////////////////////////////////////////////////////////////////////////////////
2216+
/// Evaluate product of PDFs in batch mode.
2217+
2218+
void RooFixedProdPdf::doEval(RooFit::EvalContext &ctx) const
22052219
{
2206-
initialize();
2220+
if (_isRearranged) {
2221+
auto numerator = ctx.at(rearrangedNum());
2222+
auto denominator = ctx.at(rearrangedDen());
2223+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
2224+
return;
2225+
}
2226+
std::vector<std::span<const double>> factors;
2227+
factors.reserve(partList()->size());
2228+
for (const RooAbsArg *arg : *partList()) {
2229+
auto span = ctx.at(arg);
2230+
factors.push_back(span);
2231+
}
2232+
std::array<double, 1> special{static_cast<double>(factors.size())};
2233+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
22072234
}
22082235

2209-
void RooFixedProdPdf::initialize()
2236+
double RooFixedProdPdf::evaluate() const
22102237
{
2211-
_cache = _prodPdf->createCacheElem(&_normSet, nullptr);
2212-
auto &cache = *_cache;
2238+
if (_isRearranged) {
2239+
return rearrangedNum()->getVal() / rearrangedDen()->getVal();
2240+
}
2241+
double value = 1.0;
22132242

2214-
// The actual servers for a given normalization set depend on whether the
2215-
// cache is rearranged or not. See RooProdPdf::calculateBatch to see
2216-
// which args in the cache are used directly.
2217-
if (cache._isRearranged) {
2218-
_servers.add(*cache._rearrangedNum);
2219-
_servers.add(*cache._rearrangedDen);
2220-
} else {
2221-
for (std::size_t i = 0; i < cache._partList.size(); ++i) {
2222-
_servers.add(cache._partList[i]);
2223-
}
2243+
for (auto *arg : static_range_cast<RooAbsReal *>(*partList())) {
2244+
value *= arg->getVal();
22242245
}
2246+
return value;
22252247
}
22262248

2227-
} // namespace Detail
2228-
} // namespace RooFit
2249+
} // namespace RooFit::Detail

0 commit comments

Comments
 (0)