Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,16 @@ export async function getInferenceProviderMapping(

export async function resolveProvider(
provider?: InferenceProviderOrPolicy,
modelId?: string
modelId?: string,
endpointUrl?: string
): Promise<InferenceProvider> {
if (endpointUrl) {
if (provider) {
throw new Error("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
}
if (!provider) {
console.log(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export async function audioClassification(
args: AudioClassificationArgs,
options?: Options
): Promise<AudioClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "audio-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export async function automaticSpeechRecognition(
args: AutomaticSpeechRecognitionArgs,
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
const payload = await buildPayload(args);
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
* Recommended model: espnet/kan-bayashi_ljspeech_vits
*/
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-to-speech");
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export async function request<T>(
console.warn(
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, options?.task);
const result = await innerRequest<T>(args, providerHelper, options);
return result.data;
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export async function* streamingRequest<T>(
console.warn(
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, options?.task);
yield* innerStreamingRequest(args, providerHelper, options);
}
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export async function imageClassification(
args: ImageClassificationArgs,
options?: Options
): Promise<ImageClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export async function imageSegmentation(
args: ImageSegmentationArgs,
options?: Options
): Promise<ImageSegmentationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-segmentation");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
* Recommended model: lllyasviel/sd-controlnet-depth
*/
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-to-image");
let reqArgs: RequestArgs;
if (!args.parameters) {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
* This task reads some image input and outputs the text caption.
*/
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-to-text");
const payload = preparePayload(args);
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/objectDetection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
* Recommended model: facebook/detr-resnet-50
*/
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "object-detection");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export async function textToImage(
options?: TextToImageOptions & { outputType?: undefined | "blob" }
): Promise<Blob>;
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-to-image");
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
export type TextToVideoOutput = Blob;

export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-to-video");
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export async function zeroShotImageClassification(
args: ZeroShotImageClassificationArgs,
options?: Options
): Promise<ZeroShotImageClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
const payload = await preparePayload(args);
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export async function documentQuestionAnswering(
args: DocumentQuestionAnsweringArgs,
options?: Options
): Promise<DocumentQuestionAnsweringOutput[number]> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "document-question-answering");
const reqArgs: RequestArgs = {
...args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export async function visualQuestionAnswering(
args: VisualQuestionAnsweringArgs,
options?: Options
): Promise<VisualQuestionAnsweringOutput[number]> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "visual-question-answering");
const reqArgs: RequestArgs = {
...args,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export async function chatCompletion(
args: BaseArgs & ChatCompletionInput,
options?: Options
): Promise<ChatCompletionOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "conversational");
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/chatCompletionStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export async function* chatCompletionStream(
args: BaseArgs & ChatCompletionInput,
options?: Options
): AsyncGenerator<ChatCompletionStreamOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "conversational");
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/featureExtraction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export async function featureExtraction(
args: FeatureExtractionArgs,
options?: Options
): Promise<FeatureExtractionOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "feature-extraction");
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/fillMask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
*/
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "fill-mask");
const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/questionAnswering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export async function questionAnswering(
args: QuestionAnsweringArgs,
options?: Options
): Promise<QuestionAnsweringOutput[number]> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "question-answering");
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
args,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function sentenceSimilarity(
args: SentenceSimilarityArgs,
options?: Options
): Promise<SentenceSimilarityOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "sentence-similarity");
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
*/
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "summarization");
const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
...options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function tableQuestionAnswering(
args: TableQuestionAnsweringArgs,
options?: Options
): Promise<TableQuestionAnsweringOutput[number]> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "table-question-answering");
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
args,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/textClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function textClassification(
args: TextClassificationArgs,
options?: Options
): Promise<TextClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-classification");
const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export async function textGeneration(
args: BaseArgs & TextGenerationInput,
options?: Options
): Promise<TextGenerationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-generation");
const { data: response } = await innerRequest<
HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export async function* textGenerationStream(
args: BaseArgs & TextGenerationInput,
options?: Options
): AsyncGenerator<TextGenerationStreamOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "text-generation");
yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
...options,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/tokenClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function tokenClassification(
args: TokenClassificationArgs,
options?: Options
): Promise<TokenClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "token-classification");
const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
args,
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/translation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export type TranslationArgs = BaseArgs & TranslationInput;
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
*/
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "translation");
const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
...options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function zeroShotClassification(
args: ZeroShotClassificationArgs,
options?: Options
): Promise<ZeroShotClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "zero-shot-classification");
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export async function tabularClassification(
args: TabularClassificationArgs,
options?: Options
): Promise<TabularClassificationOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "tabular-classification");
const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
...options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export async function tabularRegression(
args: TabularRegressionArgs,
options?: Options
): Promise<TabularRegressionOutput> {
const provider = await resolveProvider(args.provider, args.model);
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "tabular-regression");
const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
...options,
Expand Down
20 changes: 18 additions & 2 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,25 @@ describe.skip("InferenceClient", () => {
it("endpoint - makes request to specified endpoint", async () => {
const ep = hf.endpoint("https://router.huggingface.co/hf-inference/models/openai-community/gpt2");
const { generated_text } = await ep.textGeneration({
inputs: "one plus two equals",
inputs: "one plus one is equal to",
parameters: {
max_new_tokens: 1,
},
});
assert.include(generated_text, "two");
});

it("endpoint - makes request to specified endpoint - alternative syntax", async () => {
const epClient = new InferenceClient(env.HF_TOKEN, {
endpointUrl: "https://router.huggingface.co/hf-inference/models/openai-community/gpt2",
});
const { generated_text } = await epClient.textGeneration({
inputs: "one plus one is equal to",
parameters: {
max_new_tokens: 1,
},
});
assert.include(generated_text, "three");
assert.include(generated_text, "two");
});

it("chatCompletion modelId - OpenAI Specs", async () => {
Expand Down