Skip to content

Commit 9e6bed7

Browse files
authored
Merge pull request #10 from saturncloud/bugfix/s3anon
Adding option to set anon=True which will let user access S3 bucket anonymously
2 parents c11e134 + 4655249 commit 9e6bed7

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

dask_pytorch_ddp/data.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,19 @@ def _list_all_files(bucket: str, prefix: str, s3_client=None) -> List[str]:
3434
return all_files
3535

3636

37-
def _read_s3_fileobj(bucket, path, fileobj):
37+
def _read_s3_fileobj(bucket, path, fileobj, anon=False):
3838
"""
3939
read an obj from s3 to a file like object
4040
"""
4141
import boto3 # pylint: disable=import-outside-toplevel
42+
from botocore import UNSIGNED # pylint: disable=import-outside-toplevel
43+
from botocore.client import Config # pylint: disable=import-outside-toplevel
44+
45+
if anon:
46+
s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED))
47+
else:
48+
s3 = boto3.resource("s3")
4249

43-
s3 = boto3.resource("s3")
4450
bucket = s3.Bucket(bucket)
4551
bucket.download_fileobj(path, fileobj)
4652
fileobj.seek(0)
@@ -59,12 +65,16 @@ class S3ImageFolder(Dataset):
5965
An image folder that lives in S3. Directories containing the image are classes.
6066
"""
6167

68+
# pylint: disable=too-many-instance-attributes
69+
# pylint: disable=too-many-arguments
70+
6271
def __init__(
6372
self,
6473
s3_bucket: str,
6574
s3_prefix: str,
6675
transform: Optional[Callable] = None,
6776
target_transform: Optional[Callable] = None,
77+
anon: Optional[bool] = False,
6878
):
6979
self.s3_bucket = s3_bucket
7080
self.s3_prefix = s3_prefix
@@ -73,6 +83,7 @@ def __init__(
7383
self.class_to_idx = {k: idx for idx, k in enumerate(self.classes)}
7484
self.transform = transform
7585
self.target_transform = target_transform
86+
self.anon = anon
7687

7788
@classmethod
7889
def _get_class(cls, path):
@@ -88,7 +99,7 @@ def __getitem__(self, idx):
8899
path = self.all_files[idx]
89100
label = self.class_to_idx[self._get_class(path)]
90101
with tempfile.TemporaryFile() as f:
91-
f = _read_s3_fileobj(self.s3_bucket, path, f)
102+
f = _read_s3_fileobj(self.s3_bucket, path, f, self.anon)
92103
img = _load_image_obj(f)
93104
if self.transform is not None:
94105
img = self.transform(img)

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_image_folder_getitem():
4040
read_s3_fileobj.return_value = Mock()
4141
load_image_obj.return_value = Mock()
4242
val, label = folder[0]
43-
read_s3_fileobj.assert_called_once_with("fake-bucket", fake_file_list[0], ANY)
43+
read_s3_fileobj.assert_called_once_with("fake-bucket", fake_file_list[0], ANY, False)
4444
load_image_obj.assert_called_once_with(read_s3_fileobj())
4545
assert val == load_image_obj.return_value
4646
assert label == 1

0 commit comments

Comments
 (0)