Skip to content
62 changes: 31 additions & 31 deletions docs/source/en/guides/inference.md

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down Expand Up @@ -1321,6 +1321,7 @@ def image_to_image(
>>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
>>> image.save("tiger.jpg")
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
Expand Down Expand Up @@ -2540,6 +2541,7 @@ def text_to_image(
... )
>>> image.save("astronaut.png")
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
Expand All @@ -2560,7 +2562,7 @@ def text_to_image(
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

def text_to_video(
Expand Down Expand Up @@ -2638,6 +2640,7 @@ def text_to_video(
>>> with open("cat.mp4", "wb") as file:
... file.write(video)
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
Expand Down
7 changes: 5 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class AsyncInferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down Expand Up @@ -1353,6 +1353,7 @@ async def image_to_image(
>>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
>>> image.save("tiger.jpg")
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
Expand Down Expand Up @@ -2584,6 +2585,7 @@ async def text_to_image(
... )
>>> image.save("astronaut.png")
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
Expand All @@ -2604,7 +2606,7 @@ async def text_to_image(
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

async def text_to_video(
Expand Down Expand Up @@ -2682,6 +2684,7 @@ async def text_to_video(
>>> with open("cat.mp4", "wb") as file:
... file.write(video)
```

"""
model_id = model or self.model
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)
Expand Down
13 changes: 13 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
from .wavespeed import (
WavespeedAIImageToImageTask,
WavespeedAIImageToVideoTask,
WavespeedAITextToImageTask,
WavespeedAITextToVideoTask,
)
from .zai_org import ZaiConversationalTask


Expand All @@ -68,6 +74,7 @@
"sambanova",
"scaleway",
"together",
"wavespeed",
"zai-org",
]

Expand Down Expand Up @@ -179,6 +186,12 @@
"conversational": TogetherConversationalTask(),
"text-generation": TogetherTextGenerationTask(),
},
"wavespeed": {
"text-to-image": WavespeedAITextToImageTask(),
"text-to-video": WavespeedAITextToVideoTask(),
"image-to-image": WavespeedAIImageToImageTask(),
"image-to-video": WavespeedAIImageToVideoTask(),
},
"zai-org": {
"conversational": ZaiConversationalTask(),
},
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"sambanova": {},
"scaleway": {},
"together": {},
"wavespeed": {},
"zai-org": {},
}

Expand Down
138 changes: 138 additions & 0 deletions src/huggingface_hub/inference/_providers/wavespeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import base64
import time
from abc import ABC
from typing import Any, Optional, Union
from urllib.parse import urlparse

from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session, hf_raise_for_status
from huggingface_hub.utils.logging import get_logger


logger = get_logger(__name__)

# Polling interval (in seconds)
_POLLING_INTERVAL = 0.5


class WavespeedAITask(TaskProviderHelper, ABC):
def __init__(self, task: str):
super().__init__(provider="wavespeed", base_url="https://api.wavespeed.ai", task=task)

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/api/v3/{mapped_model}"

def get_response(
self,
response: Union[bytes, dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)
data = response_dict.get("data", {})
result_path = data.get("urls", {}).get("get")

if not result_path:
raise ValueError("No result URL found in the response")
if request_params is None:
raise ValueError("A `RequestParameters` object should be provided to get responses with WaveSpeed AI.")

# Parse the request URL to determine base URL
parsed_url = urlparse(request_params.url)
# Add /wavespeed to base URL if going through HF router
if parsed_url.netloc == "router.huggingface.co":
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/wavespeed"
else:
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

# Extract path from result_path URL
if isinstance(result_path, str):
result_url_path = urlparse(result_path).path
else:
result_url_path = result_path

result_url = f"{base_url}{result_url_path}"

logger.info("Processing request, polling for results...")

# Poll until task is completed
while True:
time.sleep(_POLLING_INTERVAL)
result_response = get_session().get(result_url, headers=request_params.headers)
hf_raise_for_status(result_response)

result = result_response.json()
task_result = result.get("data", {})
status = task_result.get("status")

if status == "completed":
# Get content from the first output URL
if not task_result.get("outputs") or len(task_result["outputs"]) == 0:
raise ValueError("No output URL in completed response")

output_url = task_result["outputs"][0]
return get_session().get(output_url).content
elif status == "failed":
error_msg = task_result.get("error", "Task failed with no specific error message")
raise ValueError(f"WaveSpeed AI task failed: {error_msg}")
elif status in ["processing", "created"]:
continue
else:
raise ValueError(f"Unknown status: {status}")


class WavespeedAITextToImageTask(WavespeedAITask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload_as_dict(
self,
inputs: Any,
parameters: dict,
provider_mapping_info: InferenceProviderMapping,
) -> Optional[dict]:
return {"prompt": inputs, **filter_none(parameters)}


class WavespeedAITextToVideoTask(WavespeedAITextToImageTask):
def __init__(self):
WavespeedAITask.__init__(self, "text-to-video")


class WavespeedAIImageToImageTask(WavespeedAITask):
def __init__(self):
super().__init__("image-to-image")

def _prepare_payload_as_dict(
self,
inputs: Any,
parameters: dict,
provider_mapping_info: InferenceProviderMapping,
) -> Optional[dict]:
# Convert inputs to image (URL or base64)
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
image = inputs
elif isinstance(inputs, str):
# If input is a file path, read it first
with open(inputs, "rb") as f:
file_content = f.read()
image_b64 = base64.b64encode(file_content).decode("utf-8")
image = f"data:image/jpeg;base64,{image_b64}"
else:
# If input is binary data
image_b64 = base64.b64encode(inputs).decode("utf-8")
image = f"data:image/jpeg;base64,{image_b64}"

# Extract prompt from parameters if present
prompt = parameters.pop("prompt", None)
payload = {"image": image, **filter_none(parameters)}
if prompt is not None:
payload["prompt"] = prompt

return payload


class WavespeedAIImageToVideoTask(WavespeedAIImageToImageTask):
def __init__(self):
WavespeedAITask.__init__(self, "image-to-video")
Loading
Loading