diff --git a/sdk/src/Services/S3/Custom/Transfer/Internal/MultipartDownloadManager.cs b/sdk/src/Services/S3/Custom/Transfer/Internal/MultipartDownloadManager.cs index 010243c8c7bd..939cedffad55 100644 --- a/sdk/src/Services/S3/Custom/Transfer/Internal/MultipartDownloadManager.cs +++ b/sdk/src/Services/S3/Custom/Transfer/Internal/MultipartDownloadManager.cs @@ -379,19 +379,38 @@ public async Task StartDownloadsAsync(DownloadDiscoveryResult discoveryResult, E _logger.DebugFormat("MultipartDownloadManager: [Part {0}] Waiting for buffer space", partNum); // Acquire capacity sequentially - guarantees Part 2 before Part 3, etc. - await _dataHandler.WaitForCapacityAsync(cancellationToken).ConfigureAwait(false); + await _dataHandler.WaitForCapacityAsync(internalCts.Token).ConfigureAwait(false); _logger.DebugFormat("MultipartDownloadManager: [Part {0}] Buffer space acquired", partNum); - var task = CreateDownloadTaskAsync(partNum, discoveryResult.ObjectSize, wrappedCallback, internalCts.Token); - downloadTasks.Add(task); + _logger.DebugFormat("MultipartDownloadManager: [Part {0}] Waiting for HTTP concurrency slot (Available: {1}/{2})", + partNum, _httpConcurrencySlots.CurrentCount, _config.ConcurrentServiceRequests); + + // Acquire HTTP slot in the loop before creating task + // Loop will block here if all slots are in use + await _httpConcurrencySlots.WaitAsync(internalCts.Token).ConfigureAwait(false); + + _logger.DebugFormat("MultipartDownloadManager: [Part {0}] HTTP concurrency slot acquired", partNum); + + try + { + var task = CreateDownloadTaskAsync(partNum, discoveryResult.ObjectSize, wrappedCallback, internalCts.Token); + downloadTasks.Add(task); + } + catch (Exception ex) + { + // If task creation fails, release the HTTP slot we just acquired + _httpConcurrencySlots.Release(); + _logger.DebugFormat("MultipartDownloadManager: [Part {0}] HTTP concurrency slot released due to task creation failure: {1}", partNum, ex); + throw; + } } var expectedTaskCount = downloadTasks.Count; _logger.DebugFormat("MultipartDownloadManager: Background task waiting for {0} download tasks", expectedTaskCount); // Wait for all downloads to complete (fails fast on first exception) - await TaskHelpers.WhenAllOrFirstExceptionAsync(downloadTasks, cancellationToken).ConfigureAwait(false); + await TaskHelpers.WhenAllOrFirstExceptionAsync(downloadTasks, internalCts.Token).ConfigureAwait(false); _logger.DebugFormat("MultipartDownloadManager: All download tasks completed successfully"); @@ -418,7 +437,27 @@ public async Task StartDownloadsAsync(DownloadDiscoveryResult discoveryResult, E catch (Exception ex) { _downloadException = ex; - _logger.Error(ex, "MultipartDownloadManager: Background download task failed"); + + + + // Cancel all remaining downloads immediately to prevent cascading timeout errors + // This ensures that when one part fails, other tasks stop gracefully instead of + // continuing until they hit their own timeout/cancellation errors + // Check if cancellation was already requested to avoid ObjectDisposedException + if (!internalCts.IsCancellationRequested) + { + try + { + internalCts.Cancel(); + _logger.DebugFormat("MultipartDownloadManager: Cancelled all in-flight downloads due to error"); + } + catch (ObjectDisposedException) + { + // CancellationTokenSource was already disposed, ignore + _logger.DebugFormat("MultipartDownloadManager: CancellationTokenSource already disposed during cancellation"); + } + } + _dataHandler.OnDownloadComplete(ex); throw; } @@ -440,6 +479,22 @@ public async Task StartDownloadsAsync(DownloadDiscoveryResult discoveryResult, E _downloadException = ex; _logger.Error(ex, "MultipartDownloadManager: Download failed"); + // Cancel all remaining downloads immediately to prevent cascading timeout errors + // Check if cancellation was already requested to avoid ObjectDisposedException + if (!internalCts.IsCancellationRequested) + { + try + { + internalCts.Cancel(); + _logger.DebugFormat("MultipartDownloadManager: Cancelled all in-flight downloads due to error"); + } + catch (ObjectDisposedException) + { + // CancellationTokenSource was already disposed, ignore + _logger.DebugFormat("MultipartDownloadManager: CancellationTokenSource already disposed during cancellation"); + } + } + _dataHandler.OnDownloadComplete(ex); // Dispose the CancellationTokenSource if background task was never started @@ -459,15 +514,8 @@ private async Task CreateDownloadTaskAsync(int partNumber, long objectSize, Even try { - _logger.DebugFormat("MultipartDownloadManager: [Part {0}] Waiting for HTTP concurrency slot (Available: {1}/{2})", - partNumber, _httpConcurrencySlots.CurrentCount, _config.ConcurrentServiceRequests); - - // Limit HTTP concurrency for both network download AND disk write - // The semaphore is held until AFTER ProcessPartAsync completes to ensure - // ConcurrentServiceRequests controls the entire I/O operation - await _httpConcurrencySlots.WaitAsync(cancellationToken).ConfigureAwait(false); - - _logger.DebugFormat("MultipartDownloadManager: [Part {0}] HTTP concurrency slot acquired", partNumber); + // HTTP slot was already acquired in the for loop before this task was created + // We just need to use it and release it when done try { @@ -544,7 +592,7 @@ private async Task CreateDownloadTaskAsync(int partNumber, long objectSize, Even finally { // Release semaphore after BOTH network download AND disk write complete - // This ensures ConcurrentServiceRequests limits the entire I/O operation + // Slot was acquired in the for loop before this task was created _httpConcurrencySlots.Release(); _logger.DebugFormat("MultipartDownloadManager: [Part {0}] HTTP concurrency slot released (Available: {1}/{2})", partNumber, _httpConcurrencySlots.CurrentCount, _config.ConcurrentServiceRequests); diff --git a/sdk/test/Services/S3/UnitTests/Custom/MultipartDownloadManagerTests.cs b/sdk/test/Services/S3/UnitTests/Custom/MultipartDownloadManagerTests.cs index fbbd1a410975..ecf90cbc4d67 100644 --- a/sdk/test/Services/S3/UnitTests/Custom/MultipartDownloadManagerTests.cs +++ b/sdk/test/Services/S3/UnitTests/Custom/MultipartDownloadManagerTests.cs @@ -3337,5 +3337,261 @@ public async Task ProgressCallback_MultiplePartsComplete_AggregatesCorrectly() } #endregion + + #region Cancellation Enhancement Tests + + [TestMethod] + public async Task StartDownloadsAsync_BackgroundPartFails_CancelsInternalToken() + { + // Arrange - Deterministic test using TaskCompletionSource to control execution order + // This ensures Part 3 waits at synchronization point, Part 2 fails, then Part 3 checks cancellation + var totalParts = 3; + var partSize = 8 * 1024 * 1024; + var totalObjectSize = totalParts * partSize; + + var part2Failed = false; + var part3SawCancellation = false; + + // Synchronization primitives to control execution order + var part3ReachedSyncPoint = new TaskCompletionSource(); + var part2CanFail = new TaskCompletionSource(); + var part3CanCheckCancellation = new TaskCompletionSource(); + + var mockDataHandler = new Mock(); + + // Capacity acquisition succeeds for all parts + mockDataHandler + .Setup(x => x.WaitForCapacityAsync(It.IsAny())) + .Returns(Task.CompletedTask); + + // PrepareAsync succeeds + mockDataHandler + .Setup(x => x.PrepareAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + // ProcessPartAsync: Controlled execution order using TaskCompletionSource + mockDataHandler + .Setup(x => x.ProcessPartAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(async (partNum, response, ct) => + { + if (partNum == 1) + { + return; // Part 1 succeeds immediately + } + else if (partNum == 2) + { + // Part 2 waits for Part 3 to reach sync point before failing + await part2CanFail.Task; + part2Failed = true; + throw new InvalidOperationException("Simulated Part 2 failure"); + } + else // Part 3 + { + // Part 3 reaches sync point and signals to Part 2 + part3ReachedSyncPoint.SetResult(true); + + // Wait for Part 2 to fail and cancellation to propagate + await part3CanCheckCancellation.Task; + + // Now check if cancellation was received from internalCts + if (ct.IsCancellationRequested) + { + part3SawCancellation = true; + throw new OperationCanceledException("Part 3 cancelled due to Part 2 failure"); + } + } + }); + + mockDataHandler.Setup(x => x.ReleaseCapacity()); + mockDataHandler.Setup(x => x.OnDownloadComplete(It.IsAny())); + + var mockClient = MultipartDownloadTestHelpers.CreateMockS3ClientForMultipart( + totalParts, partSize, totalObjectSize, "test-etag", usePartStrategy: true); + + var request = MultipartDownloadTestHelpers.CreateOpenStreamRequest( + downloadType: MultipartDownloadType.PART); + var config = MultipartDownloadTestHelpers.CreateBufferedDownloadConfiguration(concurrentRequests: 2); + var coordinator = new MultipartDownloadManager(mockClient.Object, request, config, mockDataHandler.Object); + + var discoveryResult = await coordinator.DiscoverDownloadStrategyAsync(CancellationToken.None); + + // Act - Start downloads + await coordinator.StartDownloadsAsync(discoveryResult, null, CancellationToken.None); + + // Wait for Part 3 to reach synchronization point + await part3ReachedSyncPoint.Task; + + // Allow Part 2 to fail + part2CanFail.SetResult(true); + + // Give cancellation time to propagate + await Task.Delay(100); + + // Allow Part 3 to check cancellation + part3CanCheckCancellation.SetResult(true); + + // Wait for background task to complete + try + { + await coordinator.DownloadCompletionTask; + } + catch (InvalidOperationException) + { + // Expected failure from Part 2 + } + + // Assert - Deterministic verification that cancellation propagated + Assert.IsTrue(part2Failed, "Part 2 should have failed"); + Assert.IsTrue(part3SawCancellation, + "Part 3 should have received cancellation via internalCts.Token (deterministic with TaskCompletionSource)"); + + Assert.IsNotNull(coordinator.DownloadException, + "Download exception should be captured when background part fails"); + Assert.IsInstanceOfType(coordinator.DownloadException, typeof(InvalidOperationException), + "Download exception should be the Part 2 failure"); + } + + [TestMethod] + public async Task StartDownloadsAsync_MultiplePartsFail_HandlesGracefully() + { + // Arrange - Test simultaneous failures from multiple parts + var totalParts = 4; + var partSize = 8 * 1024 * 1024; + var totalObjectSize = totalParts * partSize; + + var failedParts = new System.Collections.Concurrent.ConcurrentBag(); + var mockDataHandler = new Mock(); + + mockDataHandler + .Setup(x => x.WaitForCapacityAsync(It.IsAny())) + .Returns(Task.CompletedTask); + + mockDataHandler + .Setup(x => x.PrepareAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + // Part 1 succeeds, Parts 2, 3, 4 all fail + mockDataHandler + .Setup(x => x.ProcessPartAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((partNum, response, ct) => + { + if (partNum == 1) + { + return Task.CompletedTask; + } + + failedParts.Add(partNum); + throw new InvalidOperationException($"Simulated Part {partNum} failure"); + }); + + mockDataHandler.Setup(x => x.ReleaseCapacity()); + mockDataHandler.Setup(x => x.OnDownloadComplete(It.IsAny())); + + var mockClient = MultipartDownloadTestHelpers.CreateMockS3ClientForMultipart( + totalParts, partSize, totalObjectSize, "test-etag", usePartStrategy: true); + + var request = MultipartDownloadTestHelpers.CreateOpenStreamRequest( + downloadType: MultipartDownloadType.PART); + var config = MultipartDownloadTestHelpers.CreateBufferedDownloadConfiguration(concurrentRequests: 3); + var coordinator = new MultipartDownloadManager(mockClient.Object, request, config, mockDataHandler.Object); + + var discoveryResult = await coordinator.DiscoverDownloadStrategyAsync(CancellationToken.None); + + // Act + await coordinator.StartDownloadsAsync(discoveryResult, null, CancellationToken.None); + + try + { + await coordinator.DownloadCompletionTask; + } + catch (InvalidOperationException) + { + // Expected - at least one part failed + } + + // Assert - Should handle multiple failures gracefully + Assert.IsTrue(failedParts.Count > 0, "At least one part should have failed"); + Assert.IsNotNull(coordinator.DownloadException, "Download exception should be captured"); + } + + [TestMethod] + public async Task StartDownloadsAsync_CancellationRacesWithDispose_HandlesGracefully() + { + // Arrange - Test race condition between Cancel() and Dispose() + var totalParts = 3; + var partSize = 8 * 1024 * 1024; + var totalObjectSize = totalParts * partSize; + + var objectDisposedExceptionCaught = false; + var mockDataHandler = new Mock(); + + mockDataHandler + .Setup(x => x.WaitForCapacityAsync(It.IsAny())) + .Returns(Task.CompletedTask); + + mockDataHandler + .Setup(x => x.PrepareAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + // Part 1 succeeds, Part 2 fails triggering cancellation + mockDataHandler + .Setup(x => x.ProcessPartAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((partNum, response, ct) => + { + if (partNum == 1) + { + return Task.CompletedTask; + } + + // Part 2 failure will trigger Cancel() in catch block + // The enhancement should check IsCancellationRequested to avoid ObjectDisposedException + throw new InvalidOperationException("Simulated Part 2 failure"); + }); + + mockDataHandler.Setup(x => x.ReleaseCapacity()); + mockDataHandler + .Setup(x => x.OnDownloadComplete(It.IsAny())) + .Callback(ex => + { + // Check if ObjectDisposedException was handled + if (ex is ObjectDisposedException) + { + objectDisposedExceptionCaught = true; + } + }); + + var mockClient = MultipartDownloadTestHelpers.CreateMockS3ClientForMultipart( + totalParts, partSize, totalObjectSize, "test-etag", usePartStrategy: true); + + var request = MultipartDownloadTestHelpers.CreateOpenStreamRequest( + downloadType: MultipartDownloadType.PART); + var config = MultipartDownloadTestHelpers.CreateBufferedDownloadConfiguration(concurrentRequests: 2); + var coordinator = new MultipartDownloadManager(mockClient.Object, request, config, mockDataHandler.Object); + + var discoveryResult = await coordinator.DiscoverDownloadStrategyAsync(CancellationToken.None); + + // Act + await coordinator.StartDownloadsAsync(discoveryResult, null, CancellationToken.None); + + try + { + await coordinator.DownloadCompletionTask; + } + catch (InvalidOperationException) + { + // Expected failure + } + + // Assert - The enhancement should prevent ObjectDisposedException from being thrown + // by checking IsCancellationRequested before calling Cancel() + Assert.IsFalse(objectDisposedExceptionCaught, + "ObjectDisposedException should not propagate due to IsCancellationRequested check"); + Assert.IsNotNull(coordinator.DownloadException, + "Download exception should be the original failure, not ObjectDisposedException"); + Assert.IsInstanceOfType(coordinator.DownloadException, typeof(InvalidOperationException), + "Download exception should be the original InvalidOperationException from Part 2 failure"); + } + + #endregion } }