|
2 | 2 |
|
3 | 3 | import com.fasterxml.jackson.core.JsonGenerator; |
4 | 4 | import com.fasterxml.jackson.core.type.TypeReference; |
5 | | -import com.fasterxml.jackson.databind.MapperFeature; |
6 | 5 | import com.fasterxml.jackson.databind.ObjectMapper; |
7 | 6 | import com.fasterxml.jackson.databind.node.ObjectNode; |
8 | 7 | import net.ravendb.client.documents.conventions.DocumentConventions; |
9 | | -import net.ravendb.client.documents.AI.AiStreamCallback; |
10 | | -import net.ravendb.client.documents.operations.AI.agents.AiAgentActionResponse; |
11 | | -import net.ravendb.client.documents.AI.AiConversationCreationOptions; |
12 | 8 | import net.ravendb.client.documents.operations.AI.agents.ConversationResult; |
| 9 | +import net.ravendb.client.documents.operations.AI.agents.RunConversationOperation; |
13 | 10 | import net.ravendb.client.http.IRaftCommand; |
14 | 11 | import net.ravendb.client.http.RavenCommand; |
15 | 12 | import net.ravendb.client.http.RavenCommandResponseType; |
16 | 13 | import net.ravendb.client.http.ServerNode; |
17 | 14 | import net.ravendb.client.json.ContentProviderHttpEntity; |
18 | | -import net.ravendb.client.util.RaftIdGenerator; |
19 | 15 | import net.ravendb.client.util.UrlUtils; |
20 | 16 | import org.apache.hc.client5.http.classic.methods.HttpPost; |
21 | 17 | import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; |
| 18 | +import org.apache.hc.core5.http.ClassicHttpResponse; |
22 | 19 | import org.apache.hc.core5.http.ContentType; |
23 | 20 | import java.io.*; |
24 | 21 | import java.nio.charset.StandardCharsets; |
25 | | -import java.util.List; |
| 22 | +import java.util.UUID; |
26 | 23 | import java.util.concurrent.CompletableFuture; |
27 | 24 | import java.util.concurrent.CompletionException; |
| 25 | +import java.util.concurrent.ExecutionException; |
28 | 26 | import java.util.stream.Collectors; |
| 27 | +import static net.ravendb.client.extensions.JsonExtensions.createDefaultJsonSerializer; |
| 28 | + |
29 | 29 |
|
30 | 30 | public class RunConversationCommand<TAnswer> |
31 | 31 | extends RavenCommand<ConversationResult<TAnswer>> |
32 | 32 | implements IRaftCommand { |
33 | 33 |
|
34 | | - private final String conversationId; |
35 | | - private final String agentId; |
36 | | - private final String prompt; |
37 | | - private final List<AiAgentActionResponse> actionResponses; |
38 | | - private final AiConversationCreationOptions options; |
39 | | - private final String changeVector; |
40 | | - private final String streamPropertyPath; |
41 | | - private final AiStreamCallback streamCallback; |
| 34 | + private final RunConversationOperation<TAnswer> parent; |
42 | 35 | private final DocumentConventions conventions; |
43 | 36 | private String raftId; |
44 | 37 |
|
45 | | - public RunConversationCommand( |
46 | | - String conversationId, |
47 | | - String agentId, |
48 | | - String prompt, |
49 | | - List<AiAgentActionResponse> actionResponses, |
50 | | - AiConversationCreationOptions options, |
51 | | - String changeVector, |
52 | | - DocumentConventions conventions, |
53 | | - String streamPropertyPath, |
54 | | - AiStreamCallback streamCallback){ |
| 38 | + public RunConversationCommand(RunConversationOperation<TAnswer> parent, DocumentConventions conventions) { |
55 | 39 | super((Class<ConversationResult<TAnswer>>) (Class<?>) ConversationResult.class); |
56 | | - this.conversationId = conversationId; |
57 | | - this.agentId = agentId; |
58 | | - this.prompt = prompt; |
59 | | - this.actionResponses = actionResponses; |
60 | | - this.options = options; |
61 | | - this.changeVector = changeVector; |
62 | | - this.streamPropertyPath = streamPropertyPath; |
63 | | - this.streamCallback = streamCallback; |
64 | 40 | this.conventions = conventions; |
| 41 | + this.parent = parent; |
65 | 42 |
|
66 | | - if (this.streamPropertyPath != null && this.streamCallback != null) { |
| 43 | + if (parent.getStreamPropertyPath() != null) |
67 | 44 | this.responseType = RavenCommandResponseType.RAW; |
68 | | - } |
69 | | - |
70 | | - if (conversationId != null && conversationId.endsWith("|")) { |
71 | | - this.raftId = RaftIdGenerator.newId(); |
72 | | - } |
73 | 45 | } |
74 | 46 |
|
75 | 47 | @Override |
76 | 48 | public boolean isReadRequest() { |
77 | 49 | return false; |
78 | 50 | } |
79 | 51 |
|
80 | | - @Override |
81 | | - public String getRaftUniqueRequestId() { |
82 | | - return raftId; |
83 | | - } |
84 | | - |
85 | 52 | @Override |
86 | 53 | public HttpUriRequestBase createRequest(ServerNode node) { |
87 | 54 | StringBuilder uriBuilder = new StringBuilder(); |
88 | 55 | uriBuilder.append(node.getUrl()) |
89 | 56 | .append("/databases/") |
90 | 57 | .append(node.getDatabase()) |
91 | 58 | .append("/ai/agent?") |
92 | | - .append("conversationId=").append(UrlUtils.escapeDataString(this.conversationId)) |
93 | | - .append("&agentId=").append(UrlUtils.escapeDataString(this.agentId)); |
| 59 | + .append("conversationId=").append(UrlUtils.escapeDataString(this.parent.getConversationId())) |
| 60 | + .append("&agentId=").append(UrlUtils.escapeDataString(this.parent.getAgentId())); |
| 61 | + |
| 62 | + if (this.parent.getConversationId().charAt(this.parent.getConversationId().length() - 1) == '|') { |
| 63 | + this.raftId = UUID.randomUUID().toString(); |
| 64 | + } |
94 | 65 |
|
95 | | - if (this.changeVector != null && !this.changeVector.isEmpty()) { |
96 | | - uriBuilder.append("&changeVector=").append(UrlUtils.escapeDataString(this.changeVector)); |
| 66 | + if (this.parent.getChangeVector() != null && !this.parent.getChangeVector().isEmpty()) { |
| 67 | + uriBuilder.append("&changeVector=").append(UrlUtils.escapeDataString(this.parent.getChangeVector())); |
97 | 68 | } |
98 | | - if (this.streamPropertyPath != null) { |
99 | | - uriBuilder.append("&streamPropertyPath=").append(UrlUtils.escapeDataString(this.streamPropertyPath)); |
100 | | - uriBuilder.append("&streaming=").append(UrlUtils.escapeDataString("true")); |
| 69 | + if (this.parent.getStreamPropertyPath() != null) { |
| 70 | + uriBuilder.append("&streamPropertyPath=").append(UrlUtils.escapeDataString(this.parent.getStreamPropertyPath())); |
| 71 | + uriBuilder.append("&streaming=true"); |
101 | 72 | } |
102 | 73 |
|
103 | 74 | HttpPost request = new HttpPost(uriBuilder.toString()); |
104 | 75 |
|
105 | 76 | request.setEntity(new ContentProviderHttpEntity(outputStream -> { |
106 | 77 | try (JsonGenerator generator = createSafeJsonGenerator(outputStream)) { |
107 | 78 | ObjectNode bodyObj = mapper.createObjectNode(); |
108 | | - bodyObj.set("ActionResponses", mapper.valueToTree(this.actionResponses)); |
109 | | - bodyObj.put("UserPrompt", this.prompt); |
110 | | - bodyObj.set("CreationOptions", mapper.valueToTree(this.options)); |
| 79 | + bodyObj.set("ActionResponses", mapper.valueToTree(this.parent.getActionResponses())); |
| 80 | + bodyObj.set("UserPrompt", mapper.valueToTree(this.parent.getPromptParts())); |
| 81 | + bodyObj.set("CreationOptions", mapper.valueToTree(this.parent.getOptions())); |
111 | 82 | generator.writeTree(bodyObj); |
112 | 83 | } |
113 | 84 | }, ContentType.APPLICATION_JSON,conventions)); |
114 | 85 |
|
115 | 86 | return request; |
116 | 87 | } |
117 | 88 |
|
| 89 | + @Override |
| 90 | + public String getRaftUniqueRequestId() { |
| 91 | + return raftId; |
| 92 | + } |
| 93 | + |
118 | 94 | @Override |
119 | 95 | public CompletableFuture<String> setResponseAsync(InputStream bodyStream, boolean fromCache) { |
120 | 96 | if (bodyStream == null ) { |
121 | 97 | this.throwInvalidResponse(); |
122 | 98 | } |
123 | 99 |
|
124 | | - if (this.streamPropertyPath != null && this.streamCallback != null) { |
| 100 | + if (this.parent.getStreamPropertyPath() != null && this.parent.getStreamCallback() != null) { |
125 | 101 | return processStreamingResponse(bodyStream); |
126 | 102 | } |
127 | 103 | return this.parseResponseDefaultAsync(bodyStream); |
128 | 104 | } |
129 | 105 |
|
| 106 | + @Override |
| 107 | + public void setResponseRaw(ClassicHttpResponse response, InputStream stream) throws IOException { |
| 108 | + try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"))) { |
| 109 | + String line; |
| 110 | + while ((line = reader.readLine()) != null) { |
| 111 | + line = line.trim(); |
| 112 | + |
| 113 | + if (line.startsWith("{")) { |
| 114 | + this.result = mapper.readValue(line, ConversationResult.class); |
| 115 | + break; |
| 116 | + } |
| 117 | + |
| 118 | + String unescaped = mapper.readValue(line, String.class); |
| 119 | + |
| 120 | + if (this.parent.getStreamCallback() != null) { |
| 121 | + this.parent.getStreamCallback().onChunk(unescaped).get(); |
| 122 | + } |
| 123 | + } |
| 124 | + } catch (IOException e) { |
| 125 | + throw new RuntimeException("Failed to read conversation stream", e); |
| 126 | + } catch (ExecutionException e) { |
| 127 | + throw new RuntimeException(e); |
| 128 | + } catch (InterruptedException e) { |
| 129 | + throw new RuntimeException(e); |
| 130 | + } |
| 131 | + } |
| 132 | + |
130 | 133 | @Override |
131 | 134 | public void setResponse(String response, boolean fromCache) throws IOException { |
132 | | - ObjectMapper mapper = new ObjectMapper(); |
133 | | - mapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); |
| 135 | + ObjectMapper mapper = createDefaultJsonSerializer(); |
134 | 136 | this.result = mapper.readValue(response, ConversationResult.class); |
135 | 137 | } |
136 | 138 |
|
@@ -169,9 +171,9 @@ private CompletableFuture<String> processStreamingResponse(InputStream bodyStrea |
169 | 171 | } else { |
170 | 172 | chunk = new ObjectMapper().writeValueAsString(parsed); |
171 | 173 | } |
172 | | - streamCallback.onChunk(chunk).get(); |
| 174 | + this.parent.getStreamCallback().onChunk(chunk).get(); |
173 | 175 | } catch (Exception e) { |
174 | | - streamCallback.onChunk(line).get(); |
| 176 | + this.parent.getStreamCallback().onChunk(line).get(); |
175 | 177 | } |
176 | 178 | } |
177 | 179 |
|
|
0 commit comments