Skip to content

Commit 683b9f9

Browse files
author
Tomasz Forys
committed
GH-4985: MessageChatMemoryAdvisor with conversationId supplier
Implementation of GH-4985 (#4985) * conversationId supplier support on MessageChatMemoryAdvisor Signed-off-by: Tomasz Forys <[email protected]>
1 parent 3b14b4a commit 683b9f9

File tree

2 files changed

+86
-10
lines changed

2 files changed

+86
-10
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.concurrent.atomic.AtomicReference;
2122

2223
import org.junit.jupiter.api.Disabled;
2324
import org.junit.jupiter.api.Test;
@@ -133,6 +134,70 @@ void shouldHandleMultipleUserMessagesInPrompt() {
133134
assertThat(followUpAnswer).containsIgnoringCase("David");
134135
}
135136

137+
/**
138+
* Tests that the advisor correctly uses a conversation ID supplier when provided.
139+
*/
140+
@Test
141+
protected void testUseSupplierConversationId() {
142+
// Arrange
143+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
144+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
145+
.build();
146+
147+
// ConversationId circular iterator
148+
String firstConversationId = "conversationId-1";
149+
String secondConversationId = "conversationId-2";
150+
AtomicReference<String> conversationIdHolder = new AtomicReference<>(firstConversationId);
151+
152+
// Create advisor with conversation id supplier returning conversationId interchangeable
153+
var advisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationIdSupplier(conversationIdHolder::get).build();
154+
155+
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
156+
157+
String firstQuestion = "What is the capital of Germany?";
158+
String firstAnswer = chatClient.prompt()
159+
.user(firstQuestion)
160+
.call()
161+
.content();
162+
logger.info("First question: {}", firstQuestion);
163+
logger.info("First answer: {}", firstAnswer);
164+
// Assert response is relevant
165+
assertThat(firstAnswer).containsIgnoringCase("Berlin");
166+
167+
conversationIdHolder.set(secondConversationId);
168+
String secondQuestion = "What is the capital of Poland?";
169+
String secondAnswer = chatClient.prompt()
170+
.user(secondQuestion)
171+
.call()
172+
.content();
173+
logger.info("Second question: {}", secondQuestion);
174+
logger.info("Second answer: {}", secondAnswer);
175+
// Assert response is relevant
176+
assertThat(secondAnswer).containsIgnoringCase("Warsaw");
177+
178+
conversationIdHolder.set(firstConversationId);
179+
String thirdQuestion = "What is the capital of Spain?";
180+
String thirdAnswer = chatClient.prompt()
181+
.user(thirdQuestion)
182+
.call()
183+
.content();
184+
logger.info("Third question: {}", thirdQuestion);
185+
logger.info("Third answer: {}", thirdAnswer);
186+
// Assert response is relevant
187+
assertThat(thirdAnswer).containsIgnoringCase("Madrid");
188+
189+
// Verify first conversation memory contains the firstQuestion, firstAnswer, thirdQuestion and thirdAnswer
190+
List<Message> firstMemoryMessages = chatMemory.get(firstConversationId);
191+
assertThat(firstMemoryMessages).hasSize(4);
192+
assertThat(firstMemoryMessages.get(0).getText()).isEqualTo(firstQuestion);
193+
assertThat(firstMemoryMessages.get(2).getText()).isEqualTo(thirdQuestion);
194+
195+
// Verify second conversation memory contains the secondQuestion and secondAnswer
196+
List<Message> secondMemoryMessages = chatMemory.get(secondConversationId);
197+
assertThat(secondMemoryMessages).hasSize(2);
198+
assertThat(secondMemoryMessages.get(0).getText()).isEqualTo(secondQuestion);
199+
}
200+
136201
@Test
137202
void shouldHandleNonExistentConversation() {
138203
testHandleNonExistentConversation();
@@ -157,7 +222,6 @@ void shouldStoreCompleteContentInStreamingMode() {
157222
String userInput = "Tell me a short joke about programming";
158223

159224
// Collect the streaming responses
160-
List<String> streamedResponses = new ArrayList<>();
161225
chatClient.prompt()
162226
.user(userInput)
163227
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.function.Supplier;
2122

2223
import reactor.core.publisher.Flux;
2324
import reactor.core.publisher.Mono;
@@ -42,22 +43,23 @@
4243
* @author Christian Tzolov
4344
* @author Mark Pollack
4445
* @author Thomas Vitale
46+
* @author Tomasz Forys
4547
* @since 1.0.0
4648
*/
4749
public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {
4850

4951
private final ChatMemory chatMemory;
5052

51-
private final String defaultConversationId;
53+
private final Supplier<String> defaultConversationId;
5254

5355
private final int order;
5456

5557
private final Scheduler scheduler;
5658

57-
private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order,
58-
Scheduler scheduler) {
59+
private MessageChatMemoryAdvisor(ChatMemory chatMemory, Supplier<String> defaultConversationId, int order,
60+
Scheduler scheduler) {
5961
Assert.notNull(chatMemory, "chatMemory cannot be null");
60-
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
62+
Assert.hasText(defaultConversationId.get(), "defaultConversationId cannot be null or empty");
6163
Assert.notNull(scheduler, "scheduler cannot be null");
6264
this.chatMemory = chatMemory;
6365
this.defaultConversationId = defaultConversationId;
@@ -77,7 +79,7 @@ public Scheduler getScheduler() {
7779

7880
@Override
7981
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
80-
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
82+
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId.get());
8183

8284
// 1. Retrieve the chat memory for the current conversation.
8385
List<Message> memoryMessages = this.chatMemory.get(conversationId);
@@ -108,7 +110,7 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
108110
.map(g -> (Message) g.getOutput())
109111
.toList();
110112
}
111-
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
113+
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId.get()),
112114
assistantMessages);
113115
return chatClientResponse;
114116
}
@@ -134,7 +136,7 @@ public static Builder builder(ChatMemory chatMemory) {
134136

135137
public static final class Builder {
136138

137-
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
139+
private Supplier<String> conversationIdSupplier = () -> ChatMemory.DEFAULT_CONVERSATION_ID;
138140

139141
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
140142

@@ -152,7 +154,17 @@ private Builder(ChatMemory chatMemory) {
152154
* @return the builder
153155
*/
154156
public Builder conversationId(String conversationId) {
155-
this.conversationId = conversationId;
157+
this.conversationIdSupplier = () -> conversationId;
158+
return this;
159+
}
160+
161+
/**
162+
* Set the conversation id supplier.
163+
* @param conversationIdSupplier the conversation id supplier
164+
* @return the builder
165+
*/
166+
public Builder conversationIdSupplier(Supplier<String> conversationIdSupplier) {
167+
this.conversationIdSupplier = conversationIdSupplier;
156168
return this;
157169
}
158170

@@ -176,7 +188,7 @@ public Builder scheduler(Scheduler scheduler) {
176188
* @return the advisor
177189
*/
178190
public MessageChatMemoryAdvisor build() {
179-
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler);
191+
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationIdSupplier, this.order, this.scheduler);
180192
}
181193

182194
}

0 commit comments

Comments
 (0)