Skip to content

Commit e465a0e

Browse files
feat: add price estimation (#400)
* add price estimation * comments from the review * code style * add combined * address None and comments from review * Update test_finetune_resources.py * Update test_finetune_resources.py * fix from_checkpoint and hf_model * comments from code review * Apply suggestions from code review Co-authored-by: Max Ryabinin <[email protected]> --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent bcb672b commit e465a0e

File tree

5 files changed

+389
-5
lines changed

5 files changed

+389
-5
lines changed

src/together/cli/api/finetune.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
DownloadCheckpointType,
1818
FinetuneEventType,
1919
FinetuneTrainingLimits,
20+
FullTrainingType,
21+
LoRATrainingType,
2022
)
2123
from together.utils import (
2224
finetune_price_to_dollars,
@@ -29,13 +31,21 @@
2931

3032
_CONFIRMATION_MESSAGE = (
3133
"You are about to create a fine-tuning job. "
32-
"The cost of your job will be determined by the model size, the number of tokens "
34+
"The estimated price of this job is {price}. "
35+
"The actual cost of your job will be determined by the model size, the number of tokens "
3336
"in the training file, the number of tokens in the validation file, the number of epochs, and "
34-
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
37+
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about fine-tuning pricing.\n"
38+
"{warning}"
3539
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
3640
"Do you want to proceed?"
3741
)
3842

43+
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
44+
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
45+
"It will likely get cancelled due to insufficient funds. "
46+
"Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
47+
)
48+
3949

4050
class DownloadCheckpointTypeChoice(click.Choice):
4151
def __init__(self) -> None:
@@ -357,12 +367,36 @@ def create(
357367
"You have specified a number of evaluation loops but no validation file."
358368
)
359369

360-
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
370+
finetune_price_estimation_result = client.fine_tuning.estimate_price(
371+
training_file=training_file,
372+
validation_file=validation_file,
373+
model=model,
374+
n_epochs=n_epochs,
375+
n_evals=n_evals,
376+
training_type="lora" if lora else "full",
377+
training_method=training_method,
378+
)
379+
380+
price = click.style(
381+
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
382+
bold=True,
383+
)
384+
385+
if not finetune_price_estimation_result.allowed_to_proceed:
386+
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
387+
else:
388+
warning = ""
389+
390+
confirmation_message = _CONFIRMATION_MESSAGE.format(
391+
price=price,
392+
warning=warning,
393+
)
394+
395+
if confirm or click.confirm(confirmation_message, default=True, show_default=True):
361396
response = client.fine_tuning.create(
362397
**training_args,
363398
verbose=True,
364399
)
365-
366400
report_string = f"Successfully submitted a fine-tuning job {response.id}"
367401
if response.created_at is not None:
368402
created_time = datetime.strptime(

src/together/resources/finetune.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
FinetuneLRScheduler,
2121
FinetuneRequest,
2222
FinetuneResponse,
23+
FinetunePriceEstimationRequest,
24+
FinetunePriceEstimationResponse,
2325
FinetuneTrainingLimits,
2426
FullTrainingType,
2527
LinearLRScheduler,
@@ -31,7 +33,7 @@
3133
TrainingMethodSFT,
3234
TrainingType,
3335
)
34-
from together.types.finetune import DownloadCheckpointType
36+
from together.types.finetune import DownloadCheckpointType, TrainingMethod
3537
from together.utils import log_warn_once, normalize_key
3638

3739

@@ -42,6 +44,12 @@
4244
TrainingMethodSFT().method,
4345
TrainingMethodDPO().method,
4446
}
47+
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
48+
"The estimated price of the fine-tuning job is {} which is significantly "
49+
"greater than your current credit limit and balance combined. "
50+
"It will likely get cancelled due to insufficient funds. "
51+
"Proceed at your own risk."
52+
)
4553

4654

4755
def create_finetune_request(
@@ -473,12 +481,34 @@ def create(
473481
hf_api_token=hf_api_token,
474482
hf_output_repo_name=hf_output_repo_name,
475483
)
484+
if from_checkpoint is None and from_hf_model is None:
485+
price_estimation_result = self.estimate_price(
486+
training_file=training_file,
487+
validation_file=validation_file,
488+
model=model_name,
489+
n_epochs=finetune_request.n_epochs,
490+
n_evals=finetune_request.n_evals,
491+
training_type="lora" if lora else "full",
492+
training_method=training_method,
493+
)
494+
price_limit_passed = price_estimation_result.allowed_to_proceed
495+
else:
496+
# unsupported case
497+
price_limit_passed = True
476498

477499
if verbose:
478500
rprint(
479501
"Submitting a fine-tuning job with the following parameters:",
480502
finetune_request,
481503
)
504+
if not price_limit_passed:
505+
rprint(
506+
"[red]"
507+
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
508+
price_estimation_result.estimated_total_price
509+
)
510+
+ "[/red]",
511+
)
482512
parameter_payload = finetune_request.model_dump(exclude_none=True)
483513

