Skip to content

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 21, 2025

from datasets import Dataset
from trl import GRPOTrainer, GRPOConfig
import random

factors = [(random.randint(0, 9999), random.randint(0, 9999)) for _ in range(100)]
dataset = Dataset.from_dict(
    {
        "prompt": [[{"role": "user", "content": f"Multiply {a} and {b}."}] for a, b in factors],
        "result": [a * b for a, b in factors],
    }
)


def multiply(a: int, b: int) -> int:
    """Multiply two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The product of the two integers.
    """
    return a * b

def accuracy_reward(completions, result, **kwargs):
    return [int(str(r) in c[0]["content"]) for c, r in zip(completions, result)]


trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    args=GRPOConfig(use_vllm=True, vllm_mode="colocate", vllm_importance_sampling_correction=False),
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
    tools=[multiply],
)
trainer.train()

@qgallouedec qgallouedec changed the title a bit messy! Multi-turn tool calling support Sep 21, 2025
"""
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries.
"""
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is unfortunately no standardisation for the tool call tags across model families, but Matt is working on extending the chat templates so they can auto-parse tools calls (internal Slack thread): https://huggingface.slack.com/archives/C06JKEMK6BZ/p1757691450090859

Note that vllm works around this by providing a dedicated set of parsers that can be set when spinning up the server: https://docs.vllm.ai/en/stable/features/tool_calling.html

I'm not sure we want to go down this route, since it's quite messy in my experience to match the parser to the desired model (e.g. some Qwen models use the hermes parser, others not)

So in the meantime, we might want to give uses the ability to provide their own parsing function and default to yours (which is the most common I've seen)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.

[prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1
).sum(-1)
tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)]
tool_mask = [torch.ones_like(ids) for ids in tool_ids]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool to have a unit test for this masking so we're confident it is behaving as expected

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!! Looking forward to have this feature!

Some comments, suggestions and question below.

"""
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries.
"""
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.

RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


def extract_tool_calls(text: str) -> dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about moving this function to a non-specific trainer module, so it can be used by any trainer in the future?

"""
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries.
"""
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
Copy link
Member

@albertvillanova albertvillanova Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the compile function will be called at each function call. If optimizing performance is necessary in this case, we should set the compilation as a constant at the module level, so it is called only once at import time.

callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
tools=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add the tools param to the trainer docstring. And give a type hint.


for match in pattern.findall(text):
try:
return json.loads(match)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only return the first match?

Comment on lines +1430 to +1435
tool_calls = [extract_tool_calls(completion) for completion in completions]
tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls]
tool_messages = [
[{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None
for tc, tr in zip(tool_calls, tool_results)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure of this handles potential multiple tool calls in a single completion...

Comment on lines +1430 to +1435
tool_calls = [extract_tool_calls(completion) for completion in completions]
tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls]
tool_messages = [
[{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None
for tc, tr in zip(tool_calls, tool_results)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the messages with the tool results ("role": "tool"), shouldn't we prepend the messages with the tool calls themselves ("role": "assistant", "tool_calls":...)? Not sure of this though... A real question! 😅

@August-murr
Copy link
Contributor

I suggest that we remove any references to including tools as a parameter in the GRPOTrainer:

trainer = GRPOTrainer(…,
tools = [tool_1, tool_2]
)

While I am not necessarily against chat templates or tool parsing implementations, I believe the attribute tools should be eliminated.

To maximize scalability, a better approach would be to create an all-in-one tools sandbox. This means that the only tool the training script interacts with would be the sandboxed code executors, with all the necessary tools defined within it, like using a Docker image that contains the dependencies of the tools to initialize the sandbox

Therefore, I propose that we only add the parameter Environment, which would encompass the code executor along with all initialized tools.

The environment needs to be responsible for handling tool usage and execution.

Then you could create your own built-in environments using your preferred parsing and chat templates.

@qgallouedec
Copy link
Member Author

qgallouedec commented Sep 23, 2025

Yes, I have seen some works that suggest that when scaling up, multiplying tools works less well than a smaller, more generic set of tools, such as the one you describe. However, it seems to me that the approach proposed here is in fact compatible:

trainer = GRPOTrainer(…,
    tools = [my_big_containerized_all_in_one_tool],
)

in the end it's up to the user decide how to design it, most important here is to allow for this flexibility

@August-murr
Copy link
Contributor

but the name tools as a param in GRPOTrainer is misleading.
What I suggest is more like

MyEnv = DefaultEnv(code_executer=my_big_containerized_all_in_one_tool)

trainer = GRPOTrainer(…,
Environment=MyEnv
)

The user can then customize tool use in their own environment.

Base automatically changed from generate-method to main September 26, 2025 02:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants