diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py index 26ab89a5..25c6f63b 100644 --- a/src/together/lib/cli/api/fine_tuning.py +++ b/src/together/lib/cli/api/fine_tuning.py @@ -426,9 +426,7 @@ def list(ctx: click.Context) -> None: "Price": f"""${ finetune_price_to_dollars(float(str(i.total_price))) }""", # convert to string for mypy typing - "Progress": generate_progress_bar( - i, datetime.now().astimezone(), use_rich=False - ), + "Progress": generate_progress_bar(i, datetime.now().astimezone(), use_rich=False), } ) table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True) @@ -449,9 +447,7 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None: response.events = None rprint(JSON.from_data(response.model_json_schema())) - progress_text = generate_progress_bar( - response, datetime.now().astimezone(), use_rich=True - ) + progress_text = generate_progress_bar(response, datetime.now().astimezone(), use_rich=True) prefix = f"Status: [bold]{response.status}[/bold]," rprint(f"{prefix} {progress_text}") diff --git a/src/together/lib/cli/api/utils.py b/src/together/lib/cli/api/utils.py index eabb4d07..3242b5c0 100644 --- a/src/together/lib/cli/api/utils.py +++ b/src/together/lib/cli/api/utils.py @@ -28,9 +28,7 @@ def convert( # pyright: ignore[reportImplicitOverride] return int(value) except ValueError: self.fail( - _("{value!r} is not a valid {number_type}.").format( - value=value, number_type=self.name - ), + _("{value!r} is not a valid {number_type}.").format(value=value, number_type=self.name), param, ctx, ) @@ -39,7 +37,7 @@ def convert( # pyright: ignore[reportImplicitOverride] class BooleanWithAutoParamType(click.ParamType): name = "boolean_or_auto" - def convert( # pyright: ignore[reportImplicitOverride] + def convert( # pyright: ignore[reportImplicitOverride] self, value: str, param: click.Parameter | None, ctx: click.Context | None ) -> bool | Literal["auto"] | None: if value == "auto": @@ -48,9 +46,7 @@ def convert( # pyright: ignore[reportImplicitOverride] return bool(value) except ValueError: self.fail( - _("{value!r} is not a valid {type}.").format( - value=value, type=self.name - ), + _("{value!r} is not a valid {type}.").format(value=value, type=self.name), param, ctx, ) @@ -119,17 +115,13 @@ def generate_progress_bar( return progress elapsed_time = (current_time - update_at).total_seconds() - ratio_filled = min( - elapsed_time / finetune_job.progress.seconds_remaining, 1.0 - ) + ratio_filled = min(elapsed_time / finetune_job.progress.seconds_remaining, 1.0) percentage = ratio_filled * 100 filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH) bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled) time_left = "N/A" if finetune_job.progress.seconds_remaining > elapsed_time: - time_left = _human_readable_time( - finetune_job.progress.seconds_remaining - elapsed_time - ) + time_left = _human_readable_time(finetune_job.progress.seconds_remaining - elapsed_time) time_text = f"{time_left} left" progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]" diff --git a/src/together/lib/resources/fine_tuning.py b/src/together/lib/resources/fine_tuning.py index de96dcdf..9b44749e 100644 --- a/src/together/lib/resources/fine_tuning.py +++ b/src/together/lib/resources/fine_tuning.py @@ -238,7 +238,10 @@ def create_finetune_request( return finetune_request, training_type_pe, training_method_pe -def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]: + +def create_price_estimation_params( + finetune_request: FinetuneRequest, +) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]: training_type_cls: pe_params.TrainingType if isinstance(finetune_request.training_type, FullTrainingType): training_type_cls = pe_params.TrainingTypeFullTrainingType( @@ -275,6 +278,7 @@ def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[p return training_type_cls, training_method_cls + def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits: """ Requests training limits for a specific model diff --git a/src/together/lib/types/fine_tuning.py b/src/together/lib/types/fine_tuning.py index d3888857..96bfd5e3 100644 --- a/src/together/lib/types/fine_tuning.py +++ b/src/together/lib/types/fine_tuning.py @@ -189,6 +189,7 @@ class TrainingMethodUnknown(BaseModel): method: str + TrainingMethod: TypeAlias = Union[ TrainingMethodSFT, TrainingMethodDPO, @@ -249,6 +250,7 @@ class EmptyLRScheduler(BaseModel): lr_scheduler_type: Literal[""] lr_scheduler_args: None = None + class UnknownLRScheduler(BaseModel): """ Unknown learning rate scheduler diff --git a/src/together/resources/fine_tuning.py b/src/together/resources/fine_tuning.py index e7130bf8..841356b8 100644 --- a/src/together/resources/fine_tuning.py +++ b/src/together/resources/fine_tuning.py @@ -53,6 +53,7 @@ "Proceed at your own risk." ) + class FineTuningResource(SyncAPIResource): @cached_property def with_raw_response(self) -> FineTuningResourceWithRawResponse: @@ -232,7 +233,6 @@ def create( hf_output_repo_name=hf_output_repo_name, ) - price_estimation_result = self.estimate_price( training_file=training_file, from_checkpoint=from_checkpoint or Omit(), @@ -244,7 +244,6 @@ def create( training_method=training_method_cls, ) - if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", @@ -254,7 +253,7 @@ def create( rprint( "[red]" + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format( - price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable] + price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable] ) + "[/red]", ) @@ -764,7 +763,6 @@ async def create( hf_output_repo_name=hf_output_repo_name, ) - price_estimation_result = await self.estimate_price( training_file=training_file, from_checkpoint=from_checkpoint or Omit(), @@ -776,7 +774,6 @@ async def create( training_method=training_method_cls, ) - if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", @@ -786,7 +783,7 @@ async def create( rprint( "[red]" + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format( - price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable] + price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable] ) + "[/red]", ) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 230fe34f..9d10dbe1 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -91,8 +91,7 @@ def test_progress_bar_at_midpoint(self): # 60 seconds elapsed / 60 seconds remaining = 1.0 ratio = 100% progress # 1.0 * 40 = 40 filled bars assert ( - result - == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" ) def test_progress_bar_near_completion(self): @@ -107,8 +106,7 @@ def test_progress_bar_near_completion(self): # 300 seconds elapsed / 30 seconds remaining = 10.0 ratio = 1000% progress # 10.0 * 40 = 400, ceil(400) = 400, but width is 40 so all filled assert ( - result - == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" ) def test_progress_bar_contains_rich_formatting(self): @@ -123,8 +121,7 @@ def test_progress_bar_contains_rich_formatting(self): # 30 seconds elapsed / 60 seconds remaining = 0.5 ratio = 50% progress # 0.5 * 40 = 20 filled bars assert ( - result - == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" ) @@ -140,9 +137,7 @@ def test_rich_formatting_removed_when_use_rich_false(self): result = generate_progress_bar(finetune_job, current_time, use_rich=False) - assert ( - result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" - ) + assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" def test_rich_formatting_preserved_when_use_rich_true(self): """Test that rich formatting tags are preserved when use_rich=True.""" @@ -154,16 +149,13 @@ def test_rich_formatting_preserved_when_use_rich_true(self): result = generate_progress_bar(finetune_job, current_time, use_rich=True) assert ( - result - == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" ) def test_completed_status_formatting_removed(self): """Test that completed status formatting is removed when use_rich=False.""" current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) - finetune_job = create_finetune_response( - status=FinetuneJobStatus.STATUS_COMPLETED, progress=None - ) + finetune_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None) result = generate_progress_bar(finetune_job, current_time, use_rich=False) @@ -187,9 +179,7 @@ def test_rich_formatting_removed_at_completion(self): result = generate_progress_bar(finetune_job, current_time, use_rich=False) - assert ( - result == "Progress: ████████████████████████████████████████ 100% N/A left" - ) + assert result == "Progress: ████████████████████████████████████████ 100% N/A left" def test_default_behavior_strips_formatting(self): """Test that rich formatting is removed by default (use_rich not specified).""" @@ -200,9 +190,7 @@ def test_default_behavior_strips_formatting(self): result = generate_progress_bar(finetune_job, current_time) - assert ( - result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" - ) + assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" def test_content_consistency_between_modes(self): """Test that use_rich=True and use_rich=False have same content, just different formatting.""" @@ -213,12 +201,8 @@ def test_content_consistency_between_modes(self): progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) ) - result_with_rich = generate_progress_bar( - finetune_job, current_time, use_rich=True - ) - result_without_rich = generate_progress_bar( - finetune_job, current_time, use_rich=False - ) + result_with_rich = generate_progress_bar(finetune_job, current_time, use_rich=True) + result_without_rich = generate_progress_bar(finetune_job, current_time, use_rich=False) stripped_rich = re.sub(r"\[/?[^\]]+\]", "", result_with_rich) assert stripped_rich == result_without_rich @@ -228,19 +212,13 @@ def test_all_rich_tag_types_removed(self): current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) # Test with completed status (has [bold green] tags) - completed_job = create_finetune_response( - status=FinetuneJobStatus.STATUS_COMPLETED, progress=None - ) - result_completed = generate_progress_bar( - completed_job, current_time, use_rich=False - ) + completed_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None) + result_completed = generate_progress_bar(completed_job, current_time, use_rich=False) assert result_completed == "Progress: completed" # Test with unavailable status (has [bold red] tags) unavailable_job = create_finetune_response(progress=None) - result_unavailable = generate_progress_bar( - unavailable_job, current_time, use_rich=False - ) + result_unavailable = generate_progress_bar(unavailable_job, current_time, use_rich=False) assert result_unavailable == "Progress: unavailable" @pytest.mark.parametrize( @@ -265,9 +243,7 @@ def test_rich_parameter_with_different_statuses( current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) # Test completed status - completed_job = create_finetune_response( - status=FinetuneJobStatus.STATUS_COMPLETED, progress=None - ) + completed_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None) result = generate_progress_bar(completed_job, current_time, use_rich=use_rich) assert result == expected_completed @@ -286,10 +262,7 @@ def test_progress_percentage_1_percent(self): ) result = generate_progress_bar(finetune_job, current_time, use_rich=False) - assert ( - result - == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 1% 16min 30s left" - ) + assert result == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 1% 16min 30s left" def test_progress_percentage_75_percent(self): """Test progress bar at 75% completion.""" @@ -299,9 +272,7 @@ def test_progress_percentage_75_percent(self): ) result = generate_progress_bar(finetune_job, current_time, use_rich=False) - assert ( - result == "Progress: ██████████████████████████████░░░░░░░░░░ 75% 15s left" - ) + assert result == "Progress: ██████████████████████████████░░░░░░░░░░ 75% 15s left" class TestGenerateProgressBarCornerCases: @@ -328,17 +299,14 @@ def test_very_small_remaining_time(self): result = generate_progress_bar(finetune_job, current_time, use_rich=True) assert ( - result - == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" ) def test_very_large_remaining_time(self): """Test with very large remaining time (hours).""" current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) finetune_job = create_finetune_response( - progress=FinetuneProgress( - estimate_available=True, seconds_remaining=36000.0 - ) + progress=FinetuneProgress(estimate_available=True, seconds_remaining=36000.0) ) result = generate_progress_bar(finetune_job, current_time, use_rich=True) @@ -358,8 +326,7 @@ def test_job_exceeding_estimate(self): result = generate_progress_bar(finetune_job, current_time, use_rich=True) assert ( - result - == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" ) def test_timezone_aware_datetime(self): @@ -373,8 +340,7 @@ def test_timezone_aware_datetime(self): result = generate_progress_bar(finetune_job, current_time, use_rich=True) assert ( - result - == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" ) def test_estimate_unavailable_flag(self): @@ -409,6 +375,4 @@ def test_unicode_progress_bars_preserved(self): result = generate_progress_bar(finetune_job, current_time, use_rich=False) - assert ( - result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" - ) + assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" diff --git a/uv.lock b/uv.lock index 0143d6bf..6fb90734 100644 --- a/uv.lock +++ b/uv.lock @@ -1963,7 +1963,7 @@ wheels = [ [[package]] name = "together" -version = "2.0.0a10" +version = "2.0.0a11" source = { editable = "." } dependencies = [ { name = "anyio" },