Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdks/python/apache_beam/runners/worker/operations.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ cdef class DoOperation(Operation):
cdef dict timer_specs
cdef public object input_info
cdef object fn
cdef object scoped_timer_processing_state


cdef class SdfProcessSizedElements(DoOperation):
Expand Down
22 changes: 13 additions & 9 deletions sdks/python/apache_beam/runners/worker/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,10 @@ def __init__(
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
# A mapping of timer tags to the input "PCollections" they come in on.
self.input_info = None # type: Optional[OpInputInfo]

self.scoped_timer_processing_state = self.state_sampler.scoped_state(
self.name_context,
'process-timers',
metrics_container=self.metrics_container)
# See fn_data in dataflow_runner.py
# TODO: Store all the items from spec?
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
Expand Down Expand Up @@ -971,14 +974,15 @@ def add_timer_info(self, timer_family_id, timer_info):
self.user_state_context.add_timer_info(timer_family_id, timer_info)

def process_timer(self, tag, timer_data):
timer_spec = self.timer_specs[tag]
self.dofn_runner.process_user_timer(
timer_spec,
timer_data.user_key,
timer_data.windows[0],
timer_data.fire_timestamp,
timer_data.paneinfo,
timer_data.dynamic_timer_tag)
with self.scoped_timer_processing_state:
timer_spec = self.timer_specs[tag]
self.dofn_runner.process_user_timer(
timer_spec,
timer_data.user_key,
timer_data.windows[0],
timer_data.fire_timestamp,
timer_data.paneinfo,
timer_data.dynamic_timer_tag)

def finish(self):
# type: () -> None
Expand Down
185 changes: 185 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,56 @@
import logging
import time
import unittest
from unittest import mock
from unittest.mock import Mock
from unittest.mock import patch

from tenacity import retry
from tenacity import stop_after_attempt

from apache_beam.internal import pickler
from apache_beam.runners import common
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
from apache_beam.transforms import core
from apache_beam.transforms import userstate
from apache_beam.transforms.core import GlobalWindows
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.counters import CounterFactory
from apache_beam.utils.counters import CounterName
from apache_beam.utils.windowed_value import PaneInfo

_LOGGER = logging.getLogger(__name__)


class TimerDoFn(core.DoFn):
TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)

def __init__(self, sleep_duration_s=0):
self._sleep_duration_s = sleep_duration_s

@userstate.on_timer(TIMER_SPEC)
def on_timer_f(self):
if self._sleep_duration_s:
time.sleep(self._sleep_duration_s)


class ExceptionTimerDoFn(core.DoFn):
"""A DoFn that raises an exception when its timer fires."""
TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK)

def __init__(self, sleep_duration_s=0):
self._sleep_duration_s = sleep_duration_s

@userstate.on_timer(TIMER_SPEC)
def on_timer_f(self):
if self._sleep_duration_s:
time.sleep(self._sleep_duration_s)
raise RuntimeError("Test exception from timer")


class StateSamplerTest(unittest.TestCase):

# Due to somewhat non-deterministic nature of state sampling and sleep,
Expand Down Expand Up @@ -127,6 +166,152 @@ def test_sampler_transition_overhead(self):
# debug mode).
self.assertLess(overhead_us, 20.0)

@retry(reraise=True, stop=stop_after_attempt(3))
# Patch the problematic function to return the correct timer spec
@patch('apache_beam.transforms.userstate.get_dofn_specs')
def test_do_operation_process_timer(self, mock_get_dofn_specs):
fn = TimerDoFn()
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])

if not statesampler.FAST_SAMPLER:
self.skipTest('DoOperation test requires FAST_SAMPLER')

state_duration_ms = 200
margin_of_error = 0.75

counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'test_do_op', counter_factory, sampling_period_ms=1)

fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0)

spec = operation_specs.WorkerDoFn(
serialized_fn=pickler.dumps(
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
output_tags=[],
input=None,
side_inputs=[],
output_coders=[])

mock_user_state_context = mock.MagicMock()
op = operations.DoOperation(
common.NameContext('step1'),
spec,
counter_factory,
sampler,
user_state_context=mock_user_state_context)

op.setup()

timer_data = Mock()
timer_data.user_key = None
timer_data.windows = [GlobalWindow()]
timer_data.fire_timestamp = 0
timer_data.paneinfo = PaneInfo(
is_first=False,
is_last=False,
timing=0,
index=0,
nonspeculative_index=0)
timer_data.dynamic_timer_tag = ''

sampler.start()
op.process_timer('ts-timer', timer_data=timer_data)
sampler.stop()
sampler.commit_counters()

expected_name = CounterName(
'process-timers-msecs', step_name='step1', stage_name='test_do_op')

found_counter = None
for counter in counter_factory.get_counters():
if counter.name == expected_name:
found_counter = counter
break

self.assertIsNotNone(
found_counter, f"Expected counter '{expected_name}' to be created.")

actual_value = found_counter.value()
logging.info("Actual value %d", actual_value)
self.assertGreater(
actual_value, state_duration_ms * (1.0 - margin_of_error))

@retry(reraise=True, stop=stop_after_attempt(3))
@patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs')
def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs):
fn = ExceptionTimerDoFn()
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])

if not statesampler.FAST_SAMPLER:
self.skipTest('DoOperation test requires FAST_SAMPLER')

state_duration_ms = 200
margin_of_error = 0.50

counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'test_do_op_exception', counter_factory, sampling_period_ms=1)

fn_for_spec = ExceptionTimerDoFn(
sleep_duration_s=state_duration_ms / 1000.0)

spec = operation_specs.WorkerDoFn(
serialized_fn=pickler.dumps(
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
output_tags=[],
input=None,
side_inputs=[],
output_coders=[])

mock_user_state_context = mock.MagicMock()
op = operations.DoOperation(
common.NameContext('step1'),
spec,
counter_factory,
sampler,
user_state_context=mock_user_state_context)

op.setup()

timer_data = Mock()
timer_data.user_key = None
timer_data.windows = [GlobalWindow()]
timer_data.fire_timestamp = 0
timer_data.paneinfo = PaneInfo(
is_first=False,
is_last=False,
timing=0,
index=0,
nonspeculative_index=0)
timer_data.dynamic_timer_tag = ''

sampler.start()
# Assert that the expected exception is raised
with self.assertRaises(RuntimeError):
op.process_timer('ts-ts-timer', timer_data=timer_data)
sampler.stop()
sampler.commit_counters()

expected_name = CounterName(
'process-timers-msecs',
step_name='step1',
stage_name='test_do_op_exception')

found_counter = None
for counter in counter_factory.get_counters():
if counter.name == expected_name:
found_counter = counter
break

self.assertIsNotNone(
found_counter, f"Expected counter '{expected_name}' to be created.")

actual_value = found_counter.value()
self.assertGreater(
actual_value, state_duration_ms * (1.0 - margin_of_error))
_LOGGER.info("Exception test finished successfully.")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
Loading