Skip to content

Commit 6b9a23e

Browse files
fix: use the cloud catalog when doing save_as_table in a cloud session
1 parent d3842b0 commit 6b9a23e

File tree

6 files changed

+125
-60
lines changed

6 files changed

+125
-60
lines changed

src/fenic/_backends/cloud/catalog.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass
55
from datetime import datetime
66
from typing import Any, Coroutine, Dict, List, Optional
7+
from urllib.parse import urlparse
78
from uuid import UUID
89

910
import polars as pl
@@ -57,12 +58,12 @@
5758

5859
logger = logging.getLogger(__name__)
5960

60-
6161
@dataclass(frozen=True)
6262
class CatalogKey:
6363
catalog_name: str
6464
catalog_id: UUID
6565

66+
CLOUD_SUPPORTED_SCHEMES = ["s3"]
6667

6768
class CloudCatalog(BaseCatalog):
6869
"""A catalog for cloud execution mode. Implements the BaseCatalog -
@@ -113,11 +114,18 @@ def list_catalogs(self) -> List[str]:
113114
remote_catalogs.append(DEFAULT_CATALOG_NAME)
114115
return remote_catalogs
115116

116-
def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bool:
117+
def create_catalog(
118+
self,
119+
catalog_name: str,
120+
location: str,
121+
ignore_if_exists: bool = True) -> bool:
117122
"""Create a new catalog."""
118123
if compare_object_names(catalog_name, DEFAULT_CATALOG_NAME):
119124
raise CatalogError("Cannot create a catalog with the default name")
120125

126+
if urlparse(location).scheme not in CLOUD_SUPPORTED_SCHEMES:
127+
raise CatalogError(f"Unsupported scheme: {urlparse(location).scheme}")
128+
121129
with self.lock:
122130
if self._does_catalog_exist(catalog_name):
123131
if ignore_if_exists:
@@ -132,7 +140,7 @@ def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bo
132140
created_by_user_id=UUID(self.user_id),
133141
parent_organization_id=UUID(self.organization_id),
134142
catalog_type=TypedefCatalogTypeReferenceEnum.INTERNAL_TYPEDEF,
135-
catalog_warehouse="",
143+
catalog_warehouse=location,
136144
)
137145
)
138146
return True
@@ -268,14 +276,12 @@ def create_table(
268276
self,
269277
table_name: str,
270278
schema: Schema,
271-
location: str,
272279
ignore_if_exists: bool = True,
273-
file_format: Optional[str] = None,
274280
) -> bool:
275281
"""Create a new table in the current database."""
276282
with self.lock:
277283
return self._create_table(
278-
table_name, schema, location, ignore_if_exists, file_format
284+
table_name, schema, ignore_if_exists
279285
)
280286

281287
def create_view(
@@ -565,9 +571,7 @@ def _create_table(
565571
self,
566572
table_name: str,
567573
schema: Schema,
568-
location: str,
569574
ignore_if_exists: bool = True,
570-
file_format: Optional[str] = None,
571575
) -> bool:
572576
table_identifier = TableIdentifier.from_string(table_name).enrich(
573577
self.current_catalog_name, self.current_database_name
@@ -596,11 +600,6 @@ def _create_table(
596600
raise TableAlreadyExistsError(table_identifier.table, table_identifier.db)
597601

598602
catalog_id = self._get_catalog_id(table_identifier.catalog)
599-
fixed_file_format = (
600-
FileFormat.PARQUET
601-
if file_format is None
602-
else FileFormat(file_format.upper())
603-
)
604603
self._execute_catalog_command(
605604
self.user_client.sc_create_table(
606605
dispatch=self._get_catalog_dispatch_input(catalog_id),
@@ -610,8 +609,8 @@ def _create_table(
610609
canonical_name=table_identifier.table.casefold(),
611610
description=None,
612611
external=False,
613-
location=location,
614-
file_format=fixed_file_format,
612+
location=self._get_table_location_from_table_identifier(table_identifier),
613+
file_format=FileFormat.PARQUET,
615614
partition_field_names=[],
616615
schema_=self._get_schema_input_from_schema(schema),
617616
),
@@ -700,3 +699,8 @@ def _get_schema_type_to_pyarrow(schema_type: str):
700699
return pa.float64()
701700
else:
702701
return schema_type
702+
703+
@staticmethod
704+
def _get_table_location_from_table_identifier(table_identifier: TableIdentifier) -> str:
705+
"""Gets the key in the s3 bucket for the table based on its database and name."""
706+
return f"{table_identifier.db}/{table_identifier.table}"

src/fenic/_backends/cloud/execution.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@
2828
SaveToFileExecutionRequest,
2929
ShowExecutionRequest,
3030
StartExecutionRequest,
31-
TableIdentifier,
31+
)
32+
from fenic_cloud.protos.engine.v1.engine_pb2 import (
33+
TableIdentifier as TableIdentifierProto,
3234
)
3335
from fenic_cloud.protos.engine.v1.engine_pb2_grpc import EngineServiceStub
3436

3537
from fenic._backends.cloud.metrics import get_query_execution_metrics
3638
from fenic._backends.schema_serde import deserialize_schema, serialize_schema
37-
from fenic.core._interfaces import BaseExecution
39+
from fenic._backends.utils.catalog_utils import TableIdentifier
40+
from fenic.api.execution import CommonExecution
3841
from fenic.core._logical_plan.serde import LogicalPlanSerde
3942
from fenic.core.error import (
4043
CloudExecutionError,
@@ -58,7 +61,7 @@
5861

5962
CLOUD_SUPPORTED_SCHEMES = ["s3"]
6063

61-
class CloudExecution(BaseExecution):
64+
class CloudExecution(CommonExecution):
6265
def __init__(
6366
self, session_state: CloudSessionState, engine_stub: EngineServiceStub
6467
):
@@ -165,15 +168,33 @@ def save_as_table(
165168
"""Execute the logical plan and save the result as a table."""
166169
logger.debug(f"Saving plan {logical_plan} as table: {table_name}")
167170
# TODO (DY): check that current catalog and schema (if specified in table_name) match session state
168-
table_identifier = TableIdentifier(
169-
catalog=self.session_state.catalog,
170-
schema=self.session_state.schema,
171+
table_exists, query_metrics = self._validate_table_existance(logical_plan, table_name, mode)
172+
if not table_exists:
173+
raise CloudExecutionError(
174+
f"Cannot save to table '{table_name}' - it does not exist. "
175+
f"Choose a different approach: "
176+
f"1) Create the table in question "
177+
f"2) Use a different table name.")
178+
elif table_exists and query_metrics:
179+
# trunk-ignore-begin(bandit/B101)
180+
assert mode == "ignore", "only mode to fulfill this invariant is ignore."
181+
# trunk-ignore-end(bandit/B101)
182+
return query_metrics
183+
table_identifier = TableIdentifier.from_string(table_name).enrich(
184+
self.session_state.catalog.get_current_catalog(),
185+
self.session_state.catalog.get_current_database(),
186+
)
187+
188+
# TODO (DY): check that current catalog and schema (if specified in table_name) match session state
189+
table_identifier_proto = TableIdentifierProto(
190+
catalog=table_identifier.catalog,
191+
schema=table_identifier.db,
171192
table=table_name,
172193
)
173194
request = StartExecutionRequest(
174195
save_as_table=SaveAsTableExecutionRequest(
175196
serialized_plan=LogicalPlanSerde.serialize(logical_plan),
176-
table_identifier=table_identifier,
197+
table_identifier=table_identifier_proto,
177198
mode=mode,
178199
)
179200
)

src/fenic/_backends/local/execution.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
does_path_exist,
1313
query_files,
1414
)
15-
from fenic.core._interfaces.execution import BaseExecution
16-
from fenic.core._logical_plan.plans.base import LogicalPlan
15+
from fenic.api.execution import CommonExecution
16+
from fenic.core._logical_plan import LogicalPlan
1717
from fenic.core._utils.schema import (
1818
convert_polars_schema_to_custom_schema,
1919
)
@@ -34,7 +34,7 @@
3434
from fenic._backends.local.session_state import LocalSessionState
3535

3636

37-
class LocalExecution(BaseExecution):
37+
class LocalExecution(CommonExecution):
3838
session_state: LocalSessionState
3939
transpiler: LocalTranspiler
4040

@@ -101,35 +101,12 @@ def save_as_table(
101101
) -> QueryMetrics:
102102
"""Execute the logical plan and save the result as a table in the current database."""
103103
self.session_state._check_active()
104-
table_exists = self.session_state.catalog.does_table_exist(table_name)
105-
106-
if table_exists:
107-
if mode == "error":
108-
raise PlanError(
109-
f"Cannot save to table '{table_name}' - it already exists and mode is 'error'. "
110-
f"Choose a different approach: "
111-
f"1) Use mode='overwrite' to replace the existing table, "
112-
f"2) Use mode='append' to add data to the existing table, "
113-
f"3) Use mode='ignore' to skip saving if table exists, "
114-
f"4) Use a different table name."
115-
)
116-
if mode == "ignore":
117-
logger.warning(f"Table {table_name} already exists, ignoring write.")
118-
return QueryMetrics()
119-
if mode == "append":
120-
saved_schema = self.session_state.catalog.describe_table(table_name)
121-
plan_schema = logical_plan.schema()
122-
if saved_schema != plan_schema:
123-
raise PlanError(
124-
f"Cannot append to table '{table_name}' - schema mismatch detected. "
125-
f"The existing table has a different schema than your DataFrame. "
126-
f"Existing schema: {saved_schema} "
127-
f"Your DataFrame schema: {plan_schema} "
128-
f"To fix this: "
129-
f"1) Use mode='overwrite' to replace the table with your DataFrame's schema, "
130-
f"2) Modify your DataFrame to match the existing table's schema, "
131-
f"3) Use a different table name."
132-
)
104+
table_exists, query_metrics = self._validate_table_existance(logical_plan, table_name, mode)
105+
if table_exists and query_metrics:
106+
# trunk-ignore-begin(bandit/B101)
107+
assert mode == "ignore", "only mode to fulfill this invariant is ignore."
108+
# trunk-ignore-end(bandit/B101)
109+
return query_metrics
133110
physical_plan = self.transpiler.transpile(logical_plan)
134111
try:
135112
_, metrics = physical_plan.execute()

src/fenic/_backends/utils/catalog_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ def enrich(self, catalog_name: str, db_name: str) -> "TableIdentifier":
106106
table=self.table,
107107
)
108108

109+
def __str__(self) -> str:
110+
str_identifier = self.table
111+
if self.db:
112+
str_identifier = f"{self.db}.{str_identifier}"
113+
if self.catalog:
114+
str_identifier = f"{self.catalog}.{str_identifier}"
115+
return str_identifier
109116

110117
@dataclass(frozen=True)
111118
class DBIdentifier(BaseIdentifier):
@@ -135,6 +142,12 @@ def enrich(self, catalog_name: str) -> "DBIdentifier":
135142
return self
136143
return DBIdentifier(catalog=catalog_name, db=self.db)
137144

145+
def __str__(self) -> str:
146+
str_identifier = self.db
147+
if self.catalog:
148+
str_identifier = f"{self.catalog}.{str_identifier}"
149+
return str_identifier
150+
138151

139152
def compare_object_names(object_name_1: str, object_name_2: str) -> bool:
140153
"""Compare two object names, ignoring case."""

src/fenic/api/execution.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Common operations for execution."""
2+
3+
import logging
4+
from typing import Literal, Optional, Tuple
5+
6+
from fenic.core._interfaces.execution import BaseExecution
7+
from fenic.core._logical_plan.plans.base import LogicalPlan
8+
from fenic.core.error import PlanError
9+
from fenic.core.metrics import QueryMetrics
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class CommonExecution(BaseExecution):
15+
"""Common class for execution operations."""
16+
def _validate_table_existance(
17+
self,
18+
logical_plan: LogicalPlan,
19+
table_name: str,
20+
mode: Literal["error", "append", "overwrite", "ignore"],
21+
) -> Tuple[bool, Optional[QueryMetrics]]:
22+
if self.session_state.catalog.does_table_exist(table_name):
23+
if mode == "error":
24+
raise PlanError(
25+
f"Cannot save to table '{table_name}' - it already exists and mode is 'error'. "
26+
f"Choose a different approach: "
27+
f"1) Use mode='overwrite' to replace the existing table, "
28+
f"2) Use mode='append' to add data to the existing table, "
29+
f"3) Use mode='ignore' to skip saving if table exists, "
30+
f"4) Use a different table name.")
31+
if mode == "ignore":
32+
logger.warning(f"Table {table_name} already exists, ignoring write.")
33+
return True, QueryMetrics()
34+
if mode == "append":
35+
saved_schema = self.session_state.catalog.describe_table(table_name)
36+
plan_schema = logical_plan.schema()
37+
if saved_schema != plan_schema:
38+
raise PlanError(
39+
f"Cannot append to table '{table_name}' - schema mismatch detected. "
40+
f"The existing table has a different schema than your DataFrame. "
41+
f"Existing schema: {saved_schema} "
42+
f"Your DataFrame schema: {plan_schema} "
43+
f"To fix this: "
44+
f"1) Use mode='overwrite' to replace the table with your DataFrame's schema, "
45+
f"2) Modify your DataFrame to match the existing table's schema, "
46+
f"3) Use a different table name.")
47+
else:
48+
return True, None
49+
if mode == "overwrite":
50+
return True, None
51+
else:
52+
return False, None

