|
1 | 1 | package net.ravendb.client.test.client.documents.AI; |
2 | 2 |
|
3 | 3 | import net.ravendb.client.RemoteTestBase; |
| 4 | +import net.ravendb.client.documents.AI.AiAnswer; |
| 5 | +import net.ravendb.client.documents.AI.AiConversation; |
| 6 | +import net.ravendb.client.documents.AI.AiConversationCreationOptions; |
| 7 | +import net.ravendb.client.documents.AI.AiConversationResult; |
4 | 8 | import net.ravendb.client.documents.IDocumentStore; |
| 9 | +import net.ravendb.client.documents.operations.AI.AiConnectionString; |
| 10 | +import net.ravendb.client.documents.operations.AI.AiModelType; |
| 11 | +import net.ravendb.client.documents.operations.AI.OpenAiSettings; |
5 | 12 | import net.ravendb.client.documents.operations.AI.agents.*; |
6 | 13 | import net.ravendb.client.documents.operations.AI.agents.AiAgentConfiguration; |
| 14 | +import net.ravendb.client.documents.operations.IMaintenanceOperation; |
| 15 | +import net.ravendb.client.documents.operations.connectionStrings.GetConnectionStringsOperation; |
| 16 | +import net.ravendb.client.documents.operations.connectionStrings.GetConnectionStringsResult; |
7 | 17 | import net.ravendb.client.documents.operations.connectionStrings.PutConnectionStringOperation; |
| 18 | +import net.ravendb.client.documents.operations.connectionStrings.PutConnectionStringResult; |
8 | 19 | import net.ravendb.client.documents.operations.etl.RavenConnectionString; |
| 20 | +import net.ravendb.client.documents.session.IDocumentSession; |
9 | 21 | import net.ravendb.client.infrastructure.EnableOnServer; |
10 | 22 | import org.junit.jupiter.api.Assertions; |
| 23 | +import org.junit.jupiter.api.Disabled; |
11 | 24 | import org.junit.jupiter.api.Test; |
12 | 25 | import static org.assertj.core.api.Assertions.assertThat; |
13 | 26 | import static org.junit.jupiter.api.Assertions.assertEquals; |
14 | 27 | import static org.junit.jupiter.api.Assertions.assertNotNull; |
15 | 28 | import java.util.ArrayList; |
| 29 | +import java.util.Arrays; |
16 | 30 | import java.util.Collections; |
| 31 | +import java.util.List; |
17 | 32 |
|
18 | 33 | @EnableOnServer(thresholdVersion = "7.1") |
19 | 34 | public class AiAgentTests extends RemoteTestBase { |
20 | 35 |
|
| 36 | + @Disabled |
| 37 | + @Test |
| 38 | + public void AiAgentClientApiBasicTest(){ |
| 39 | + String apiKey = System.getenv("RAVENDB_JAVA_TESTS_OPENAI_API_KEY"); |
| 40 | + assertNotNull(apiKey, "OpenAI API key is not set in environment variable RAVENDB_JAVA_TESTS_OPENAI_API_KEY"); |
| 41 | + try (IDocumentStore store = getDocumentStore()) { |
| 42 | + try (IDocumentSession session = store.openSession()) { |
| 43 | + OpenAiSettings ai = new OpenAiSettings(apiKey, |
| 44 | + "https://api.openai.com/", |
| 45 | + "gpt-4o-mini", |
| 46 | + null, |
| 47 | + null, |
| 48 | + null, |
| 49 | + 0.0); |
| 50 | + |
| 51 | + AiConnectionString openAiCs = new AiConnectionString(); |
| 52 | + openAiCs.setModelType(AiModelType.Chat); |
| 53 | + openAiCs.setName("openai"); |
| 54 | + openAiCs.setOpenAiSettings(ai); |
| 55 | + |
| 56 | + IMaintenanceOperation<PutConnectionStringResult> putOpenAi = new PutConnectionStringOperation(openAiCs); |
| 57 | + store.maintenance().send(putOpenAi); |
| 58 | + |
| 59 | + GetConnectionStringsResult connectionStrings = store.maintenance().send(new GetConnectionStringsOperation()); |
| 60 | + assertThat(connectionStrings.getAiConnectionStrings()).hasSize(1); |
| 61 | + assertThat(connectionStrings.getAiConnectionStrings().get("openai")).isNotNull(); |
| 62 | + |
| 63 | + AiAgentConfiguration agent = new AiAgentConfiguration("shopping assistant", "openai", "You are an AI agent of an online shop, helping customers answer queries about that topic only. When talking about orders or products, include the ids as well."); |
| 64 | + agent.getParameters().add(new AiAgentParameter("company")); |
| 65 | + List<AiAgentToolQuery> queries = new ArrayList<>(); |
| 66 | + queries.add(new AiAgentToolQuery("ProductSearch", "semantic search the store product catalog", "from Products where vector.search(embedding.text(Name), $query)", "{\"query\": [\"term or phrase to search in the catalog\"]}")); |
| 67 | + queries.add(new AiAgentToolQuery("RecentOrder", "Get the recent orders of the current user", "from Orders where Company = $company order by OrderedAt desc limit 10", "{}")); |
| 68 | + agent.setQueries(queries); |
| 69 | + String identifier = store.ai().createAgent(agent,AnswerSchema.INSTANCE).getIdentifier(); |
| 70 | + AiConversation chat = store.ai().conversation(identifier,"chats/", new AiConversationCreationOptions().addParameter("company", "companies/90-A")); |
| 71 | + |
| 72 | + chat.setUserPrompt("what goes well with my cheese?"); |
| 73 | + AiAnswer<AnswerSchema> result = chat.<AnswerSchema>run().get(); |
| 74 | + assertEquals(AiConversationResult.Done, result.getStatus()); |
| 75 | + assertNotNull(result.getAnswer()); |
| 76 | + assertNotNull(chat.getId()); |
| 77 | + |
| 78 | + chat.setUserPrompt("what goes well with my cheese?"); |
| 79 | + result = chat.<AnswerSchema>run().get(); |
| 80 | + assertEquals(AiConversationResult.Done, result.getStatus()); |
| 81 | + assertNotNull(result.getAnswer()); |
| 82 | + |
| 83 | + chat.setUserPrompt("what cheese goes well with italian food?"); |
| 84 | + result = chat.<AnswerSchema>run().get(); |
| 85 | + assertEquals(AiConversationResult.Done, result.getStatus()); |
| 86 | + assertNotNull(result.getAnswer()); |
| 87 | + } |
| 88 | + } catch (Exception e){ |
| 89 | + throw new RuntimeException(e); |
| 90 | + } |
| 91 | + } |
| 92 | + |
21 | 93 | @Test |
22 | 94 | public void canCreateAiAgent() { |
23 | 95 | try (IDocumentStore store = getDocumentStore()) { |
@@ -177,4 +249,18 @@ public void cannotCreateAgentWithoutSchemaOrSampleObject() { |
177 | 249 | throw new RuntimeException(e); |
178 | 250 | } |
179 | 251 | } |
| 252 | + |
| 253 | + private static class AnswerSchema { |
| 254 | + public static final AnswerSchema INSTANCE = new AnswerSchema(); |
| 255 | + public String answer = "Answer to the user question"; |
| 256 | + public boolean relevant = true; |
| 257 | + public List<String> relevantOrdersId = new ArrayList<>( |
| 258 | + Arrays.asList("The order ids relevant to the query or response") |
| 259 | + ); |
| 260 | + public List<String> matchingProductsId = new ArrayList<>( |
| 261 | + Arrays.asList("All the product ids referenced either by the user or the system") |
| 262 | + ); |
| 263 | + private AnswerSchema() { |
| 264 | + } |
| 265 | + } |
180 | 266 | } |
0 commit comments