Skip to content

Commit c9484b1

Browse files
Align docstring parameters with function definitions (#4017)
1 parent f5c2fec commit c9484b1

File tree

13 files changed

+70
-10
lines changed

13 files changed

+70
-10
lines changed

examples/notebooks/gpt2-sentiment.ipynb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,14 @@
147147
" customize this function to train the model on its own dataset.\n",
148148
"\n",
149149
" Args:\n",
150+
" config (`PPOConfig`):\n",
151+
" The configuration of the PPO training.\n",
150152
" dataset_name (`str`):\n",
151153
" The name of the dataset to be loaded.\n",
154+
" input_min_text_length (`int`, defaults to 5):\n",
155+
" The minimum length of the input text.\n",
156+
" input_max_text_length (`int`, defaults to 10):\n",
157+
" The maximum length of the input text.\n",
152158
"\n",
153159
" Returns:\n",
154160
" dataloader (`torch.utils.data.DataLoader`):\n",

examples/research_projects/stack_llama/scripts/rl_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def build_dataset(
126126
customize this function to train the model on its own dataset.
127127
128128
Args:
129+
tokenizer (`transformers.PreTrainedTokenizer`):
130+
The tokenizer used for the model.
129131
dataset_name (`str`):
130132
The name of the dataset to be loaded.
131133

examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,14 @@ def build_dataset(
104104
customize this function to train the model on its own dataset.
105105
106106
Args:
107+
config (`PPOConfig`):
108+
The configuration of the PPO training.
107109
dataset_name (`str`):
108110
The name of the dataset to be loaded.
111+
input_min_text_length (`int`, defaults to 5):
112+
The minimum length of the input text.
113+
input_max_text_length (`int`, defaults to 10):
114+
The maximum length of the input text.
109115
110116
Returns:
111117
dataloader (`torch.utils.data.DataLoader`):

trl/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def truncate_dataset(
691691
Args:
692692
dataset (`Dataset` or `DatasetDict`):
693693
Dataset to truncate.
694-
seq_length (`int`):
694+
max_length (`int`):
695695
Maximum sequence length to truncate to.
696696
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
697697
Additional keyword arguments to pass to the dataset's map method when truncating examples.

trl/extras/dataset_formatting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def get_formatting_func_from_dataset(
9494
Args:
9595
dataset (Dataset): User dataset
9696
tokenizer (AutoTokenizer): Tokenizer used for formatting
97+
tools (list, *optional*): List of tools (callable functions) that will be accessible to the model.
98+
If the template does not support function calling, this argument will have no effect.
9799
98100
Returns:
99101
Callable: Formatting function if the dataset format is supported else None

trl/models/modeling_sd_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def scheduler_step(
205205
from the learned model outputs (most often the predicted noise).
206206
207207
Args:
208+
self: scheduler.
208209
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
209210
timestep (`int`): current discrete timestep in the diffusion chain.
210211
sample (`torch.FloatTensor`):
@@ -215,9 +216,7 @@ def scheduler_step(
215216
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide
216217
with the one provided as input and `use_clipped_model_output` will have not effect.
217218
generator: random number generator.
218-
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
219-
can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion.
220-
(https://huggingface.co/papers/2210.05559)
219+
prev_sample (`torch.FloatTensor`, *optional*): if not `None`, the previous sample to be used
221220
222221
Returns:
223222
`DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
@@ -564,6 +563,7 @@ def pipeline_step_with_grad(
564563
Function to get RGB image with gradients attached to the model weights.
565564
566565
Args:
566+
pipeline (`StableDiffusionPipeline`): Pipeline to be used for image generation.
567567
prompt (`str` or `list[str]`, *optional*, defaults to `None`):
568568
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
569569
instead.

trl/models/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import itertools
1616
import warnings
17+
from collections.abc import Callable
1718
from contextlib import contextmanager
1819
from copy import deepcopy
1920
from dataclasses import dataclass
@@ -431,18 +432,18 @@ class _ForwardRedirection:
431432
"""
432433

433434
def __call__(
434-
self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any
435+
self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any
435436
):
436437
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
437438
438439
Args:
439440
wrapper_module: The module that has `original_module` wrapped.
440441
original_module: The module that was wrapped inside `wrapper_module`.
441-
method_name: The name of the method that should be called on the `original_module` after inputs get
442+
method: The method that should be called on the `original_module` after inputs get
442443
redirected through the `wrapper_module`'s `forward` method.
443-
*args: The positional arguments to the method `method_name`. They will get passed to a patched
444+
*args: The positional arguments to the `method`. They will get passed to a patched
444445
`forward` method instead.
445-
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
446+
**kwargs: The keyword arguments to the `method`. They will get passed to a patched
446447
`forward` method instead.
447448
448449
"""

trl/trainer/alignprop_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
329329
Args:
330330
batch_size (int): Batch size to use for sampling
331331
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
332+
prompts (list[str], *optional*): If provided, use these prompts instead of generating new ones.
332333
333334
Returns:
334335
prompt_image_pairs (dict[Any])

trl/trainer/bco_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,12 @@ def get_batch_logps(
10351035
average_log_prob:
10361036
If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
10371037
log probabilities of the (non-masked) tokens.
1038+
label_pad_token_id:
1039+
The label value to ignore when computing log probabilities.
1040+
is_encoder_decoder:
1041+
Whether the model is an encoder-decoder model. If True, the labels are not shifted, and the logits are
1042+
assumed to already be aligned with the labels. If False, the labels are shifted to the right by one
1043+
position, and the logits are assumed to be aligned with the shifted labels.
10381044
10391045
Returns:
10401046
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
@@ -1144,6 +1150,7 @@ def bco_loss(
11441150
batch_size,)
11451151
chosen_embeddings: embeddings of desirable prompts
11461152
rejected_embeddings: embeddings of undesirable prompts
1153+
do_train: whether to update the running delta value. Default is True.
11471154
11481155
Returns:
11491156
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). The losses tensor contains the

trl/trainer/dpo_trainer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,29 @@ def dpo_loss(
10151015
Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`.
10161016
ref_rejected_logps (`torch.FloatTensor`):
10171017
Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`.
1018+
loss_type (`str`, defaults to `"sigmoid"`):
1019+
The type of loss to compute. One of:
1020+
- `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
1021+
- `"hinge"`: Hinge loss on the normalized likelihood from the
1022+
[SLiC](https://huggingface.co/papers/2305.10425) paper.
1023+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
1024+
- `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
1025+
- `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
1026+
- `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
1027+
DPO](https://huggingface.co/papers/2403.00409) paper.
1028+
- `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
1029+
- `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
1030+
paper.
1031+
- `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
1032+
- `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
1033+
paper.
1034+
- `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
1035+
[DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
1036+
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
1037+
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
1038+
- `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
1039+
model_output (`dict[str, torch.FloatTensor]`, *optional*):
1040+
The output of the model's forward pass. This is used to compute auxiliary losses if enabled.
10181041
10191042
Returns:
10201043
A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO

0 commit comments

Comments
 (0)