diff --git a/hamilton/experimental/h_cache.py b/hamilton/experimental/h_cache.py index c6db958a4..ba0a74c99 100644 --- a/hamilton/experimental/h_cache.py +++ b/hamilton/experimental/h_cache.py @@ -33,7 +33,7 @@ def write_feather(data: object, filepath: str, name: str) -> None: @singledispatch -def read_feather(data: object, filepath: str) -> Any: +def read_feather(data: object, filepath: str, **kwargs) -> Any: """Reads from a feather file""" raise NotImplementedError(f"No feather reader for type {type(data)} registered.") @@ -45,7 +45,7 @@ def write_parquet(data: object, filepath: str, name: str) -> None: @singledispatch -def read_parquet(data: object, filepath: str) -> Any: +def read_parquet(data: object, filepath: str, **kwargs) -> Any: """Reads from a parquet file""" raise NotImplementedError(f"No parquet reader for type {type(data)} registered.") @@ -57,7 +57,7 @@ def write_json(data: object, filepath: str, name: str) -> None: @singledispatch -def read_json(data: object, filepath: str) -> Any: +def read_json(data: object, filepath: str, **kwargs) -> Any: """Reads from a json file""" raise NotImplementedError(f"No json reader for type {type(data)} registered.") @@ -69,7 +69,7 @@ def write_pickle(data: Any, filepath: str, name: str) -> None: @singledispatch -def read_pickle(data: Any, filtepath: str) -> object: +def read_pickle(data: Any, filtepath: str, **kwargs) -> object: """Reads from a pickle file""" raise NotImplementedError(f"No object reader for type {type(data)} registered.") @@ -89,13 +89,13 @@ def write_json_pd2(data: pd.Series, filepath: str, name: str) -> None: return _df.to_json(filepath) @read_json.register(pd.Series) - def read_json_pd1(data: pd.Series, filepath: str) -> pd.Series: + def read_json_pd1(data: pd.Series, filepath: str, **kwargs) -> pd.Series: """Reads a series from a feather file.""" _df = pd.read_json(filepath) return _df[_df.columns[0]] @read_json.register(pd.DataFrame) - def read_json_pd2(data: pd.DataFrame, filepath: str) -> pd.DataFrame: + def read_json_pd2(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame: """Reads a dataframe from a feather file.""" return pd.read_json(filepath) @@ -113,12 +113,12 @@ def write_feather_pd2(data: pd.Series, filepath: str, name: str) -> None: data.to_frame(name=name).to_feather(filepath) @read_feather.register(pd.DataFrame) - def read_feather_pd1(data: pd.DataFrame, filepath: str) -> pd.DataFrame: + def read_feather_pd1(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame: """Reads a dataframe from a feather file.""" return pd.read_feather(filepath) @read_feather.register(pd.Series) - def read_feather_pd2(data: pd.Series, filepath: str) -> pd.Series: + def read_feather_pd2(data: pd.Series, filepath: str, **kwargs) -> pd.Series: """Reads a series from a feather file.""" _df = pd.read_feather(filepath) return _df[_df.columns[0]] @@ -134,12 +134,12 @@ def write_parquet_pd2(data: pd.Series, filepath: str, name: str) -> None: data.to_frame(name=name).to_parquet(filepath) @read_parquet.register(pd.DataFrame) - def read_parquet_pd1(data: pd.DataFrame, filepath: str) -> pd.DataFrame: + def read_parquet_pd1(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame: """Reads a dataframe from a parquet file.""" return pd.read_parquet(filepath) @read_parquet.register(pd.Series) - def read_parquet_pd2(data: pd.Series, filepath: str) -> pd.Series: + def read_parquet_pd2(data: pd.Series, filepath: str, **kwargs) -> pd.Series: """Reads a series from a parquet file.""" _df = pd.read_parquet(filepath) return _df[_df.columns[0]] @@ -148,6 +148,23 @@ def read_parquet_pd2(data: pd.Series, filepath: str) -> pd.Series: pass +except ImportError: + pass + +try: + import pyspark.sql as ps + + @write_parquet.register(ps.DataFrame) + def write_parquet_ps(data: ps.DataFrame, filepath: str) -> None: + """Writes a pyspark dataframe to a parquet file.""" + data.write.parquet(filepath, mode="overwrite") + + @read_parquet.register(ps.DataFrame) + def read_parquet_ps(data: ps.DataFrame, filepath: str, **kwargs) -> pd.DataFrame: + """Reads a dataframe from a parquet file.""" + spark = kwargs["spark_session"] + return spark.read.parquet(filepath) + except ImportError: pass @@ -163,7 +180,7 @@ def write_json_dict(data: dict, filepath: str, name: str) -> None: @read_json.register(dict) -def read_json_dict(data: dict, filepath: str) -> dict: +def read_json_dict(data: dict, filepath: str, **kwargs) -> dict: """Reads a dictionary from a JSON file.""" with open(filepath, "r", encoding="utf8") as file: return json.load(file) @@ -180,7 +197,7 @@ def write_pickle_object(data: object, filepath: str, name: str) -> None: @read_pickle.register(object) -def read_pickle_object(data: object, filepath: str) -> object: +def read_pickle_object(data: object, filepath: str, **kwargs) -> object: """Reads a pickle file""" print(filepath) with open(filepath, "rb") as file: @@ -213,7 +230,7 @@ class CachingGraphAdapter(SimplePythonGraphAdapter): and `name` is the name of the node that is being written. Reader functions need to have the following signature: - `def read_(data: Any, filepath: str) -> Any: ...` + `def read_(data: Any, filepath: str, **kwargs) -> Any: ...` where `data` is an EMPTY OBJECT of the type you wish to instantiate, and `filepath` is the path to the file to be read from. @@ -227,7 +244,7 @@ def write_json_pd1(data: T, filepath: str, name: str) -> None: ... @read_json.register(T) - def read_json_dict(data: T, filepath: str) -> T: + def read_json_dict(data: T, filepath: str, **kwargs) -> T: ... Usage @@ -280,6 +297,8 @@ def __init__( force_compute: Optional[Set[str]] = None, writers: Optional[Dict[str, Callable[[Any, str, str], None]]] = None, readers: Optional[Dict[str, Callable[[Any, str], Any]]] = None, + read_kwargs: Optional[Dict[str, Any]] = None, + read_after_write: bool = False, **kwargs, ): """Constructs the adapter. @@ -288,6 +307,7 @@ def __init__( :param force_compute: Set of nodes that should be forced to compute even if cache exists. :param writers: A dictionary of writers for custom formats. :param readers: A dictionary of readers for custom formats. + :param read_kwargs: A dictionary of keyword arguments to pass to the readers. """ super().__init__(*args, **kwargs) @@ -298,6 +318,8 @@ def __init__( self.writers = writers or {} self.readers = readers or {} + self.read_kwargs = read_kwargs or {} + self.read_after_write = read_after_write self._init_default_readers_writers() def _init_default_readers_writers(self): @@ -331,7 +353,7 @@ def _write_cache(self, fmt: str, data: Any, filepath: str, node_name: str) -> No def _read_cache(self, fmt: str, expected_type: Any, filepath: str) -> None: self._check_format(fmt) - return self.readers[fmt](expected_type, filepath) + return self.readers[fmt](expected_type, filepath, **self.read_kwargs) def _get_empty_expected_type(self, expected_type: Type) -> Any: if typing_inspect.is_generic_type(expected_type): @@ -364,6 +386,10 @@ def execute_node(self, node: Node, kwargs: Dict[str, Any]) -> Any: cache_format, ) self._write_cache(cache_format, result, filepath, node.name) + if self.read_after_write: + # this could be useful for delayed execution type things as a means to reset + # that they have set internally + result = self._read_cache(cache_format, result, filepath) self.computed_nodes.add(node.name) return result empty_expected_type = self._get_empty_expected_type(node.type)