Skip to content
2 changes: 1 addition & 1 deletion application_sdk/activities/metadata_extraction/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ async def transform_data(
dataframe=dataframe, **workflow_args
)
await transformed_output.write_daft_dataframe(transform_metadata)
return await transformed_output.get_statistics()
return await transformed_output.get_statistics(typename=typename)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Metadata Handling Flaw Breaks Data Transformation

The write_daft_dataframe call is outside the conditional block that assigns transform_metadata. When the dataframe is empty, transform_metadata will be undefined causing a NameError. Additionally, if multiple dataframes are processed and some are empty, the write operation will incorrectly reuse the previous non-empty transformation result.

Fix in Cursor Fix in Web


@activity.defn
@auto_heartbeater
Expand Down
2 changes: 2 additions & 0 deletions application_sdk/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


89 changes: 89 additions & 0 deletions application_sdk/decorators/method_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import asyncio
from functools import wraps
from typing import Any, Awaitable, Callable, Optional

from temporalio import activity

from application_sdk.clients.redis import RedisClientAsync
from application_sdk.constants import (
APPLICATION_NAME,
IS_LOCKING_DISABLED,
)
from application_sdk.observability.logger_adaptor import get_logger

logger = get_logger(__name__)


def lock_per_run(
lock_name: Optional[str] = None, ttl_seconds: int = 10
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
"""Serialize an async method within an activity per workflow run.
Uses Redis SET NX EX for acquisition and an owner-verified release.
The lock key is namespaced and scoped to the current workflow run:
``{APPLICATION_NAME}:meth:{method_name}:run:{workflow_run_id}``.
Args:
lock_name: Optional explicit lock name. Defaults to the wrapped method's name.
ttl_seconds: Lock TTL in seconds. Should cover worst-case wait + execution time.
Returns:
A decorator for async callables to guard them with a per-run distributed lock.
"""

def _decorate(
fn: Callable[..., Awaitable[Any]]
) -> Callable[..., Awaitable[Any]]:
@wraps(fn)
async def _wrapped(*args: Any, **kwargs: Any) -> Any:
if IS_LOCKING_DISABLED:
return await fn(*args, **kwargs)

run_id = activity.info().workflow_run_id
name = lock_name or fn.__name__

resource_id = f"{APPLICATION_NAME}:meth:{name}:run:{run_id}"
owner_id = f"{APPLICATION_NAME}:{run_id}"

async with RedisClientAsync() as rc:
# Acquire with retry
retry_count = 0
while True:
logger.debug(f"Attempting to acquire lock: {resource_id}, owner: {owner_id}")
acquired = await rc._acquire_lock(
resource_id, owner_id, ttl_seconds
)
if acquired:
logger.info(f"Lock acquired: {resource_id}, owner: {owner_id}")
break
retry_count += 1
logger.debug(
f"Lock not available, retrying (attempt {retry_count}): {resource_id}"
)
await asyncio.sleep(5)

try:
return await fn(*args, **kwargs)
finally:
# Best-effort release; TTL guarantees cleanup if this fails
try:
logger.debug(f"Releasing lock: {resource_id}, owner: {owner_id}")
released, result = await rc._release_lock(resource_id, owner_id)
if released:
logger.info(f"Lock released successfully: {resource_id}")
else:
logger.warning(
f"Lock release failed (may already be released): {resource_id}, result: {result}"
)
except Exception as e:
logger.warning(
f"Exception during lock release for {resource_id}: {e}. TTL will handle cleanup."
)

return _wrapped

return _decorate


100 changes: 97 additions & 3 deletions application_sdk/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
from temporalio import activity

from application_sdk.activities.common.models import ActivityStatistics
from application_sdk.activities.common.utils import get_object_store_prefix
from application_sdk.activities.common.utils import get_object_store_prefix, build_output_path
from application_sdk.common.dataframe_utils import is_empty_dataframe
from application_sdk.observability.logger_adaptor import get_logger
from application_sdk.observability.metrics_adaptor import MetricType
from application_sdk.services.objectstore import ObjectStore
from application_sdk.constants import TEMPORARY_PATH
from application_sdk.decorators.method_lock import lock_per_run

logger = get_logger(__name__)
activity.logger = logger


if TYPE_CHECKING:
import daft # type: ignore
import pandas as pd
Expand Down Expand Up @@ -71,6 +74,19 @@ class Output(ABC):
current_buffer_size_bytes: int
partitions: List[int]

def _infer_phase_from_path(self) -> Optional[str]:
"""Infer phase from output path by checking for raw/transformed directories.

Returns:
Optional[str]: "Extract" for raw, "Transform" for transformed, else None.
"""
path_parts = str(self.output_path).split("/")
if "raw" in path_parts:
return "Extract"
if "transformed" in path_parts:
return "Transform"
return None

def estimate_dataframe_record_size(self, dataframe: "pd.DataFrame") -> int:
"""Estimate File size of a DataFrame by sampling a few records."""
if len(dataframe) == 0:
Expand Down Expand Up @@ -330,7 +346,7 @@ async def get_statistics(
Exception: If there's an error writing the statistics
"""
try:
statistics = await self.write_statistics()
statistics = await self.write_statistics(typename)
if not statistics:
raise ValueError("No statistics data available")
statistics = ActivityStatistics.model_validate(statistics)
Expand Down Expand Up @@ -390,7 +406,7 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
logger.error(f"Error flushing buffer to files: {str(e)}")
raise e

async def write_statistics(self) -> Optional[Dict[str, Any]]:
async def write_statistics(self, typename: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Write statistics about the output to a JSON file.

This method writes statistics including total record count and chunk count
Expand Down Expand Up @@ -418,6 +434,84 @@ async def write_statistics(self) -> Optional[Dict[str, Any]]:
source=output_file_name,
destination=destination_file_path,
)

if typename:
statistics["typename"] = typename
# Update aggregated statistics at run root in object store
try:
await self._update_run_aggregate(destination_file_path, statistics)
except Exception as e:
logger.warning(f"Failed to update aggregated statistics: {str(e)}")
return statistics
except Exception as e:
logger.error(f"Error writing statistics: {str(e)}")

#TODO Do we need locking here ?
@lock_per_run()
async def _update_run_aggregate(
self, per_path_destination: str, statistics: Dict[str, Any]
) -> None:
"""Aggregate stats into a single file at the workflow run root.

Args:
per_path_destination: Object store destination path for this stats file
(used as key in the aggregate map)
statistics: The statistics dictionary to store
"""
inferred_phase = self._infer_phase_from_path()
if inferred_phase is None:
logger.info("Phase could not be inferred from path. Skipping aggregation.")
return

logger.info(f"Starting _update_run_aggregate for phase: {inferred_phase}")
workflow_run_root_relative = build_output_path()
output_file_name = f"{TEMPORARY_PATH}{workflow_run_root_relative}/statistics.json.ignore"
destination_file_path = get_object_store_prefix(output_file_name)

# Load existing aggregate from object store if present
# Structure: {"Extract": {"typename": {"record_count": N}}, "Transform": {...}, "Publish": {...}}
aggregate_by_phase: Dict[str, Dict[str, Dict[str, Any]]] = {
"Extract": {},
"Transform": {},
"Publish": {}
}

try:
# Download existing aggregate file if present
await ObjectStore.download_file(
source=destination_file_path,
destination=output_file_name,
)
# Load existing JSON structure
with open(output_file_name, "r") as f:
existing_aggregate = orjson.loads(f.read())
# Phase-based structure
aggregate_by_phase.update(existing_aggregate)
logger.info(f"Successfully loaded existing aggregates")
except Exception:
logger.info(
"No existing aggregate found or failed to read. Initializing a new aggregate structure."
)

# Accumulate statistics by typename within the phase
typename = statistics.get("typename", "unknown")

if typename not in aggregate_by_phase[inferred_phase]:
aggregate_by_phase[inferred_phase][typename] = {
"record_count": 0
}

logger.info(f"Accumulating statistics for phase '{inferred_phase}', typename '{typename}': +{statistics['total_record_count']} records")

# Accumulate the record count
aggregate_by_phase[inferred_phase][typename]["record_count"] += statistics["total_record_count"]

with open(output_file_name, "w") as f:
f.write(orjson.dumps(aggregate_by_phase).decode("utf-8"))
logger.info(f"Successfully updated aggregate with accumulated stats for phase '{inferred_phase}'")

# Upload aggregate to object store
await ObjectStore.upload_file(
source=output_file_name,
destination=destination_file_path,
)
Loading
Loading