Skip to content

Commit d148ca2

Browse files
alxkmilayaperumalg
authored andcommitted
test: Add comprehensive test coverage for DefaultToolCallingManager
Co-authored-by: Oleksandr Klymenko <[email protected]> Signed-off-by: Oleksandr Klymenko <[email protected]>
1 parent 0ae56f7 commit d148ca2

File tree

1 file changed

+268
-0
lines changed

1 file changed

+268
-0
lines changed

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,272 @@ public String call(String toolInput) {
155155
assertThatNoException().isThrownBy(() -> managerWithCallback.executeToolCalls(prompt, chatResponse));
156156
}
157157

158+
@Test
159+
void shouldHandleMultipleToolCallsInSingleResponse() {
160+
// Create mock tool callbacks
161+
ToolCallback toolCallback1 = new ToolCallback() {
162+
@Override
163+
public ToolDefinition getToolDefinition() {
164+
return DefaultToolDefinition.builder()
165+
.name("tool1")
166+
.description("First tool")
167+
.inputSchema("{\"type\": \"object\", \"properties\": {\"param\": {\"type\": \"string\"}}}")
168+
.build();
169+
}
170+
171+
@Override
172+
public ToolMetadata getToolMetadata() {
173+
return ToolMetadata.builder().build();
174+
}
175+
176+
@Override
177+
public String call(String toolInput) {
178+
return "{\"result\": \"tool1_success\"}";
179+
}
180+
};
181+
182+
ToolCallback toolCallback2 = new ToolCallback() {
183+
@Override
184+
public ToolDefinition getToolDefinition() {
185+
return DefaultToolDefinition.builder()
186+
.name("tool2")
187+
.description("Second tool")
188+
.inputSchema("{\"type\": \"object\", \"properties\": {\"value\": {\"type\": \"number\"}}}")
189+
.build();
190+
}
191+
192+
@Override
193+
public ToolMetadata getToolMetadata() {
194+
return ToolMetadata.builder().build();
195+
}
196+
197+
@Override
198+
public String call(String toolInput) {
199+
return "{\"result\": \"tool2_success\"}";
200+
}
201+
};
202+
203+
// Create multiple ToolCalls
204+
AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "tool1",
205+
"{\"param\": \"test\"}");
206+
AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "tool2",
207+
"{\"value\": 42}");
208+
209+
// Create ChatResponse with multiple tool calls
210+
AssistantMessage assistantMessage = AssistantMessage.builder()
211+
.content("")
212+
.properties(Map.of())
213+
.toolCalls(List.of(toolCall1, toolCall2))
214+
.build();
215+
Generation generation = new Generation(assistantMessage);
216+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
217+
218+
Prompt prompt = new Prompt(List.of(new UserMessage("test multiple tools")));
219+
220+
DefaultToolCallingManager manager = DefaultToolCallingManager.builder()
221+
.observationRegistry(ObservationRegistry.NOOP)
222+
.toolCallbackResolver(toolName -> {
223+
if ("tool1".equals(toolName)) {
224+
return toolCallback1;
225+
}
226+
if ("tool2".equals(toolName)) {
227+
return toolCallback2;
228+
}
229+
return null;
230+
})
231+
.build();
232+
233+
assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse));
234+
}
235+
236+
@Test
237+
void shouldHandleToolCallWithComplexJsonArguments() {
238+
ToolCallback complexToolCallback = new ToolCallback() {
239+
@Override
240+
public ToolDefinition getToolDefinition() {
241+
return DefaultToolDefinition.builder()
242+
.name("complexTool")
243+
.description("A tool with complex JSON input")
244+
.inputSchema("{\"type\": \"object\", \"properties\": {\"nested\": {\"type\": \"object\"}}}")
245+
.build();
246+
}
247+
248+
@Override
249+
public ToolMetadata getToolMetadata() {
250+
return ToolMetadata.builder().build();
251+
}
252+
253+
@Override
254+
public String call(String toolInput) {
255+
assertThat(toolInput).contains("nested");
256+
assertThat(toolInput).contains("array");
257+
return "{\"result\": \"processed\"}";
258+
}
259+
};
260+
261+
String complexJson = "{\"nested\": {\"level1\": {\"level2\": \"value\"}}, \"array\": [1, 2, 3]}";
262+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "complexTool", complexJson);
263+
264+
AssistantMessage assistantMessage = AssistantMessage.builder()
265+
.content("")
266+
.properties(Map.of())
267+
.toolCalls(List.of(toolCall))
268+
.build();
269+
Generation generation = new Generation(assistantMessage);
270+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
271+
272+
Prompt prompt = new Prompt(List.of(new UserMessage("test complex json")));
273+
274+
DefaultToolCallingManager manager = DefaultToolCallingManager.builder()
275+
.observationRegistry(ObservationRegistry.NOOP)
276+
.toolCallbackResolver(toolName -> "complexTool".equals(toolName) ? complexToolCallback : null)
277+
.build();
278+
279+
assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse));
280+
}
281+
282+
@Test
283+
void shouldHandleToolCallWithMalformedJson() {
284+
ToolCallback toolCallback = new ToolCallback() {
285+
@Override
286+
public ToolDefinition getToolDefinition() {
287+
return DefaultToolDefinition.builder()
288+
.name("testTool")
289+
.description("Test tool")
290+
.inputSchema("{}")
291+
.build();
292+
}
293+
294+
@Override
295+
public ToolMetadata getToolMetadata() {
296+
return ToolMetadata.builder().build();
297+
}
298+
299+
@Override
300+
public String call(String toolInput) {
301+
// Should still receive some input even if malformed
302+
assertThat(toolInput).isNotNull();
303+
return "{\"result\": \"handled\"}";
304+
}
305+
};
306+
307+
// Malformed JSON as tool arguments
308+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool",
309+
"{invalid json}");
310+
311+
AssistantMessage assistantMessage = AssistantMessage.builder()
312+
.content("")
313+
.properties(Map.of())
314+
.toolCalls(List.of(toolCall))
315+
.build();
316+
Generation generation = new Generation(assistantMessage);
317+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
318+
319+
Prompt prompt = new Prompt(List.of(new UserMessage("test malformed json")));
320+
321+
DefaultToolCallingManager manager = DefaultToolCallingManager.builder()
322+
.observationRegistry(ObservationRegistry.NOOP)
323+
.toolCallbackResolver(toolName -> "testTool".equals(toolName) ? toolCallback : null)
324+
.build();
325+
326+
assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse));
327+
}
328+
329+
@Test
330+
void shouldHandleToolCallReturningNull() {
331+
ToolCallback toolCallback = new ToolCallback() {
332+
@Override
333+
public ToolDefinition getToolDefinition() {
334+
return DefaultToolDefinition.builder()
335+
.name("nullReturningTool")
336+
.description("Tool that returns null")
337+
.inputSchema("{}")
338+
.build();
339+
}
340+
341+
@Override
342+
public ToolMetadata getToolMetadata() {
343+
return ToolMetadata.builder().build();
344+
}
345+
346+
@Override
347+
public String call(String toolInput) {
348+
return null; // Return null
349+
}
350+
};
351+
352+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "nullReturningTool", "{}");
353+
354+
AssistantMessage assistantMessage = AssistantMessage.builder()
355+
.content("")
356+
.properties(Map.of())
357+
.toolCalls(List.of(toolCall))
358+
.build();
359+
Generation generation = new Generation(assistantMessage);
360+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
361+
362+
Prompt prompt = new Prompt(List.of(new UserMessage("test null return")));
363+
364+
DefaultToolCallingManager manager = DefaultToolCallingManager.builder()
365+
.observationRegistry(ObservationRegistry.NOOP)
366+
.toolCallbackResolver(toolName -> "nullReturningTool".equals(toolName) ? toolCallback : null)
367+
.build();
368+
369+
assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse));
370+
}
371+
372+
@Test
373+
void shouldHandleMultipleGenerationsWithToolCalls() {
374+
ToolCallback toolCallback = new ToolCallback() {
375+
@Override
376+
public ToolDefinition getToolDefinition() {
377+
return DefaultToolDefinition.builder()
378+
.name("multiGenTool")
379+
.description("Tool for multiple generations")
380+
.inputSchema("{}")
381+
.build();
382+
}
383+
384+
@Override
385+
public ToolMetadata getToolMetadata() {
386+
return ToolMetadata.builder().build();
387+
}
388+
389+
@Override
390+
public String call(String toolInput) {
391+
return "{\"result\": \"success\"}";
392+
}
393+
};
394+
395+
// Create multiple generations with tool calls
396+
AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "multiGenTool", "{}");
397+
AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "multiGenTool", "{}");
398+
399+
AssistantMessage assistantMessage1 = AssistantMessage.builder()
400+
.content("")
401+
.properties(Map.of())
402+
.toolCalls(List.of(toolCall1))
403+
.build();
404+
405+
AssistantMessage assistantMessage2 = AssistantMessage.builder()
406+
.content("")
407+
.properties(Map.of())
408+
.toolCalls(List.of(toolCall2))
409+
.build();
410+
411+
Generation generation1 = new Generation(assistantMessage1);
412+
Generation generation2 = new Generation(assistantMessage2);
413+
414+
ChatResponse chatResponse = new ChatResponse(List.of(generation1, generation2));
415+
416+
Prompt prompt = new Prompt(List.of(new UserMessage("test multiple generations")));
417+
418+
DefaultToolCallingManager manager = DefaultToolCallingManager.builder()
419+
.observationRegistry(ObservationRegistry.NOOP)
420+
.toolCallbackResolver(toolName -> "multiGenTool".equals(toolName) ? toolCallback : null)
421+
.build();
422+
423+
assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse));
424+
}
425+
158426
}

0 commit comments

Comments
 (0)