|
21 | 21 | import time |
22 | 22 | from collections import namedtuple |
23 | 23 | from collections.abc import Sequence |
24 | | -from typing import TYPE_CHECKING, Any |
| 24 | +from typing import TYPE_CHECKING, Any, cast |
25 | 25 |
|
26 | 26 | from azure.mgmt.containerinstance.models import ( |
27 | 27 | Container, |
|
33 | 33 | DnsConfiguration, |
34 | 34 | EnvironmentVariable, |
35 | 35 | IpAddress, |
| 36 | + ResourceIdentityType, |
36 | 37 | ResourceRequests, |
37 | 38 | ResourceRequirements, |
| 39 | + UserAssignedIdentities, |
38 | 40 | Volume as _AzureVolume, |
39 | 41 | VolumeMount, |
40 | 42 | ) |
@@ -147,10 +149,13 @@ class AzureContainerInstancesOperator(BaseOperator): |
147 | 149 | }, |
148 | 150 | priority="Regular", |
149 | 151 | identity = { |
150 | | - { |
151 | | - "type": "UserAssigned", |
152 | | - "resource_ids": ["/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my_identity"], |
153 | | - }, |
| 152 | + "type": "UserAssigned" | "SystemAssigned" | "SystemAssigned,UserAssigned", |
| 153 | + "resource_ids": [ |
| 154 | + "/subscriptions/<sub>/resourceGroups/<rg>/providers/Microsoft.ManagedIdentity/userAssignedIdentities/<id>" |
| 155 | + ] |
| 156 | + "user_assigned_identities": { |
| 157 | + "/subscriptions/.../userAssignedIdentities/<id>": {} |
| 158 | + } |
154 | 159 | } |
155 | 160 | command=["/bin/echo", "world"], |
156 | 161 | task_id="start_container", |
@@ -188,7 +193,7 @@ def __init__( |
188 | 193 | dns_config: DnsConfiguration | None = None, |
189 | 194 | diagnostics: ContainerGroupDiagnostics | None = None, |
190 | 195 | priority: str | None = "Regular", |
191 | | - identity: ContainerGroupIdentity | None = None, |
| 196 | + identity: ContainerGroupIdentity | dict | None = None, |
192 | 197 | **kwargs, |
193 | 198 | ) -> None: |
194 | 199 | super().__init__(**kwargs) |
@@ -231,14 +236,74 @@ def __init__( |
231 | 236 | self.dns_config = dns_config |
232 | 237 | self.diagnostics = diagnostics |
233 | 238 | self.priority = priority |
234 | | - self.identity = identity |
| 239 | + self.identity = self._ensure_identity(identity) |
235 | 240 | if self.priority not in ["Regular", "Spot"]: |
236 | 241 | raise AirflowException( |
237 | 242 | "Invalid value for the priority argument. " |
238 | 243 | "Please set 'Regular' or 'Spot' as the priority. " |
239 | 244 | f"Found `{self.priority}`." |
240 | 245 | ) |
241 | 246 |
|
| 247 | + # helper to accept dict (user-friendly) or ContainerGroupIdentity (SDK object) |
| 248 | + @staticmethod |
| 249 | + def _ensure_identity(identity: ContainerGroupIdentity | dict | None) -> ContainerGroupIdentity | None: |
| 250 | + """ |
| 251 | + Normalize identity input into a ContainerGroupIdentity instance. |
| 252 | +
|
| 253 | + Accepts: |
| 254 | + - None -> returns None |
| 255 | + - ContainerGroupIdentity -> returned as-is |
| 256 | + - dict -> converted to ContainerGroupIdentity |
| 257 | + - any other object -> returned as-is (pass-through) to preserve backwards compatibility |
| 258 | +
|
| 259 | + Expected dict shapes: |
| 260 | + {"type": "UserAssigned", "resource_ids": ["/.../userAssignedIdentities/id1", ...]} |
| 261 | + or |
| 262 | + {"type": "SystemAssigned"} |
| 263 | + or |
| 264 | + {"type": "SystemAssigned,UserAssigned", "resource_ids": [...]} |
| 265 | + """ |
| 266 | + if identity is None: |
| 267 | + return None |
| 268 | + |
| 269 | + if isinstance(identity, ContainerGroupIdentity): |
| 270 | + return identity |
| 271 | + |
| 272 | + if isinstance(identity, dict): |
| 273 | + # require type |
| 274 | + id_type = identity.get("type") |
| 275 | + if not id_type: |
| 276 | + raise AirflowException( |
| 277 | + "identity dict must include 'type' key with value 'UserAssigned' or 'SystemAssigned'" |
| 278 | + ) |
| 279 | + |
| 280 | + # map common string type names to ResourceIdentityType enum values if available |
| 281 | + type_map = { |
| 282 | + "SystemAssigned": ResourceIdentityType.system_assigned, |
| 283 | + "UserAssigned": ResourceIdentityType.user_assigned, |
| 284 | + "SystemAssigned,UserAssigned": ResourceIdentityType.system_assigned_user_assigned, |
| 285 | + "SystemAssigned, UserAssigned": ResourceIdentityType.system_assigned_user_assigned, |
| 286 | + } |
| 287 | + cg_type = type_map.get(id_type, id_type) |
| 288 | + |
| 289 | + # build user_assigned_identities mapping if resource_ids provided |
| 290 | + resource_ids = identity.get("resource_ids") |
| 291 | + if resource_ids: |
| 292 | + if not isinstance(resource_ids, (list, tuple)): |
| 293 | + raise AirflowException("identity['resource_ids'] must be a list of resource id strings") |
| 294 | + user_assigned_identities: dict[str, Any] = {rid: {} for rid in resource_ids} |
| 295 | + else: |
| 296 | + # accept a pre-built mapping if given |
| 297 | + user_assigned_identities = identity.get("user_assigned_identities") or {} |
| 298 | + |
| 299 | + return ContainerGroupIdentity( |
| 300 | + type=cg_type, |
| 301 | + user_assigned_identities=cast( |
| 302 | + "dict[str, UserAssignedIdentities] | None", user_assigned_identities |
| 303 | + ), |
| 304 | + ) |
| 305 | + return identity |
| 306 | + |
242 | 307 | def execute(self, context: Context) -> int: |
243 | 308 | # Check name again in case it was templated. |
244 | 309 | self._check_name(self.name) |
|
0 commit comments