2020 FinetuneLRScheduler ,
2121 FinetuneRequest ,
2222 FinetuneResponse ,
23+ FinetunePriceEstimationRequest ,
24+ FinetunePriceEstimationResponse ,
2325 FinetuneTrainingLimits ,
2426 FullTrainingType ,
2527 LinearLRScheduler ,
3133 TrainingMethodSFT ,
3234 TrainingType ,
3335)
34- from together .types .finetune import DownloadCheckpointType
36+ from together .types .finetune import DownloadCheckpointType , TrainingMethod
3537from together .utils import log_warn_once , normalize_key
3638
3739
4244 TrainingMethodSFT ().method ,
4345 TrainingMethodDPO ().method ,
4446}
47+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
48+ "The estimated price of the fine-tuning job is {} which is significantly "
49+ "greater than your current credit limit and balance combined. "
50+ "It will likely get cancelled due to insufficient funds. "
51+ "Proceed at your own risk."
52+ )
4553
4654
4755def create_finetune_request (
@@ -473,12 +481,34 @@ def create(
473481 hf_api_token = hf_api_token ,
474482 hf_output_repo_name = hf_output_repo_name ,
475483 )
484+ if from_checkpoint is None and from_hf_model is None :
485+ price_estimation_result = self .estimate_price (
486+ training_file = training_file ,
487+ validation_file = validation_file ,
488+ model = model_name ,
489+ n_epochs = finetune_request .n_epochs ,
490+ n_evals = finetune_request .n_evals ,
491+ training_type = "lora" if lora else "full" ,
492+ training_method = training_method ,
493+ )
494+ price_limit_passed = price_estimation_result .allowed_to_proceed
495+ else :
496+ # unsupported case
497+ price_limit_passed = True
476498
477499 if verbose :
478500 rprint (
479501 "Submitting a fine-tuning job with the following parameters:" ,
480502 finetune_request ,
481503 )
504+ if not price_limit_passed :
505+ rprint (
506+ "[red]"
507+ + _WARNING_MESSAGE_INSUFFICIENT_FUNDS .format (
508+ price_estimation_result .estimated_total_price
509+ )
510+ + "[/red]" ,
511+ )
482512 parameter_payload = finetune_request .model_dump (exclude_none = True )
483513
484514 response , _ , _ = requestor .request (
@@ -493,6 +523,81 @@ def create(
493523
494524 return FinetuneResponse (** response .data )
495525
526+ def estimate_price (
527+ self ,
528+ * ,
529+ training_file : str ,
530+ model : str ,
531+ validation_file : str | None = None ,
532+ n_epochs : int | None = 1 ,
533+ n_evals : int | None = 0 ,
534+ training_type : str = "lora" ,
535+ training_method : str = "sft" ,
536+ ) -> FinetunePriceEstimationResponse :
537+ """
538+ Estimates the price of a fine-tuning job
539+
540+ Args:
541+ training_file (str): File-ID of a file uploaded to the Together API
542+ model (str): Name of the base model to run fine-tune job on
543+ validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
544+ n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
545+ n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
546+ training_type (str, optional): Training type. Defaults to "lora".
547+ training_method (str, optional): Training method. Defaults to "sft".
548+
549+ Returns:
550+ FinetunePriceEstimationResponse: Object containing the price estimation result.
551+ """
552+ training_type_cls : TrainingType
553+ training_method_cls : TrainingMethod
554+
555+ if training_method == "sft" :
556+ training_method_cls = TrainingMethodSFT (method = "sft" )
557+ elif training_method == "dpo" :
558+ training_method_cls = TrainingMethodDPO (method = "dpo" )
559+ else :
560+ raise ValueError (f"Unknown training method: { training_method } " )
561+
562+ if training_type .lower () == "lora" :
563+ # parameters of lora are unused in price estimation
564+ # but we need to set them to valid values
565+ training_type_cls = LoRATrainingType (
566+ type = "Lora" ,
567+ lora_r = 16 ,
568+ lora_alpha = 16 ,
569+ lora_dropout = 0.0 ,
570+ lora_trainable_modules = "all-linear" ,
571+ )
572+ elif training_type .lower () == "full" :
573+ training_type_cls = FullTrainingType (type = "Full" )
574+ else :
575+ raise ValueError (f"Unknown training type: { training_type } " )
576+
577+ request = FinetunePriceEstimationRequest (
578+ training_file = training_file ,
579+ validation_file = validation_file ,
580+ model = model ,
581+ n_epochs = n_epochs ,
582+ n_evals = n_evals ,
583+ training_type = training_type_cls ,
584+ training_method = training_method_cls ,
585+ )
586+ parameter_payload = request .model_dump (exclude_none = True )
587+ requestor = api_requestor .APIRequestor (
588+ client = self ._client ,
589+ )
590+
591+ response , _ , _ = requestor .request (
592+ options = TogetherRequest (
593+ method = "POST" , url = "fine-tunes/estimate-price" , params = parameter_payload
594+ ),
595+ stream = False ,
596+ )
597+ assert isinstance (response , TogetherResponse )
598+
599+ return FinetunePriceEstimationResponse (** response .data )
600+
496601 def list (self ) -> FinetuneList :
497602 """
498603 Lists fine-tune job history
@@ -941,11 +1046,34 @@ async def create(
9411046 hf_output_repo_name = hf_output_repo_name ,
9421047 )
9431048
1049+ if from_checkpoint is None and from_hf_model is None :
1050+ price_estimation_result = await self .estimate_price (
1051+ training_file = training_file ,
1052+ validation_file = validation_file ,
1053+ model = model_name ,
1054+ n_epochs = finetune_request .n_epochs ,
1055+ n_evals = finetune_request .n_evals ,
1056+ training_type = "lora" if lora else "full" ,
1057+ training_method = training_method ,
1058+ )
1059+ price_limit_passed = price_estimation_result .allowed_to_proceed
1060+ else :
1061+ # unsupported case
1062+ price_limit_passed = True
1063+
9441064 if verbose :
9451065 rprint (
9461066 "Submitting a fine-tuning job with the following parameters:" ,
9471067 finetune_request ,
9481068 )
1069+ if not price_limit_passed :
1070+ rprint (
1071+ "[red]"
1072+ + _WARNING_MESSAGE_INSUFFICIENT_FUNDS .format (
1073+ price_estimation_result .estimated_total_price
1074+ )
1075+ + "[/red]" ,
1076+ )
9491077 parameter_payload = finetune_request .model_dump (exclude_none = True )
9501078
9511079 response , _ , _ = await requestor .arequest (
@@ -961,6 +1089,81 @@ async def create(
9611089
9621090 return FinetuneResponse (** response .data )
9631091
1092+ async def estimate_price (
1093+ self ,
1094+ * ,
1095+ training_file : str ,
1096+ model : str ,
1097+ validation_file : str | None = None ,
1098+ n_epochs : int | None = 1 ,
1099+ n_evals : int | None = 0 ,
1100+ training_type : str = "lora" ,
1101+ training_method : str = "sft" ,
1102+ ) -> FinetunePriceEstimationResponse :
1103+ """
1104+ Estimates the price of a fine-tuning job
1105+
1106+ Args:
1107+ training_file (str): File-ID of a file uploaded to the Together API
1108+ model (str): Name of the base model to run fine-tune job on
1109+ validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
1110+ n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
1111+ n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
1112+ training_type (str, optional): Training type. Defaults to "lora".
1113+ training_method (str, optional): Training method. Defaults to "sft".
1114+
1115+ Returns:
1116+ FinetunePriceEstimationResponse: Object containing the price estimation result.
1117+ """
1118+ training_type_cls : TrainingType
1119+ training_method_cls : TrainingMethod
1120+
1121+ if training_method == "sft" :
1122+ training_method_cls = TrainingMethodSFT (method = "sft" )
1123+ elif training_method == "dpo" :
1124+ training_method_cls = TrainingMethodDPO (method = "dpo" )
1125+ else :
1126+ raise ValueError (f"Unknown training method: { training_method } " )
1127+
1128+ if training_type .lower () == "lora" :
1129+ # parameters of lora are unused in price estimation
1130+ # but we need to set them to valid values
1131+ training_type_cls = LoRATrainingType (
1132+ type = "Lora" ,
1133+ lora_r = 16 ,
1134+ lora_alpha = 16 ,
1135+ lora_dropout = 0.0 ,
1136+ lora_trainable_modules = "all-linear" ,
1137+ )
1138+ elif training_type .lower () == "full" :
1139+ training_type_cls = FullTrainingType (type = "Full" )
1140+ else :
1141+ raise ValueError (f"Unknown training type: { training_type } " )
1142+
1143+ request = FinetunePriceEstimationRequest (
1144+ training_file = training_file ,
1145+ validation_file = validation_file ,
1146+ model = model ,
1147+ n_epochs = n_epochs ,
1148+ n_evals = n_evals ,
1149+ training_type = training_type_cls ,
1150+ training_method = training_method_cls ,
1151+ )
1152+ parameter_payload = request .model_dump (exclude_none = True )
1153+ requestor = api_requestor .APIRequestor (
1154+ client = self ._client ,
1155+ )
1156+
1157+ response , _ , _ = await requestor .arequest (
1158+ options = TogetherRequest (
1159+ method = "POST" , url = "fine-tunes/estimate-price" , params = parameter_payload
1160+ ),
1161+ stream = False ,
1162+ )
1163+ assert isinstance (response , TogetherResponse )
1164+
1165+ return FinetunePriceEstimationResponse (** response .data )
1166+
9641167 async def list (self ) -> FinetuneList :
9651168 """
9661169 Async method to list fine-tune job history
0 commit comments