44from dataclasses import dataclass
55from datetime import datetime
66from typing import Any , Coroutine , Dict , List , Optional
7+ from urllib .parse import urlparse
78from uuid import UUID
89
910import polars as pl
5657
5758logger = logging .getLogger (__name__ )
5859
59-
6060@dataclass (frozen = True )
6161class CatalogKey :
6262 catalog_name : str
6363 catalog_id : UUID
6464
65+ CLOUD_SUPPORTED_SCHEMES = ["s3" ]
6566
6667class CloudCatalog (BaseCatalog ):
6768 """A catalog for cloud execution mode. Implements the BaseCatalog -
@@ -112,11 +113,18 @@ def list_catalogs(self) -> List[str]:
112113 remote_catalogs .append (DEFAULT_CATALOG_NAME )
113114 return remote_catalogs
114115
115- def create_catalog (self , catalog_name : str , ignore_if_exists : bool = True ) -> bool :
116+ def create_catalog (
117+ self ,
118+ catalog_name : str ,
119+ location : str ,
120+ ignore_if_exists : bool = True ) -> bool :
116121 """Create a new catalog."""
117122 if compare_object_names (catalog_name , DEFAULT_CATALOG_NAME ):
118123 raise CatalogError ("Cannot create a catalog with the default name" )
119124
125+ if urlparse (location ).scheme not in CLOUD_SUPPORTED_SCHEMES :
126+ raise CatalogError (f"Unsupported scheme: { urlparse (location ).scheme } " )
127+
120128 with self .lock :
121129 if self ._does_catalog_exist (catalog_name ):
122130 if ignore_if_exists :
@@ -131,7 +139,7 @@ def create_catalog(self, catalog_name: str, ignore_if_exists: bool = True) -> bo
131139 created_by_user_id = UUID (self .user_id ),
132140 parent_organization_id = UUID (self .organization_id ),
133141 catalog_type = TypedefCatalogTypeReferenceEnum .INTERNAL_TYPEDEF ,
134- catalog_warehouse = "" ,
142+ catalog_warehouse = location ,
135143 )
136144 )
137145 return True
@@ -267,14 +275,12 @@ def create_table(
267275 self ,
268276 table_name : str ,
269277 schema : Schema ,
270- location : str ,
271278 ignore_if_exists : bool = True ,
272- file_format : Optional [str ] = None ,
273279 ) -> bool :
274280 """Create a new table in the current database."""
275281 with self .lock :
276282 return self ._create_table (
277- table_name , schema , location , ignore_if_exists , file_format
283+ table_name , schema , ignore_if_exists
278284 )
279285
280286
@@ -526,9 +532,7 @@ def _create_table(
526532 self ,
527533 table_name : str ,
528534 schema : Schema ,
529- location : str ,
530535 ignore_if_exists : bool = True ,
531- file_format : Optional [str ] = None ,
532536 ) -> bool :
533537 table_identifier = TableIdentifier .from_string (table_name ).enrich (
534538 self .current_catalog_name , self .current_database_name
@@ -557,11 +561,6 @@ def _create_table(
557561 raise TableAlreadyExistsError (table_identifier .table , table_identifier .db )
558562
559563 catalog_id = self ._get_catalog_id (table_identifier .catalog )
560- fixed_file_format = (
561- FileFormat .PARQUET
562- if file_format is None
563- else FileFormat (file_format .upper ())
564- )
565564 self ._execute_catalog_command (
566565 self .user_client .sc_create_table (
567566 dispatch = self ._get_catalog_dispatch_input (catalog_id ),
@@ -571,8 +570,8 @@ def _create_table(
571570 canonical_name = table_identifier .table .casefold (),
572571 description = None ,
573572 external = False ,
574- location = location ,
575- file_format = fixed_file_format ,
573+ location = self . _get_table_location_from_table_identifier ( table_identifier ) ,
574+ file_format = FileFormat . PARQUET ,
576575 partition_field_names = [],
577576 schema_ = self ._get_schema_input_from_schema (schema ),
578577 ),
@@ -661,3 +660,8 @@ def _get_schema_type_to_pyarrow(schema_type: str):
661660 return pa .float64 ()
662661 else :
663662 return schema_type
663+
664+ @staticmethod
665+ def _get_table_location_from_table_identifier (table_identifier : TableIdentifier ) -> str :
666+ """Gets the key in the s3 bucket for the table based on its database and name."""
667+ return f"{ table_identifier .db } /{ table_identifier .table } "
0 commit comments