diff --git a/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionRequest.java b/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionRequest.java index 6ea32e6..38e5a67 100644 --- a/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionRequest.java +++ b/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionRequest.java @@ -64,6 +64,12 @@ public final class ChatCompletionRequest { @JsonProperty private final Boolean parallelToolCalls; @JsonProperty + private final Boolean store; + @JsonProperty + private final Map metadata; + @JsonProperty + private final String serviceTier; + @JsonProperty @Deprecated private final List functions; @JsonProperty @@ -90,6 +96,9 @@ private ChatCompletionRequest(Builder builder) { this.tools = builder.tools; this.toolChoice = builder.toolChoice; this.parallelToolCalls = builder.parallelToolCalls; + this.store = builder.store; + this.metadata = builder.metadata; + this.serviceTier = builder.serviceTier; this.functions = builder.functions; this.functionCall = builder.functionCall; } @@ -170,6 +179,18 @@ public Boolean parallelToolCalls() { return parallelToolCalls; } + public Boolean store() { + return store; + } + + public Map metadata() { + return metadata; + } + + public String serviceTier() { + return serviceTier; + } + @Deprecated public List functions() { return functions; @@ -207,6 +228,9 @@ private boolean equalTo(ChatCompletionRequest another) { && Objects.equals(tools, another.tools) && Objects.equals(toolChoice, another.toolChoice) && Objects.equals(parallelToolCalls, another.parallelToolCalls) + && Objects.equals(store, another.store) + && Objects.equals(metadata, another.metadata) + && Objects.equals(serviceTier, another.serviceTier) && Objects.equals(functions, another.functions) && Objects.equals(functionCall, another.functionCall); } @@ -233,6 +257,9 @@ public int hashCode() { h += (h << 5) + Objects.hashCode(tools); h += (h << 5) + Objects.hashCode(toolChoice); h += (h << 5) + Objects.hashCode(parallelToolCalls); + h += (h << 5) + Objects.hashCode(store); + h += (h << 5) + Objects.hashCode(metadata); + h += (h << 5) + Objects.hashCode(serviceTier); h += (h << 5) + Objects.hashCode(functions); h += (h << 5) + Objects.hashCode(functionCall); return h; @@ -260,6 +287,9 @@ public String toString() { + ", tools=" + tools + ", toolChoice=" + toolChoice + ", parallelToolCalls=" + parallelToolCalls + + ", store=" + store + + ", metadata=" + metadata + + ", serviceTier=" + serviceTier + ", functions=" + functions + ", functionCall=" + functionCall + "}"; @@ -293,6 +323,9 @@ public static final class Builder { private List tools; private Object toolChoice; private Boolean parallelToolCalls; + private Boolean store; + private Map metadata; + private String serviceTier; @Deprecated private List functions; @Deprecated @@ -321,6 +354,9 @@ public Builder from(ChatCompletionRequest instance) { tools(instance.tools); toolChoice(instance.toolChoice); parallelToolCalls(instance.parallelToolCalls); + store(instance.store); + metadata(instance.metadata); + serviceTier(instance.serviceTier); functions(instance.functions); functionCall(instance.functionCall); return this; @@ -503,6 +539,23 @@ public Builder parallelToolCalls(Boolean parallelToolCalls) { return this; } + public Builder store(Boolean store) { + this.store = store; + return this; + } + + public Builder metadata(Map metadata) { + if (metadata != null) { + this.metadata = unmodifiableMap(metadata); + } + return this; + } + + public Builder serviceTier(String serviceTier) { + this.serviceTier = serviceTier; + return this; + } + @Deprecated public Builder functions(Function... functions) { return functions(asList(functions)); diff --git a/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionResponse.java b/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionResponse.java index 8db78a8..8aae68b 100644 --- a/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionResponse.java +++ b/src/main/java/dev/ai4j/openai4j/chat/ChatCompletionResponse.java @@ -31,6 +31,8 @@ public final class ChatCompletionResponse { private final Usage usage; @JsonProperty private final String systemFingerprint; + @JsonProperty + private final String serviceTier; private ChatCompletionResponse(Builder builder) { this.id = builder.id; @@ -39,6 +41,7 @@ private ChatCompletionResponse(Builder builder) { this.choices = builder.choices; this.usage = builder.usage; this.systemFingerprint = builder.systemFingerprint; + this.serviceTier = builder.serviceTier; } public String id() { @@ -65,6 +68,10 @@ public String systemFingerprint() { return systemFingerprint; } + public String serviceTier() { + return serviceTier; + } + /** * Convenience method to get the content of the message from the first choice. */ @@ -85,7 +92,8 @@ private boolean equalTo(ChatCompletionResponse another) { && Objects.equals(model, another.model) && Objects.equals(choices, another.choices) && Objects.equals(usage, another.usage) - && Objects.equals(systemFingerprint, another.systemFingerprint); + && Objects.equals(systemFingerprint, another.systemFingerprint) + && Objects.equals(serviceTier, another.serviceTier); } @Override @@ -97,6 +105,7 @@ public int hashCode() { h += (h << 5) + Objects.hashCode(choices); h += (h << 5) + Objects.hashCode(usage); h += (h << 5) + Objects.hashCode(systemFingerprint); + h += (h << 5) + Objects.hashCode(serviceTier); return h; } @@ -109,6 +118,7 @@ public String toString() { + ", choices=" + choices + ", usage=" + usage + ", systemFingerprint=" + systemFingerprint + + ", serviceTier=" + serviceTier + "}"; } @@ -127,6 +137,7 @@ public static final class Builder { private List choices; private Usage usage; private String systemFingerprint; + private String serviceTier; private Builder() { } @@ -163,6 +174,11 @@ public Builder systemFingerprint(String systemFingerprint) { return this; } + public Builder serviceTier(String serviceTier) { + this.serviceTier = serviceTier; + return this; + } + public ChatCompletionResponse build() { return new ChatCompletionResponse(this); } diff --git a/src/main/java/dev/ai4j/openai4j/shared/PromptTokensDetails.java b/src/main/java/dev/ai4j/openai4j/shared/PromptTokensDetails.java new file mode 100644 index 0000000..310f9b9 --- /dev/null +++ b/src/main/java/dev/ai4j/openai4j/shared/PromptTokensDetails.java @@ -0,0 +1,77 @@ +package dev.ai4j.openai4j.shared; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; + +import java.util.Objects; + +@JsonDeserialize(builder = PromptTokensDetails.Builder.class) +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public final class PromptTokensDetails { + + @JsonProperty + private final Integer cachedTokens; + + private PromptTokensDetails(Builder builder) { + this.cachedTokens = builder.cachedTokens; + } + + public Integer cachedTokens() { + return cachedTokens; + } + + @Override + public boolean equals(Object another) { + if (this == another) return true; + return another instanceof PromptTokensDetails + && equalTo((PromptTokensDetails) another); + } + + private boolean equalTo(PromptTokensDetails another) { + return Objects.equals(cachedTokens, another.cachedTokens); + } + + @Override + public int hashCode() { + int h = 5381; + h += (h << 5) + Objects.hashCode(cachedTokens); + return h; + } + + @Override + public String toString() { + return "PromptTokensDetails{" + + "cachedTokens=" + cachedTokens + + "}"; + } + + public static Builder builder() { + return new Builder(); + } + + @JsonPOJOBuilder(withPrefix = "") + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public static final class Builder { + + private Integer cachedTokens; + + private Builder() { + } + + public Builder cachedTokens(Integer cachedTokens) { + this.cachedTokens = cachedTokens; + return this; + } + + public PromptTokensDetails build() { + return new PromptTokensDetails(this); + } + } +} diff --git a/src/main/java/dev/ai4j/openai4j/shared/Usage.java b/src/main/java/dev/ai4j/openai4j/shared/Usage.java index efe0e25..168feba 100644 --- a/src/main/java/dev/ai4j/openai4j/shared/Usage.java +++ b/src/main/java/dev/ai4j/openai4j/shared/Usage.java @@ -20,6 +20,8 @@ public final class Usage { @JsonProperty private final Integer promptTokens; @JsonProperty + private final PromptTokensDetails promptTokensDetails; + @JsonProperty private final Integer completionTokens; @JsonProperty private final CompletionTokensDetails completionTokensDetails; @@ -27,6 +29,7 @@ public final class Usage { private Usage(Builder builder) { this.totalTokens = builder.totalTokens; this.promptTokens = builder.promptTokens; + this.promptTokensDetails = builder.promptTokensDetails; this.completionTokens = builder.completionTokens; this.completionTokensDetails = builder.completionTokensDetails; } @@ -39,6 +42,10 @@ public Integer promptTokens() { return promptTokens; } + public PromptTokensDetails promptTokensDetails() { + return promptTokensDetails; + } + public Integer completionTokens() { return completionTokens; } @@ -57,6 +64,7 @@ public boolean equals(Object another) { private boolean equalTo(Usage another) { return Objects.equals(totalTokens, another.totalTokens) && Objects.equals(promptTokens, another.promptTokens) + && Objects.equals(promptTokensDetails, another.promptTokensDetails) && Objects.equals(completionTokens, another.completionTokens) && Objects.equals(completionTokensDetails, another.completionTokensDetails); } @@ -66,6 +74,7 @@ public int hashCode() { int h = 5381; h += (h << 5) + Objects.hashCode(totalTokens); h += (h << 5) + Objects.hashCode(promptTokens); + h += (h << 5) + Objects.hashCode(promptTokensDetails); h += (h << 5) + Objects.hashCode(completionTokens); h += (h << 5) + Objects.hashCode(completionTokensDetails); return h; @@ -76,6 +85,7 @@ public String toString() { return "Usage{" + "totalTokens=" + totalTokens + ", promptTokens=" + promptTokens + + ", promptTokensDetails=" + promptTokensDetails + ", completionTokens=" + completionTokens + ", completionTokensDetails=" + completionTokensDetails + "}"; @@ -92,6 +102,7 @@ public static final class Builder { private Integer totalTokens; private Integer promptTokens; + private PromptTokensDetails promptTokensDetails; private Integer completionTokens; private CompletionTokensDetails completionTokensDetails; @@ -108,6 +119,11 @@ public Builder promptTokens(Integer promptTokens) { return this; } + public Builder promptTokensDetails(PromptTokensDetails promptTokensDetails) { + this.promptTokensDetails = promptTokensDetails; + return this; + } + public Builder completionTokens(Integer completionTokens) { this.completionTokens = completionTokens; return this; diff --git a/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json b/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json index 0244edf..89afd1e 100644 --- a/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json +++ b/src/main/resources/META-INF/native-image/dev.ai4j/openai4j/reflect-config.json @@ -458,6 +458,15 @@ "allDeclaredFields": true, "allPublicFields": true }, + { + "name": "dev.ai4j.openai4j.shared.PromptTokensDetails", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, { "name": "dev.ai4j.openai4j.shared.StreamOptions", "allDeclaredConstructors": true, diff --git a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java index aeb3296..93ee120 100644 --- a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java +++ b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java @@ -7,6 +7,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; @@ -87,12 +88,24 @@ void testCustomizableApi(ChatCompletionModel model) { .user("Klaus") .responseFormat(TEXT) .seed(42) + .store(true) + .metadata(new HashMap(){{ + put("one", "1"); + put("two", "2"); + }}) + .serviceTier("default") .build(); // when ChatCompletionResponse response = client.chatCompletion(request).execute(); // then + assertThat(response.id()).isNotBlank(); + assertThat(response.created()).isPositive(); + assertThat(response.model()).isNotBlank(); + // TODO assertThat(response.systemFingerprint()).isNotBlank(); + assertThat(response.serviceTier()).isNotBlank(); + assertThat(response.choices()).hasSize(1); assertThat(response.choices().get(0).message().content()).containsIgnoringCase("hello world"); @@ -100,9 +113,12 @@ void testCustomizableApi(ChatCompletionModel model) { Usage usage = response.usage(); assertThat(usage.promptTokens()).isGreaterThan(0); + assertThat(usage.promptTokensDetails().cachedTokens()).isEqualTo(0); + assertThat(usage.completionTokens()).isGreaterThan(0); assertThat(usage.completionTokensDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.totalTokens()).isGreaterThan(usage.promptTokens() + usage.completionTokens()); + + assertThat(usage.totalTokens()).isEqualTo(usage.promptTokens() + usage.completionTokens()); } @ParameterizedTest