484514
response, _, _ = requestor.request(
@@ -493,6 +523,81 @@ def create(
493523

494524
return FinetuneResponse(**response.data)
495525

526+
def estimate_price(
527+
self,
528+
*,
529+
training_file: str,
530+
model: str,
531+
validation_file: str | None = None,
532+
n_epochs: int | None = 1,
533+
n_evals: int | None = 0,
534+
training_type: str = "lora",
535+
training_method: str = "sft",
536+
) -> FinetunePriceEstimationResponse:
537+
"""
538+
Estimates the price of a fine-tuning job
539+
540+
Args:
541+
training_file (str): File-ID of a file uploaded to the Together API
542+
model (str): Name of the base model to run fine-tune job on
543+
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
544+
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
545+
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
546+
training_type (str, optional): Training type. Defaults to "lora".
547+
training_method (str, optional): Training method. Defaults to "sft".
548+
549+
Returns:
550+
FinetunePriceEstimationResponse: Object containing the price estimation result.
551+
"""
552+
training_type_cls: TrainingType
553+
training_method_cls: TrainingMethod
554+
555+
if training_method == "sft":
556+
training_method_cls = TrainingMethodSFT(method="sft")
557+
elif training_method == "dpo":
558+
training_method_cls = TrainingMethodDPO(method="dpo")
559+
else:
560+
raise ValueError(f"Unknown training method: {training_method}")
561+
562+
if training_type.lower() == "lora":
563+
# parameters of lora are unused in price estimation
564+
# but we need to set them to valid values
565+
training_type_cls = LoRATrainingType(
566+
type="Lora",
567+
lora_r=16,
568+
lora_alpha=16,
569+
lora_dropout=0.0,
570+
lora_trainable_modules="all-linear",
571+
)
572+
elif training_type.lower() == "full":
573+
training_type_cls = FullTrainingType(type="Full")
574+
else:
575+
raise ValueError(f"Unknown training type: {training_type}")
576+
577+
request = FinetunePriceEstimationRequest(
578+
training_file=training_file,
579+
validation_file=validation_file,
580+
model=model,
581+
n_epochs=n_epochs,
582+
n_evals=n_evals,
583+
training_type=training_type_cls,
584+
training_method=training_method_cls,
585+
)
586+
parameter_payload = request.model_dump(exclude_none=True)
587+
requestor = api_requestor.APIRequestor(
588+
client=self._client,
589+
)
590+
591+
response, _, _ = requestor.request(
592+
options=TogetherRequest(
593+
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
594+
),
595+
stream=False,
596+
)
597+
assert isinstance(response, TogetherResponse)
598+
599+
return FinetunePriceEstimationResponse(**response.data)
600+
496601
def list(self) -> FinetuneList:
497602
"""
498603
Lists fine-tune job history
@@ -941,11 +1046,34 @@ async def create(
9411046
hf_output_repo_name=hf_output_repo_name,
9421047
)
9431048

1049+
if from_checkpoint is None and from_hf_model is None:
1050+
price_estimation_result = await self.estimate_price(
1051+
training_file=training_file,
1052+
validation_file=validation_file,
1053+
model=model_name,
1054+
n_epochs=finetune_request.n_epochs,
1055+
n_evals=finetune_request.n_evals,
1056+
training_type="lora" if lora else "full",
1057+
training_method=training_method,
1058+
)
1059+
price_limit_passed = price_estimation_result.allowed_to_proceed
1060+
else:
1061+
# unsupported case
1062+
price_limit_passed = True
1063+
9441064
if verbose:
9451065
rprint(
9461066
"Submitting a fine-tuning job with the following parameters:",
9471067
finetune_request,
9481068
)
1069+
if not price_limit_passed:
1070+
rprint(
1071+
"[red]"
1072+
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
1073+
price_estimation_result.estimated_total_price
1074+
)
1075+
+ "[/red]",
1076+
)
9491077
parameter_payload = finetune_request.model_dump(exclude_none=True)
9501078

9511079
response, _, _ = await requestor.arequest(
@@ -961,6 +1089,81 @@ async def create(
9611089

9621090
return FinetuneResponse(**response.data)
9631091

1092+
async def estimate_price(
1093+
self,
1094+
*,
1095+
training_file: str,
1096+
model: str,
1097+
validation_file: str | None = None,
1098+
n_epochs: int | None = 1,
1099+
n_evals: int | None = 0,
1100+
training_type: str = "lora",
1101+
training_method: str = "sft",
1102+
) -> FinetunePriceEstimationResponse:
1103+
"""
1104+
Estimates the price of a fine-tuning job
1105+
1106+
Args:
1107+
training_file (str): File-ID of a file uploaded to the Together API
1108+
model (str): Name of the base model to run fine-tune job on
1109+
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
1110+
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
1111+
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
1112+
training_type (str, optional): Training type. Defaults to "lora".
1113+
training_method (str, optional): Training method. Defaults to "sft".
1114+
1115+
Returns:
1116+
FinetunePriceEstimationResponse: Object containing the price estimation result.
1117+
"""
1118+
training_type_cls: TrainingType
1119+
training_method_cls: TrainingMethod
1120+
1121+
if training_method == "sft":
1122+
training_method_cls = TrainingMethodSFT(method="sft")
1123+
elif training_method == "dpo":
1124+
training_method_cls = TrainingMethodDPO(method="dpo")
1125+
else:
1126+
raise ValueError(f"Unknown training method: {training_method}")
1127+
1128+
if training_type.lower() == "lora":
1129+
# parameters of lora are unused in price estimation
1130+
# but we need to set them to valid values
1131+
training_type_cls = LoRATrainingType(
1132+
type="Lora",
1133+
lora_r=16,
1134+
lora_alpha=16,
1135+
lora_dropout=0.0,
1136+
lora_trainable_modules="all-linear",
1137+
)
1138+
elif training_type.lower() == "full":
1139+
training_type_cls = FullTrainingType(type="Full")
1140+
else:
1141+
raise ValueError(f"Unknown training type: {training_type}")
1142+
1143+
request = FinetunePriceEstimationRequest(
1144+
training_file=training_file,
1145+
validation_file=validation_file,
1146+
model=model,
1147+
n_epochs=n_epochs,
1148+
n_evals=n_evals,
1149+
training_type=training_type_cls,
1150+
training_method=training_method_cls,
1151+
)
1152+
parameter_payload = request.model_dump(exclude_none=True)
1153+
requestor = api_requestor.APIRequestor(
1154+
client=self._client,
1155+
)
1156+
1157+
response, _, _ = await requestor.arequest(
1158+
options=TogetherRequest(
1159+
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
1160+
),
1161+
stream=False,
1162+
)
1163+
assert isinstance(response, TogetherResponse)
1164+
1165+
return FinetunePriceEstimationResponse(**response.data)
1166+
9641167
async def list(self) -> FinetuneList:
9651168
"""
9661169
Async method to list fine-tune job history

src/together/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
FinetuneListEvents,
5555
FinetuneRequest,
5656
FinetuneResponse,
57+
FinetunePriceEstimationRequest,
58+
FinetunePriceEstimationResponse,
5759
FinetuneDeleteResponse,
5860
FinetuneTrainingLimits,
5961
FullTrainingType,
@@ -103,6 +105,8 @@
103105
"FinetuneDeleteResponse",
104106
"FinetuneDownloadResult",
105107
"FinetuneLRScheduler",
108+
"FinetunePriceEstimationRequest",
109+
"FinetunePriceEstimationResponse",
106110
"LinearLRScheduler",
107111
"LinearLRSchedulerArgs",
108112
"CosineLRScheduler",

src/together/types/finetune.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType:
308308
raise ValueError("Unknown training type")
309309

310310

311+
class FinetunePriceEstimationRequest(BaseModel):
312+
"""
313+
Fine-tune price estimation request type
314+
"""
315+
316+
training_file: str
317+
validation_file: str | None = None
318+
model: str
319+
n_epochs: int
320+
n_evals: int
321+
training_type: TrainingType
322+
training_method: TrainingMethod
323+
324+
325+
class FinetunePriceEstimationResponse(BaseModel):
326+
"""
327+
Fine-tune price estimation response type
328+
"""
329+
330+
estimated_total_price: float
331+
user_limit: float
332+
estimated_train_token_count: int
333+
estimated_eval_token_count: int
334+
allowed_to_proceed: bool
335+
336+
311337
class FinetuneList(BaseModel):
312338
# object type
313339
object: Literal["list"] | None = None

0 commit comments

Comments
 (0)