Skip to content
69 changes: 54 additions & 15 deletions airbyte_cdk/sources/file_based/file_based_stream_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
#

import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from io import IOBase
from os import makedirs, path
from typing import Any, Callable, Iterable, List, MutableMapping, Optional, Set, Tuple
from typing import Any, Iterable, List, MutableMapping, Optional, Set, Tuple

from airbyte_protocol_dataclasses.models import FailureType
from wcmatch.glob import GLOBSTAR, globmatch

from airbyte_cdk.models import AirbyteRecordMessageFileReference
Expand All @@ -19,8 +21,9 @@
preserve_directory_structure,
use_file_transfer,
)
from airbyte_cdk.sources.file_based.exceptions import FileSizeLimitError
from airbyte_cdk.sources.file_based.file_record_data import FileRecordData
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.remote_file import RemoteFile, UploadableRemoteFile


class FileReadMode(Enum):
Expand All @@ -34,6 +37,7 @@ class AbstractFileBasedStreamReader(ABC):
FILE_NAME = "file_name"
LOCAL_FILE_PATH = "local_file_path"
FILE_FOLDER = "file_folder"
FILE_SIZE_LIMIT = 1_500_000_000

def __init__(self) -> None:
self._config = None
Expand Down Expand Up @@ -113,16 +117,6 @@ def filter_files_by_globs_and_start_date(
seen.add(file.uri)
yield file

@abstractmethod
def file_size(self, file: RemoteFile) -> int:
"""Utility method to get size of the remote file.

This is required for connectors that will support writing to
files. If the connector does not support writing files, then the
subclass can simply `return 0`.
"""
...

@staticmethod
def file_matches_globs(file: RemoteFile, globs: List[str]) -> bool:
# Use the GLOBSTAR flag to enable recursive ** matching
Expand Down Expand Up @@ -153,9 +147,8 @@ def include_identities_stream(self) -> bool:
return include_identities_stream(self.config)
return False

@abstractmethod
def upload(
self, file: RemoteFile, local_directory: str, logger: logging.Logger
self, file: UploadableRemoteFile, local_directory: str, logger: logging.Logger
) -> Tuple[FileRecordData, AirbyteRecordMessageFileReference]:
"""
This is required for connectors that will support writing to
Expand All @@ -173,7 +166,53 @@ def upload(
- file_size_bytes (int): The size of the referenced file in bytes.
- source_file_relative_path (str): The relative path to the referenced file in source.
"""
...
if not isinstance(file, UploadableRemoteFile):
raise TypeError(f"Expected UploadableRemoteFile, got {type(file)}")

file_size = file.size

if file_size > self.FILE_SIZE_LIMIT:
message = f"File size exceeds the {self.FILE_SIZE_LIMIT / 1e9} GB limit."
raise FileSizeLimitError(
message=message, internal_message=message, failure_type=FailureType.config_error
)

file_paths = self._get_file_transfer_paths(
source_file_relative_path=file.source_file_relative_path,
staging_directory=local_directory,
)
local_file_path = file_paths[self.LOCAL_FILE_PATH]
file_relative_path = file_paths[self.FILE_RELATIVE_PATH]
file_name = file_paths[self.FILE_NAME]

logger.info(
f"Starting to download the file {file.file_uri_for_logging} with size: {file_size / (1024 * 1024):,.2f} MB ({file_size / (1024 * 1024 * 1024):.2f} GB)"
)
start_download_time = time.time()

file.download_to_local_directory(local_file_path)

write_duration = time.time() - start_download_time
logger.info(
f"Finished downloading the file {file.file_uri_for_logging} and saved to {local_file_path} in {write_duration:,.2f} seconds."
)

file_record_data = FileRecordData(
folder=file_paths[self.FILE_FOLDER],
file_name=file_name,
bytes=file_size,
id=file.id,
mime_type=file.mime_type,
created_at=file.created_at,
updated_at=file.updated_at,
source_uri=file.uri,
)
file_reference = AirbyteRecordMessageFileReference(
staging_file_url=local_file_path,
source_file_relative_path=file_relative_path,
file_size_bytes=file_size,
)
return file_record_data, file_reference

def _get_file_transfer_paths(
self, source_file_relative_path: str, staging_directory: str
Expand Down
4 changes: 2 additions & 2 deletions airbyte_cdk/sources/file_based/file_types/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from airbyte_cdk.models import AirbyteRecordMessageFileReference
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
from airbyte_cdk.sources.file_based.file_record_data import FileRecordData
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.remote_file import UploadableRemoteFile
from airbyte_cdk.sources.utils.files_directory import get_files_directory


Expand All @@ -17,7 +17,7 @@ def __init__(self) -> None:

def upload(
self,
file: RemoteFile,
file: UploadableRemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Tuple[FileRecordData, AirbyteRecordMessageFileReference]]:
Expand Down
41 changes: 40 additions & 1 deletion airbyte_cdk/sources/file_based/remote_file.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional

Expand All @@ -16,3 +16,42 @@ class RemoteFile(BaseModel):
uri: str
last_modified: datetime
mime_type: Optional[str] = None


class UploadableRemoteFile(RemoteFile, ABC):
"""
A file in a file-based stream that supports uploading(file transferring).
"""

id: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None

@property
@abstractmethod
def size(self) -> int:
"""
Returns the file size in bytes.
"""
...

@abstractmethod
def download_to_local_directory(self, local_file_path: str) -> None:
"""
Download the file from remote source to local storage.
"""
...

@property
def source_file_relative_path(self) -> str:
"""
Returns the relative path of the source file.
"""
return self.uri

Comment on lines +21 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@property
def file_uri_for_logging(self) -> str:
"""
Returns the URI for the file being logged.
"""
return self.uri
70 changes: 69 additions & 1 deletion unit_tests/sources/file_based/test_file_based_stream_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from io import IOBase
from os import path
from typing import Any, ClassVar, Dict, Iterable, List, Mapping, Optional, Set
from unittest.mock import MagicMock

import pytest
from pydantic.v1 import AnyUrl

from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.exceptions import FileSizeLimitError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.remote_file import RemoteFile, UploadableRemoteFile
from airbyte_cdk.sources.utils.files_directory import get_files_directory
from unit_tests.sources.file_based.helpers import make_remote_files

Expand Down Expand Up @@ -64,6 +66,38 @@
}


class TestStreamReaderWithDefaultUpload(AbstractFileBasedStreamReader):
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name

@property
def config(self) -> Optional[AbstractFileBasedSpec]:
return self._config

@config.setter
def config(self, value: AbstractFileBasedSpec) -> None:
self._config = value

def get_matching_files(self, globs: List[str]) -> Iterable[RemoteFile]:
pass

def open_file(self, file: RemoteFile) -> IOBase:
pass

def get_file_acl_permissions(self, file: RemoteFile, logger: logging.Logger) -> Dict[str, Any]:
return {}

def load_identity_groups(self, logger: logging.Logger) -> Iterable[Dict[str, Any]]:
return [{}]

@property
def file_permissions_schema(self) -> Dict[str, Any]:
return {"type": "object", "properties": {}}

@property
def identities_schema(self) -> Dict[str, Any]:
return {"type": "object", "properties": {}}


class TestStreamReader(AbstractFileBasedStreamReader):
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name

Expand Down Expand Up @@ -458,3 +492,37 @@ def test_preserve_sub_directories_scenarios(
assert file_paths[AbstractFileBasedStreamReader.LOCAL_FILE_PATH] == expected_local_file_path
assert file_paths[AbstractFileBasedStreamReader.FILE_NAME] == path.basename(source_file_path)
assert file_paths[AbstractFileBasedStreamReader.FILE_FOLDER] == path.dirname(source_file_path)


def test_upload_with_file_transfer_reader():
stream_reader = TestStreamReaderWithDefaultUpload()

class TestUploadableRemoteFile(UploadableRemoteFile):
blob: Any

@property
def size(self) -> int:
return self.blob.size

def download_to_local_directory(self, local_file_path: str) -> None:
pass

blob = MagicMock()
blob.size = 200
uploadable_remote_file = TestUploadableRemoteFile(
uri="test/uri", last_modified=datetime.now(), blob=blob
)

logger = logging.getLogger("airbyte")

file_record_data, file_reference = stream_reader.upload(
uploadable_remote_file, "test_directory", logger
)
assert file_record_data
assert file_reference

blob.size = 2_500_000_000
with pytest.raises(FileSizeLimitError):
stream_reader.upload(uploadable_remote_file, "test_directory", logger)
with pytest.raises(FileSizeLimitError):
stream_reader.upload(uploadable_remote_file, "test_directory", logger)
Loading