Skip to content

Commit 8ba41cc

Browse files
committed
fix: results for handlers
1 parent 2a3e5bf commit 8ba41cc

File tree

4 files changed

+55
-31
lines changed

4 files changed

+55
-31
lines changed

application_sdk/clients/sql.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ async def run_query(self, query: str, batch_size: int = 100000):
404404
logger.info("Query execution completed")
405405

406406
def _execute_pandas_query(
407-
self, conn, query
407+
self, conn, query, chunksize: Optional[int]
408408
) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
409409
"""Helper function to execute SQL query using pandas.
410410
The function is responsible for using import_optional_dependency method of the pandas library to import sqlalchemy
@@ -424,13 +424,13 @@ def _execute_pandas_query(
424424
from sqlalchemy import text
425425

426426
if import_optional_dependency("sqlalchemy", errors="ignore"):
427-
return pd.read_sql_query(text(query), conn, chunksize=self.chunk_size)
427+
return pd.read_sql_query(text(query), conn, chunksize=chunksize)
428428
else:
429429
dbapi_conn = getattr(conn, "connection", None)
430-
return pd.read_sql_query(query, dbapi_conn, chunksize=self.chunk_size)
430+
return pd.read_sql_query(query, dbapi_conn, chunksize=chunksize)
431431

432432
def _read_sql_query(
433-
self, session: "Session", query: str
433+
self, session: "Session", query: str, chunksize: Optional[int]
434434
) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
435435
"""Execute SQL query using the provided session.
436436
@@ -442,10 +442,10 @@ def _read_sql_query(
442442
or iterator of DataFrames if chunked.
443443
"""
444444
conn = session.connection()
445-
return self._execute_pandas_query(conn, query)
445+
return self._execute_pandas_query(conn, query, chunksize=chunksize)
446446

447447
def _execute_query_daft(
448-
self, query: str
448+
self, query: str, chunksize: Optional[int]
449449
) -> Union["daft.DataFrame", Iterator["daft.DataFrame"]]:
450450
"""Execute SQL query using the provided engine and daft.
451451
@@ -462,15 +462,11 @@ def _execute_query_daft(
462462
raise ValueError("Engine is not initialized. Call load() first.")
463463

464464
if isinstance(self.engine, str):
465-
return daft.read_sql(
466-
query, self.engine, infer_schema_length=self.chunk_size
467-
)
468-
return daft.read_sql(
469-
query, self.engine.connect, infer_schema_length=self.chunk_size
470-
)
465+
return daft.read_sql(query, self.engine, infer_schema_length=chunksize)
466+
return daft.read_sql(query, self.engine.connect, infer_schema_length=chunksize)
471467

472468
def _execute_query(
473-
self, query: str
469+
self, query: str, chunksize: Optional[int]
474470
) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
475471
"""Execute SQL query using the provided engine and pandas.
476472
@@ -482,7 +478,7 @@ def _execute_query(
482478
raise ValueError("Engine is not initialized. Call load() first.")
483479

484480
with self.engine.connect() as conn:
485-
return self._execute_pandas_query(conn, query)
481+
return self._execute_pandas_query(conn, query, chunksize)
486482

487483
async def get_batched_results(
488484
self,
@@ -513,12 +509,14 @@ async def get_batched_results(
513509

514510
if async_session:
515511
async with async_session() as session:
516-
return await session.run_sync(self._read_sql_query, query)
512+
return await session.run_sync(
513+
self._read_sql_query, query, chunksize=self.chunk_size
514+
)
517515
else:
518516
# Run the blocking operation in a thread pool
519517
with concurrent.futures.ThreadPoolExecutor() as executor:
520518
return await asyncio.get_event_loop().run_in_executor( # type: ignore
521-
executor, self._execute_query, query
519+
executor, self._execute_query, query, self.chunk_size
522520
)
523521
except Exception as e:
524522
logger.error(f"Error reading batched data(pandas) from SQL: {str(e)}")
@@ -549,12 +547,14 @@ async def get_results(self, query: str) -> "pd.DataFrame":
549547

550548
if async_session:
551549
async with async_session() as session:
552-
return await session.run_sync(self._read_sql_query, query)
550+
return await session.run_sync(
551+
self._read_sql_query, query, chunksize=None
552+
)
553553
else:
554554
# Run the blocking operation in a thread pool
555555
with concurrent.futures.ThreadPoolExecutor() as executor:
556556
result = await asyncio.get_event_loop().run_in_executor(
557-
executor, self._execute_query, query
557+
executor, self._execute_query, query, None
558558
)
559559
import pandas as pd
560560

tests/unit/activities/metadata_extraction/test_sql.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for SQL metadata extraction activities (context-free)."""
22

3+
import os
34
from typing import Any, Dict
45
from unittest.mock import AsyncMock, Mock, patch
56

@@ -370,7 +371,9 @@ async def test_transform_data_success(
370371
assert result is not None
371372
assert isinstance(result, ActivityStatistics)
372373
assert result.total_record_count == 20
373-
mock_download_files.assert_called_once_with("/test/path/raw", ".parquet", None)
374+
# Normalize path for cross-platform compatibility
375+
expected_path = os.path.join("/test/path", "raw")
376+
mock_download_files.assert_called_once_with(expected_path, ".parquet", None)
374377
mock_transform_metadata.assert_called_once()
375378
mock_write_daft_dataframe.assert_called_once()
376379

tests/unit/io/readers/test_json_reader.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ async def test_download_file_invoked_for_missing_files() -> None:
6363

6464
def mock_isfile(path):
6565
# Return False for initial local check, True for downloaded files
66-
if path in ["./local/tmp/local/a.json", "./local/tmp/local/b.json"]:
66+
# Normalize paths for cross-platform comparison
67+
expected_paths = [
68+
os.path.join("./local/tmp/local", "a.json"),
69+
os.path.join("./local/tmp/local", "b.json"),
70+
]
71+
if path in expected_paths:
6772
return True
6873
return False
6974

@@ -80,12 +85,24 @@ def mock_isfile(path):
8085
result = await download_files(json_input.path, ".json", json_input.file_names)
8186

8287
# Each file should be attempted to be downloaded - using correct signature (with destination)
88+
# Normalize paths for cross-platform compatibility
8389
expected_calls = [
84-
call(source="local/a.json", destination="./local/tmp/local/a.json"),
85-
call(source="local/b.json", destination="./local/tmp/local/b.json"),
90+
call(
91+
source="local/a.json",
92+
destination=os.path.join("./local/tmp/local", "a.json"),
93+
),
94+
call(
95+
source="local/b.json",
96+
destination=os.path.join("./local/tmp/local", "b.json"),
97+
),
8698
]
8799
mock_download.assert_has_calls(expected_calls, any_order=True)
88-
assert result == ["./local/tmp/local/a.json", "./local/tmp/local/b.json"]
100+
# Normalize result paths for comparison
101+
expected_result = [
102+
os.path.join("./local/tmp/local", "a.json"),
103+
os.path.join("./local/tmp/local", "b.json"),
104+
]
105+
assert result == expected_result
89106

90107

91108
@pytest.mark.asyncio

tests/unit/io/test_base_io.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,20 @@ async def test_download_files_download_specific_files_success(self):
231231
file_names = ["file1.parquet", "file2.parquet"]
232232
input_instance = MockReader(path, file_names)
233233
# Expected files will be in temporary directory after download
234+
# Normalize paths for cross-platform compatibility
234235
expected_files = [
235-
"./local/tmp/data/file1.parquet",
236-
"./local/tmp/data/file2.parquet",
236+
os.path.join("./local/tmp/data", "file1.parquet"),
237+
os.path.join("./local/tmp/data", "file2.parquet"),
237238
]
238239

239240
def mock_isfile(path):
240241
# Return False for initial local check, True for downloaded files
241-
if path in [
242-
"./local/tmp/data/file1.parquet",
243-
"./local/tmp/data/file2.parquet",
244-
]:
242+
# Normalize paths for cross-platform comparison
243+
expected_paths = [
244+
os.path.join("./local/tmp/data", "file1.parquet"),
245+
os.path.join("./local/tmp/data", "file2.parquet"),
246+
]
247+
if path in expected_paths:
245248
return True
246249
return False
247250

@@ -262,14 +265,15 @@ def mock_isfile(path):
262265
)
263266

264267
# Should download each specific file
268+
# Normalize paths for cross-platform compatibility
265269
assert mock_download.call_count == 2
266270
mock_download.assert_any_call(
267271
source="data/file1.parquet",
268-
destination="./local/tmp/data/file1.parquet",
272+
destination=os.path.join("./local/tmp/data", "file1.parquet"),
269273
)
270274
mock_download.assert_any_call(
271275
source="data/file2.parquet",
272-
destination="./local/tmp/data/file2.parquet",
276+
destination=os.path.join("./local/tmp/data", "file2.parquet"),
273277
)
274278
assert result == expected_files
275279

0 commit comments

Comments
 (0)