Skip to content

Commit fe02ea2

Browse files
😴 Add vllm_enable_sleep_mode to RLOO Trainer (#4107)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 68408d7 commit fe02ea2

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

docs/source/rloo_trainer.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-mem
230230
231231
If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
232232

233+
If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
234+
233235
</Tip>
234236

235237
<Tip>

trl/trainer/rloo_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ class RLOOConfig(TrainingArguments):
141141
Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
142142
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
143143
launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
144+
vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
145+
Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
146+
for weight sync and generation.
144147
145148
> Parameters that control the training
146149
@@ -507,6 +510,13 @@ class RLOOConfig(TrainingArguments):
507510
"model implementation."
508511
},
509512
)
513+
vllm_enable_sleep_mode: bool = field(
514+
default=False,
515+
metadata={
516+
"help": "Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step "
517+
"and woken for weight sync and generation."
518+
},
519+
)
510520
vllm_guided_decoding_regex: Optional[str] = field(
511521
default=None,
512522
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},

trl/trainer/rloo_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def decode(example, tokenizer):
607607
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
608608
max_num_batched_tokens=4096,
609609
model_impl=self.args.vllm_model_impl,
610+
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
610611
)
612+
if self.args.vllm_enable_sleep_mode:
613+
self.llm.sleep(level=1)
611614
else:
612615
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
613616

@@ -1130,6 +1133,11 @@ def _generate_and_score_completions(
11301133

11311134
# Generate completions using either vLLM or regular generation
11321135
if self.use_vllm:
1136+
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
1137+
# wake up colocated vLLM instances if needed
1138+
torch.cuda.empty_cache() # required to avoid OOM in some cases
1139+
self.llm.wake_up()
1140+
11331141
# First, update the vLLM weights if needed
11341142
if self.state.global_step != self._last_loaded_step:
11351143
self._move_model_to_vllm()
@@ -1240,6 +1248,9 @@ def _generate_and_score_completions(
12401248
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
12411249
completion_ids = completion_ids[tp_slice]
12421250

1251+
if self.args.vllm_enable_sleep_mode:
1252+
self.llm.sleep(level=1)
1253+
12431254
# Pad the completions, and concatenate them with the prompts
12441255
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
12451256
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)

0 commit comments

Comments
 (0)