Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137677.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137677
summary: "[Inference] Implementing the completion task type on EIS"
area: "Inference"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public class ElasticInferenceService extends SenderService {
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand Down Expand Up @@ -129,6 +130,7 @@ public class ElasticInferenceService extends SenderService {
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand Down Expand Up @@ -303,7 +305,7 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
if (model instanceof ElasticInferenceServiceCompletionModel == false || model.getTaskType() != TaskType.CHAT_COMPLETION) {
listener.onFailure(createInvalidModelException(model));
return;
}
Expand Down Expand Up @@ -506,7 +508,7 @@ private static ElasticInferenceServiceModel createModel(
context,
chunkingSettings
);
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,28 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
Expand All @@ -35,6 +41,8 @@

public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {

public static final String USER_ROLE = "user";

static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler(
"elastic dense text embedding",
ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse
Expand All @@ -45,6 +53,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
);

static final ResponseHandler COMPLETION_HANDLER = new ElasticInferenceServiceResponseHandler(
"elastic completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private final Sender sender;

private final ServiceComponents serviceComponents;
Expand Down Expand Up @@ -108,4 +121,25 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings");
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings) {
var threadPool = serviceComponents.threadPool();

var manager = new GenericRequestManager<>(
threadPool,
model,
COMPLETION_HANDLER,
(chatCompletionInput) -> new ElasticInferenceServiceUnifiedChatCompletionRequest(
new UnifiedChatInput(chatCompletionInput.getInputs(), USER_ROLE, false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inference API supports streaming and non-stream for the completion task type. I think we can use this constructor and it'll handle it for us:

new UnifiedChatInput(inputs, USER_ROLE)

model,
traceContext,
extractRequestMetadataFromThreadContext(threadPool.getThreadContext())
),
ChatCompletionInput.class
);

var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s completion", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
return new SenderExecutableAction(sender, manager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
package org.elasticsearch.xpack.inference.services.elastic.action;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.Map;

public interface ElasticInferenceServiceActionVisitor {

ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);

ExecutableAction create(ElasticInferenceServiceRerankModel model);

ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model);

ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel;
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel {
public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceExecutableActionModel {

public static ElasticInferenceServiceCompletionModel of(
ElasticInferenceServiceCompletionModel model,
Expand All @@ -49,7 +51,7 @@ public ElasticInferenceServiceCompletionModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets,
@Nullable Map<String, Object> secrets,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
ConfigurationParseContext context
) {
Expand All @@ -70,7 +72,6 @@ public ElasticInferenceServiceCompletionModel(
) {
super(model, serviceSettings);
this.uri = createUri();

}

public ElasticInferenceServiceCompletionModel(
Expand All @@ -88,9 +89,7 @@ public ElasticInferenceServiceCompletionModel(
serviceSettings,
elasticInferenceServiceComponents
);

this.uri = createUri();

}

@Override
Expand Down Expand Up @@ -120,5 +119,8 @@ private URI createUri() throws ElasticsearchStatusException {
}
}

// TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings).
@Override
public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,6 @@ public String getInferenceEntityId() {

@Override
public boolean isStreaming() {
return true;
return unifiedChatInput.stream();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws
thrownException.getMessage(),
is(
"Inference entity [model_id] does not support task type [chat_completion] "
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank, completion]. "
+ "The task type for the inference entity is chat_completion, "
+ "please use the _inference/chat_completion/model_id/_stream URL."
)
Expand Down Expand Up @@ -1371,7 +1371,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
var model = new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings("model_id"),
EmptyTaskSettings.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,63 @@ public void testOverridingModelId() {
assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id"));
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
}

public void testUriCreation() {
var url = "http://eis-gateway.com";
var model = createModel(url, "my-model-id");

var uri = model.uri();
assertThat(uri.toString(), is(url + "/api/v1/chat"));
}

public void testGetServiceSettings() {
var modelId = "test-model";
var model = createModel("http://eis-gateway.com", modelId);

var serviceSettings = model.getServiceSettings();
assertThat(serviceSettings.modelId(), is(modelId));
}

public void testGetTaskType() {
var model = createModel("http://eis-gateway.com", "my-model-id");
assertThat(model.getTaskType(), is(TaskType.COMPLETION));
}

public void testGetInferenceEntityId() {
var inferenceEntityId = "test-id";
var model = new ElasticInferenceServiceCompletionModel(
inferenceEntityId,
TaskType.COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings("my-model-id"),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.of("http://eis-gateway.com")
);

assertThat(model.getInferenceEntityId(), is(inferenceEntityId));
}

public void testModelWithOverriddenServiceSettings() {
var originalModel = createModel("http://eis-gateway.com", "original-model");
var newServiceSettings = new ElasticInferenceServiceCompletionServiceSettings("new-model");

var overriddenModel = new ElasticInferenceServiceCompletionModel(originalModel, newServiceSettings);

assertThat(overriddenModel.getServiceSettings().modelId(), is("new-model"));
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
assertThat(overriddenModel.uri().toString(), is(originalModel.uri().toString()));
}

public static ElasticInferenceServiceCompletionModel createModel(String url, String modelId) {
return new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings(modelId),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.of(url)
);
}
}
Loading