Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import string
from collections import Counter
from datetime import datetime, timezone
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Set, Union

import backoff
import requests
Expand All @@ -19,6 +19,7 @@
from inference_cli.lib.env import API_BASE_URL
from inference_cli.lib.roboflow_cloud.batch_processing.entities import (
AggregationFormat,
CompilationDevice,
ComputeConfigurationV2,
GetJobMetadataResponse,
JobLog,
Expand All @@ -32,6 +33,7 @@
MachineType,
StagingBatchInputV1,
TaskStatus,
TRTCompilationJobV1,
WorkflowProcessingJobV1,
WorkflowsProcessingSpecificationV1,
)
Expand Down Expand Up @@ -247,7 +249,10 @@ def display_batch_job_details(job_id: str, api_key: str) -> None:
]
error_reports_str = "\n".join(error_reports)
if not error_reports_str:
error_reports_str = "All Good 😃"
if not job_metadata.error:
error_reports_str = "All Good 😃"
else:
error_reports_str = "Nothing found 🕵"
expected_tasks = stage.tasks_number
registered_tasks = len([t for t in job_tasks if t.progress is not None])
tasks_waiting_for_processing = expected_tasks - registered_tasks
Expand Down Expand Up @@ -477,6 +482,33 @@ def trigger_job_with_workflows_videos_processing(
return job_id


def trigger_trt_compilation_job(
model_id: str,
job_id: Optional[str],
compilation_devices: Optional[List[CompilationDevice]],
notifications_url: Optional[str],
api_key: str,
) -> str:
if not job_id:
job_id = f"trt-{_generate_random_string(length=12)}"
workspace = get_workspace(api_key=api_key)
if compilation_devices:
compilation_devices = list(set(compilation_devices))
compilation_specification = TRTCompilationJobV1(
type="trt-compilation-v1",
model_id=model_id,
compilation_devices=compilation_devices,
notifications_url=notifications_url,
)
create_batch_job(
workspace=workspace,
job_id=job_id,
job_configuration=compilation_specification,
api_key=api_key,
)
return job_id


@backoff.on_exception(
backoff.constant,
exception=RetryError,
Expand All @@ -486,7 +518,7 @@ def trigger_job_with_workflows_videos_processing(
def create_batch_job(
workspace: str,
job_id: str,
job_configuration: WorkflowProcessingJobV1,
job_configuration: Union[WorkflowProcessingJobV1, TRTCompilationJobV1],
api_key: str,
) -> None:
params = {"api_key": api_key}
Expand Down
65 changes: 65 additions & 0 deletions inference_cli/lib/roboflow_cloud/batch_processing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
restart_batch_job,
trigger_job_with_workflows_images_processing,
trigger_job_with_workflows_videos_processing,
trigger_trt_compilation_job,
)
from inference_cli.lib.roboflow_cloud.batch_processing.entities import (
AggregationFormat,
CompilationDevice,
LogSeverity,
MachineSize,
MachineType,
Expand Down Expand Up @@ -422,6 +424,69 @@ def process_videos_with_workflow(
raise typer.Exit(code=1)


@batch_processing_app.command(help="Trigger TRT compilation of a model")
def trt_compile(
model_id: Annotated[
str,
typer.Option("--model-id", "-m", help="Model to be compiled"),
],
compilation_devices: Annotated[
Optional[List[CompilationDevice]],
typer.Option("--device", "-d", help="Target compilation devices"),
],
job_id: Annotated[
Optional[str],
typer.Option(
"--job-id",
"-j",
help="Identifier of job (if not given - will be generated)",
),
] = None,
api_key: Annotated[
Optional[str],
typer.Option(
"--api-key",
"-a",
help="Roboflow API key for your workspace. If not given - env variable `ROBOFLOW_API_KEY` will be used",
),
] = None,
debug_mode: Annotated[
bool,
typer.Option(
"--debug-mode/--no-debug-mode",
help="Flag enabling errors stack traces to be displayed (helpful for debugging)",
),
] = False,
notifications_url: Annotated[
Optional[str],
typer.Option(
"--notifications-url",
help="URL of the Webhook to be used for job state notifications.",
),
] = None,
) -> None:
if api_key is None:
api_key = ROBOFLOW_API_KEY
try:
ensure_api_key_is_set(api_key=api_key)
job_id = trigger_trt_compilation_job(
model_id=model_id,
job_id=job_id,
compilation_devices=compilation_devices,
notifications_url=notifications_url,
api_key=api_key,
)
print(f"Triggered job with ID: {job_id}")
except KeyboardInterrupt:
print("Command interrupted.")
return
except Exception as error:
if debug_mode:
raise error
typer.echo(f"Command failed. Cause: {error}")
raise typer.Exit(code=1)


@batch_processing_app.command(help="Terminate running job")
def abort_job(
job_id: Annotated[
Expand Down
16 changes: 16 additions & 0 deletions inference_cli/lib/roboflow_cloud/batch_processing/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ class WorkflowProcessingJobV1(BaseModel):
)


class CompilationDevice(str, Enum):
NVIDIA_L4 = "nvidia-l4"
NVIDIA_T4 = "nvidia-t4"


class TRTCompilationJobV1(BaseModel):
type: Literal["trt-compilation-v1"]
model_id: str = Field(serialization_alias="modelId")
compilation_devices: List[CompilationDevice] = Field(
serialization_alias="compilationDevices"
)
notifications_url: Optional[str] = Field(
serialization_alias="notificationsURL", default=None
)


class LogSeverity(str, Enum):
INFO = "info"
ERROR = "error"
Expand Down