Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions tests/test_env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,55 @@ class TestTaskCancellation:
wire format, and the server cancels the corresponding asyncio task.
"""

@pytest.mark.asyncio
async def test_cancelled_client_task_should_cancel_server_task_before_request_processing(
self,
):
"""Cancellation should still propagate before process_request enters its body."""
process_request_blocked = asyncio.Event()
original_process_request_entered = asyncio.Event()
server_task_cancelled = asyncio.Event()

async with run_server_and_client() as (server, client):
original_process_request = server.process_request

async def delayed_process_request(
client_id,
request_id_bytes,
payload_bytes,
):
process_request_blocked.set()
try:
await asyncio.Event().wait()
original_process_request_entered.set()
return await original_process_request(
client_id,
request_id_bytes,
payload_bytes,
)
except asyncio.CancelledError:
server_task_cancelled.set()
raise

server.process_request = delayed_process_request # type: ignore[assignment]

client_task = asyncio.create_task(
client.send_request(
make_rollout_request(), RunRolloutResponse, timeout=30
)
)

await asyncio.wait_for(process_request_blocked.wait(), timeout=5)
assert len(server.request_tasks) == 1
assert not original_process_request_entered.is_set()

client_task.cancel()
with pytest.raises(asyncio.CancelledError):
await client_task

await asyncio.wait_for(server_task_cancelled.wait(), timeout=5)
assert not original_process_request_entered.is_set()

@pytest.mark.asyncio
async def test_cancelled_client_task_should_cancel_server_task(self):
"""When the asyncio task awaiting send_request() is cancelled on the
Expand Down Expand Up @@ -358,7 +407,7 @@ async def slow_handle_run_rollout(request):

# Wait for the server to actually start processing
await asyncio.wait_for(server_task_started.wait(), timeout=5)
assert len(server.pending_tasks) == 1
assert len(server.request_tasks) == 1

# Cancel on the client side
client_task.cancel()
Expand Down Expand Up @@ -403,7 +452,7 @@ async def slow_handle_run_rollout(request):

# Confirm the server started processing
await asyncio.wait_for(server_task_started.wait(), timeout=5)
assert len(server.pending_tasks) == 1
assert len(server.request_tasks) == 1

# Give the system time to propagate
await asyncio.sleep(0.5)
Expand Down
61 changes: 61 additions & 0 deletions tests/test_rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
from datasets import Dataset
from prime_sandboxes import UploadTimeoutError

import verifiers as vf
from verifiers.envs.experimental import rlm_env as rlm_module
Expand Down Expand Up @@ -2707,6 +2708,28 @@ async def test_cleanup_calls_executor(self, tmp_path: Path):


class TestFilesystemProvisioning:
def test_rlm_env_threads_sandbox_client_pool_settings_to_executor(self):
dataset = make_dataset({})
with (
patch("verifiers.envs.environment.signal.signal"),
patch(
"verifiers.envs.experimental.sandbox_mixin.ThreadedAsyncSandboxClient"
) as mock_client_cls,
):
mock_client_cls.return_value = MagicMock()
RLMEnv(
dataset=dataset,
sandbox_client_max_workers=123,
sandbox_client_max_connections=234,
sandbox_client_max_keepalive_connections=56,
)

mock_client_cls.assert_called_with(
max_workers=123,
max_connections=234,
max_keepalive_connections=56,
)

@pytest.mark.asyncio
async def test_prepare_filesystem_uploads_and_sets_paths(self, tmp_path: Path):
dataset = make_dataset({})
Expand Down Expand Up @@ -2766,3 +2789,41 @@ async def test_write_sandbox_files_uploads_worker_and_context(self, tmp_path: Pa
await executor._write_sandbox_files(session, state)

assert executor.sandbox_client.upload_file.await_count == 3

@pytest.mark.asyncio
async def test_write_sandbox_files_retries_upload_timeout(self):
dataset = make_dataset({})
env = build_env(
dataset,
repl_language="python",
sandbox_transfer_max_retries=1,
)
state = {
"rollout_id": "rlm_test",
"rlm_fs_root": "/tmp/rlm_rlm_test/rlm_fs",
"model": "m",
"client": MagicMock(),
"interception_url": "http://example.invalid",
"root_tool_url": "http://example.invalid",
}

executor = env._executor
executor._sessions.clear()
session = executor._get_or_create_session(state)
session.sandbox_id = "sbx_1"
session.sandbox_control_dir = "/tmp/rlm_rlm_test/rlm_control"
session.sandbox_fs_root = "/tmp/rlm_rlm_test/rlm_fs"
session.paths = rlm_module._build_worker_paths(session.sandbox_control_dir)

executor.sandbox_client.upload_file = AsyncMock(
side_effect=[
UploadTimeoutError("sbx_1", session.paths.context_file, 300),
MagicMock(),
MagicMock(),
MagicMock(),
]
)

await executor._write_sandbox_files(session, state)

assert executor.sandbox_client.upload_file.await_count == 4
47 changes: 46 additions & 1 deletion tests/test_sandbox_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from prime_sandboxes import CommandTimeoutError, SandboxOOMError, SandboxTimeoutError
import httpx
from prime_sandboxes import (
APIError,
CommandTimeoutError,
DownloadTimeoutError,
SandboxOOMError,
SandboxTimeoutError,
UploadTimeoutError,
)

import verifiers as vf
from verifiers.envs.experimental.sandbox_mixin import (
Expand All @@ -11,6 +19,8 @@
SandboxNotReadyError,
SandboxSetupError,
ThreadedAsyncSandboxClient,
is_retryable_sandbox_api_error,
is_retryable_sandbox_read_error,
)

MODULE = "verifiers.envs.experimental.sandbox_mixin"
Expand Down Expand Up @@ -41,6 +51,28 @@ def test_init_creates_client_and_retry():
assert callable(obj.with_retry)


@pytest.mark.parametrize(
"exception",
[
UploadTimeoutError("sb", "/tmp/file", 300),
DownloadTimeoutError("sb", "/tmp/file", 300),
CommandTimeoutError("sb", "echo hi", 30),
httpx.ReadTimeout("timed out"),
APIError("Upload failed: HTTP 503: retry me"),
APIError("Upload failed: ConnectError at POST /upload: boom"),
],
)
def test_retryable_sandbox_read_error_matches_current_sdk_exceptions(exception):
assert is_retryable_sandbox_read_error(exception) is True


def test_retryable_sandbox_api_error_ignores_non_retryable_api_error():
assert (
is_retryable_sandbox_api_error(APIError("Upload failed: HTTP 400: nope"))
is False
)


# ── create_sandbox ───────────────────────────────────────────────────


Expand Down Expand Up @@ -68,6 +100,19 @@ def test_create_sandbox_creation_fails(mixin):
asyncio.run(mixin.create_sandbox({}, request=MagicMock()))


def test_create_sandbox_max_retries_is_true_retry_count():
obj = ConcreteMixin(max_retries=1, base_delay=0.01)
obj.logger = MagicMock()
sandbox_obj = MagicMock(id="sb-retry")
obj.sandbox_client.create = AsyncMock(side_effect=[Exception("boom"), sandbox_obj])
obj.sandbox_client.wait_for_creation = AsyncMock()

result = asyncio.run(obj.create_sandbox({}, request=MagicMock()))

assert result == "sb-retry"
assert obj.sandbox_client.create.await_count == 2


def test_create_sandbox_not_ready(mixin):
sandbox_obj = MagicMock(id="sb-2")
mixin.sandbox_client.create = AsyncMock(return_value=sandbox_obj)
Expand Down
Loading
Loading