Skip to content

Commit b2f7efe

Browse files
[v3-1-test] fix(asset-alias): Preserve Asset.extra when using AssetAlias (#58038) (#58712)
Co-authored-by: Wei Lee <[email protected]>
1 parent 0522dc0 commit b2f7efe

File tree

6 files changed

+58
-35
lines changed

6 files changed

+58
-35
lines changed

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def register_asset_changes_in_db(
13811381
session=session,
13821382
)
13831383

1384-
def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[str]]:
1384+
def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str, str], set[str]]:
13851385
d = defaultdict(set)
13861386
for event in outlet_events:
13871387
try:
@@ -1391,31 +1391,38 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[s
13911391
if alias_name not in outlet_alias_names:
13921392
continue
13931393
asset_key = AssetUniqueKey(**event["dest_asset_key"])
1394-
extra_json = json.dumps(event["extra"], sort_keys=True)
1395-
d[asset_key, extra_json].add(alias_name)
1394+
# fallback for backward compatibility
1395+
asset_extra_json = json.dumps(event.get("dest_asset_extra", {}), sort_keys=True)
1396+
asset_event_extra_json = json.dumps(event["extra"], sort_keys=True)
1397+
d[asset_key, asset_extra_json, asset_event_extra_json].add(alias_name)
13961398
return d
13971399

13981400
outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name}
13991401
if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()):
1400-
for (asset_key, extra_json), event_aliase_names in event_extras_from_aliases.items():
1401-
extra = json.loads(extra_json)
1402+
for (
1403+
asset_key,
1404+
asset_extra_json,
1405+
asset_event_extras_json,
1406+
), event_aliase_names in event_extras_from_aliases.items():
1407+
asset_event_extra = json.loads(asset_event_extras_json)
1408+
asset = Asset(name=asset_key.name, uri=asset_key.uri, extra=json.loads(asset_extra_json))
14021409
ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names)
14031410
event = asset_manager.register_asset_change(
14041411
task_instance=ti,
1405-
asset=asset_key,
1412+
asset=asset,
14061413
source_alias_names=event_aliase_names,
1407-
extra=extra,
1414+
extra=asset_event_extra,
14081415
session=session,
14091416
)
14101417
if event is None:
14111418
ti.log.info("Dynamically creating AssetModel %s", asset_key)
1412-
session.add(AssetModel(name=asset_key.name, uri=asset_key.uri))
1419+
session.add(AssetModel.from_public(asset))
14131420
session.flush() # So event can set up its asset fk.
14141421
asset_manager.register_asset_change(
14151422
task_instance=ti,
1416-
asset=asset_key,
1423+
asset=asset,
14171424
source_alias_names=event_aliase_names,
1418-
extra=extra,
1425+
extra=asset_event_extra,
14191426
session=session,
14201427
)
14211428

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
398398
dest_asset_key=AssetUniqueKey(
399399
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
400400
),
401+
# fallback for backward compatibility
402+
dest_asset_extra=e.get("dest_asset_extra", {}),
401403
extra=e["extra"],
402404
)
403405
for e in asset_alias_events

airflow-core/tests/unit/serialization/test_serialized_objects.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,8 @@ def __len__(self) -> int:
409409
AssetAliasEvent(
410410
source_alias_name="test_alias",
411411
dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"),
412-
extra={},
412+
dest_asset_extra={"extra": "from asset itself"},
413+
extra={"extra": "from event"},
413414
)
414415
],
415416
),

task-sdk/src/airflow/sdk/definitions/asset/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,4 +689,5 @@ class AssetAliasEvent(attrs.AttrsInstance):
689689

690690
source_alias_name: str
691691
dest_asset_key: AssetUniqueKey
692+
dest_asset_extra: dict[str, JsonValue]
692693
extra: dict[str, JsonValue]

task-sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,9 @@ def __hash__(self):
425425

426426

427427
class _AssetRefResolutionMixin:
428-
_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
428+
_asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {}
429429

430-
def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
430+
def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]:
431431
with contextlib.suppress(KeyError):
432432
return self._asset_ref_cache[ref]
433433

@@ -442,8 +442,8 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
442442
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
443443
unique_key = AssetUniqueKey.from_asset(asset)
444444
for ref in refs_to_cache:
445-
self._asset_ref_cache[ref] = unique_key
446-
return unique_key
445+
self._asset_ref_cache[ref] = (unique_key, asset.extra)
446+
return (unique_key, asset.extra)
447447

448448
# TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
449449
@staticmethod
@@ -488,14 +488,16 @@ def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None
488488
return
489489

