Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_curator/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
# We cannot use a free port between 10000 and 19999 as it is used by Ray.
DEFAULT_RAY_MIN_WORKER_PORT = 10002
DEFAULT_RAY_MAX_WORKER_PORT = 19999
RAY_CLUSTER_START_VERIFICATION_TIMEOUT = 300
62 changes: 62 additions & 0 deletions nemo_curator/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from typing import TYPE_CHECKING

import ray
import tenacity
from loguru import logger
from ray._private.services import canonicalize_bootstrap_address, find_gcs_addresses

from nemo_curator.core.constants import (
DEFAULT_RAY_AUTOSCALER_METRIC_PORT,
DEFAULT_RAY_DASHBOARD_METRIC_PORT,
DEFAULT_RAY_MAX_WORKER_PORT,
DEFAULT_RAY_MIN_WORKER_PORT,
RAY_CLUSTER_START_VERIFICATION_TIMEOUT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +71,47 @@ def _logger_custom_deserializer(
return logger


@tenacity.retry(
wait=tenacity.wait_fixed(1),
stop=tenacity.stop_after_delay(RAY_CLUSTER_START_VERIFICATION_TIMEOUT),
retry=tenacity.retry_if_result(lambda x: x is False),
reraise=True,
)
def _verify_gcs_running(expected_address: str, proc: subprocess.Popen) -> bool:
"""Verify that the Ray GCS is running at the expected address.

Args:
expected_address: The expected GCS address (ip:port format)
proc: The subprocess running the Ray cluster

Returns:
True if GCS is running at expected address, False otherwise

Raises:
RuntimeError: If the Ray process exited with an error
"""
# Check if the process exited with an error
returncode = proc.poll()
if returncode is not None:
msg = f"Ray cluster failed to start. Process exited with code {returncode}."
logger.error(msg)
raise RuntimeError(msg)

# Check if GCS is running at the expected address
gcs_addresses = find_gcs_addresses()
if gcs_addresses:
# Canonicalize both addresses for comparison
canonical_gcs_addresses = []
for gcs_address in gcs_addresses:
canonical_gcs_addresses.append(canonicalize_bootstrap_address(gcs_address))
canonical_expected_address = canonicalize_bootstrap_address(expected_address)
if canonical_expected_address in canonical_gcs_addresses:
logger.info(f"Ray cluster successfully started at {expected_address}")
return True
logger.debug(f"Found GCS at {gcs_addresses}, waiting for {expected_address}")
return False


def init_cluster( # noqa: PLR0913
ray_port: int,
ray_temp_dir: str,
Expand Down Expand Up @@ -123,4 +167,22 @@ def init_cluster( # noqa: PLR0913

proc = subprocess.Popen(ray_command, shell=False) # noqa: S603
logger.info(f"Ray start command: {' '.join(ray_command)}")

# Verify that Ray cluster actually started successfully using tenacity retry logic
expected_address = f"{ip_address}:{ray_port}"
try:
_verify_gcs_running(expected_address, proc)
except tenacity.RetryError:
# Check one final time if process failed
returncode = proc.poll()
if returncode is not None:
msg = f"Ray cluster failed to start. Process exited with code {returncode}."
logger.error(msg)
raise RuntimeError(msg) # noqa: B904

# Process is still running but GCS not detected
msg = f"Ray cluster verification timeout after {RAY_CLUSTER_START_VERIFICATION_TIMEOUT}s. GCS address not detected at {expected_address}."
logger.error(msg)
raise RuntimeError(msg) # noqa: B904

return proc
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"pandas>=2.1.0",
"pyarrow",
"ray[default,data]>=2.49",
"tenacity",
"torch",
"transformers==4.55.2",
]
Expand Down
Loading