-
Notifications
You must be signed in to change notification settings - Fork 27
fix: use the cloud catalog when doing save_as_table in a cloud session #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from dataclasses import dataclass | ||
| from datetime import datetime | ||
| from typing import Any, Coroutine, Dict, List, Optional | ||
| from urllib.parse import urlparse | ||
| from uuid import UUID | ||
|
|
||
| import polars as pl | ||
|
|
@@ -57,12 +58,12 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class CatalogKey: | ||
| catalog_name: str | ||
| catalog_id: UUID | ||
|
|
||
| CLOUD_SUPPORTED_SCHEMES = ["s3"] | ||
|
|
||
| class CloudCatalog(BaseCatalog): | ||
| """A catalog for cloud execution mode. Implements the BaseCatalog - | ||
|
|
@@ -113,11 +114,19 @@ def list_catalogs(self) -> List[str]: | |
| remote_catalogs.append(DEFAULT_CATALOG_NAME) | ||
| return remote_catalogs | ||
|
|
||
| def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bool: | ||
| def create_catalog( | ||
| self, | ||
| catalog_name: str, | ||
| location: str, | ||
| ignore_if_exists: bool = True) -> bool: | ||
| """Create a new catalog.""" | ||
| if compare_object_names(catalog_name, DEFAULT_CATALOG_NAME): | ||
| raise CatalogError("Cannot create a catalog with the default name") | ||
|
|
||
| catalog_location = urlparse(location) | ||
| if catalog_location.scheme not in CLOUD_SUPPORTED_SCHEMES: | ||
| raise CatalogError(f"Unsupported scheme: {catalog_location.scheme}") | ||
|
|
||
| with self.lock: | ||
| if self._does_catalog_exist(catalog_name): | ||
| if ignore_if_exists: | ||
|
|
@@ -132,7 +141,7 @@ def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bo | |
| created_by_user_id=UUID(self.user_id), | ||
| parent_organization_id=UUID(self.organization_id), | ||
| catalog_type=TypedefCatalogTypeReferenceEnum.INTERNAL_TYPEDEF, | ||
| catalog_warehouse="", | ||
| catalog_warehouse=location, | ||
| ) | ||
| ) | ||
| return True | ||
|
|
@@ -268,14 +277,12 @@ def create_table( | |
| self, | ||
| table_name: str, | ||
| schema: Schema, | ||
| location: str, | ||
| ignore_if_exists: bool = True, | ||
| file_format: Optional[str] = None, | ||
| ) -> bool: | ||
| """Create a new table in the current database.""" | ||
| with self.lock: | ||
| return self._create_table( | ||
| table_name, schema, location, ignore_if_exists, file_format | ||
| table_name, schema, ignore_if_exists | ||
| ) | ||
|
|
||
| def create_view( | ||
|
|
@@ -565,9 +572,7 @@ def _create_table( | |
| self, | ||
| table_name: str, | ||
| schema: Schema, | ||
| location: str, | ||
| ignore_if_exists: bool = True, | ||
| file_format: Optional[str] = None, | ||
| ) -> bool: | ||
| table_identifier = TableIdentifier.from_string(table_name).enrich( | ||
| self.current_catalog_name, self.current_database_name | ||
|
|
@@ -596,11 +601,6 @@ def _create_table( | |
| raise TableAlreadyExistsError(table_identifier.table, table_identifier.db) | ||
|
|
||
| catalog_id = self._get_catalog_id(table_identifier.catalog) | ||
| fixed_file_format = ( | ||
| FileFormat.PARQUET | ||
| if file_format is None | ||
| else FileFormat(file_format.upper()) | ||
| ) | ||
| self._execute_catalog_command( | ||
| self.user_client.sc_create_table( | ||
| dispatch=self._get_catalog_dispatch_input(catalog_id), | ||
|
|
@@ -610,8 +610,8 @@ def _create_table( | |
| canonical_name=table_identifier.table.casefold(), | ||
| description=None, | ||
| external=False, | ||
| location=location, | ||
| file_format=fixed_file_format, | ||
| location=self._get_table_location_from_table_identifier(table_identifier), | ||
| file_format=FileFormat.PARQUET, | ||
| partition_field_names=[], | ||
| schema_=self._get_schema_input_from_schema(schema), | ||
| ), | ||
|
|
@@ -700,3 +700,8 @@ def _get_schema_type_to_pyarrow(schema_type: str): | |
| return pa.float64() | ||
| else: | ||
| return schema_type | ||
|
|
||
| @staticmethod | ||
| def _get_table_location_from_table_identifier(table_identifier: TableIdentifier) -> str: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't the location be based on the base location in the catalog? if the user provides a bucket or prefix as the catalog's base location, all of the paths for dbs and tables should be relative to that base location, i would think. |
||
| """Gets the key in the s3 bucket for the table based on its database and name.""" | ||
| return f"{table_identifier.db}/{table_identifier.table}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,13 +28,16 @@ | |
| SaveToFileExecutionRequest, | ||
| ShowExecutionRequest, | ||
| StartExecutionRequest, | ||
| TableIdentifier, | ||
| ) | ||
| from fenic_cloud.protos.engine.v1.engine_pb2 import ( | ||
| TableIdentifier as TableIdentifierProto, | ||
| ) | ||
| from fenic_cloud.protos.engine.v1.engine_pb2_grpc import EngineServiceStub | ||
|
|
||
| from fenic._backends.cloud.metrics import get_query_execution_metrics | ||
| from fenic._backends.schema_serde import deserialize_schema, serialize_schema | ||
| from fenic.core._interfaces import BaseExecution | ||
| from fenic._backends.utils.catalog_utils import TableIdentifier | ||
| from fenic.api.execution import CommonExecution | ||
| from fenic.core._logical_plan.serde import LogicalPlanSerde | ||
| from fenic.core.error import ( | ||
| CloudExecutionError, | ||
|
|
@@ -58,7 +61,7 @@ | |
|
|
||
| CLOUD_SUPPORTED_SCHEMES = ["s3"] | ||
|
|
||
| class CloudExecution(BaseExecution): | ||
| class CloudExecution(CommonExecution): | ||
| def __init__( | ||
| self, session_state: CloudSessionState, engine_stub: EngineServiceStub | ||
| ): | ||
|
|
@@ -165,15 +168,33 @@ def save_as_table( | |
| """Execute the logical plan and save the result as a table.""" | ||
| logger.debug(f"Saving plan {logical_plan} as table: {table_name}") | ||
| # TODO (DY): check that current catalog and schema (if specified in table_name) match session state | ||
| table_identifier = TableIdentifier( | ||
| catalog=self.session_state.catalog, | ||
| schema=self.session_state.schema, | ||
| table_exists, query_metrics = self._validate_table_existance(logical_plan, table_name, mode) | ||
| if not table_exists: | ||
| raise CloudExecutionError( | ||
| f"Cannot save to table '{table_name}' - it does not exist. " | ||
| f"Choose a different approach: " | ||
| f"1) Create the table in question " | ||
| f"2) Use a different table name.") | ||
| elif table_exists and query_metrics: | ||
| # trunk-ignore-begin(bandit/B101) | ||
| assert mode == "ignore", "only mode to fulfill this invariant is ignore." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. particular reason to choose assert here and not raise an exception? |
||
| # trunk-ignore-end(bandit/B101) | ||
| return query_metrics | ||
| table_identifier = TableIdentifier.from_string(table_name).enrich( | ||
| self.session_state.catalog.get_current_catalog(), | ||
| self.session_state.catalog.get_current_database(), | ||
| ) | ||
|
|
||
| # TODO (DY): check that current catalog and schema (if specified in table_name) match session state | ||
| table_identifier_proto = TableIdentifierProto( | ||
| catalog=table_identifier.catalog, | ||
| schema=table_identifier.db, | ||
german-typedef marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| table=table_name, | ||
| ) | ||
| request = StartExecutionRequest( | ||
| save_as_table=SaveAsTableExecutionRequest( | ||
| serialized_plan=LogicalPlanSerde.serialize(logical_plan), | ||
| table_identifier=table_identifier, | ||
| table_identifier=table_identifier_proto, | ||
| mode=mode, | ||
| ) | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| """Common operations for execution.""" | ||
|
|
||
| import logging | ||
| from typing import Literal, Optional, Tuple | ||
|
|
||
| from fenic.core._interfaces.execution import BaseExecution | ||
| from fenic.core._logical_plan.plans.base import LogicalPlan | ||
| from fenic.core.error import PlanError | ||
| from fenic.core.metrics import QueryMetrics | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CommonExecution(BaseExecution): | ||
| """Common class for execution operations.""" | ||
| def _validate_table_existance( | ||
| self, | ||
| logical_plan: LogicalPlan, | ||
| table_name: str, | ||
| mode: Literal["error", "append", "overwrite", "ignore"], | ||
| ) -> Tuple[bool, Optional[QueryMetrics]]: | ||
| if self.session_state.catalog.does_table_exist(table_name): | ||
| if mode == "error": | ||
| raise PlanError( | ||
| f"Cannot save to table '{table_name}' - it already exists and mode is 'error'. " | ||
| f"Choose a different approach: " | ||
| f"1) Use mode='overwrite' to replace the existing table, " | ||
| f"2) Use mode='append' to add data to the existing table, " | ||
| f"3) Use mode='ignore' to skip saving if table exists, " | ||
| f"4) Use a different table name.") | ||
| if mode == "ignore": | ||
| logger.warning(f"Table {table_name} already exists, ignoring write.") | ||
| return True, QueryMetrics() | ||
| if mode == "append": | ||
| saved_schema = self.session_state.catalog.describe_table(table_name) | ||
| plan_schema = logical_plan.schema() | ||
| if saved_schema != plan_schema: | ||
| raise PlanError( | ||
| f"Cannot append to table '{table_name}' - schema mismatch detected. " | ||
| f"The existing table has a different schema than your DataFrame. " | ||
| f"Existing schema: {saved_schema} " | ||
| f"Your DataFrame schema: {plan_schema} " | ||
| f"To fix this: " | ||
| f"1) Use mode='overwrite' to replace the table with your DataFrame's schema, " | ||
| f"2) Modify your DataFrame to match the existing table's schema, " | ||
| f"3) Use a different table name.") | ||
| else: | ||
| return True, None | ||
| if mode == "overwrite": | ||
| return True, None | ||
| else: | ||
| return False, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll have to adjust this in #157