Skip to content

Commit bdf6c91

Browse files
authored
adds _ensure_identity, modifies files, and tests (#58563)
1 parent 7a51b05 commit bdf6c91

File tree

2 files changed

+107
-7
lines changed

2 files changed

+107
-7
lines changed

providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
from collections import namedtuple
2323
from collections.abc import Sequence
24-
from typing import TYPE_CHECKING, Any
24+
from typing import TYPE_CHECKING, Any, cast
2525

2626
from azure.mgmt.containerinstance.models import (
2727
Container,
@@ -33,8 +33,10 @@
3333
DnsConfiguration,
3434
EnvironmentVariable,
3535
IpAddress,
36+
ResourceIdentityType,
3637
ResourceRequests,
3738
ResourceRequirements,
39+
UserAssignedIdentities,
3840
Volume as _AzureVolume,
3941
VolumeMount,
4042
)
@@ -147,10 +149,13 @@ class AzureContainerInstancesOperator(BaseOperator):
147149
},
148150
priority="Regular",
149151
identity = {
150-
{
151-
"type": "UserAssigned",
152-
"resource_ids": ["/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my_identity"],
153-
},
152+
"type": "UserAssigned" | "SystemAssigned" | "SystemAssigned,UserAssigned",
153+
"resource_ids": [
154+
"/subscriptions/<sub>/resourceGroups/<rg>/providers/Microsoft.ManagedIdentity/userAssignedIdentities/<id>"
155+
]
156+
"user_assigned_identities": {
157+
"/subscriptions/.../userAssignedIdentities/<id>": {}
158+
}
154159
}
155160
command=["/bin/echo", "world"],
156161
task_id="start_container",
@@ -188,7 +193,7 @@ def __init__(
188193
dns_config: DnsConfiguration | None = None,
189194
diagnostics: ContainerGroupDiagnostics | None = None,
190195
priority: str | None = "Regular",
191-
identity: ContainerGroupIdentity | None = None,
196+
identity: ContainerGroupIdentity | dict | None = None,
192197
**kwargs,
193198
) -> None:
194199
super().__init__(**kwargs)
@@ -231,14 +236,74 @@ def __init__(
231236
self.dns_config = dns_config
232237
self.diagnostics = diagnostics
233238
self.priority = priority
234-
self.identity = identity
239+
self.identity = self._ensure_identity(identity)
235240
if self.priority not in ["Regular", "Spot"]:
236241
raise AirflowException(
237242
"Invalid value for the priority argument. "
238243
"Please set 'Regular' or 'Spot' as the priority. "
239244
f"Found `{self.priority}`."
240245
)
241246

247+
# helper to accept dict (user-friendly) or ContainerGroupIdentity (SDK object)
248+
@staticmethod
249+
def _ensure_identity(identity: ContainerGroupIdentity | dict | None) -> ContainerGroupIdentity | None:
250+
"""
251+
Normalize identity input into a ContainerGroupIdentity instance.
252+
253+
Accepts:
254+
- None -> returns None
255+
- ContainerGroupIdentity -> returned as-is
256+
- dict -> converted to ContainerGroupIdentity
257+
- any other object -> returned as-is (pass-through) to preserve backwards compatibility
258+
259+
Expected dict shapes:
260+
{"type": "UserAssigned", "resource_ids": ["/.../userAssignedIdentities/id1", ...]}
261+
or
262+
{"type": "SystemAssigned"}
263+
or
264+
{"type": "SystemAssigned,UserAssigned", "resource_ids": [...]}
265+
"""
266+
if identity is None:
267+
return None
268+
269+
if isinstance(identity, ContainerGroupIdentity):
270+
return identity
271+
272+
if isinstance(identity, dict):
273+
# require type
274+
id_type = identity.get("type")
275+
if not id_type:
276+
raise AirflowException(
277+
"identity dict must include 'type' key with value 'UserAssigned' or 'SystemAssigned'"
278+
)
279+
280+
# map common string type names to ResourceIdentityType enum values if available
281+
type_map = {
282+
"SystemAssigned": ResourceIdentityType.system_assigned,
283+
"UserAssigned": ResourceIdentityType.user_assigned,
284+
"SystemAssigned,UserAssigned": ResourceIdentityType.system_assigned_user_assigned,
285+
"SystemAssigned, UserAssigned": ResourceIdentityType.system_assigned_user_assigned,
286+
}
287+
cg_type = type_map.get(id_type, id_type)
288+
289+
# build user_assigned_identities mapping if resource_ids provided
290+
resource_ids = identity.get("resource_ids")
291+
if resource_ids:
292+
if not isinstance(resource_ids, (list, tuple)):
293+
raise AirflowException("identity['resource_ids'] must be a list of resource id strings")
294+
user_assigned_identities: dict[str, Any] = {rid: {} for rid in resource_ids}
295+
else:
296+
# accept a pre-built mapping if given
297+
user_assigned_identities = identity.get("user_assigned_identities") or {}
298+
299+
return ContainerGroupIdentity(
300+
type=cg_type,
301+
user_assigned_identities=cast(
302+
"dict[str, UserAssignedIdentities] | None", user_assigned_identities
303+
),
304+
)
305+
return identity
306+
242307
def execute(self, context: Context) -> int:
243308
# Check name again in case it was templated.
244309
self._check_name(self.name)

providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_container_instances.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,41 @@ def test_execute_with_identity(self, aci_mock):
611611

612612
assert called_cg.identity == identity
613613

614+
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
615+
def test_execute_with_identity_dict(self, aci_mock):
616+
# New test: pass a dict and verify operator converts it to ContainerGroupIdentity
617+
resource_id = "/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my_identity"
618+
identity_dict = {
619+
"type": "UserAssigned",
620+
"resource_ids": [resource_id],
621+
}
622+
623+
aci_mock.return_value.get_state.return_value = make_mock_container(
624+
state="Terminated", exit_code=0, detail_status="test"
625+
)
626+
627+
aci_mock.return_value.exists.return_value = False
628+
629+
aci = AzureContainerInstancesOperator(
630+
ci_conn_id=None,
631+
registry_conn_id=None,
632+
resource_group="resource-group",
633+
name="container-name",
634+
image="container-image",
635+
region="region",
636+
task_id="task",
637+
identity=identity_dict,
638+
)
639+
aci.execute(None)
640+
assert aci_mock.return_value.create_or_update.call_count == 1
641+
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
642+
643+
# verify the operator converted dict -> ContainerGroupIdentity with proper mapping
644+
assert hasattr(called_cg, "identity")
645+
assert called_cg.identity is not None
646+
# user_assigned_identities should contain the resource id as a key
647+
assert resource_id in (called_cg.identity.user_assigned_identities or {})
648+
614649

615650
class XcomMock:
616651
def __init__(self) -> None:

0 commit comments

Comments
 (0)