Skip to content

Commit 51160d0

Browse files
committed
fix(asset): handle asset ref cases
1 parent ede9241 commit 51160d0

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

task-sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ def __hash__(self):
428428

429429

430430
class _AssetRefResolutionMixin:
431-
_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
431+
_asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {}
432432

433-
def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
433+
def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]:
434434
with contextlib.suppress(KeyError):
435435
return self._asset_ref_cache[ref]
436436

@@ -445,8 +445,8 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
445445
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
446446
unique_key = AssetUniqueKey.from_asset(asset)
447447
for ref in refs_to_cache:
448-
self._asset_ref_cache[ref] = unique_key
449-
return unique_key
448+
self._asset_ref_cache[ref] = (unique_key, asset.extra)
449+
return (unique_key, asset.extra)
450450

451451
# TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
452452
@staticmethod
@@ -491,15 +491,16 @@ def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None
491491
return
492492

493493
if isinstance(asset, AssetRef):
494-
asset_key = self._resolve_asset_ref(asset)
494+
asset_key, asset_extra = self._resolve_asset_ref(asset)
495495
else:
496496
asset_key = AssetUniqueKey.from_asset(asset)
497+
asset_extra = asset.extra
497498

498499
asset_alias_name = self.key.name
499500
event = AssetAliasEvent(
500501
source_alias_name=asset_alias_name,
501502
dest_asset_key=asset_key,
502-
dest_asset_extra=asset.extra if isinstance(asset, Asset) else {},
503+
dest_asset_extra=asset_extra,
503504
extra=extra or {},
504505
)
505506
self.asset_alias_events.append(event)
@@ -560,7 +561,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor
560561
elif isinstance(key, AssetAlias):
561562
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
562563
elif isinstance(key, AssetRef):
563-
hashable_key = self._resolve_asset_ref(key)
564+
hashable_key, _ = self._resolve_asset_ref(key)
564565
else:
565566
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
566567

@@ -769,7 +770,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEvent
769770
if isinstance(key, Asset):
770771
hashable_key = AssetUniqueKey.from_asset(key)
771772
elif isinstance(key, AssetRef):
772-
hashable_key = self._resolve_asset_ref(key)
773+
hashable_key, _ = self._resolve_asset_ref(key)
773774
elif isinstance(key, AssetAlias):
774775
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
775776
else:

task-sdk/tests/task_sdk/execution_time/test_context.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,13 @@ def test_nested_context(self):
346346

347347
class TestOutletEventAccessor:
348348
@pytest.mark.parametrize(
349-
"add_arg",
349+
"add_args",
350350
[
351-
Asset("name", "uri"),
352-
Asset.ref(name="name"),
353-
Asset.ref(uri="uri"),
351+
(Asset("name", "uri", extra={"extra": "from asset itself"}), {"extra": "from event"}),
352+
(Asset.ref(name="name"), {"extra": "from event"}),
353+
(Asset.ref(uri="uri"), {"extra": "from event"}),
354354
],
355+
ids=["asset", "asset name ref", "asset uri ref"],
355356
)
356357
@pytest.mark.parametrize(
357358
("key", "asset_alias_events"),
@@ -363,18 +364,21 @@ class TestOutletEventAccessor:
363364
AssetAliasEvent(
364365
source_alias_name="test_alias",
365366
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
366-
dest_asset_extra={},
367-
extra={},
367+
dest_asset_extra={"extra": "from asset itself"},
368+
extra={"extra": "from event"},
368369
)
369370
],
370371
),
371372
),
373+
ids=["inactive asset", "active asset"],
372374
)
373-
def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
374-
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
375+
def test_add(self, add_args, key, asset_alias_events, mock_supervisor_comms):
376+
mock_supervisor_comms.send.return_value = AssetResponse(
377+
name="name", uri="uri", group="", extra={"extra": "from asset itself"}
378+
)
375379

376380
outlet_event_accessor = OutletEventAccessor(key=key, extra={})
377-
outlet_event_accessor.add(add_arg)
381+
outlet_event_accessor.add(*add_args)
378382
assert outlet_event_accessor.asset_alias_events == asset_alias_events
379383

380384
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)