|
6 | 6 | import org.junit.jupiter.api.Test;
|
7 | 7 | import org.junit.jupiter.params.ParameterizedTest;
|
8 | 8 | import org.junit.jupiter.params.provider.EnumSource;
|
| 9 | +import org.slf4j.Logger; |
| 10 | +import org.slf4j.LoggerFactory; |
9 | 11 |
|
10 | 12 | import java.util.Map;
|
11 | 13 | import java.util.concurrent.CompletableFuture;
|
12 | 14 | import java.util.concurrent.atomic.AtomicBoolean;
|
| 15 | +import java.util.concurrent.atomic.AtomicReference; |
13 | 16 |
|
14 | 17 | import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_4O;
|
15 | 18 | import static dev.ai4j.openai4j.chat.ChatCompletionTest.*;
|
|
22 | 25 | import static java.util.concurrent.Executors.newSingleThreadExecutor;
|
23 | 26 | import static java.util.concurrent.TimeUnit.SECONDS;
|
24 | 27 | import static org.assertj.core.api.Assertions.assertThat;
|
| 28 | +import static org.junit.jupiter.api.Assertions.fail; |
25 | 29 | import static org.junit.jupiter.params.provider.EnumSource.Mode.EXCLUDE;
|
26 | 30 | import static org.junit.jupiter.params.provider.EnumSource.Mode.INCLUDE;
|
27 | 31 |
|
28 | 32 | class ChatCompletionStreamingTest extends RateLimitAwareTest {
|
29 | 33 |
|
| 34 | + private static final Logger log = LoggerFactory.getLogger(ChatCompletionStreamingTest.class); |
| 35 | + |
30 | 36 | private final OpenAiClient client = OpenAiClient.builder()
|
31 | 37 | .baseUrl(System.getenv("OPENAI_BASE_URL"))
|
32 | 38 | .openAiApiKey(System.getenv("OPENAI_API_KEY"))
|
@@ -106,7 +112,6 @@ void testCustomizableApi(ChatCompletionModel model) throws Exception {
|
106 | 112 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = {
|
107 | 113 | "GPT_3_5_TURBO_0125", // don't have access to it yet
|
108 | 114 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models
|
109 |
| - "GPT_4_0314", // Does not support tools/functions |
110 | 115 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format
|
111 | 116 | })
|
112 | 117 | void testTools(ChatCompletionModel model) throws Exception {
|
@@ -224,7 +229,6 @@ void testTools(ChatCompletionModel model) throws Exception {
|
224 | 229 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = {
|
225 | 230 | "GPT_3_5_TURBO_0125", // don't have access to it yet
|
226 | 231 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models
|
227 |
| - "GPT_4_0314", // Does not support tools/functions |
228 | 232 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format
|
229 | 233 | })
|
230 | 234 | void testFunctions(ChatCompletionModel model) throws Exception {
|
@@ -322,7 +326,6 @@ void testFunctions(ChatCompletionModel model) throws Exception {
|
322 | 326 | "GPT_3_5_TURBO_0125", // don't have access to it yet
|
323 | 327 | "GPT_4_TURBO_PREVIEW", // keeps returning "felsius" as temp unit
|
324 | 328 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models
|
325 |
| - "GPT_4_0314", // Does not support tools/functions |
326 | 329 | "GPT_4_VISION_PREVIEW" // Does not support many things now, including logit_bias and response_format
|
327 | 330 | })
|
328 | 331 | void testToolChoice(ChatCompletionModel model) throws Exception {
|
@@ -440,7 +443,6 @@ void testToolChoice(ChatCompletionModel model) throws Exception {
|
440 | 443 | @EnumSource(value = ChatCompletionModel.class, mode = EXCLUDE, names = {
|
441 | 444 | "GPT_3_5_TURBO_0125", // don't have access to it yet
|
442 | 445 | "GPT_4_32K", "GPT_4_32K_0314", "GPT_4_32K_0613", // I don't have access to these models
|
443 |
| - "GPT_4_0314", // Does not support tools/functions |
444 | 446 | "GPT_4_VISION_PREVIEW"
|
445 | 447 | })
|
446 | 448 | void testFunctionChoice(ChatCompletionModel model) throws Exception {
|
@@ -769,45 +771,27 @@ void testCancelStreamingAfterStreamingStarted() throws Exception {
|
769 | 771 | .logStreamingResponses()
|
770 | 772 | .build();
|
771 | 773 |
|
772 |
| - AtomicBoolean streamingStarted = new AtomicBoolean(false); |
773 |
| - AtomicBoolean streamingCancelled = new AtomicBoolean(false); |
774 |
| - AtomicBoolean cancellationSucceeded = new AtomicBoolean(true); |
| 774 | + final AtomicBoolean streamingCancelled = new AtomicBoolean(false); |
| 775 | + final AtomicReference<ResponseHandle> atomicReference = new AtomicReference<>(); |
| 776 | + final CompletableFuture<Void> completableFuture = new CompletableFuture<>(); |
775 | 777 |
|
776 | 778 | ResponseHandle responseHandle = client.chatCompletion("Write a poem about AI in 10 words")
|
777 | 779 | .onPartialResponse(partialResponse -> {
|
778 |
| - streamingStarted.set(true); |
779 |
| - System.out.println("[[streaming started]]"); |
780 |
| - if (streamingCancelled.get()) { |
781 |
| - cancellationSucceeded.set(false); |
782 |
| - System.out.println("[[cancellation failed]]"); |
| 780 | + if (! streamingCancelled.getAndSet(true)) { |
| 781 | + log.info("Executor thread {}", Thread.currentThread()); |
| 782 | + atomicReference.get().cancel(); |
| 783 | + completableFuture.complete(null); |
783 | 784 | }
|
784 | 785 | })
|
785 |
| - .onComplete(() -> { |
786 |
| - cancellationSucceeded.set(false); |
787 |
| - System.out.println("[[cancellation failed]]"); |
788 |
| - }) |
789 |
| - .onError(e -> { |
790 |
| - cancellationSucceeded.set(false); |
791 |
| - System.out.println("[[cancellation failed]]"); |
792 |
| - }) |
| 786 | + .onComplete(() -> fail("Response completed")) |
| 787 | + .onError(e -> fail("Response errored")) |
793 | 788 | .execute();
|
794 | 789 |
|
795 |
| - while (!streamingStarted.get()) { |
796 |
| - Thread.sleep(10); |
797 |
| - } |
| 790 | + log.info("Test thread {}", Thread.currentThread()); |
| 791 | + atomicReference.set(responseHandle); |
| 792 | + completableFuture.get(); |
798 | 793 |
|
799 |
| - newSingleThreadExecutor().execute(() -> { |
800 |
| - responseHandle.cancel(); |
801 |
| - streamingCancelled.set(true); |
802 |
| - System.out.println("[[streaming cancelled]]"); |
803 |
| - }); |
804 |
| - |
805 |
| - while (!streamingCancelled.get()) { |
806 |
| - Thread.sleep(10); |
807 |
| - } |
808 |
| - Thread.sleep(2000); |
809 |
| - |
810 |
| - assertThat(cancellationSucceeded).isTrue(); |
| 794 | + assertThat(streamingCancelled).isTrue(); |
811 | 795 | }
|
812 | 796 |
|
813 | 797 | @Test
|
|
0 commit comments