Skip to content
2 changes: 1 addition & 1 deletion application_sdk/activities/metadata_extraction/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ async def transform_data(
dataframe=dataframe, **workflow_args
)
await transformed_output.write_daft_dataframe(transform_metadata)
return await transformed_output.get_statistics()
return await transformed_output.get_statistics(typename=typename)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Metadata Handling Flaw Breaks Data Transformation

The write_daft_dataframe call is outside the conditional block that assigns transform_metadata. When the dataframe is empty, transform_metadata will be undefined causing a NameError. Additionally, if multiple dataframes are processed and some are empty, the write operation will incorrectly reuse the previous non-empty transformation result.

Fix in Cursor Fix in Web


@activity.defn
@auto_heartbeater
Expand Down
98 changes: 95 additions & 3 deletions application_sdk/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@
from temporalio import activity

from application_sdk.activities.common.models import ActivityStatistics
from application_sdk.activities.common.utils import get_object_store_prefix
from application_sdk.activities.common.utils import get_object_store_prefix, build_output_path
from application_sdk.common.dataframe_utils import is_empty_dataframe
from application_sdk.observability.logger_adaptor import get_logger
from application_sdk.observability.metrics_adaptor import MetricType
from application_sdk.services.objectstore import ObjectStore
from application_sdk.constants import TEMPORARY_PATH

logger = get_logger(__name__)
activity.logger = logger


if TYPE_CHECKING:
import daft # type: ignore
import pandas as pd
Expand Down Expand Up @@ -71,6 +73,19 @@ class Output(ABC):
current_buffer_size_bytes: int
partitions: List[int]

def _infer_phase_from_path(self) -> Optional[str]:
"""Infer phase from output path by checking for raw/transformed directories.

Returns:
Optional[str]: "Extract" for raw, "Transform" for transformed, else None.
"""
path_parts = str(self.output_path).split("/")
if "raw" in path_parts:
return "Extract"
if "transformed" in path_parts:
return "Transform"
return None

