Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_curator/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
# We cannot use a free port between 10000 and 19999 as it is used by Ray.
DEFAULT_RAY_MIN_WORKER_PORT = 10002
DEFAULT_RAY_MAX_WORKER_PORT = 19999
RAY_CLUSTER_START_VERIFICATION_TIMEOUT = 300
62 changes: 62 additions & 0 deletions nemo_curator/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from typing import TYPE_CHECKING

import ray
import tenacity
from loguru import logger
from ray._private.services import canonicalize_bootstrap_address, find_gcs_addresses

from nemo_curator.core.constants import (
DEFAULT_RAY_AUTOSCALER_METRIC_PORT,
DEFAULT_RAY_DASHBOARD_METRIC_PORT,
DEFAULT_RAY_MAX_WORKER_PORT,
DEFAULT_RAY_MIN_WORKER_PORT,
RAY_CLUSTER_START_VERIFICATION_TIMEOUT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +71,47 @@ def _logger_custom_deserializer(
return logger


@tenacity.retry(
wait=tenacity.wait_fixed(1),
stop=tenacity.stop_after_delay(RAY_CLUSTER_START_VERIFICATION_TIMEOUT),
retry=tenacity.retry_if_result(lambda x: x is False),
reraise=True,
)
def _verify_gcs_running(expected_address: str, proc: subprocess.Popen) -> bool:
"""Verify that the Ray GCS is running at the expected address.

Args:
expected_address: The expected GCS address (ip:port format)
proc: The subprocess running the Ray cluster

Returns:
True if GCS is running at expected address, False otherwise

Raises:
RuntimeError: If the Ray process exited with an error
"""
# Check if the process exited with an error
returncode = proc.poll()
if returncode is not None:
msg = f"Ray cluster failed to start. Process exited with code {returncode}."
logger.error(msg)
raise RuntimeError(msg)

# Check if GCS is running at the expected address
gcs_addresses = find_gcs_addresses()
if gcs_addresses:
# Canonicalize both addresses for comparison
canonical_gcs_addresses = []
for gcs_address in gcs_addresses:
canonical_gcs_addresses.append(canonicalize_bootstrap_address(gcs_address))
canonical_expected_address = canonicalize_bootstrap_address(expected_address)
if canonical_expected_address in canonical_gcs_addresses:
logger.info(f"Ray cluster successfully started at {expected_address}")
return True
logger.debug(f"Found GCS at {gcs_addresses}, waiting for {expected_address}")
return False


def init_cluster( # noqa: PLR0913
ray_port: int,
ray_temp_dir: str,
Expand Down Expand Up @@ -123,4 +167,22 @@ def init_cluster( # noqa: PLR0913

proc = subprocess.Popen(ray_command, shell=False) # noqa: S603
logger.info(f"Ray start command: {' '.join(ray_command)}")

# Verify that Ray cluster actually started successfully using tenacity retry logic
expected_address = f"{ip_address}:{ray_port}"
try:
_verify_gcs_running(expected_address, proc)
except tenacity.RetryError:
# Check one final time if process failed
returncode = proc.poll()
if returncode is not None:
msg = f"Ray cluster failed to start. Process exited with code {returncode}."
logger.error(msg)
raise RuntimeError(msg) # noqa: B904

# Process is still running but GCS not detected
msg = f"Ray cluster verification timeout after {RAY_CLUSTER_START_VERIFICATION_TIMEOUT}s. GCS address not detected at {expected_address}."
logger.error(msg)
raise RuntimeError(msg) # noqa: B904

return proc
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"pandas>=2.1.0",
"pyarrow",
"ray[default,data]>=2.49",
"tenacity",
"torch",
"transformers==4.55.2",
]
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.