Skip to content

Commit 10d7bd3

Browse files
committed
[Python][UHI] Add implementation of __iter__ to fix list(h)
1 parent 0a34f6c commit 10d7bd3

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,6 @@ def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice
169169
if index == -1:
170170
return _overflow(self, axis) - 1
171171

172-
if index == _overflow(self, axis):
173-
return index + (1 if include_flow_bins else 0)
174-
175172
# Shift the indices by 1 to align with the UHI convention,
176173
# where 0 corresponds to the first bin, unlike ROOT where 0 represents underflow and 1 is the first bin.
177174
nbins = _get_axis_len(self, axis) + (1 if is_slice_stop else 0)
@@ -379,9 +376,16 @@ def _setitem(self, index, value):
379376
_slice_set(self, uhi_index, index, value)
380377

381378

379+
def _iter(self):
380+
array = _values_by_copy(self, include_flow_bins=True)
381+
for val in array.flat:
382+
yield val.item()
383+
384+
382385
def _add_indexing_features(klass: Any) -> None:
383386
klass.__getitem__ = _getitem
384387
klass.__setitem__ = _setitem
388+
klass.__iter__ = _iter
385389

386390

387391
"""
@@ -492,18 +496,20 @@ def _values_default(self) -> np.typing.NDArray[Any]: # noqa: F821
492496

493497
# Special case for TH1K: we need the array length to correspond to the number of bins
494498
# according to the UHI plotting protocol
495-
def _values_by_copy(self) -> np.typing.NDArray[Any]: # noqa: F821
499+
def _values_by_copy(self, include_flow_bins=False) -> np.typing.NDArray[Any]: # noqa: F821
496500
from itertools import product
497501

498502
import numpy as np
499503

504+
offset = 0 if include_flow_bins else 1
500505
dimensions = [
501-
range(1, _get_axis_len(self, axis, include_flow_bins=False) + 1) for axis in range(self.GetDimension())
506+
range(offset, _get_axis_len(self, axis, include_flow_bins=include_flow_bins) + offset)
507+
for axis in range(self.GetDimension())
502508
]
503509
bin_combinations = product(*dimensions)
504510

505511
return np.array([self.GetBinContent(*bin) for bin in bin_combinations]).reshape(
506-
_shape(self, include_flow_bins=False)
512+
_shape(self, include_flow_bins=include_flow_bins)
507513
)
508514

509515

bindings/pyroot/pythonizations/test/uhi_indexing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,16 @@ def test_statistics_slice(self, hist_setup):
315315
assert hist_setup.GetStdDev() == pytest.approx(sliced_hist.GetStdDev(), rel=10e-5)
316316
assert hist_setup.GetMean() == pytest.approx(sliced_hist.GetMean(), rel=10e-5)
317317

318+
def test_list_iter(self, hist_setup):
319+
import numpy as np
320+
321+
if _special_setting(hist_setup):
322+
pytest.skip("Setting cannot be tested here")
323+
324+
expected = np.full(_shape(hist_setup), 3, dtype=np.int64)
325+
hist_setup[...] = expected
326+
assert list(hist_setup) == expected.flatten().tolist()
327+
318328

319329
if __name__ == "__main__":
320330
raise SystemExit(pytest.main(args=[__file__]))

0 commit comments

Comments
 (0)