Skip to content

Commit 0cf374c

Browse files
authored
Implement multipart upload (#364)
* implement core multipart upload + test * fixed whitespaces * refactor * Fix multipart name mismatch * Add test_data/ to gitignore to prevent committing large test files * Fix
1 parent eced2af commit 0cf374c

File tree

8 files changed

+970
-12
lines changed

8 files changed

+970
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,4 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
test_data/

src/together/constants.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@
1515
DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB
1616
DISABLE_TQDM = False
1717

18+
# Upload defaults
19+
MAX_CONCURRENT_PARTS = 4 # Maximum concurrent parts for multipart upload
20+
21+
# Multipart upload constants
22+
MIN_PART_SIZE_MB = 5 # Minimum part size (S3 requirement)
23+
TARGET_PART_SIZE_MB = 100 # Target part size for optimal performance
24+
MAX_MULTIPART_PARTS = 250 # Maximum parts per upload (S3 limit)
25+
MULTIPART_UPLOAD_TIMEOUT = 300 # Timeout in seconds for uploading each part
26+
MULTIPART_THRESHOLD_GB = 5.0 # threshold for switching to multipart upload
27+
28+
# maximum number of GB sized files we support finetuning for
29+
MAX_FILE_SIZE_GB = 25.0
30+
31+
1832
# Messages
1933
MISSING_API_KEY_MESSAGE = """TOGETHER_API_KEY not found.
2034
Please set it as an environment variable or set it as together.api_key
@@ -26,8 +40,6 @@
2640
# the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison
2741
NUM_BYTES_IN_GB = 2**30
2842

29-
# maximum number of GB sized files we support finetuning for
30-
MAX_FILE_SIZE_GB = 4.9
3143

3244
# expected columns for Parquet files
3345
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]

src/together/filemanager.py

Lines changed: 230 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,40 @@
11
from __future__ import annotations
22

3+
import math
34
import os
45
import shutil
56
import stat
67
import tempfile
78
import uuid
9+
from concurrent.futures import ThreadPoolExecutor, as_completed
810
from functools import partial
911
from pathlib import Path
10-
from typing import Tuple
12+
from typing import Any, Dict, List, Tuple
1113

1214
import requests
1315
from filelock import FileLock
1416
from requests.structures import CaseInsensitiveDict
1517
from tqdm import tqdm
16-
from tqdm.utils import CallbackIOWrapper
1718

18-
import together.utils
1919
from together.abstract import api_requestor
20-
from together.constants import DISABLE_TQDM, DOWNLOAD_BLOCK_SIZE, MAX_RETRIES
20+
from together.constants import (
21+
DISABLE_TQDM,
22+
DOWNLOAD_BLOCK_SIZE,
23+
MAX_CONCURRENT_PARTS,
24+
MAX_FILE_SIZE_GB,
25+
MAX_RETRIES,
26+
MIN_PART_SIZE_MB,
27+
NUM_BYTES_IN_GB,
28+
TARGET_PART_SIZE_MB,
29+
MAX_MULTIPART_PARTS,
30+
MULTIPART_UPLOAD_TIMEOUT,
31+
)
2132
from together.error import (
2233
APIError,
2334
AuthenticationError,
2435
DownloadError,
2536
FileTypeError,
37+
ResponseError,
2638
)
2739
from together.together_response import TogetherResponse
2840
from together.types import (
@@ -32,6 +44,8 @@
3244
TogetherClient,
3345
TogetherRequest,
3446
)
47+
from tqdm.utils import CallbackIOWrapper
48+
import together.utils
3549

3650

3751
def chmod_and_replace(src: Path, dst: Path) -> None:
@@ -339,7 +353,7 @@ def upload(
339353
)
340354
redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype)
341355

342-
file_size = os.stat(file.as_posix()).st_size
356+
file_size = os.stat(file).st_size
343357

344358
with tqdm(
345359
total=file_size,
@@ -385,3 +399,214 @@ def upload(
385399
assert isinstance(response, TogetherResponse)
386400

387401
return FileResponse(**response.data)
402+
403+
404+
class MultipartUploadManager:
405+
"""Handles multipart uploads for large files"""
406+
407+
def __init__(self, client: TogetherClient) -> None:
408+
self._client = client
409+
self.max_concurrent_parts = MAX_CONCURRENT_PARTS
410+
411+
def upload(
412+
self,
413+
url: str,
414+
file: Path,
415+
purpose: FilePurpose,
416+
) -> FileResponse:
417+
"""Upload large file using multipart upload"""
418+
419+
file_size = os.stat(file).st_size
420+
421+
file_size_gb = file_size / NUM_BYTES_IN_GB
422+
if file_size_gb > MAX_FILE_SIZE_GB:
423+
raise FileTypeError(
424+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
425+
)
426+
427+
part_size, num_parts = self._calculate_parts(file_size)
428+
429+
file_type = self._get_file_type(file)
430+
upload_info = None
431+
432+
try:
433+
upload_info = self._initiate_upload(
434+
url, file, file_size, num_parts, purpose, file_type
435+
)
436+
437+
completed_parts = self._upload_parts_concurrent(
438+
file, upload_info, part_size
439+
)
440+
441+
return self._complete_upload(
442+
url, upload_info["upload_id"], upload_info["file_id"], completed_parts
443+
)
444+
445+
except Exception as e:
446+
# Cleanup on failure
447+
if upload_info is not None:
448+
self._abort_upload(
449+
url, upload_info["upload_id"], upload_info["file_id"]
450+
)
451+
raise e
452+
453+
def _get_file_type(self, file: Path) -> str:
454+
"""Get file type from extension, raising ValueError for unsupported extensions"""
455+
if file.suffix == ".jsonl":
456+
return "jsonl"
457+
elif file.suffix == ".parquet":
458+
return "parquet"
459+
elif file.suffix == ".csv":
460+
return "csv"
461+
else:
462+
raise ValueError(
463+
f"Unsupported file extension: '{file.suffix}'. "
464+
f"Supported extensions: .jsonl, .parquet, .csv"
465+
)
466+
467+
def _calculate_parts(self, file_size: int) -> tuple[int, int]:
468+
"""Calculate optimal part size and count"""
469+
min_part_size = MIN_PART_SIZE_MB * 1024 * 1024 # 5MB
470+
target_part_size = TARGET_PART_SIZE_MB * 1024 * 1024 # 100MB
471+
472+
if file_size <= target_part_size:
473+
return file_size, 1
474+
475+
num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size))
476+
part_size = math.ceil(file_size / num_parts)
477+
478+
if part_size < min_part_size:
479+
part_size = min_part_size
480+
num_parts = math.ceil(file_size / part_size)
481+
482+
return part_size, num_parts
483+
484+
def _initiate_upload(
485+
self,
486+
url: str,
487+
file: Path,
488+
file_size: int,
489+
num_parts: int,
490+
purpose: FilePurpose,
491+
file_type: str,
492+
) -> Any:
493+
"""Initiate multipart upload with backend"""
494+
495+
requestor = api_requestor.APIRequestor(client=self._client)
496+
497+
payload = {
498+
"file_name": file.name,
499+
"file_size": file_size,
500+
"num_parts": num_parts,
501+
"purpose": purpose.value,
502+
"file_type": file_type,
503+
}
504+
505+
response, _, _ = requestor.request(
506+
options=TogetherRequest(
507+
method="POST",
508+
url="files/multipart/initiate",
509+
params=payload,
510+
),
511+
)
512+
513+
return response.data
514+
515+
def _upload_parts_concurrent(
516+
self, file: Path, upload_info: Dict[str, Any], part_size: int
517+
) -> List[Dict[str, Any]]:
518+
"""Upload file parts concurrently with progress tracking"""
519+
520+
parts = upload_info["parts"]
521+
completed_parts = []
522+
523+
with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor:
524+
with tqdm(total=len(parts), desc="Uploading parts", unit="part") as pbar:
525+
future_to_part = {}
526+
527+
with open(file, "rb") as f:
528+
for part_info in parts:
529+
f.seek((part_info["PartNumber"] - 1) * part_size)
530+
part_data = f.read(part_size)
531+
532+
future = executor.submit(
533+
self._upload_single_part, part_info, part_data
534+
)
535+
future_to_part[future] = part_info["PartNumber"]
536+
537+
# Collect results
538+
for future in as_completed(future_to_part):
539+
part_number = future_to_part[future]
540+
try:
541+
etag = future.result()
542+
completed_parts.append(
543+
{"part_number": part_number, "etag": etag}
544+
)
545+
pbar.update(1)
546+
except Exception as e:
547+
raise Exception(f"Failed to upload part {part_number}: {e}")
548+
549+
completed_parts.sort(key=lambda x: x["part_number"])
550+
return completed_parts
551+
552+
def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> str:
553+
"""Upload a single part and return ETag"""
554+
555+
response = requests.put(
556+
part_info["URL"],
557+
data=part_data,
558+
headers=part_info.get("Headers", {}),
559+
timeout=MULTIPART_UPLOAD_TIMEOUT,
560+
)
561+
response.raise_for_status()
562+
563+
etag = response.headers.get("ETag", "").strip('"')
564+
if not etag:
565+
raise ResponseError(f"No ETag returned for part {part_info['PartNumber']}")
566+
567+
return etag
568+
569+
def _complete_upload(
570+
self,
571+
url: str,
572+
upload_id: str,
573+
file_id: str,
574+
completed_parts: List[Dict[str, Any]],
575+
) -> FileResponse:
576+
"""Complete the multipart upload"""
577+
578+
requestor = api_requestor.APIRequestor(client=self._client)
579+
580+
payload = {
581+
"upload_id": upload_id,
582+
"file_id": file_id,
583+
"parts": completed_parts,
584+
}
585+
586+
response, _, _ = requestor.request(
587+
options=TogetherRequest(
588+
method="POST",
589+
url="files/multipart/complete",
590+
params=payload,
591+
),
592+
)
593+
594+
return FileResponse(**response.data.get("file", response.data))
595+
596+
def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None:
597+
"""Abort the multipart upload"""
598+
599+
requestor = api_requestor.APIRequestor(client=self._client)
600+
601+
payload = {
602+
"upload_id": upload_id,
603+
"file_id": file_id,
604+
}
605+
606+
requestor.request(
607+
options=TogetherRequest(
608+
method="POST",
609+
url="files/multipart/abort",
610+
params=payload,
611+
),
612+
)

src/together/resources/files.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import os
34
from pathlib import Path
45
from pprint import pformat
56

67
from together.abstract import api_requestor
8+
from together.constants import MULTIPART_THRESHOLD_GB, NUM_BYTES_IN_GB
79
from together.error import FileTypeError
8-
from together.filemanager import DownloadManager, UploadManager
10+
from together.filemanager import DownloadManager, UploadManager, MultipartUploadManager
911
from together.together_response import TogetherResponse
1012
from together.types import (
1113
FileDeleteResponse,
@@ -30,7 +32,6 @@ def upload(
3032
purpose: FilePurpose | str = FilePurpose.FineTune,
3133
check: bool = True,
3234
) -> FileResponse:
33-
upload_manager = UploadManager(self._client)
3435

3536
if check and purpose == FilePurpose.FineTune:
3637
report_dict = check_file(file)
@@ -47,7 +48,15 @@ def upload(
4748

4849
assert isinstance(purpose, FilePurpose)
4950

50-
return upload_manager.upload("files", file, purpose=purpose, redirect=True)
51+
file_size = os.stat(file).st_size
52+
file_size_gb = file_size / NUM_BYTES_IN_GB
53+
54+
if file_size_gb > MULTIPART_THRESHOLD_GB:
55+
multipart_manager = MultipartUploadManager(self._client)
56+
return multipart_manager.upload("files", file, purpose)
57+
else:
58+
upload_manager = UploadManager(self._client)
59+
return upload_manager.upload("files", file, purpose=purpose, redirect=True)
5160

5261
def list(self) -> FileList:
5362
requestor = api_requestor.APIRequestor(

src/together/types/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class FileResponse(BaseModel):
5252
"""
5353

5454
id: str
55-
object: Literal[ObjectType.File]
55+
object: str
5656
# created timestamp
5757
created_at: int | None = None
5858
type: FileType | None = None

src/together/utils/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def check_file(
6565
else:
6666
report_dict["found"] = True
6767

68-
file_size = os.stat(file.as_posix()).st_size
68+
file_size = os.stat(file).st_size
6969

7070
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
7171
report_dict["message"] = (

0 commit comments

Comments
 (0)