|
24 | 24 | from concurrent.futures import Future, ThreadPoolExecutor |
25 | 25 | from dataclasses import asdict, dataclass, field |
26 | 26 | from datetime import datetime |
| 27 | +from enum import Enum |
27 | 28 | from functools import wraps |
28 | 29 | from itertools import islice |
29 | 30 | from pathlib import Path |
@@ -1646,6 +1647,11 @@ def _inner(self, *args, **kwargs): |
1646 | 1647 | return _inner # type: ignore |
1647 | 1648 |
|
1648 | 1649 |
|
| 1650 | +class EndpointsScalingMetric(Enum): |
| 1651 | + pending_requests = "pendingRequests" |
| 1652 | + hardware_usage = "hardwareUsage" |
| 1653 | + |
| 1654 | + |
1649 | 1655 | class HfApi: |
1650 | 1656 | """ |
1651 | 1657 | Client to interact with the Hugging Face Hub via HTTP. |
@@ -7391,6 +7397,8 @@ def create_inference_endpoint( |
7391 | 7397 | account_id: Optional[str] = None, |
7392 | 7398 | min_replica: int = 1, |
7393 | 7399 | max_replica: int = 1, |
| 7400 | + scaling_metric: Optional[str | EndpointsScalingMetric] = None, |
| 7401 | + scaling_threshold: Optional[int] = None, |
7394 | 7402 | scale_to_zero_timeout: Optional[int] = None, |
7395 | 7403 | revision: Optional[str] = None, |
7396 | 7404 | task: Optional[str] = None, |
@@ -7431,6 +7439,12 @@ def create_inference_endpoint( |
7431 | 7439 | scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1. |
7432 | 7440 | max_replica (`int`, *optional*): |
7433 | 7441 | The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. |
| 7442 | + scaling_metric (`str`, *optional*): |
| 7443 | + The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Optional. |
| 7444 | + Defaults to None (meaning: let the hf endpoints service specify the threshold). |
| 7445 | + scaling_threshold (`int`, *optional*): |
| 7446 | + The scaling metric threshold to trigger a scale up. Optional. Ignored when scaling metric is not |
| 7447 | + provided. Defaults to None (meaning: let the hf endpoints service specify the threshold). |
7434 | 7448 | scale_to_zero_timeout (`int`, *optional*): |
7435 | 7449 | The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if |
7436 | 7450 | set to None and `min_replica` is not 0. Defaults to None. |
@@ -7582,6 +7596,9 @@ def create_inference_endpoint( |
7582 | 7596 | }, |
7583 | 7597 | "type": type, |
7584 | 7598 | } |
| 7599 | + if scaling_metric: |
| 7600 | + scaling_metric = EndpointsScalingMetric(scaling_metric) |
| 7601 | + payload["compute"]["scaling"]["measure"] = {scaling_metric.value: scaling_threshold} |
7585 | 7602 | if env: |
7586 | 7603 | payload["model"]["env"] = env |
7587 | 7604 | if secrets: |
@@ -7746,6 +7763,8 @@ def update_inference_endpoint( |
7746 | 7763 | min_replica: Optional[int] = None, |
7747 | 7764 | max_replica: Optional[int] = None, |
7748 | 7765 | scale_to_zero_timeout: Optional[int] = None, |
| 7766 | + scaling_metric: Optional[str | EndpointsScalingMetric] = None, |
| 7767 | + scaling_threshold: Optional[int] = None, |
7749 | 7768 | # Model update |
7750 | 7769 | repository: Optional[str] = None, |
7751 | 7770 | framework: Optional[str] = None, |
@@ -7786,7 +7805,12 @@ def update_inference_endpoint( |
7786 | 7805 | The maximum number of replicas (instances) to scale to for the Inference Endpoint. |
7787 | 7806 | scale_to_zero_timeout (`int`, *optional*): |
7788 | 7807 | The duration in minutes before an inactive endpoint is scaled to zero. |
7789 | | -
|
| 7808 | + scaling_metric (`str`, *optional*): |
| 7809 | + The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Optional. |
| 7810 | + Defaults to None. |
| 7811 | + scaling_threshold (`int`, *optional*): |
| 7812 | + The scaling metric threshold to trigger a scale up. Optional. Ignored when scaling metric is not |
| 7813 | + provided. Defaults to None. |
7790 | 7814 | repository (`str`, *optional*): |
7791 | 7815 | The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). |
7792 | 7816 | framework (`str`, *optional*): |
@@ -7840,6 +7864,9 @@ def update_inference_endpoint( |
7840 | 7864 | payload["compute"]["scaling"]["minReplica"] = min_replica |
7841 | 7865 | if scale_to_zero_timeout is not None: |
7842 | 7866 | payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout |
| 7867 | + if scaling_metric: |
| 7868 | + scaling_metric = EndpointsScalingMetric(scaling_metric) |
| 7869 | + payload["compute"]["scaling"]["measure"] = {scaling_metric.value: scaling_threshold} |
7843 | 7870 | if repository is not None: |
7844 | 7871 | payload["model"]["repository"] = repository |
7845 | 7872 | if framework is not None: |
|
0 commit comments