Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,7 @@ def list(ctx: click.Context) -> None:
"Price": f"""${
finetune_price_to_dollars(float(str(i.total_price)))
}""", # convert to string for mypy typing
"Progress": generate_progress_bar(
i, datetime.now().astimezone(), use_rich=False
),
"Progress": generate_progress_bar(i, datetime.now().astimezone(), use_rich=False),
}
)
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
Expand All @@ -449,9 +447,7 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
response.events = None

rprint(JSON.from_data(response.model_json_schema()))
progress_text = generate_progress_bar(
response, datetime.now().astimezone(), use_rich=True
)
progress_text = generate_progress_bar(response, datetime.now().astimezone(), use_rich=True)
prefix = f"Status: [bold]{response.status}[/bold],"
rprint(f"{prefix} {progress_text}")

Expand Down
18 changes: 5 additions & 13 deletions src/together/lib/cli/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def convert( # pyright: ignore[reportImplicitOverride]
return int(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {number_type}.").format(
value=value, number_type=self.name
),
_("{value!r} is not a valid {number_type}.").format(value=value, number_type=self.name),
param,
ctx,
)
Expand All @@ -39,7 +37,7 @@ def convert( # pyright: ignore[reportImplicitOverride]
class BooleanWithAutoParamType(click.ParamType):
name = "boolean_or_auto"

def convert( # pyright: ignore[reportImplicitOverride]
def convert( # pyright: ignore[reportImplicitOverride]
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> bool | Literal["auto"] | None:
if value == "auto":
Expand All @@ -48,9 +46,7 @@ def convert( # pyright: ignore[reportImplicitOverride]
return bool(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {type}.").format(
value=value, type=self.name
),
_("{value!r} is not a valid {type}.").format(value=value, type=self.name),
param,
ctx,
)
Expand Down Expand Up @@ -119,17 +115,13 @@ def generate_progress_bar(
return progress

elapsed_time = (current_time - update_at).total_seconds()
ratio_filled = min(
elapsed_time / finetune_job.progress.seconds_remaining, 1.0
)
ratio_filled = min(elapsed_time / finetune_job.progress.seconds_remaining, 1.0)
percentage = ratio_filled * 100
filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH)
bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled)
time_left = "N/A"
if finetune_job.progress.seconds_remaining > elapsed_time:
time_left = _human_readable_time(
finetune_job.progress.seconds_remaining - elapsed_time
)
time_left = _human_readable_time(finetune_job.progress.seconds_remaining - elapsed_time)
time_text = f"{time_left} left"
progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]"

Expand Down
6 changes: 5 additions & 1 deletion src/together/lib/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ def create_finetune_request(

return finetune_request, training_type_pe, training_method_pe

def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]:

def create_price_estimation_params(
finetune_request: FinetuneRequest,
) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]:
training_type_cls: pe_params.TrainingType
if isinstance(finetune_request.training_type, FullTrainingType):
training_type_cls = pe_params.TrainingTypeFullTrainingType(
Expand Down Expand Up @@ -275,6 +278,7 @@ def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[p

return training_type_cls, training_method_cls


def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model
Expand Down
2 changes: 2 additions & 0 deletions src/together/lib/types/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TrainingMethodUnknown(BaseModel):

method: str


TrainingMethod: TypeAlias = Union[
TrainingMethodSFT,
TrainingMethodDPO,
Expand Down Expand Up @@ -249,6 +250,7 @@ class EmptyLRScheduler(BaseModel):
lr_scheduler_type: Literal[""]
lr_scheduler_args: None = None


class UnknownLRScheduler(BaseModel):
"""
Unknown learning rate scheduler
Expand Down
9 changes: 3 additions & 6 deletions src/together/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"Proceed at your own risk."
)


class FineTuningResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> FineTuningResourceWithRawResponse:
Expand Down Expand Up @@ -232,7 +233,6 @@ def create(
hf_output_repo_name=hf_output_repo_name,
)


price_estimation_result = self.estimate_price(
training_file=training_file,
from_checkpoint=from_checkpoint or Omit(),
Expand All @@ -244,7 +244,6 @@ def create(
training_method=training_method_cls,
)


if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
Expand All @@ -254,7 +253,7 @@ def create(
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
)
+ "[/red]",
)
Expand Down Expand Up @@ -764,7 +763,6 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)


price_estimation_result = await self.estimate_price(
training_file=training_file,
from_checkpoint=from_checkpoint or Omit(),
Expand All @@ -776,7 +774,6 @@ async def create(
training_method=training_method_cls,
)


if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
Expand All @@ -786,7 +783,7 @@ async def create(
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
)
+ "[/red]",
)
Expand Down
78 changes: 21 additions & 57 deletions tests/test_cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def test_progress_bar_at_midpoint(self):
# 60 seconds elapsed / 60 seconds remaining = 1.0 ratio = 100% progress
# 1.0 * 40 = 40 filled bars
assert (
result
== "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
)

def test_progress_bar_near_completion(self):
Expand All @@ -107,8 +106,7 @@ def test_progress_bar_near_completion(self):
# 300 seconds elapsed / 30 seconds remaining = 10.0 ratio = 1000% progress
# 10.0 * 40 = 400, ceil(400) = 400, but width is 40 so all filled
assert (
result
== "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
)

def test_progress_bar_contains_rich_formatting(self):
Expand All @@ -123,8 +121,7 @@ def test_progress_bar_contains_rich_formatting(self):
# 30 seconds elapsed / 60 seconds remaining = 0.5 ratio = 50% progress
# 0.5 * 40 = 20 filled bars
assert (
result
== "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
)


Expand All @@ -140,9 +137,7 @@ def test_rich_formatting_removed_when_use_rich_false(self):

result = generate_progress_bar(finetune_job, current_time, use_rich=False)

assert (
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"
)
assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"

def test_rich_formatting_preserved_when_use_rich_true(self):
"""Test that rich formatting tags are preserved when use_rich=True."""
Expand All @@ -154,16 +149,13 @@ def test_rich_formatting_preserved_when_use_rich_true(self):
result = generate_progress_bar(finetune_job, current_time, use_rich=True)

assert (
result
== "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
)

def test_completed_status_formatting_removed(self):
"""Test that completed status formatting is removed when use_rich=False."""
current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc)
finetune_job = create_finetune_response(
status=FinetuneJobStatus.STATUS_COMPLETED, progress=None
)
finetune_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None)

result = generate_progress_bar(finetune_job, current_time, use_rich=False)

Expand All @@ -187,9 +179,7 @@ def test_rich_formatting_removed_at_completion(self):

result = generate_progress_bar(finetune_job, current_time, use_rich=False)

assert (
result == "Progress: ████████████████████████████████████████ 100% N/A left"
)
assert result == "Progress: ████████████████████████████████████████ 100% N/A left"

def test_default_behavior_strips_formatting(self):
"""Test that rich formatting is removed by default (use_rich not specified)."""
Expand All @@ -200,9 +190,7 @@ def test_default_behavior_strips_formatting(self):

result = generate_progress_bar(finetune_job, current_time)

assert (
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"
)
assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"

def test_content_consistency_between_modes(self):
"""Test that use_rich=True and use_rich=False have same content, just different formatting."""
Expand All @@ -213,12 +201,8 @@ def test_content_consistency_between_modes(self):
progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0)
)

result_with_rich = generate_progress_bar(
finetune_job, current_time, use_rich=True
)
result_without_rich = generate_progress_bar(
finetune_job, current_time, use_rich=False
)
result_with_rich = generate_progress_bar(finetune_job, current_time, use_rich=True)
result_without_rich = generate_progress_bar(finetune_job, current_time, use_rich=False)

stripped_rich = re.sub(r"\[/?[^\]]+\]", "", result_with_rich)
assert stripped_rich == result_without_rich
Expand All @@ -228,19 +212,13 @@ def test_all_rich_tag_types_removed(self):
current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc)

# Test with completed status (has [bold green] tags)
completed_job = create_finetune_response(
status=FinetuneJobStatus.STATUS_COMPLETED, progress=None
)
result_completed = generate_progress_bar(
completed_job, current_time, use_rich=False
)
completed_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None)
result_completed = generate_progress_bar(completed_job, current_time, use_rich=False)
assert result_completed == "Progress: completed"

# Test with unavailable status (has [bold red] tags)
unavailable_job = create_finetune_response(progress=None)
result_unavailable = generate_progress_bar(
unavailable_job, current_time, use_rich=False
)
result_unavailable = generate_progress_bar(unavailable_job, current_time, use_rich=False)
assert result_unavailable == "Progress: unavailable"

@pytest.mark.parametrize(
Expand All @@ -265,9 +243,7 @@ def test_rich_parameter_with_different_statuses(
current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc)

# Test completed status
completed_job = create_finetune_response(
status=FinetuneJobStatus.STATUS_COMPLETED, progress=None
)
completed_job = create_finetune_response(status=FinetuneJobStatus.STATUS_COMPLETED, progress=None)
result = generate_progress_bar(completed_job, current_time, use_rich=use_rich)
assert result == expected_completed

Expand All @@ -286,10 +262,7 @@ def test_progress_percentage_1_percent(self):
)

result = generate_progress_bar(finetune_job, current_time, use_rich=False)
assert (
result
== "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 1% 16min 30s left"
)
assert result == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 1% 16min 30s left"

def test_progress_percentage_75_percent(self):
"""Test progress bar at 75% completion."""
Expand All @@ -299,9 +272,7 @@ def test_progress_percentage_75_percent(self):
)

result = generate_progress_bar(finetune_job, current_time, use_rich=False)
assert (
result == "Progress: ██████████████████████████████░░░░░░░░░░ 75% 15s left"
)
assert result == "Progress: ██████████████████████████████░░░░░░░░░░ 75% 15s left"


class TestGenerateProgressBarCornerCases:
Expand All @@ -328,17 +299,14 @@ def test_very_small_remaining_time(self):
result = generate_progress_bar(finetune_job, current_time, use_rich=True)

assert (
result
== "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
)

def test_very_large_remaining_time(self):
"""Test with very large remaining time (hours)."""
current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc)
finetune_job = create_finetune_response(
progress=FinetuneProgress(
estimate_available=True, seconds_remaining=36000.0
)
progress=FinetuneProgress(estimate_available=True, seconds_remaining=36000.0)
)

result = generate_progress_bar(finetune_job, current_time, use_rich=True)
Expand All @@ -358,8 +326,7 @@ def test_job_exceeding_estimate(self):
result = generate_progress_bar(finetune_job, current_time, use_rich=True)

assert (
result
== "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
result == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]"
)

def test_timezone_aware_datetime(self):
Expand All @@ -373,8 +340,7 @@ def test_timezone_aware_datetime(self):
result = generate_progress_bar(finetune_job, current_time, use_rich=True)

assert (
result
== "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]"
)

def test_estimate_unavailable_flag(self):
Expand Down Expand Up @@ -409,6 +375,4 @@ def test_unicode_progress_bars_preserved(self):

result = generate_progress_bar(finetune_job, current_time, use_rich=False)

assert (
result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"
)
assert result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left"
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.