Skip to content

Commit 6c60f1a

Browse files
fix: Replace DefaultAzureCredential with ManagedIdentityCredential for production-safe authentication (#1873)
1 parent 7065008 commit 6c60f1a

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
3+
from azure.identity.aio import (
4+
ManagedIdentityCredential as AioManagedIdentityCredential,
5+
DefaultAzureCredential as AioDefaultAzureCredential,
6+
)
7+
8+
9+
async def get_azure_credential_async(client_id=None):
10+
"""
11+
Returns an Azure credential asynchronously based on the application environment.
12+
13+
If the environment is 'dev', it uses AioDefaultAzureCredential.
14+
Otherwise, it uses AioManagedIdentityCredential.
15+
16+
Args:
17+
client_id (str, optional): The client ID for the Managed Identity Credential.
18+
19+
Returns:
20+
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
21+
"""
22+
if os.getenv("APP_ENV", "prod").lower() == "dev":
23+
return (
24+
AioDefaultAzureCredential()
25+
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
26+
else:
27+
return AioManagedIdentityCredential(client_id=client_id)
28+
29+
30+
def get_azure_credential(client_id=None):
31+
"""
32+
Returns an Azure credential based on the application environment.
33+
34+
If the environment is 'dev', it uses DefaultAzureCredential.
35+
Otherwise, it uses ManagedIdentityCredential.
36+
37+
Args:
38+
client_id (str, optional): The client ID for the Managed Identity Credential.
39+
40+
Returns:
41+
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
42+
"""
43+
if os.getenv("APP_ENV", "prod").lower() == "dev":
44+
return (
45+
DefaultAzureCredential()
46+
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
47+
else:
48+
return ManagedIdentityCredential(client_id=client_id)

scripts/data_scripts/create_postgres_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from azure.identity import DefaultAzureCredential
1+
from azure_credential_utils import get_azure_credential
22
import psycopg2
33
from psycopg2 import sql
44

@@ -61,7 +61,7 @@ def grant_permissions(cursor, dbname, schema_name, principal_name):
6161

6262

6363
# Acquire the access token
64-
cred = DefaultAzureCredential()
64+
cred = get_azure_credential()
6565
access_token = cred.get_token("https://ossrdbms-aad.database.windows.net/.default")
6666

6767
# Combine the token with the connection string to establish the connection.

scripts/run_create_table_script.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ az postgres flexible-server firewall-rule create --resource-group $resourceGroup
2323

2424
# Download the create table python file
2525
curl --output "create_postgres_tables.py" ${baseUrl}"scripts/data_scripts/create_postgres_tables.py"
26+
curl --output "azure_credential_utils.py" ${baseUrl}"scripts/data_scripts/azure_credential_utils.py"
2627

2728
# Download the requirement file
2829
curl --output "$requirementFile" "$requirementFileUrl"

0 commit comments

Comments
 (0)