def estimate_dataframe_record_size(self, dataframe: "pd.DataFrame") -> int:
"""Estimate File size of a DataFrame by sampling a few records."""
if len(dataframe) == 0:
Expand Down Expand Up @@ -330,7 +345,7 @@ async def get_statistics(
Exception: If there's an error writing the statistics
"""
try:
statistics = await self.write_statistics()
statistics = await self.write_statistics(typename)
if not statistics:
raise ValueError("No statistics data available")
statistics = ActivityStatistics.model_validate(statistics)
Expand Down Expand Up @@ -390,7 +405,7 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
logger.error(f"Error flushing buffer to files: {str(e)}")
raise e

async def write_statistics(self) -> Optional[Dict[str, Any]]:
async def write_statistics(self, typename: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Write statistics about the output to a JSON file.

This method writes statistics including total record count and chunk count
Expand Down Expand Up @@ -418,6 +433,83 @@ async def write_statistics(self) -> Optional[Dict[str, Any]]:
source=output_file_name,
destination=destination_file_path,
)

if typename:
statistics["typename"] = typename
# Update aggregated statistics at run root in object store
try:
await self._update_run_aggregate(destination_file_path, statistics)
except Exception as e:
logger.warning(f"Failed to update aggregated statistics: {str(e)}")
return statistics
except Exception as e:
logger.error(f"Error writing statistics: {str(e)}")

#TODO Do we need locking here ?
async def _update_run_aggregate(
self, per_path_destination: str, statistics: Dict[str, Any]
) -> None:
"""Aggregate stats into a single file at the workflow run root.

Args:
per_path_destination: Object store destination path for this stats file
(used as key in the aggregate map)
statistics: The statistics dictionary to store
"""
inferred_phase = self._infer_phase_from_path()
if inferred_phase is None:
logger.info("Phase could not be inferred from path. Skipping aggregation.")
return

logger.info(f"Starting _update_run_aggregate for phase: {inferred_phase}")
workflow_run_root_relative = build_output_path()
output_file_name = f"{TEMPORARY_PATH}{workflow_run_root_relative}/statistics.json.ignore"
destination_file_path = get_object_store_prefix(output_file_name)

# Load existing aggregate from object store if present
# Structure: {"Extract": {"typename": {"record_count": N}}, "Transform": {...}, "Publish": {...}}
aggregate_by_phase: Dict[str, Dict[str, Dict[str, Any]]] = {
"Extract": {},
"Transform": {},
"Publish": {}
}

try:
# Download existing aggregate file if present
await ObjectStore.download_file(
source=destination_file_path,
destination=output_file_name,
)
# Load existing JSON structure
with open(output_file_name, "r") as f:
existing_aggregate = orjson.loads(f.read())
# Phase-based structure
aggregate_by_phase.update(existing_aggregate)
logger.info(f"Successfully loaded existing aggregates")
except Exception:
logger.info(
"No existing aggregate found or failed to read. Initializing a new aggregate structure."
)

# Accumulate statistics by typename within the phase
typename = statistics.get("typename", "unknown")

if typename not in aggregate_by_phase[inferred_phase]:
aggregate_by_phase[inferred_phase][typename] = {
"record_count": 0
}

logger.info(f"Accumulating statistics for phase '{inferred_phase}', typename '{typename}': +{statistics['total_record_count']} records")

# Accumulate the record count
aggregate_by_phase[inferred_phase][typename]["record_count"] += statistics["total_record_count"]

with open(output_file_name, "w") as f:
f.write(orjson.dumps(aggregate_by_phase).decode("utf-8"))
logger.info(f"Successfully updated aggregate with accumulated stats for phase '{inferred_phase}'")

# Upload aggregate to object store
await ObjectStore.upload_file(
source=output_file_name,
destination=destination_file_path,
)
114 changes: 114 additions & 0 deletions tests/unit/outputs/test_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for output interface."""

from typing import Any
import json
from unittest.mock import AsyncMock, mock_open, patch

import pandas as pd
Expand Down Expand Up @@ -161,3 +162,116 @@ async def test_write_statistics_error(self):
assert result is None
mock_logger.assert_called_once()
assert "Error writing statistics" in mock_logger.call_args[0][0]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Error handling mismatch breaks tests.

The test expects write_statistics to return None on error, but the implementation now raises exceptions instead (line 455 adds explicit raise). The test expects old swallowing behavior that no longer exists, causing the test to fail.

Fix in Cursor Fix in Web

async def test__update_run_aggregate_skips_when_phase_unknown(self):
"""Skips aggregation when phase cannot be inferred from output_path."""
# Ensure no 'raw' or 'transformed' in path so phase is None
self.output.output_path = "/tmp/no-phase/path"
stats = {"typename": "table", "total_record_count": 10}

with patch(
"application_sdk.services.objectstore.ObjectStore.download_file",
new_callable=AsyncMock,
) as mock_dl, patch(
"application_sdk.services.objectstore.ObjectStore.upload_file",
new_callable=AsyncMock,
) as mock_ul, patch("builtins.open", mock_open()) as m:
await self.output._update_run_aggregate("ignored", stats)
mock_dl.assert_not_awaited()
mock_ul.assert_not_awaited()
# No reads/writes when phase is unknown
assert m.call_count == 0

async def test__update_run_aggregate_creates_new_aggregate_extract(self):
"""Creates a new aggregate structure and writes stats for Extract phase."""
# Make phase inference return "Extract"
self.output.output_path = "/tmp/run/raw/path"
stats = {"typename": "table", "total_record_count": 7}

with patch(
"application_sdk.outputs.build_output_path", return_value="workflow/run"
), patch(
"application_sdk.outputs.get_object_store_prefix",
return_value="os://bucket/statistics.json.ignore",
), patch(
"application_sdk.services.objectstore.ObjectStore.download_file",
new_callable=AsyncMock,
) as mock_dl, patch(
"application_sdk.services.objectstore.ObjectStore.upload_file",
new_callable=AsyncMock,
) as mock_ul, patch("builtins.open", mock_open()) as m:
# Simulate no existing aggregate in object store
mock_dl.side_effect = Exception("not found")

await self.output._update_run_aggregate("ignored", stats)

handle = m()
# One write with the aggregated payload
write_calls = handle.write.call_args_list
assert len(write_calls) == 1
payload = write_calls[0].args[0]
data = json.loads(payload)

assert data["Extract"]["table"]["record_count"] == 7
# Other phases should exist, even if empty
assert "Transform" in data and "Publish" in data

mock_ul.assert_awaited_once()

async def test__update_run_aggregate_accumulates_existing_transform(self):
"""Accumulates total_record_count into existing Transform aggregate."""
# Make phase inference return "Transform"
self.output.output_path = "/tmp/run/transformed/path"
stats = {"typename": "table", "total_record_count": 3}

existing = {"Extract": {}, "Transform": {"table": {"record_count": 5}}, "Publish": {}}
m = mock_open(read_data=json.dumps(existing))

with patch(
"application_sdk.outputs.build_output_path", return_value="workflow/run"
), patch(
"application_sdk.outputs.get_object_store_prefix",
return_value="os://bucket/statistics.json.ignore",
), patch(
"application_sdk.services.objectstore.ObjectStore.download_file",
new_callable=AsyncMock,
) as mock_dl, patch(
"application_sdk.services.objectstore.ObjectStore.upload_file",
new_callable=AsyncMock,
) as mock_ul, patch("builtins.open", m) as mo:
await self.output._update_run_aggregate("ignored", stats)

handle = mo()
written = handle.write.call_args[0][0]
data = json.loads(written)

# 5 existing + 3 new
assert data["Transform"]["table"]["record_count"] == 8
mock_dl.assert_awaited_once()
mock_ul.assert_awaited_once()

async def test__update_run_aggregate_defaults_unknown_typename(self):
"""Uses 'unknown' typename when not provided in statistics."""
self.output.output_path = "/tmp/run/raw/path"
stats = {"total_record_count": 4} # no 'typename'

with patch(
"application_sdk.outputs.build_output_path", return_value="workflow/run"
), patch(
"application_sdk.outputs.get_object_store_prefix",
return_value="os://bucket/statistics.json.ignore",
), patch(
"application_sdk.services.objectstore.ObjectStore.download_file",
new_callable=AsyncMock,
) as mock_dl, patch(
"application_sdk.services.objectstore.ObjectStore.upload_file",
new_callable=AsyncMock,
) as mock_ul, patch("builtins.open", mock_open()) as m:
mock_dl.side_effect = Exception("not found")

await self.output._update_run_aggregate("ignored", stats)

payload = m().write.call_args[0][0]
data = json.loads(payload)
assert data["Extract"]["unknown"]["record_count"] == 4
mock_ul.assert_awaited_once()
Loading