Skip to content

Commit 62643db

Browse files
committed
[BEAM-36736] Add state sampling for timer processing
1 parent f2860fa commit 62643db

File tree

3 files changed

+62
-11
lines changed

3 files changed

+62
-11
lines changed

sdks/python/apache_beam/runners/worker/operations.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,13 @@ def __init__(
809809
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
810810
# A mapping of timer tags to the input "PCollections" they come in on.
811811
self.input_info = None # type: Optional[OpInputInfo]
812-
812+
self.scoped_timer_processing_state = None
813+
if self.state_sampler:
814+
self.scoped_timer_processing_state = self.state_sampler.scoped_state(
815+
self.name_context,
816+
'process-timers',
817+
metrics_container=self.metrics_container,
818+
suffix="-millis")
813819
# See fn_data in dataflow_runner.py
814820
# TODO: Store all the items from spec?
815821
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
@@ -971,14 +977,21 @@ def add_timer_info(self, timer_family_id, timer_info):
971977
self.user_state_context.add_timer_info(timer_family_id, timer_info)
972978

973979
def process_timer(self, tag, timer_data):
974-
timer_spec = self.timer_specs[tag]
975-
self.dofn_runner.process_user_timer(
976-
timer_spec,
977-
timer_data.user_key,
978-
timer_data.windows[0],
979-
timer_data.fire_timestamp,
980-
timer_data.paneinfo,
981-
timer_data.dynamic_timer_tag)
980+
def process_timer_logic():
981+
timer_spec = self.timer_specs[tag]
982+
self.dofn_runner.process_user_timer(
983+
timer_spec,
984+
timer_data.user_key,
985+
timer_data.windows[0],
986+
timer_data.fire_timestamp,
987+
timer_data.paneinfo,
988+
timer_data.dynamic_timer_tag)
989+
990+
if self.scoped_timer_processing_state:
991+
with self.scoped_timer_processing_state:
992+
process_timer_logic()
993+
else:
994+
process_timer_logic()
982995

983996
def finish(self):
984997
# type: () -> None

sdks/python/apache_beam/runners/worker/statesampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def scoped_state(
134134
name_context: Union[str, 'common.NameContext'],
135135
state_name: str,
136136
io_target=None,
137-
metrics_container: Optional['MetricsContainer'] = None
137+
metrics_container: Optional['MetricsContainer'] = None,
138+
suffix: str = '-msecs'
138139
) -> statesampler_impl.ScopedState:
139140
"""Returns a ScopedState object associated to a Step and a State.
140141
@@ -152,7 +153,7 @@ def scoped_state(
152153
name_context = common.NameContext(name_context)
153154

154155
counter_name = CounterName(
155-
state_name + '-msecs',
156+
state_name + suffix,
156157
stage_name=self._prefix,
157158
step_name=name_context.metrics_name(),
158159
io_target=io_target)

sdks/python/apache_beam/runners/worker/statesampler_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,43 @@ def test_sampler_transition_overhead(self):
127127
# debug mode).
128128
self.assertLess(overhead_us, 20.0)
129129

130+
@retry(reraise=True, stop=stop_after_attempt(3))
131+
def test_timer_sampler(self):
132+
# Set up state sampler.
133+
counter_factory = CounterFactory()
134+
sampler = statesampler.StateSampler(
135+
'timer', counter_factory, sampling_period_ms=1)
136+
137+
# Duration of the timer processing.
138+
state_duration_ms = 100
139+
margin_of_error = 0.25
140+
141+
sampler.start()
142+
with sampler.scoped_state(
143+
'step1', 'process-timers', suffix='-millis'):
144+
time.sleep(state_duration_ms / 1000)
145+
sampler.stop()
146+
sampler.commit_counters()
147+
148+
if not statesampler.FAST_SAMPLER:
149+
# The slow sampler does not implement sampling, so we won't test it.
150+
return
151+
152+
# Test that sampled state timings are close to their expected values.
153+
expected_counter_values = {
154+
CounterName('process-timers-millis', step_name='step1', stage_name='timer'):
155+
state_duration_ms,
156+
}
157+
for counter in counter_factory.get_counters():
158+
self.assertIn(counter.name, expected_counter_values)
159+
expected_value = expected_counter_values[counter.name]
160+
actual_value = counter.value()
161+
deviation = float(abs(actual_value - expected_value)) / expected_value
162+
_LOGGER.info('Sampling deviation from expectation: %f', deviation)
163+
self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error))
164+
self.assertLess(actual_value, expected_value * (1.0 + margin_of_error))
165+
166+
130167

131168
if __name__ == '__main__':
132169
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)