Skip to content

Commit 5906010

Browse files
fix: make catalog use sc method to implement does_table_exist
1 parent 356086a commit 5906010

File tree

3 files changed

+119
-76
lines changed

3 files changed

+119
-76
lines changed

src/fenic/_backends/cloud/catalog.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
TypedefCatalogTypeReferenceEnum,
1818
)
1919
from fenic_cloud.hasura_client.generated_graphql_client.input_types import (
20-
CatalogNamespaceInsertInput,
2120
CreateTableInput,
2221
NestedFieldInput,
2322
SchemaInput,
@@ -30,7 +29,6 @@
3029
)
3130

3231
from fenic._backends.cloud.manager import CloudSessionManager
33-
from fenic._backends.cloud.session_state import CloudSessionState
3432
from fenic._backends.local.catalog import (
3533
DEFAULT_CATALOG_NAME,
3634
DEFAULT_DATABASE_NAME,
@@ -70,16 +68,21 @@ class CloudCatalog(BaseCatalog):
7068
all table reads and writes should go through this class for unified table name canonicalization.
7169
"""
7270

73-
def __init__(self, session_state: CloudSessionState, cloud_session_manager: CloudSessionManager):
71+
def __init__(self,
72+
ephemeral_catalog_id: str,
73+
asyncio_loop: asyncio.AbstractEventLoop,
74+
cloud_session_manager: CloudSessionManager):
7475
"""Initialize the remote catalog."""
75-
self.session_state = session_state
76+
self.cloud_session_manager = cloud_session_manager
7677
self.lock = threading.Lock()
77-
self.current_catalog_id: UUID = UUID(session_state.ephemeral_catalog_id)
78+
self.asyncio_loop = asyncio_loop
79+
self.ephemeral_catalog_id: UUID = UUID(ephemeral_catalog_id)
80+
self.current_catalog_id: UUID = self.ephemeral_catalog_id
7881
self.current_catalog_name: str = DEFAULT_CATALOG_NAME
7982
self.current_database_name: str = DEFAULT_DATABASE_NAME
80-
self.user_id = cloud_session_manager.user_id
81-
self.organization_id = cloud_session_manager.organization_id
82-
self.user_client = cloud_session_manager.hasura_user_client
83+
self.user_id = self.cloud_session_manager._client_id
84+
self.organization_id = self.cloud_session_manager._organization_id
85+
self.user_client = self.cloud_session_manager.hasura_user_client
8386

8487
def does_catalog_exist(self, catalog_name: str) -> bool:
8588
"""Checks if a catalog with the specified name exists."""
@@ -129,7 +132,6 @@ def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bo
129132
parent_organization_id=UUID(self.organization_id),
130133
catalog_type=TypedefCatalogTypeReferenceEnum.INTERNAL_TYPEDEF,
131134
catalog_warehouse="",
132-
catalog_description=None,
133135
)
134136
)
135137
return True
@@ -342,14 +344,12 @@ def _create_database(
342344
raise DatabaseAlreadyExistsError(database_name)
343345

344346
self._execute_catalog_command(
345-
self.user_client.create_namespace(
346-
namespace=CatalogNamespaceInsertInput(
347-
name=db_identifier.db,
348-
canonical_name=db_identifier.db.casefold(),
349-
parent_organization_id=self.organization_id,
350-
catalog_id=self.current_catalog_id,
351-
created_by_user_id=self.user_id,
352-
)
347+
self.user_client.sc_create_namespace(
348+
dispatch=self._get_catalog_dispatch_input(self.current_catalog_id),
349+
name=db_identifier.db,
350+
canonical_name=db_identifier.db.casefold(),
351+
description=None,
352+
properties=[],
353353
)
354354
)
355355
return True
@@ -365,8 +365,8 @@ def _does_table_exist(
365365
if not self._does_database_exist(catalog_name, db_name):
366366
return False
367367

368-
tables = self._get_tables_for_database(catalog_name, db_name)
369-
return any(compare_object_names(table, table_name) for table in tables)
368+
table = self._get_table(catalog_name, db_name, table_name)
369+
return table is not None
370370

371371
def _set_current_catalog(self, catalog_name: str) -> None:
372372
if not catalog_name:
@@ -376,7 +376,7 @@ def _set_current_catalog(self, catalog_name: str) -> None:
376376
return
377377

378378
if compare_object_names(catalog_name, DEFAULT_CATALOG_NAME):
379-
self.current_catalog_id = UUID(self.session_state.ephemeral_catalog_id)
379+
self.current_catalog_id = self.ephemeral_catalog_id
380380
self.current_catalog_name = DEFAULT_CATALOG_NAME
381381
return
382382

@@ -421,7 +421,7 @@ def _does_database_exist(self, catalog_name: str, database_name: str) -> bool:
421421

422422
def _execute_catalog_command(self, command: Coroutine[Any, Any, Any]) -> Any:
423423
return asyncio.run_coroutine_threadsafe(
424-
command, self.session_state.asyncio_loop
424+
command, self.asyncio_loop
425425
).result()
426426

427427
def _get_catalog_by_name(self, catalog_name: str) -> Optional[CatalogKey]:
@@ -481,18 +481,35 @@ def _get_tables_for_database(
481481
)
482482
return [dataset.name for dataset in result.catalog_dataset]
483483

484+
def _get_table(
485+
self,
486+
catalog_name: str,
487+
db_name: str,
488+
table_name: str,
489+
ignore_if_not_exists: bool = True,
490+
) -> LoadTableSimpleCatalogLoadTable:
491+
catalog_id = self._get_catalog_id(catalog_name)
492+
try:
493+
result = self._execute_catalog_command(
494+
self.user_client.load_table(
495+
dispatch=self._get_catalog_dispatch_input(catalog_id),
496+
namespace=db_name,
497+
name=table_name,
498+
)
499+
)
500+
return result.simple_catalog.load_table
501+
except Exception as e:
502+
if ignore_if_not_exists:
503+
return None
504+
logger.debug(f"Error getting table {table_name} from catalog {catalog_name} and database {db_name}: {e}")
505+
raise e
506+
507+
484508
def _get_table_details(
485509
self, catalog_name: str, db_name: str, table_name: str
486510
) -> Schema:
487-
catalog_id = self._get_catalog_id(catalog_name)
488-
result = self._execute_catalog_command(
489-
self.user_client.load_table(
490-
dispatch=self._get_catalog_dispatch_input(catalog_id),
491-
namespace=db_name,
492-
name=table_name,
493-
)
494-
)
495-
return self._get_table_schema(result.simple_catalog.load_table)
511+
load_table = self._get_table(catalog_name, db_name, table_name, ignore_if_not_exists=False)
512+
return self._get_table_schema(load_table)
496513

497514
def _get_catalog_dispatch_input(
498515
self,

src/fenic/_backends/cloud/session_state.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fenic._backends.cloud.execution import CloudExecution
2626
from fenic._backends.cloud.settings import CloudSettings
2727
from fenic.core._interfaces import BaseSessionState
28+
from fenic.core._interfaces.catalog import BaseCatalog
2829
from fenic.core._resolved_session_config import (
2930
CloudExecutorSize,
3031
ResolvedSessionConfig,
@@ -61,6 +62,7 @@ class CloudSessionState(BaseSessionState):
6162
session_id: Optional[str] = None
6263
session_name: Optional[str] = None
6364
session_canonical_name: Optional[str] = None
65+
cloud_catalog: BaseCatalog = None
6466

6567
def __init__(
6668
self,
@@ -119,7 +121,16 @@ def execution(self) -> CloudExecution:
119121

120122
@property
121123
def catalog(self):
122-
pass
124+
from fenic._backends.cloud.catalog import CloudCatalog
125+
from fenic._backends.cloud.manager import CloudSessionManager
126+
127+
if self.cloud_catalog is None:
128+
self.cloud_catalog = CloudCatalog(
129+
ephemeral_catalog_id=self.ephemeral_catalog_id,
130+
asyncio_loop=self.asyncio_loop,
131+
cloud_session_manager=CloudSessionManager(),
132+
)
133+
return self.cloud_catalog
123134

124135
# properties and methods referencing dynamic state managed by the CloudSessionManager
125136
@property
@@ -231,6 +242,7 @@ async def _entrypoint_get_or_create_session_engine(self):
231242
self.session_canonical_name = response.canonical_name
232243
self.engine_uri = response.uris.remote_actions_uri
233244
self.arrow_ipc_uri = response.uris.remote_results_uri_prefix
245+
self.ephemeral_catalog_id = response.ephemeral_catalog_id
234246
logger.info(
235247
f"{'Found' if existing else 'Created'} Executor with session_id: {self.session_uuid}"
236248
)

tests/_backends/cloud/catalog/test_cloud_catalog.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
pytest.importorskip("fenic_cloud")
1212
from fenic_cloud.hasura_client.generated_graphql_client import Client
13+
from fenic_cloud.hasura_client.generated_graphql_client.client import (
14+
CatalogDispatchInput,
15+
)
1316
from fenic_cloud.hasura_client.generated_graphql_client.create_namespace import (
1417
CreateNamespace,
1518
CreateNamespaceInsertCatalogNamespaceOne,
@@ -103,6 +106,7 @@
103106
TEST_NEW_TABLE_NAME = "new_table"
104107
TEST_SAMPLE_LOCATION = "s3://test-bucket/test-path"
105108
TEST_NEW_CATALOG_NAME = "new_catalog"
109+
TEST_NONEXISTENT_IDENTIFIER = "nonexistent"
106110

107111

108112
@pytest.fixture(scope="function")
@@ -121,6 +125,55 @@ def schema(): # noqa: D103
121125
]
122126
)
123127

128+
def _load_table_side_effect(
129+
dispatch: CatalogDispatchInput,
130+
namespace: str,
131+
name: str) -> LoadTable:
132+
if TEST_NONEXISTENT_IDENTIFIER in name or name in [TEST_NEW_TABLE_NAME]:
133+
raise Exception("Table not found")
134+
return LoadTable(
135+
simple_catalog=LoadTableSimpleCatalog(
136+
load_table=LoadTableSimpleCatalogLoadTable(
137+
created_at=datetime.now(),
138+
updated_at=datetime.now(),
139+
schema_=SimpleCatalogTableDetailsSchema(
140+
schema_id=1,
141+
identifier_field_ids=[1, 2, 3],
142+
fields=[
143+
SimpleCatalogSchemaDetailsFields(
144+
id=1,
145+
name="id",
146+
data_type="int64",
147+
arrow_data_type="int64",
148+
nullable=False,
149+
metadata=None,
150+
),
151+
SimpleCatalogSchemaDetailsFields(
152+
id=2,
153+
name="name",
154+
data_type="string",
155+
arrow_data_type="string",
156+
nullable=False,
157+
metadata=None,
158+
),
159+
SimpleCatalogSchemaDetailsFields(
160+
id=3,
161+
name="account_balance",
162+
data_type="Decimal128",
163+
arrow_data_type="Decimal128",
164+
nullable=False,
165+
metadata=None,
166+
),
167+
],
168+
),
169+
name=TEST_TABLE_NAME_1,
170+
location=None,
171+
external=True,
172+
file_format=None,
173+
partition_field_names=None,
174+
),
175+
),
176+
)
124177

125178
@pytest.fixture
126179
def mock_user_client(schema):
@@ -181,49 +234,7 @@ def mock_user_client(schema):
181234
)
182235
)
183236

184-
user_client.load_table.return_value = LoadTable(
185-
simple_catalog=LoadTableSimpleCatalog(
186-
load_table=LoadTableSimpleCatalogLoadTable(
187-
created_at=datetime.now(),
188-
updated_at=datetime.now(),
189-
schema_=SimpleCatalogTableDetailsSchema(
190-
schema_id=1,
191-
identifier_field_ids=[1, 2, 3],
192-
fields=[
193-
SimpleCatalogSchemaDetailsFields(
194-
id=1,
195-
name="id",
196-
data_type="int64",
197-
arrow_data_type="int64",
198-
nullable=False,
199-
metadata=None,
200-
),
201-
SimpleCatalogSchemaDetailsFields(
202-
id=2,
203-
name="name",
204-
data_type="string",
205-
arrow_data_type="string",
206-
nullable=False,
207-
metadata=None,
208-
),
209-
SimpleCatalogSchemaDetailsFields(
210-
id=3,
211-
name="account_balance",
212-
data_type="Decimal128",
213-
arrow_data_type="Decimal128",
214-
nullable=False,
215-
metadata=None,
216-
),
217-
],
218-
),
219-
name=TEST_TABLE_NAME_1,
220-
location=None,
221-
external=True,
222-
file_format=None,
223-
partition_field_names=None,
224-
),
225-
),
226-
)
237+
user_client.load_table.side_effect = _load_table_side_effect
227238

228239
user_client.create_namespace.return_value = CreateNamespace(
229240
insert_catalog_namespace_one=CreateNamespaceInsertCatalogNamespaceOne(
@@ -590,9 +601,12 @@ def _init_cloud_catalog(
590601
) -> Any:
591602
os.environ["TYPEDEF_USER_ID"] = "mock_user_id"
592603
os.environ["TYPEDEF_USER_SECRET"] = "mock_user_secret" # nosec B105
593-
os.environ["HASURA_GRAPHQL_ADMIN_SECRET"] = "mock_admin_secret" # nosec B105
594604
os.environ["REMOTE_SESSION_AUTH_PROVIDER_URI"] = "mock_auth_provider_uri"
595-
cloud_catalog = CloudCatalog(session_state, cloud_session_manager)
605+
cloud_catalog = CloudCatalog(
606+
ephemeral_catalog_id=session_state.ephemeral_catalog_id,
607+
asyncio_loop=session_state.asyncio_loop,
608+
cloud_session_manager=cloud_session_manager,
609+
)
596610
cloud_catalog.user_client = client
597611
cloud_catalog.user_id = TEST_DEFAULT_USER_ID
598612
cloud_catalog.organization_id = TEST_DEFAULT_ORGANIZATION_ID

0 commit comments

Comments
 (0)