-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Multi-turn tool calling support #4115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
""" | ||
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries. | ||
""" | ||
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is unfortunately no standardisation for the tool call tags across model families, but Matt is working on extending the chat templates so they can auto-parse tools calls (internal Slack thread): https://huggingface.slack.com/archives/C06JKEMK6BZ/p1757691450090859
Note that vllm
works around this by providing a dedicated set of parsers that can be set when spinning up the server: https://docs.vllm.ai/en/stable/features/tool_calling.html
I'm not sure we want to go down this route, since it's quite messy in my experience to match the parser to the desired model (e.g. some Qwen models use the hermes
parser, others not)
So in the meantime, we might want to give uses the ability to provide their own parsing function and default to yours (which is the most common I've seen)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.
[prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1 | ||
).sum(-1) | ||
tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)] | ||
tool_mask = [torch.ones_like(ids) for ids in tool_ids] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be cool to have a unit test for this masking so we're confident it is behaving as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!! Looking forward to have this feature!
Some comments, suggestions and question below.
""" | ||
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries. | ||
""" | ||
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] | ||
|
||
|
||
def extract_tool_calls(text: str) -> dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about moving this function to a non-specific trainer module, so it can be used by any trainer in the future?
""" | ||
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries. | ||
""" | ||
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the compile function will be called at each function call. If optimizing performance is necessary in this case, we should set the compilation as a constant at the module level, so it is called only once at import time.
callbacks: Optional[list[TrainerCallback]] = None, | ||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), | ||
peft_config: Optional["PeftConfig"] = None, | ||
tools=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to add the tools
param to the trainer docstring. And give a type hint.
|
||
for match in pattern.findall(text): | ||
try: | ||
return json.loads(match) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only return the first match?
tool_calls = [extract_tool_calls(completion) for completion in completions] | ||
tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls] | ||
tool_messages = [ | ||
[{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None | ||
for tc, tr in zip(tool_calls, tool_results) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure of this handles potential multiple tool calls in a single completion...
tool_calls = [extract_tool_calls(completion) for completion in completions] | ||
tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls] | ||
tool_messages = [ | ||
[{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None | ||
for tc, tr in zip(tool_calls, tool_results) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before the messages with the tool results ("role": "tool"
), shouldn't we prepend the messages with the tool calls themselves ("role": "assistant", "tool_calls":...
)? Not sure of this though... A real question! 😅
I suggest that we remove any references to including tools as a parameter in the
While I am not necessarily against chat templates or tool parsing implementations, I believe the attribute To maximize scalability, a better approach would be to create an all-in-one tools sandbox. This means that the only tool the training script interacts with would be the sandboxed code executors, with all the necessary tools defined within it, like using a Docker image that contains the dependencies of the tools to initialize the sandbox Therefore, I propose that we only add the parameter The environment needs to be responsible for handling tool usage and execution. Then you could create your own built-in environments using your preferred parsing and chat templates. |
Yes, I have seen some works that suggest that when scaling up, multiplying tools works less well than a smaller, more generic set of tools, such as the one you describe. However, it seems to me that the approach proposed here is in fact compatible: trainer = GRPOTrainer(…,
tools = [my_big_containerized_all_in_one_tool],
) in the end it's up to the user decide how to design it, most important here is to allow for this flexibility |
but the name MyEnv = DefaultEnv(code_executer=my_big_containerized_all_in_one_tool)
trainer = GRPOTrainer(…,
Environment=MyEnv
) The user can then customize tool use in their own environment. |
Co-authored-by: Albert Villanova del Moral <[email protected]>
Uh oh!
There was an error while loading. Please reload this page.