tests/_backends/cloud/catalog/test_cloud_catalog.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,12 @@ def test_create_table(cloud_catalog, schema): # noqa: D103
531531
cloud_catalog.create_table(
532532
TEST_TABLE_NAME_1,
533533
schema=schema,
534-
location=TEST_SAMPLE_LOCATION,
535534
ignore_if_exists=False,
536535
)
537536
with pytest.raises(CatalogError):
538537
cloud_catalog.create_table(
539538
"some_catalog.some_database.some_table",
540539
schema=schema,
541-
location=TEST_SAMPLE_LOCATION,
542540
)
543541

544542

@@ -566,16 +564,16 @@ def test_drop_table(cloud_catalog): # noqa: D103
566564

567565
def test_create_catalog(cloud_catalog): # noqa: D103
568566
with pytest.raises(CatalogError):
569-
cloud_catalog.create_catalog(DEFAULT_CATALOG_NAME)
567+
cloud_catalog.create_catalog(DEFAULT_CATALOG_NAME, TEST_SAMPLE_LOCATION)
570568

571-
assert cloud_catalog.create_catalog(TEST_NEW_CATALOG_NAME)
569+
assert cloud_catalog.create_catalog(TEST_NEW_CATALOG_NAME, TEST_SAMPLE_LOCATION)
572570

573571
# The catalog already exists, so we should return False (default for ignore_if_exists is True)
574-
assert not cloud_catalog.create_catalog(TEST_CATALOG_NAME)
572+
assert not cloud_catalog.create_catalog(TEST_CATALOG_NAME, TEST_SAMPLE_LOCATION)
575573

576574
# The catalog already exists, so we should raise an error if ignore_if_exists is False
577575
with pytest.raises(CatalogAlreadyExistsError):
578-
cloud_catalog.create_catalog(TEST_CATALOG_NAME, ignore_if_exists=False)
576+
cloud_catalog.create_catalog(TEST_CATALOG_NAME, TEST_SAMPLE_LOCATION, ignore_if_exists=False)
579577

580578

581579
def test_drop_catalog(cloud_catalog): # noqa: D103

0 commit comments

Comments
 (0)