Skip to content

Commit 206b960

Browse files
committed
move custom exit code and signal handlers away from retry decorator
1 parent 9c72cff commit 206b960

File tree

7 files changed

+64
-53
lines changed

7 files changed

+64
-53
lines changed

metaflow/plugins/argo/argo_workflows.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
5555

5656
from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet
57-
from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents
57+
from metaflow.plugins.kubernetes.kubernetes import SPOT_INTERRUPT_EXITCODE
58+
from metaflow.plugins.retry_decorator import RetryEvents
5859
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
5960
from metaflow.user_configs.config_options import ConfigInput
6061
from metaflow.util import (
@@ -1553,7 +1554,7 @@ def _container_templates(self):
15531554
event_to_expr = {
15541555
RetryEvents.STEP: "asInt(lastRetry.exitCode) == 1",
15551556
RetryEvents.PREEMPT: "asInt(lastRetry.exitCode) == %s"
1556-
% PLATFORM_EVICTED_EXITCODE,
1557+
% SPOT_INTERRUPT_EXITCODE,
15571558
}
15581559
retry_expr = None
15591560
if retry_conditions:

metaflow/plugins/aws/batch/batch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
4444
STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
4545

46+
SPOT_INTERRUPT_EXITCODE = 234
47+
4648

4749
class BatchException(MetaflowException):
4850
headline = "AWS Batch error"
@@ -488,7 +490,7 @@ def wait_for_launch(job, child_jobs):
488490
if self.job.is_crashed:
489491

490492
# Custom exception for spot instance terminations
491-
if self.job.status_code == 234:
493+
if self.job.status_code == SPOT_INTERRUPT_EXITCODE:
492494
raise BatchSpotInstanceTerminated()
493495

494496
msg = next(

metaflow/plugins/aws/batch/batch_decorator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import platform
3+
import signal
34
import sys
5+
import threading
46
import time
57

68
from metaflow import R, current
@@ -24,7 +26,7 @@
2426
get_docker_registry,
2527
get_ec2_instance_metadata,
2628
)
27-
from .batch import BatchException
29+
from .batch import SPOT_INTERRUPT_EXITCODE, BatchException
2830

2931

3032
class BatchDecorator(StepDecorator):
@@ -298,6 +300,29 @@ def task_pre_step(
298300
self._save_logs_sidecar = Sidecar("save_logs_periodically")
299301
self._save_logs_sidecar.start()
300302

303+
# Set up signal handling for spot termination
304+
main_pid = os.getpid()
305+
306+
def _termination_timer():
307+
time.sleep(30)
308+
os.kill(main_pid, signal.SIGALRM)
309+
310+
def _spot_term_signal_handler(*args, **kwargs):
311+
if os.path.isfile(current.spot_termination_notice):
312+
print(
313+
"Spot instance termination detected. Starting a timer to end the Metaflow task."
314+
)
315+
timer_thread = threading.Thread(
316+
target=_termination_timer, daemon=True
317+
)
318+
timer_thread.start()
319+
320+
def _curtain_call(*args, **kwargs):
321+
# custom exit code in case of Spot termination
322+
sys.exit(SPOT_INTERRUPT_EXITCODE)
323+
324+
signal.signal(signal.SIGUSR1, _spot_term_signal_handler)
325+
signal.signal(signal.SIGALRM, _curtain_call)
301326
# Start spot termination monitor sidecar.
302327
# TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process.
303328
os.environ["MF_MAIN_PID"] = str(os.getpid())

metaflow/plugins/aws/step_functions/step_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
1919
)
2020
from metaflow.parameters import deploy_time_eval
21-
from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents
21+
from metaflow.plugins.retry_decorator import RetryEvents
2222
from metaflow.user_configs.config_options import ConfigInput
2323
from metaflow.util import dict_to_cli_options, to_pascalcase
2424

25-
from ..batch.batch import Batch
25+
from ..batch.batch import Batch, SPOT_INTERRUPT_EXITCODE
2626
from .event_bridge_client import EventBridgeClient
2727
from .step_functions_client import StepFunctionsClient
2828

@@ -842,7 +842,7 @@ def _batch(self, node):
842842
RetryEvents.STEP: {"action": "RETRY", "onExitCode": "1"},
843843
RetryEvents.PREEMPT: {
844844
"action": "RETRY",
845-
"onExitCode": str(PLATFORM_EVICTED_EXITCODE),
845+
"onExitCode": str(SPOT_INTERRUPT_EXITCODE),
846846
},
847847
}
848848
retry_expr = None

metaflow/plugins/kubernetes/kubernetes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
"{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}"
6464
)
6565

66+
SPOT_INTERRUPT_EXITCODE = 234
67+
6668

6769
class KubernetesException(MetaflowException):
6870
headline = "Kubernetes error"
@@ -768,7 +770,7 @@ def _has_updates():
768770
)
769771
if int(exit_code) == 134:
770772
raise KubernetesException("%s (exit code %s)" % (msg, exit_code))
771-
if int(exit_code) == 234:
773+
if int(exit_code) == SPOT_INTERRUPT_EXITCODE:
772774
# NOTE. K8S exit codes are mod 256
773775
raise KubernetesSpotInstanceTerminated()
774776
else:

metaflow/plugins/kubernetes/kubernetes_decorator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
22
import os
33
import platform
4+
import signal
45
import sys
6+
import threading
57
import time
68

79
from metaflow import current
@@ -37,7 +39,7 @@
3739
from metaflow.unbounded_foreach import UBF_CONTROL
3840

3941
from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
40-
from .kubernetes import KubernetesException
42+
from .kubernetes import KubernetesException, SPOT_INTERRUPT_EXITCODE
4143
from .kube_utils import validate_kube_labels, parse_kube_keyvalue_list
4244

4345
try:
@@ -548,6 +550,29 @@ def task_pre_step(
548550
self._save_logs_sidecar = Sidecar("save_logs_periodically")
549551
self._save_logs_sidecar.start()
550552

553+
# Set up signal handling for spot termination
554+
main_pid = os.getpid()
555+
556+
def _termination_timer():
557+
time.sleep(30)
558+
os.kill(main_pid, signal.SIGALRM)
559+
560+
def _spot_term_signal_handler(*args, **kwargs):
561+
if os.path.isfile(current.spot_termination_notice):
562+
print(
563+
"Spot instance termination detected. Starting a timer to end the Metaflow task."
564+
)
565+
timer_thread = threading.Thread(
566+
target=_termination_timer, daemon=True
567+
)
568+
timer_thread.start()
569+
570+
def _curtain_call(*args, **kwargs):
571+
# custom exit code in case of Spot termination
572+
sys.exit(SPOT_INTERRUPT_EXITCODE)
573+
574+
signal.signal(signal.SIGUSR1, _spot_term_signal_handler)
575+
signal.signal(signal.SIGALRM, _curtain_call)
551576
# Start spot termination monitor sidecar.
552577
# TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process.
553578
os.environ["MF_MAIN_PID"] = str(os.getpid())

metaflow/plugins/retry_decorator.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,15 @@
11
from enum import Enum
2-
import os
3-
import signal
4-
import sys
5-
import threading
6-
from time import sleep
72

83
from metaflow.decorators import StepDecorator
94
from metaflow.exception import MetaflowException
105
from metaflow.metaflow_config import MAX_ATTEMPTS
11-
from metaflow import current
126

137

148
class RetryEvents(Enum):
159
STEP = "step"
1610
PREEMPT = "instance-preemption"
1711

1812

19-
PLATFORM_EVICTED_EXITCODE = 234
20-
21-
2213
class RetryDecorator(StepDecorator):
2314
"""
2415
Specifies the number of times the task corresponding
@@ -76,40 +67,5 @@ def _known_event(event: str):
7667
% ", ".join("*%s*" % event for event in unsupported_events)
7768
)
7869

79-
def task_pre_step(
80-
self,
81-
step_name,
82-
task_datastore,
83-
metadata,
84-
run_id,
85-
task_id,
86-
flow,
87-
graph,
88-
retry_count,
89-
max_user_code_retries,
90-
ubf_context,
91-
inputs,
92-
):
93-
pid = os.getpid()
94-
95-
def _termination_timer():
96-
sleep(30)
97-
os.kill(pid, signal.SIGALRM)
98-
99-
def _spot_term_signal_handler(*args, **kwargs):
100-
if os.path.isfile(current.spot_termination_notice):
101-
print(
102-
"Spot instance termination detected. Starting a timer to end the Metaflow task."
103-
)
104-
timer_thread = threading.Thread(target=_termination_timer, daemon=True)
105-
timer_thread.start()
106-
107-
def _curtain_call(*args, **kwargs):
108-
# custom exit code in case of Spot termination
109-
sys.exit(PLATFORM_EVICTED_EXITCODE)
110-
111-
signal.signal(signal.SIGUSR1, _spot_term_signal_handler)
112-
signal.signal(signal.SIGALRM, _curtain_call)
113-
11470
def step_task_retry_count(self):
11571
return int(self.attributes["times"]), 0

0 commit comments

Comments
 (0)