-
Notifications
You must be signed in to change notification settings - Fork 2k
Added support for the "think" for Ollama #3386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
524f12c
af9958b
91fb151
f77e08a
b188adc
552a346
557a98e
afc3ba2
3411fe2
abfc9e1
f61a99f
2f9417a
dfb2522
b8a1115
64bee7b
faa8b7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| /* | ||
| * Copyright 2023-2025 the original author or authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * https://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.springframework.ai.ollama; | ||
|
|
||
| import io.micrometer.observation.tck.TestObservationRegistry; | ||
| import org.junit.jupiter.api.BeforeEach; | ||
| import org.junit.jupiter.api.Test; | ||
|
|
||
| import org.springframework.ai.chat.metadata.ChatGenerationMetadata; | ||
| import org.springframework.ai.chat.model.ChatResponse; | ||
| import org.springframework.ai.chat.prompt.Prompt; | ||
| import org.springframework.ai.ollama.api.OllamaApi; | ||
| import org.springframework.ai.ollama.api.OllamaModel; | ||
| import org.springframework.ai.ollama.api.OllamaOptions; | ||
| import org.springframework.beans.factory.annotation.Autowired; | ||
| import org.springframework.boot.SpringBootConfiguration; | ||
| import org.springframework.boot.test.context.SpringBootTest; | ||
| import org.springframework.context.annotation.Bean; | ||
|
|
||
| import static org.assertj.core.api.Assertions.assertThat; | ||
|
|
||
| /** | ||
| * Unit Tests for {@link OllamaChatModel} asserting AI metadata. | ||
| * | ||
| * @author Sun Yuhan | ||
| */ | ||
| @SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class) | ||
| class OllamaChatModelMetadataTests extends BaseOllamaIT { | ||
|
|
||
| private static final String MODEL = OllamaModel.QWEN_3_06B.getName(); | ||
|
|
||
| @Autowired | ||
| TestObservationRegistry observationRegistry; | ||
|
|
||
| @Autowired | ||
| OllamaChatModel chatModel; | ||
|
|
||
| @BeforeEach | ||
| void beforeEach() { | ||
| this.observationRegistry.clear(); | ||
| } | ||
|
|
||
| @Test | ||
| void ollamaThinkingMetadataCaptured() { | ||
| var options = OllamaOptions.builder().model(MODEL).think(true).build(); | ||
|
|
||
| Prompt prompt = new Prompt("Why is the sky blue?", options); | ||
|
|
||
| ChatResponse chatResponse = this.chatModel.call(prompt); | ||
| assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); | ||
|
|
||
| chatResponse.getResults().forEach(generation -> { | ||
| ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); | ||
| assertThat(chatGenerationMetadata).isNotNull(); | ||
| assertThat(chatGenerationMetadata.containsKey("thinking")); | ||
| }); | ||
| } | ||
|
|
||
| @Test | ||
| void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() { | ||
| var options = OllamaOptions.builder().model(MODEL).build(); | ||
|
|
||
| Prompt prompt = new Prompt("Why is the sky blue?", options); | ||
|
|
||
| ChatResponse chatResponse = this.chatModel.call(prompt); | ||
| assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); | ||
|
|
||
| chatResponse.getResults().forEach(generation -> { | ||
| ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); | ||
| assertThat(chatGenerationMetadata).isNotNull(); | ||
| var thinking = chatGenerationMetadata.get("thinking"); | ||
| assertThat(thinking).isNull(); | ||
| }); | ||
| } | ||
|
|
||
| @Test | ||
| void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() { | ||
| var options = OllamaOptions.builder().model(MODEL).think(false).build(); | ||
|
|
||
| Prompt prompt = new Prompt("Why is the sky blue?", options); | ||
|
|
||
| ChatResponse chatResponse = this.chatModel.call(prompt); | ||
| assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); | ||
|
|
||
| chatResponse.getResults().forEach(generation -> { | ||
| ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); | ||
| assertThat(chatGenerationMetadata).isNotNull(); | ||
| var thinking = chatGenerationMetadata.get("thinking"); | ||
| assertThat(thinking).isNull(); | ||
| }); | ||
| } | ||
|
|
||
| @SpringBootConfiguration | ||
| static class Config { | ||
|
|
||
| @Bean | ||
| public TestObservationRegistry observationRegistry() { | ||
| return TestObservationRegistry.create(); | ||
| } | ||
|
|
||
| @Bean | ||
| public OllamaApi ollamaApi() { | ||
| return initializeOllama(MODEL); | ||
| } | ||
|
|
||
| @Bean | ||
| public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { | ||
| return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); | ||
| } | ||
|
|
||
| } | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,10 +33,12 @@ | |
| import org.springframework.ai.ollama.api.OllamaApi.Message.Role; | ||
|
|
||
| import static org.assertj.core.api.Assertions.assertThat; | ||
| import static org.junit.jupiter.api.Assertions.assertNull; | ||
|
|
||
| /** | ||
| * @author Christian Tzolov | ||
| * @author Thomas Vitale | ||
| * @author Sun Yuhan | ||
| */ | ||
| public class OllamaApiIT extends BaseOllamaIT { | ||
|
|
||
|
|
@@ -146,4 +148,88 @@ public void think() { | |
| assertThat(response.message().thinking()).isNotEmpty(); | ||
| } | ||
|
|
||
| @Test | ||
| public void chatWithThinking() { | ||
| var request = ChatRequest.builder(THINKING_MODEL) | ||
| .stream(true) | ||
| .think(true) | ||
| .messages(List.of(Message.builder(Role.USER) | ||
| .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") | ||
| .build())) | ||
| .options(OllamaOptions.builder().temperature(0.9).build().toMap()) | ||
| .build(); | ||
|
|
||
| Flux<ChatResponse> response = getOllamaApi().streamingChat(request); | ||
|
|
||
| List<ChatResponse> responses = response.collectList().block(); | ||
| System.out.println(responses); | ||
|
|
||
| assertThat(responses).isNotNull(); | ||
| assertThat(responses.stream() | ||
| .filter(r -> r.message() != null) | ||
| .map(r -> r.message().thinking()) | ||
| .collect(Collectors.joining(System.lineSeparator()))).contains("Sofia"); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assertion is failing for me. |
||
|
|
||
| ChatResponse lastResponse = responses.get(responses.size() - 1); | ||
| assertThat(lastResponse.message().content()).isEmpty(); | ||
| assertNull(lastResponse.message().thinking()); | ||
| assertThat(lastResponse.done()).isTrue(); | ||
| } | ||
|
|
||
| @Test | ||
| public void streamChatWithThinking() { | ||
| var request = ChatRequest.builder(THINKING_MODEL) | ||
| .stream(true) | ||
| .think(true) | ||
| .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) | ||
| .options(OllamaOptions.builder().temperature(0.9).build().toMap()) | ||
| .build(); | ||
|
|
||
| Flux<ChatResponse> response = getOllamaApi().streamingChat(request); | ||
|
|
||
| List<ChatResponse> responses = response.collectList().block(); | ||
| System.out.println(responses); | ||
|
|
||
| assertThat(responses).isNotNull(); | ||
| assertThat(responses.stream() | ||
| .filter(r -> r.message() != null) | ||
| .map(r -> r.message().thinking()) | ||
| .collect(Collectors.joining(System.lineSeparator()))).contains("solar"); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assertion is failing for me. |
||
|
|
||
| ChatResponse lastResponse = responses.get(responses.size() - 1); | ||
| assertThat(lastResponse.message().content()).isEmpty(); | ||
| assertNull(lastResponse.message().thinking()); | ||
| assertThat(lastResponse.done()).isTrue(); | ||
| } | ||
|
|
||
| @Test | ||
| public void streamChatWithoutThinking() { | ||
| var request = ChatRequest.builder(THINKING_MODEL) | ||
| .stream(true) | ||
| .think(false) | ||
| .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) | ||
| .options(OllamaOptions.builder().temperature(0.9).build().toMap()) | ||
| .build(); | ||
|
|
||
| Flux<ChatResponse> response = getOllamaApi().streamingChat(request); | ||
|
|
||
| List<ChatResponse> responses = response.collectList().block(); | ||
| System.out.println(responses); | ||
|
|
||
| assertThat(responses).isNotNull(); | ||
|
|
||
| assertThat(responses.stream() | ||
| .filter(r -> r.message() != null) | ||
| .map(r -> r.message().content()) | ||
| .collect(Collectors.joining(System.lineSeparator()))).contains("Earth"); | ||
|
|
||
| assertThat(responses.stream().filter(r -> r.message() != null).allMatch(r -> r.message().thinking() == null)) | ||
| .isTrue(); | ||
|
|
||
| ChatResponse lastResponse = responses.get(responses.size() - 1); | ||
| assertThat(lastResponse.message().content()).isEmpty(); | ||
| assertNull(lastResponse.message().thinking()); | ||
| assertThat(lastResponse.done()).isTrue(); | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions { | |
| @Nullable | ||
| Double getTopP(); | ||
|
|
||
| /** | ||
| * Returns the think flag to use for the chat. | ||
| * @return the think flag to use for the chat | ||
| */ | ||
| @Nullable | ||
| default Boolean getThink() { | ||
| return false; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't add 'think' to default chat option as a simple boolean is not sufficient to capture the full range of options across models, e.g. there is often a thinking token budget or other parameters. |
||
| } | ||
|
|
||
| /** | ||
| * Returns a copy of this {@link ChatOptions}. | ||
| * @return a copy of this {@link ChatOptions} | ||
|
|
@@ -158,6 +167,13 @@ interface Builder { | |
| */ | ||
| Builder topP(Double topP); | ||
|
|
||
| /** | ||
| * Builds with the think to use for the chat. | ||
| * @param think Whether to enable thinking mode | ||
| * @return the builder. | ||
| */ | ||
| Builder think(Boolean think); | ||
|
|
||
| /** | ||
| * Build the {@link ChatOptions}. | ||
| * @return the Chat options. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assertion isn't passing.