Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-mem

If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.

If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).

</Tip>

<Tip>
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class RLOOConfig(TrainingArguments):
Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
for weight sync and generation.

> Parameters that control the training

Expand Down Expand Up @@ -507,6 +510,13 @@ class RLOOConfig(TrainingArguments):
"model implementation."
},
)
vllm_enable_sleep_mode: bool = field(
default=False,
metadata={
"help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step "
"and woken for weight sync and generation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
Expand Down
11 changes: 11 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def decode(example, tokenizer):
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")

Expand Down Expand Up @@ -1130,6 +1133,11 @@ def _generate_and_score_completions(

# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up()

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
Expand Down Expand Up @@ -1240,6 +1248,9 @@ def _generate_and_score_completions(
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
Expand Down
Loading