Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -26,6 +27,8 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter

internal Func<AttestationTokenInput, CancellationToken, Task<AttestationTokenResponse>> AttestationTokenProvider { get; set; }

internal X509Certificate2 MtlsCertificate { get; set; }

public void LogParameters(ILoggerAdapter logger)
{
if (logger.IsLoggingEnabled(LogLevel.Info))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,17 @@ public string Claims
}
}

public IAuthenticationOperation AuthenticationScheme => _commonParameters.AuthenticationOperation;
private IAuthenticationOperation _requestOverrideScheme;

/// <summary>
/// Effective authentication operation (scheme) for this request.
/// Defaults to the app's configured operation unless a request-scoped override is applied.
/// </summary>
public IAuthenticationOperation AuthenticationScheme
{
get => _requestOverrideScheme ?? _commonParameters.AuthenticationOperation; // <-- correct fallback
internal set => _requestOverrideScheme = value; // internal set satisfies the “make it settable?” review
}

public IEnumerable<string> PersistedCacheParameters => _commonParameters.AdditionalCacheParameters;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.AuthScheme.PoP;
using Microsoft.Identity.Client.Cache.Items;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.ManagedIdentity;
Expand Down Expand Up @@ -39,6 +41,18 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
{
AuthenticationResult authResult = null;
ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger;

// Prime the scheme before any cache lookup if we already have a binding cert from a prior mint
if (AuthenticationRequestParameters.IsMtlsPopRequested &&
!(AuthenticationRequestParameters.AuthenticationScheme is MtlsPopAuthenticationOperation))
{
var priorCert = _managedIdentityClient.RuntimeMtlsBindingCertificate;
if (priorCert != null)
{
AuthenticationRequestParameters.AuthenticationScheme = new MtlsPopAuthenticationOperation(priorCert);
AuthenticationRequestParameters.RequestContext.Logger.Info("[ManagedIdentity] Using prior mTLS binding certificate for cache lookup.");
}
}

// 1. FIRST, handle ForceRefresh
if (_managedIdentityParameters.ForceRefresh)
Expand Down Expand Up @@ -209,6 +223,24 @@ await _managedIdentityClient
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)
.ConfigureAwait(false);

if (AuthenticationRequestParameters.IsMtlsPopRequested
&& _managedIdentityParameters.MtlsCertificate != null
&& !(AuthenticationRequestParameters.AuthenticationScheme is MtlsPopAuthenticationOperation))
{
// Remember the cert for future requests (same app instance) BEFORE we clear it.
_managedIdentityClient.SetRuntimeMtlsBindingCertificate(_managedIdentityParameters.MtlsCertificate);

// Apply mTLS scheme BEFORE caching so the token is stored under the mtls_pop key.
AuthenticationRequestParameters.AuthenticationScheme =
new MtlsPopAuthenticationOperation(_managedIdentityParameters.MtlsCertificate);

// Do not hold cert past this request boundary
_managedIdentityParameters.MtlsCertificate = null;

AuthenticationRequestParameters.RequestContext.Logger.Info(
"[ManagedIdentity] Applied mtls_pop scheme prior to caching.");
}

var msalTokenResponse = MsalTokenResponse.CreateFromManagedIdentityResponse(managedIdentityResponse);
msalTokenResponse.Scope = AuthenticationRequestParameters.Scope.AsSingleString();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(

ManagedIdentityRequest request = await CreateRequestAsync(resource).ConfigureAwait(false);

if (parameters.IsMtlsPopRequested && request?.MtlsCertificate != null)
{
parameters.MtlsCertificate = request.MtlsCertificate;
}

// Automatically add claims / capabilities if this MI source supports them
if (_sourceType.SupportsClaimsAndCapabilities())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.IO;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.ManagedIdentity.V2;
using System.Security.Cryptography.X509Certificates;

namespace Microsoft.Identity.Client.ManagedIdentity
{
Expand All @@ -22,6 +23,9 @@ internal class ManagedIdentityClient
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds";
internal static ManagedIdentitySource s_sourceName = ManagedIdentitySource.None;

// Holds the most recently minted mTLS binding certificate for this application instance.
internal X509Certificate2 RuntimeMtlsBindingCertificate { get; private set; }

internal static void ResetSourceForTest()
{
s_sourceName = ManagedIdentitySource.None;
Expand Down Expand Up @@ -157,5 +161,13 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string
logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available.");
return false;
}

internal void SetRuntimeMtlsBindingCertificate(X509Certificate2 cert)
{
var old = RuntimeMtlsBindingCertificate;
RuntimeMtlsBindingCertificate = cert;
//dispose prior stored cert on replacement
old?.Dispose();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,19 @@ public async Task mTLSPopTokenHappyPath(
Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.IsNotNull(result.BindingCertificate);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);

// TODO: broken until Gladwin's PR is merged in
/*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationProviderForTests(s_fakeAttestationProvider)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/
Assert.IsNotNull(result.BindingCertificate);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);
}
}

Expand Down Expand Up @@ -358,19 +357,18 @@ public async Task mTLSPopTokenIsPerIdentity(
Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.IsNotNull(result.BindingCertificate);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);

// TODO: broken until Gladwin's PR is merged in
/*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/
Assert.IsNotNull(result.BindingCertificate);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);
#endregion Identity 1

#region Identity 2
Expand All @@ -393,20 +391,19 @@ public async Task mTLSPopTokenIsPerIdentity(
Assert.IsNotNull(result2);
Assert.IsNotNull(result2.AccessToken);
Assert.AreEqual(result2.TokenType, MTLSPoP);
// Assert.IsNotNull(result2.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.IsNotNull(result2.BindingCertificate);
Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource);

// TODO: broken until Gladwin's PR is merged in
/*result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationProviderForTests(s_fakeAttestationProvider)
.ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result2);
Assert.IsNotNull(result2.AccessToken);
Assert.AreEqual(result2.TokenType, MTLSPoP);
// Assert.IsNotNull(result2.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource);*/
Assert.IsNotNull(result2.BindingCertificate);
Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource);
#endregion Identity 2

// TODO: Assert.AreEqual(CertificateCache.Count, 2);
Expand Down Expand Up @@ -439,26 +436,22 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired(
Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.IsNotNull(result.BindingCertificate);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);

// TODO: Add functionality to check cert expiration in the cache
/**
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);

result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationProviderForTests(s_fakeAttestationProvider)
.ExecuteAsync().ConfigureAwait(false);
//To-Do : Add cert expiry check functionality
//AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(result.TokenType, MTLSPoP);
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
//result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
// .WithMtlsProofOfPossession()
// .WithAttestationProviderForTests(s_fakeAttestationProvider)
// .ExecuteAsync().ConfigureAwait(false);

Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache
*/
//Assert.IsNotNull(result);
//Assert.IsNotNull(result.AccessToken);
//Assert.AreEqual(result.TokenType, MTLSPoP);
//Assert.IsNotNull(result.BindingCertificate);
//Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
}
}
#endregion mTLS Pop Token Tests
Expand Down