490490
if isinstance(asset, AssetRef):
491-
asset_key = self._resolve_asset_ref(asset)
491+
asset_key, asset_extra = self._resolve_asset_ref(asset)
492492
else:
493493
asset_key = AssetUniqueKey.from_asset(asset)
494+
asset_extra = asset.extra
494495

495496
asset_alias_name = self.key.name
496497
event = AssetAliasEvent(
497498
source_alias_name=asset_alias_name,
498499
dest_asset_key=asset_key,
500+
dest_asset_extra=asset_extra,
499501
extra=extra or {},
500502
)
501503
self.asset_alias_events.append(event)
@@ -556,7 +558,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor
556558
elif isinstance(key, AssetAlias):
557559
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
558560
elif isinstance(key, AssetRef):
559-
hashable_key = self._resolve_asset_ref(key)
561+
hashable_key, _ = self._resolve_asset_ref(key)
560562
else:
561563
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
562564

@@ -684,7 +686,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEvent
684686
if isinstance(key, Asset):
685687
hashable_key = AssetUniqueKey.from_asset(key)
686688
elif isinstance(key, AssetRef):
687-
hashable_key = self._resolve_asset_ref(key)
689+
hashable_key, _ = self._resolve_asset_ref(key)
688690
elif isinstance(key, AssetAlias):
689691
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
690692
else:

task-sdk/tests/task_sdk/execution_time/test_context.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,13 @@ def test_nested_context(self):
346346

347347
class TestOutletEventAccessor:
348348
@pytest.mark.parametrize(
349-
"add_arg",
349+
"add_args",
350350
[
351-
Asset("name", "uri"),
352-
Asset.ref(name="name"),
353-
Asset.ref(uri="uri"),
351+
(Asset("name", "uri", extra={"extra": "from asset itself"}), {"extra": "from event"}),
352+
(Asset.ref(name="name"), {"extra": "from event"}),
353+
(Asset.ref(uri="uri"), {"extra": "from event"}),
354354
],
355+
ids=["asset", "asset name ref", "asset uri ref"],
355356
)
356357
@pytest.mark.parametrize(
357358
"key, asset_alias_events",
@@ -363,26 +364,31 @@ class TestOutletEventAccessor:
363364
AssetAliasEvent(
364365
source_alias_name="test_alias",
365366
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
366-
extra={},
367+
dest_asset_extra={"extra": "from asset itself"},
368+
extra={"extra": "from event"},
367369
)
368370
],
369371
),
370372
),
373+
ids=["inactive asset", "active asset"],
371374
)
372-
def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
373-
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
375+
def test_add(self, add_args, key, asset_alias_events, mock_supervisor_comms):
376+
mock_supervisor_comms.send.return_value = AssetResponse(
377+
name="name", uri="uri", group="", extra={"extra": "from asset itself"}
378+
)
374379

375380
outlet_event_accessor = OutletEventAccessor(key=key, extra={})
376-
outlet_event_accessor.add(add_arg)
381+
outlet_event_accessor.add(*add_args)
377382
assert outlet_event_accessor.asset_alias_events == asset_alias_events
378383

379384
@pytest.mark.parametrize(
380-
"add_arg",
385+
"add_args",
381386
[
382-
Asset("name", "uri"),
383-
Asset.ref(name="name"),
384-
Asset.ref(uri="uri"),
387+
(Asset(name="name", uri="uri", extra={"extra": "from asset itself"}), {"extra": "from event"}),
388+
(Asset.ref(name="name"), {"extra": "from event"}),
389+
(Asset.ref(uri="uri"), {"extra": "from event"}),
385390
],
391+
ids=["asset", "asset name ref", "asset uri ref"],
386392
)
387393
@pytest.mark.parametrize(
388394
"key, asset_alias_events",
@@ -394,17 +400,21 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
394400
AssetAliasEvent(
395401
source_alias_name="test_alias",
396402
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
397-
extra={},
403+
dest_asset_extra={"extra": "from asset itself"},
404+
extra={"extra": "from event"},
398405
)
399406
],
400407
),
401408
),
409+
ids=["inactive asset", "active asset"],
402410
)
403-
def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
404-
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
411+
def test_add_with_db(self, add_args, key, asset_alias_events, mock_supervisor_comms):
412+
mock_supervisor_comms.send.return_value = AssetResponse(
413+
name="name", uri="uri", group="", extra={"extra": "from asset itself"}
414+
)
405415

406-
outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
407-
outlet_event_accessor.add(add_arg, extra={})
416+
outlet_event_accessor = OutletEventAccessor(key=key)
417+
outlet_event_accessor.add(*add_args)
408418
assert outlet_event_accessor.asset_alias_events == asset_alias_events
409419

410420

0 commit comments

Comments
 (0)