diff --git a/docs/changelog/137677.yaml b/docs/changelog/137677.yaml new file mode 100644 index 0000000000000..56b41374dc73c --- /dev/null +++ b/docs/changelog/137677.yaml @@ -0,0 +1,5 @@ +pr: 137677 +summary: "[Inference] Implementing the completion task type on EIS" +area: "Inference" +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8a6d2626ad521..e1d3e3785748a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -88,6 +88,7 @@ public class ElasticInferenceService extends SenderService { public static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, + TaskType.COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING ); @@ -103,6 +104,7 @@ public class ElasticInferenceService extends SenderService { */ private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, + TaskType.COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING ); @@ -154,7 +156,7 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - if (model instanceof ElasticInferenceServiceCompletionModel == false) { + if (model instanceof ElasticInferenceServiceCompletionModel == false || model.getTaskType() != TaskType.CHAT_COMPLETION) { listener.onFailure(createInvalidModelException(model)); return; } @@ -359,7 +361,7 @@ private static ElasticInferenceServiceModel createModel( context, chunkingSettings ); - case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel( + case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel( inferenceEntityId, taskType, NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index e980d7f713495..180ecd84aed3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -11,22 +11,28 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; @@ -35,6 +41,8 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { + public static final String USER_ROLE = "user"; + static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( "elastic dense text embedding", ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse @@ -45,6 +53,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response) ); + static final ResponseHandler COMPLETION_HANDLER = new ElasticInferenceServiceResponseHandler( + "elastic completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + private final Sender sender; private final ServiceComponents serviceComponents; @@ -108,4 +121,25 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings"); return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); } + + @Override + public ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map taskSettings) { + var threadPool = serviceComponents.threadPool(); + + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + (chatCompletionInput) -> new ElasticInferenceServiceUnifiedChatCompletionRequest( + new UnifiedChatInput(chatCompletionInput.getInputs(), USER_ROLE, chatCompletionInput.stream()), + model, + traceContext, + extractRequestMetadataFromThreadContext(threadPool.getThreadContext()) + ), + ChatCompletionInput.class + ); + + var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s completion", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)); + return new SenderExecutableAction(sender, manager, errorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index 4f8a9c9ec20a4..550c1a119fb66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -8,10 +8,13 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import java.util.Map; + public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); @@ -19,4 +22,6 @@ public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceRerankModel model); ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model); + + ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 969bf06d47fe0..6757cb335d112 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -18,16 +18,18 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; +import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; import java.util.Objects; -public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { +public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceExecutableActionModel { public static ElasticInferenceServiceCompletionModel of( ElasticInferenceServiceCompletionModel model, @@ -49,7 +51,7 @@ public ElasticInferenceServiceCompletionModel( String service, Map serviceSettings, Map taskSettings, - Map secrets, + @Nullable Map secrets, ElasticInferenceServiceComponents elasticInferenceServiceComponents, ConfigurationParseContext context ) { @@ -70,7 +72,6 @@ public ElasticInferenceServiceCompletionModel( ) { super(model, serviceSettings); this.uri = createUri(); - } public ElasticInferenceServiceCompletionModel( @@ -88,9 +89,7 @@ public ElasticInferenceServiceCompletionModel( serviceSettings, elasticInferenceServiceComponents ); - this.uri = createUri(); - } @Override @@ -120,5 +119,8 @@ private URI createUri() throws ElasticsearchStatusException { } } - // TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings). + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 2f41285a3345b..13684f36e77e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -82,6 +82,6 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return true; + return unifiedChatInput.stream(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4b17cab04471a..de9910ebf2771 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -491,7 +491,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. " + + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank, completion]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -1132,7 +1132,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson)); var model = new ElasticInferenceServiceCompletionModel( "id", - TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, "elastic", new ElasticInferenceServiceCompletionServiceSettings("model_id"), EmptyTaskSettings.INSTANCE, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index d0d3a67b2d9d5..49b98039afd4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -229,8 +229,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - // EIS does not yet support completions so this model will be ignored - EnumSet.of(TaskType.COMPLETION) + EnumSet.noneOf(TaskType.class) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 58750e7d8c456..89ea08edc46fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -47,4 +47,63 @@ public void testOverridingModelId() { assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id")); assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); } + + public void testUriCreation() { + var url = "http://eis-gateway.com"; + var model = createModel(url, "my-model-id"); + + var uri = model.uri(); + assertThat(uri.toString(), is(url + "/api/v1/chat")); + } + + public void testGetServiceSettings() { + var modelId = "test-model"; + var model = createModel("http://eis-gateway.com", modelId); + + var serviceSettings = model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is(modelId)); + } + + public void testGetTaskType() { + var model = createModel("http://eis-gateway.com", "my-model-id"); + assertThat(model.getTaskType(), is(TaskType.COMPLETION)); + } + + public void testGetInferenceEntityId() { + var inferenceEntityId = "test-id"; + var model = new ElasticInferenceServiceCompletionModel( + inferenceEntityId, + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("my-model-id"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of("http://eis-gateway.com") + ); + + assertThat(model.getInferenceEntityId(), is(inferenceEntityId)); + } + + public void testModelWithOverriddenServiceSettings() { + var originalModel = createModel("http://eis-gateway.com", "original-model"); + var newServiceSettings = new ElasticInferenceServiceCompletionServiceSettings("new-model"); + + var overriddenModel = new ElasticInferenceServiceCompletionModel(originalModel, newServiceSettings); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("new-model")); + assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); + assertThat(overriddenModel.uri().toString(), is(originalModel.uri().toString())); + } + + public static ElasticInferenceServiceCompletionModel createModel(String url, String modelId) { + return new ElasticInferenceServiceCompletionModel( + "id", + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings(modelId), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java index b067350e26aa5..aaf57345a8512 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java @@ -14,11 +14,13 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.request.OpenAiUnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.ArrayList; +import java.util.List; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; @@ -67,4 +69,152 @@ public void testModelUserFieldsSerialization() throws IOException { assertJsonEquals(jsonString, expectedJson); } + public void testSerialization_NonStreaming_ForCompletion() throws IOException { + // Test non-streaming case (used for COMPLETION task type) + var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "What is 2+2?", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_MultipleInputs_NonStreaming() throws IOException { + // Test multiple inputs converted to messages (used for COMPLETION task type) + var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?", "What is the capital of France?"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "What is 2+2?", + "role": "user" + }, + { + "content": "What is the capital of France?", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_EmptyInput_NonStreaming() throws IOException { + var unifiedChatInput = new UnifiedChatInput(List.of(""), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_AlwaysSetsNToOne_NonStreaming() throws IOException { + // Verify n is always 1 regardless of number of inputs + var unifiedChatInput = new UnifiedChatInput(List.of("input1", "input2", "input3"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "input1", + "role": "user" + }, + { + "content": "input2", + "role": "user" + }, + { + "content": "input3", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_AllMessagesHaveUserRole_NonStreaming() throws IOException { + // Verify all messages have "user" role when converting from simple inputs + var unifiedChatInput = new UnifiedChatInput(List.of("first", "second", "third"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "test-model"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "first", + "role": "user" + }, + { + "content": "second", + "role": "user" + }, + { + "content": "third", + "role": "user" + } + ], + "model": "test-model", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..48dc306287b2b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java @@ -0,0 +1,243 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class ElasticInferenceServiceUnifiedChatCompletionRequestTests extends ESTestCase { + + public void testCreateHttpRequest_SingleInput() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is(url + "/api/v1/chat")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(1)); + assertThat(messages.get(0).get("content"), is(input)); + assertThat(messages.get(0).get("role"), is("user")); + } + + public void testCreateHttpRequest_MultipleInputs() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var inputs = List.of("What is 2+2?", "What is the capital of France?"); + + var request = createRequest(url, modelId, inputs, false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(2)); + assertThat(messages.get(0).get("content"), is(inputs.get(0))); + assertThat(messages.get(0).get("role"), is("user")); + assertThat(messages.get(1).get("content"), is(inputs.get(1))); + assertThat(messages.get(1).get("role"), is("user")); + } + + public void testCreateHttpRequest_NonStreaming() throws IOException { + // Test non-streaming case (used for COMPLETION task type) + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(false)); + assertFalse(request.isStreaming()); + } + + public void testCreateHttpRequest_Streaming() throws IOException { + // Test streaming case (used for CHAT_COMPLETION task type) + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(true)); + assertTrue(request.isStreaming()); + } + + public void testGetURI() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertThat(request.getURI().toString(), is(url + "/api/v1/chat")); + } + + public void testGetInferenceEntityId() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var inferenceEntityId = "test-endpoint-id"; + + var model = new ElasticInferenceServiceCompletionModel( + inferenceEntityId, + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings(modelId), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + + var unifiedChatInput = new UnifiedChatInput(List.of("input"), "user", false); + var request = new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext("trace-parent", "trace-state"), + randomElasticInferenceServiceRequestMetadata() + ); + + assertThat(request.getInferenceEntityId(), is(inferenceEntityId)); + } + + public void testTruncate_ReturnsSameInstance() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var truncatedRequest = request.truncate(); + + // Should return the same instance (no truncation) + assertThat(truncatedRequest, is(request)); + + // Verify content is unchanged + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(1)); + assertThat(messages.get(0).get("content"), is(input)); + } + + public void testGetTruncationInfo_ReturnsNull() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertThat(request.getTruncationInfo(), nullValue()); + } + + public void testIsStreaming_NonStreamingReturnsFalse() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertFalse(request.isStreaming()); + } + + public void testIsStreaming_StreamingReturnsTrue() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), true); + + assertTrue(request.isStreaming()); + } + + public void testTraceContextPropagatedThroughHTTPHeaders() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var traceParent = randomAlphaOfLength(10); + var traceState = randomAlphaOfLength(10); + + var model = ElasticInferenceServiceCompletionModelTests.createModel(url, modelId); + var unifiedChatInput = new UnifiedChatInput(List.of("input"), "user", false); + var request = new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext(traceParent, traceState), + randomElasticInferenceServiceRequestMetadata() + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + } + + private ElasticInferenceServiceUnifiedChatCompletionRequest createRequest( + String url, + String modelId, + List inputs, + boolean stream + ) { + var model = ElasticInferenceServiceCompletionModelTests.createModel(url, modelId); + var unifiedChatInput = new UnifiedChatInput(inputs, "user", stream); + + return new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata() + ); + } +}