|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import math |
3 | 4 | import os |
4 | 5 | import shutil |
5 | 6 | import stat |
6 | 7 | import tempfile |
7 | 8 | import uuid |
| 9 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
8 | 10 | from functools import partial |
9 | 11 | from pathlib import Path |
10 | | -from typing import Tuple |
| 12 | +from typing import Any, Dict, List, Tuple |
11 | 13 |
|
12 | 14 | import requests |
13 | 15 | from filelock import FileLock |
14 | 16 | from requests.structures import CaseInsensitiveDict |
15 | 17 | from tqdm import tqdm |
16 | | -from tqdm.utils import CallbackIOWrapper |
17 | 18 |
|
18 | | -import together.utils |
19 | 19 | from together.abstract import api_requestor |
20 | | -from together.constants import DISABLE_TQDM, DOWNLOAD_BLOCK_SIZE, MAX_RETRIES |
| 20 | +from together.constants import ( |
| 21 | + DISABLE_TQDM, |
| 22 | + DOWNLOAD_BLOCK_SIZE, |
| 23 | + MAX_CONCURRENT_PARTS, |
| 24 | + MAX_FILE_SIZE_GB, |
| 25 | + MAX_RETRIES, |
| 26 | + MIN_PART_SIZE_MB, |
| 27 | + NUM_BYTES_IN_GB, |
| 28 | + TARGET_PART_SIZE_MB, |
| 29 | + MAX_MULTIPART_PARTS, |
| 30 | + MULTIPART_UPLOAD_TIMEOUT, |
| 31 | +) |
21 | 32 | from together.error import ( |
22 | 33 | APIError, |
23 | 34 | AuthenticationError, |
24 | 35 | DownloadError, |
25 | 36 | FileTypeError, |
| 37 | + ResponseError, |
26 | 38 | ) |
27 | 39 | from together.together_response import TogetherResponse |
28 | 40 | from together.types import ( |
|
32 | 44 | TogetherClient, |
33 | 45 | TogetherRequest, |
34 | 46 | ) |
| 47 | +from tqdm.utils import CallbackIOWrapper |
| 48 | +import together.utils |
35 | 49 |
|
36 | 50 |
|
37 | 51 | def chmod_and_replace(src: Path, dst: Path) -> None: |
@@ -339,7 +353,7 @@ def upload( |
339 | 353 | ) |
340 | 354 | redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype) |
341 | 355 |
|
342 | | - file_size = os.stat(file.as_posix()).st_size |
| 356 | + file_size = os.stat(file).st_size |
343 | 357 |
|
344 | 358 | with tqdm( |
345 | 359 | total=file_size, |
@@ -385,3 +399,214 @@ def upload( |
385 | 399 | assert isinstance(response, TogetherResponse) |
386 | 400 |
|
387 | 401 | return FileResponse(**response.data) |
| 402 | + |
| 403 | + |
| 404 | +class MultipartUploadManager: |
| 405 | + """Handles multipart uploads for large files""" |
| 406 | + |
| 407 | + def __init__(self, client: TogetherClient) -> None: |
| 408 | + self._client = client |
| 409 | + self.max_concurrent_parts = MAX_CONCURRENT_PARTS |
| 410 | + |
| 411 | + def upload( |
| 412 | + self, |
| 413 | + url: str, |
| 414 | + file: Path, |
| 415 | + purpose: FilePurpose, |
| 416 | + ) -> FileResponse: |
| 417 | + """Upload large file using multipart upload""" |
| 418 | + |
| 419 | + file_size = os.stat(file).st_size |
| 420 | + |
| 421 | + file_size_gb = file_size / NUM_BYTES_IN_GB |
| 422 | + if file_size_gb > MAX_FILE_SIZE_GB: |
| 423 | + raise FileTypeError( |
| 424 | + f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB" |
| 425 | + ) |
| 426 | + |
| 427 | + part_size, num_parts = self._calculate_parts(file_size) |
| 428 | + |
| 429 | + file_type = self._get_file_type(file) |
| 430 | + upload_info = None |
| 431 | + |
| 432 | + try: |
| 433 | + upload_info = self._initiate_upload( |
| 434 | + url, file, file_size, num_parts, purpose, file_type |
| 435 | + ) |
| 436 | + |
| 437 | + completed_parts = self._upload_parts_concurrent( |
| 438 | + file, upload_info, part_size |
| 439 | + ) |
| 440 | + |
| 441 | + return self._complete_upload( |
| 442 | + url, upload_info["upload_id"], upload_info["file_id"], completed_parts |
| 443 | + ) |
| 444 | + |
| 445 | + except Exception as e: |
| 446 | + # Cleanup on failure |
| 447 | + if upload_info is not None: |
| 448 | + self._abort_upload( |
| 449 | + url, upload_info["upload_id"], upload_info["file_id"] |
| 450 | + ) |
| 451 | + raise e |
| 452 | + |
| 453 | + def _get_file_type(self, file: Path) -> str: |
| 454 | + """Get file type from extension, raising ValueError for unsupported extensions""" |
| 455 | + if file.suffix == ".jsonl": |
| 456 | + return "jsonl" |
| 457 | + elif file.suffix == ".parquet": |
| 458 | + return "parquet" |
| 459 | + elif file.suffix == ".csv": |
| 460 | + return "csv" |
| 461 | + else: |
| 462 | + raise ValueError( |
| 463 | + f"Unsupported file extension: '{file.suffix}'. " |
| 464 | + f"Supported extensions: .jsonl, .parquet, .csv" |
| 465 | + ) |
| 466 | + |
| 467 | + def _calculate_parts(self, file_size: int) -> tuple[int, int]: |
| 468 | + """Calculate optimal part size and count""" |
| 469 | + min_part_size = MIN_PART_SIZE_MB * 1024 * 1024 # 5MB |
| 470 | + target_part_size = TARGET_PART_SIZE_MB * 1024 * 1024 # 100MB |
| 471 | + |
| 472 | + if file_size <= target_part_size: |
| 473 | + return file_size, 1 |
| 474 | + |
| 475 | + num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size)) |
| 476 | + part_size = math.ceil(file_size / num_parts) |
| 477 | + |
| 478 | + if part_size < min_part_size: |
| 479 | + part_size = min_part_size |
| 480 | + num_parts = math.ceil(file_size / part_size) |
| 481 | + |
| 482 | + return part_size, num_parts |
| 483 | + |
| 484 | + def _initiate_upload( |
| 485 | + self, |
| 486 | + url: str, |
| 487 | + file: Path, |
| 488 | + file_size: int, |
| 489 | + num_parts: int, |
| 490 | + purpose: FilePurpose, |
| 491 | + file_type: str, |
| 492 | + ) -> Any: |
| 493 | + """Initiate multipart upload with backend""" |
| 494 | + |
| 495 | + requestor = api_requestor.APIRequestor(client=self._client) |
| 496 | + |
| 497 | + payload = { |
| 498 | + "file_name": file.name, |
| 499 | + "file_size": file_size, |
| 500 | + "num_parts": num_parts, |
| 501 | + "purpose": purpose.value, |
| 502 | + "file_type": file_type, |
| 503 | + } |
| 504 | + |
| 505 | + response, _, _ = requestor.request( |
| 506 | + options=TogetherRequest( |
| 507 | + method="POST", |
| 508 | + url="files/multipart/initiate", |
| 509 | + params=payload, |
| 510 | + ), |
| 511 | + ) |
| 512 | + |
| 513 | + return response.data |
| 514 | + |
| 515 | + def _upload_parts_concurrent( |
| 516 | + self, file: Path, upload_info: Dict[str, Any], part_size: int |
| 517 | + ) -> List[Dict[str, Any]]: |
| 518 | + """Upload file parts concurrently with progress tracking""" |
| 519 | + |
| 520 | + parts = upload_info["parts"] |
| 521 | + completed_parts = [] |
| 522 | + |
| 523 | + with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor: |
| 524 | + with tqdm(total=len(parts), desc="Uploading parts", unit="part") as pbar: |
| 525 | + future_to_part = {} |
| 526 | + |
| 527 | + with open(file, "rb") as f: |
| 528 | + for part_info in parts: |
| 529 | + f.seek((part_info["PartNumber"] - 1) * part_size) |
| 530 | + part_data = f.read(part_size) |
| 531 | + |
| 532 | + future = executor.submit( |
| 533 | + self._upload_single_part, part_info, part_data |
| 534 | + ) |
| 535 | + future_to_part[future] = part_info["PartNumber"] |
| 536 | + |
| 537 | + # Collect results |
| 538 | + for future in as_completed(future_to_part): |
| 539 | + part_number = future_to_part[future] |
| 540 | + try: |
| 541 | + etag = future.result() |
| 542 | + completed_parts.append( |
| 543 | + {"part_number": part_number, "etag": etag} |
| 544 | + ) |
| 545 | + pbar.update(1) |
| 546 | + except Exception as e: |
| 547 | + raise Exception(f"Failed to upload part {part_number}: {e}") |
| 548 | + |
| 549 | + completed_parts.sort(key=lambda x: x["part_number"]) |
| 550 | + return completed_parts |
| 551 | + |
| 552 | + def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> str: |
| 553 | + """Upload a single part and return ETag""" |
| 554 | + |
| 555 | + response = requests.put( |
| 556 | + part_info["URL"], |
| 557 | + data=part_data, |
| 558 | + headers=part_info.get("Headers", {}), |
| 559 | + timeout=MULTIPART_UPLOAD_TIMEOUT, |
| 560 | + ) |
| 561 | + response.raise_for_status() |
| 562 | + |
| 563 | + etag = response.headers.get("ETag", "").strip('"') |
| 564 | + if not etag: |
| 565 | + raise ResponseError(f"No ETag returned for part {part_info['PartNumber']}") |
| 566 | + |
| 567 | + return etag |
| 568 | + |
| 569 | + def _complete_upload( |
| 570 | + self, |
| 571 | + url: str, |
| 572 | + upload_id: str, |
| 573 | + file_id: str, |
| 574 | + completed_parts: List[Dict[str, Any]], |
| 575 | + ) -> FileResponse: |
| 576 | + """Complete the multipart upload""" |
| 577 | + |
| 578 | + requestor = api_requestor.APIRequestor(client=self._client) |
| 579 | + |
| 580 | + payload = { |
| 581 | + "upload_id": upload_id, |
| 582 | + "file_id": file_id, |
| 583 | + "parts": completed_parts, |
| 584 | + } |
| 585 | + |
| 586 | + response, _, _ = requestor.request( |
| 587 | + options=TogetherRequest( |
| 588 | + method="POST", |
| 589 | + url="files/multipart/complete", |
| 590 | + params=payload, |
| 591 | + ), |
| 592 | + ) |
| 593 | + |
| 594 | + return FileResponse(**response.data.get("file", response.data)) |
| 595 | + |
| 596 | + def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None: |
| 597 | + """Abort the multipart upload""" |
| 598 | + |
| 599 | + requestor = api_requestor.APIRequestor(client=self._client) |
| 600 | + |
| 601 | + payload = { |
| 602 | + "upload_id": upload_id, |
| 603 | + "file_id": file_id, |
| 604 | + } |
| 605 | + |
| 606 | + requestor.request( |
| 607 | + options=TogetherRequest( |
| 608 | + method="POST", |
| 609 | + url="files/multipart/abort", |
| 610 | + params=payload, |
| 611 | + ), |
| 612 | + ) |
0 commit comments