Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137677.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137677
summary: "[Inference] Implementing the completion task type on EIS"
area: "Machine Learning, Inference"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
Expand Down Expand Up @@ -98,6 +99,7 @@ public class ElasticInferenceService extends SenderService {
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand Down Expand Up @@ -129,6 +131,7 @@ public class ElasticInferenceService extends SenderService {
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand Down Expand Up @@ -188,7 +191,7 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
return Map.of(
DEFAULT_CHAT_COMPLETION_MODEL_ID_V1,
new DefaultModelConfig(
new ElasticInferenceServiceCompletionModel(
new ElasticInferenceServiceChatCompletionModel(
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
TaskType.CHAT_COMPLETION,
NAME,
Expand Down Expand Up @@ -303,7 +306,7 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
if (model instanceof ElasticInferenceServiceChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
Expand All @@ -313,8 +316,8 @@ protected void doUnifiedCompletionInfer(
// generating a different "traceparent" as every task and every REST request creates a new span).
var currentTraceInfo = getCurrentTraceInfo();

var completionModel = (ElasticInferenceServiceCompletionModel) model;
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
var completionModel = (ElasticInferenceServiceChatCompletionModel) model;
var overriddenModel = ElasticInferenceServiceChatCompletionModel.of(completionModel, inputs.getRequest());
var errorMessage = constructFailedToSendRequestMessage(
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
);
Expand Down Expand Up @@ -506,7 +509,17 @@ private static ElasticInferenceServiceModel createModel(
context,
chunkingSettings
);
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
case CHAT_COMPLETION -> new ElasticInferenceServiceChatCompletionModel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we switch to using a single model class we should be able to follow what OpenAI does here: and have case CHAT_COMPLETION, COMPLETION -> ...

inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
elasticInferenceServiceComponents,
context
);
case COMPLETION -> new ElasticInferenceServiceCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -32,7 +32,7 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
private static final ResponseHandler HANDLER = createCompletionHandler();

public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
ElasticInferenceServiceCompletionModel model,
ElasticInferenceServiceChatCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
Expand All @@ -43,11 +43,11 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
);
}

private final ElasticInferenceServiceCompletionModel model;
private final ElasticInferenceServiceChatCompletionModel model;
private final TraceContext traceContext;

private ElasticInferenceServiceUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
ElasticInferenceServiceChatCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
Expand All @@ -20,13 +21,17 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceCompletionRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
Expand All @@ -45,6 +50,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
);

static final ResponseHandler COMPLETION_HANDLER = new ElasticInferenceServiceResponseHandler(
"elastic completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private final Sender sender;

private final ServiceComponents serviceComponents;
Expand Down Expand Up @@ -108,4 +118,25 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings");
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings) {
var threadPool = serviceComponents.threadPool();

var manager = new GenericRequestManager<>(
threadPool,
model,
COMPLETION_HANDLER,
(chatCompletionInput) -> new ElasticInferenceServiceCompletionRequest(
chatCompletionInput.getInputs(),
model,
traceContext,
extractRequestMetadataFromThreadContext(threadPool.getThreadContext())
),
ChatCompletionInput.class
);

var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s completion", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
return new SenderExecutableAction(sender, manager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
package org.elasticsearch.xpack.inference.services.elastic.action;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.Map;

public interface ElasticInferenceServiceActionVisitor {

ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);

ExecutableAction create(ElasticInferenceServiceRerankModel model);

ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model);

ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic.completion;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

public class ElasticInferenceServiceChatCompletionModel extends ElasticInferenceServiceModel {

public static ElasticInferenceServiceChatCompletionModel of(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, this is the only method that makes this class different from the ElasticInferenceServiceCompletionModel. How about we just use ElasticInferenceServiceCompletionModel and add this method there?

ElasticInferenceServiceChatCompletionModel model,
UnifiedCompletionRequest request
) {
var originalModelServiceSettings = model.getServiceSettings();
var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings(
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId())
);

return new ElasticInferenceServiceChatCompletionModel(model, overriddenServiceSettings);
}

private final URI uri;

public ElasticInferenceServiceChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
ElasticInferenceServiceCompletionServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
);
}

public ElasticInferenceServiceChatCompletionModel(
ElasticInferenceServiceChatCompletionModel model,
ElasticInferenceServiceCompletionServiceSettings serviceSettings
) {
super(model, serviceSettings);
this.uri = createUri();

}

public ElasticInferenceServiceChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
ElasticInferenceServiceCompletionServiceSettings serviceSettings,
@Nullable TaskSettings taskSettings,
@Nullable SecretSettings secretSettings,
ElasticInferenceServiceComponents elasticInferenceServiceComponents
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secretSettings),
serviceSettings,
elasticInferenceServiceComponents
);

this.uri = createUri();

}

@Override
public ElasticInferenceServiceCompletionServiceSettings getServiceSettings() {
return (ElasticInferenceServiceCompletionServiceSettings) super.getServiceSettings();
}

public URI uri() {
return uri;
}

private URI createUri() throws ElasticsearchStatusException {
try {
// TODO, consider transforming the base URL into a URI for better error handling.
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat");
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Failed to create URI for service ["
+ this.getConfigurations().getService()
+ "] with taskType ["
+ this.getTaskType()
+ "]: "
+ e.getMessage(),
RestStatus.BAD_REQUEST,
e
);
}
}

// TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings).
}
Loading