Skip to content

Commit f795b49

Browse files
committed
Allow sequential tool calls
1 parent 318f863 commit f795b49

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

src/OllamaSharp/Chat.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ public class Chat
8787
/// </summary>
8888
public ThinkValue? Think { get; set; }
8989

90+
/// <summary>
91+
/// Allow recursive tool calls when a model decides to call different tools after each other.
92+
/// </summary>
93+
public bool AllowRecursiveToolCalls { get; set; } = true;
94+
9095
/// <summary>
9196
/// Initializes a new instance of the <see cref="Chat"/> class.
9297
/// This basic constructor sets up the chat without a predefined system prompt.
@@ -452,10 +457,11 @@ public async IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message,
452457
var answerMessage = messageBuilder.ToMessage();
453458
Messages.Add(answerMessage);
454459

455-
if (ToolInvoker is not null && role != ChatRole.Tool)
460+
// support recursive tool calls when a model decides to call tools several times in a row
461+
if (ToolInvoker is not null && answerMessage.ToolCalls?.Any() == true && AllowRecursiveToolCalls)
456462
{
457463
var toolResultMessages = new List<Message>();
458-
foreach (var toolCall in answerMessage.ToolCalls ?? [])
464+
foreach (var toolCall in answerMessage.ToolCalls)
459465
{
460466
// call tools if available and requested by the AI model and yield the results
461467
OnToolCall?.Invoke(this, toolCall);

test/ChatTests.cs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ namespace Tests;
1010

1111
public class ChatTests
1212
{
13-
private readonly TestOllamaApiClient _ollama = new();
14-
1513
public class SendMethod : ChatTests
1614
{
1715
[Test]
1816
public async Task Sends_Assistant_Answer_To_Streamer()
1917
{
20-
_ollama.SetExpectedChatResponses(
18+
var ollama = new TestOllamaApiClient();
19+
ollama.SetExpectedChatResponses(
2120
new ChatResponseStream { Message = CreateMessage(ChatRole.Assistant, "Hi hu") },
2221
new ChatResponseStream { Message = CreateMessage(ChatRole.Assistant, "man, how") },
2322
new ChatResponseStream { Message = CreateMessage(ChatRole.Assistant, " are you?") });
2423

25-
var chat = new Chat(_ollama);
24+
var chat = new Chat(ollama);
2625
var answer = await chat.SendAsync("henlo", CancellationToken.None).StreamToEndAsync();
2726

2827
answer.ShouldBe("Hi human, how are you?");
@@ -34,7 +33,8 @@ public async Task Sends_Assistant_Answer_To_Streamer()
3433
[Test]
3534
public async Task Sends_Assistant_ToolsCall_To_Streamer()
3635
{
37-
_ollama.SetExpectedChatResponses(
36+
var ollama = new TestOllamaApiClient();
37+
ollama.SetExpectedChatResponses(
3838
new ChatResponseStream
3939
{
4040
Message = new Message
@@ -50,7 +50,7 @@ public async Task Sends_Assistant_ToolsCall_To_Streamer()
5050
Arguments = new Dictionary<string, object?>()
5151
{
5252
["format"] = "celsius",
53-
["location"] = "Los Angeles, CA",
53+
["location"] = "Los Angeles2, CA",
5454
["number"] = 30,
5555
}
5656
}
@@ -59,7 +59,8 @@ public async Task Sends_Assistant_ToolsCall_To_Streamer()
5959
}
6060
});
6161

62-
var chat = new Chat(_ollama) { ToolInvoker = null }; // we have no tool implementation in this test
62+
var chat = new Chat(ollama) { ToolInvoker = null }; // we have no tool implementation in this test
63+
chat.AllowRecursiveToolCalls = false; // this is required because the expected chat response contains a tool call that would cause an infinite loop
6364
await chat.SendAsync("How is the weather in LA?", CancellationToken.None).StreamToEndAsync();
6465

6566
chat.Messages.Last().Role.ShouldBe(ChatRole.Assistant);
@@ -70,7 +71,8 @@ public async Task Sends_Assistant_ToolsCall_To_Streamer()
7071
[Test]
7172
public async Task Sends_System_Prompt_Message()
7273
{
73-
var chat = new Chat(_ollama, "Speak like a pirate.");
74+
var ollama = new TestOllamaApiClient();
75+
var chat = new Chat(ollama, "Speak like a pirate.");
7476
await chat.SendAsync("henlo", CancellationToken.None).StreamToEndAsync();
7577

7678
chat.Messages.First().Role.ShouldBe(ChatRole.System);
@@ -80,7 +82,9 @@ public async Task Sends_System_Prompt_Message()
8082
[Test]
8183
public async Task Sends_Messages_As_User()
8284
{
83-
var chat = new Chat(_ollama);
85+
var ollama = new TestOllamaApiClient();
86+
var chat = new Chat(ollama);
87+
8488
await chat.SendAsync("henlo", CancellationToken.None).StreamToEndAsync();
8589

8690
chat.Messages.First().Role.ShouldBe(ChatRole.User);
@@ -93,7 +97,9 @@ public async Task Sends_Image_Bytes_As_Base64()
9397
var bytes1 = System.Text.Encoding.ASCII.GetBytes("ABC");
9498
var bytes2 = System.Text.Encoding.ASCII.GetBytes("ABD");
9599

96-
var chat = new Chat(_ollama);
100+
var ollama = new TestOllamaApiClient();
101+
102+
var chat = new Chat(ollama);
97103
await chat.SendAsync("", [bytes1, bytes2], CancellationToken.None).StreamToEndAsync();
98104

99105
chat.Messages.Single(m => m.Role == ChatRole.User).Images.ShouldBe(["QUJD", "QUJE"], ignoreOrder: true);
@@ -105,11 +111,12 @@ public class SendAsMethod : ChatTests
105111
[Test]
106112
public async Task Sends_Messages_As_Defined_Role()
107113
{
108-
_ollama.SetExpectedChatResponses(
114+
var ollama = new TestOllamaApiClient();
115+
ollama.SetExpectedChatResponses(
109116
new ChatResponseStream { Message = CreateMessage(ChatRole.Assistant, "Hi") },
110117
new ChatResponseStream { Message = CreateMessage(ChatRole.Assistant, " tool.") });
111118

112-
var chat = new Chat(_ollama);
119+
var chat = new Chat(ollama);
113120
await chat.SendAsAsync(ChatRole.Tool, "Henlo assistant.", CancellationToken.None).StreamToEndAsync();
114121

115122
var history = chat.Messages.ToArray();
@@ -123,10 +130,12 @@ public async Task Sends_Messages_As_Defined_Role()
123130
[Test]
124131
public async Task Sends_Image_Bytes_As_Base64()
125132
{
133+
var ollama = new TestOllamaApiClient();
134+
126135
var bytes1 = System.Text.Encoding.ASCII.GetBytes("ABC");
127136
var bytes2 = System.Text.Encoding.ASCII.GetBytes("ABD");
128137

129-
var chat = new Chat(_ollama);
138+
var chat = new Chat(ollama);
130139
await chat.SendAsAsync(ChatRole.User, "", [bytes1, bytes2], CancellationToken.None).StreamToEndAsync();
131140

132141
chat.Messages.Single(m => m.Role == ChatRole.User).Images.ShouldBe(["QUJD", "QUJE"], ignoreOrder: true);
@@ -138,7 +147,9 @@ public class MessagesPropertyMethod : ChatTests
138147
[Test]
139148
public void Replaces_Chat_History()
140149
{
141-
var chat = new Chat(_ollama)
150+
var ollama = new TestOllamaApiClient();
151+
152+
var chat = new Chat(ollama)
142153
{
143154
Messages = [new Message { Content = "A", Role = ChatRole.System }]
144155
};

0 commit comments

Comments
 (0)