Skip to content

Commit b59cc72

Browse files
committed
Add a mechanism to prevent concurrent token refreshes
When multiple different firebase services require a token refresh at the same time, multiple token refresh can be triggered at the same time inside the SecureTokenService. This PR makes sure only one token refresh is happening at the same time
1 parent 4f36a1c commit b59cc72

File tree

3 files changed

+329
-1
lines changed

3 files changed

+329
-1
lines changed

FirebaseAuth/Sources/Swift/SystemService/SecureTokenService.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ private let kFiveMinutes = 5 * 60.0
1919

2020
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
2121
actor SecureTokenServiceInternal {
22+
/// Coalescer to deduplicate concurrent token refresh requests.
23+
/// When multiple requests arrive at the same time, only one network call is made.
24+
private let refreshCoalescer = TokenRefreshCoalescer()
25+
2226
/// Fetch a fresh ephemeral access token for the ID associated with this instance. The token
2327
/// received in the callback should be considered short lived and not cached.
2428
///
@@ -32,7 +36,20 @@ actor SecureTokenServiceInternal {
3236
return (service.accessToken, false)
3337
} else {
3438
AuthLog.logDebug(code: "I-AUT000017", message: "Fetching new token from backend.")
35-
return try await requestAccessToken(retryIfExpired: true, service: service, backend: backend)
39+
40+
// Use coalescer to deduplicate concurrent refresh requests.
41+
// If multiple requests arrive while one is in progress, they all wait
42+
// for the same network response instead of making redundant calls.
43+
let currentToken = service.accessToken
44+
return try await refreshCoalescer.coalescedRefresh(
45+
currentToken: currentToken
46+
) {
47+
try await self.requestAccessToken(
48+
retryIfExpired: true,
49+
service: service,
50+
backend: backend
51+
)
52+
}
3653
}
3754
}
3855

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Coalesces multiple concurrent token refresh requests into a single network call.
18+
///
19+
/// When multiple requests for a token refresh arrive concurrently (e.g., from Storage, Firestore,
20+
/// and auto-refresh), instead of making separate network calls for each one, this class ensures
21+
/// that only ONE network request is made. All concurrent callers wait for and receive the same
22+
/// refreshed token.
23+
///
24+
/// This prevents redundant STS (Secure Token Service) calls and reduces load on both the client
25+
/// and server.
26+
///
27+
/// Example:
28+
/// ```
29+
/// // Multiple concurrent requests arrive at the same time
30+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(backend: backend, ...) } // 1
31+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(backend: backend, ...) } // 2
32+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(backend: backend, ...) } // 3
33+
///
34+
/// // Only ONE network call is made. All three tasks receive the same refreshed token.
35+
/// ```
36+
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
37+
actor TokenRefreshCoalescer {
38+
/// The in-flight token refresh task, if any.
39+
/// When this is set, all concurrent calls wait for this task instead of starting their own.
40+
private var pendingRefreshTask: Task<(String?, Bool), Error>?
41+
42+
/// The token string of the pending refresh.
43+
/// Used to ensure we only coalesce requests for the same token.
44+
private var pendingRefreshToken: String?
45+
46+
/// Performs a coalesced token refresh.
47+
///
48+
/// If a refresh is already in progress, this method waits for that refresh to complete
49+
/// and returns its result. If no refresh is in progress, it starts a new one and stores
50+
/// the task so other concurrent callers can wait for it.
51+
///
52+
/// - Parameters:
53+
/// - currentToken: The current token string. Used to detect token changes.
54+
/// If the current token differs from the pending refresh token,
55+
/// a new refresh is started (old one is ignored).
56+
/// - refreshFunction: A closure that performs the actual network request and refresh.
57+
/// Should be called only if a new refresh is needed.
58+
///
59+
/// - Returns: A tuple containing (refreshedToken, wasUpdated) matching the format
60+
/// of SecureTokenService.
61+
///
62+
/// - Throws: Any error from the refresh operation.
63+
func coalescedRefresh(currentToken: String,
64+
refreshFunction: @escaping () async throws -> (String?, Bool)) async throws
65+
-> (
66+
String?,
67+
Bool
68+
) {
69+
// Check if a refresh is already in progress for this token
70+
if let pendingTask = pendingRefreshTask,
71+
pendingRefreshToken == currentToken {
72+
// Token hasn't changed and a refresh is in progress
73+
// Wait for the pending refresh to complete
74+
return try await pendingTask.value
75+
}
76+
77+
// Either no refresh is in progress, or the token has changed.
78+
// Start a new refresh task.
79+
let task = Task {
80+
try await refreshFunction()
81+
}
82+
83+
// Store the task so other concurrent callers can wait for it
84+
pendingRefreshTask = task
85+
pendingRefreshToken = currentToken
86+
87+
defer {
88+
// Clean up the pending task after it completes
89+
pendingRefreshTask = nil
90+
pendingRefreshToken = nil
91+
}
92+
93+
do {
94+
return try await task.value
95+
} catch {
96+
// On error, clear the pending task so the next call will retry
97+
pendingRefreshTask = nil
98+
pendingRefreshToken = nil
99+
throw error
100+
}
101+
}
102+
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@testable import FirebaseAuth
16+
import XCTest
17+
18+
actor Counter {
19+
private var valueInternal: Int = 0
20+
func increment() { valueInternal += 1 }
21+
func value() -> Int { valueInternal }
22+
}
23+
24+
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
25+
class TokenRefreshCoalescerTests: XCTestCase {
26+
/// Tests that when multiple concurrent refresh requests arrive for the same token,
27+
/// only ONE network call is made.
28+
///
29+
/// This is the main issue fix: Previously, each concurrent caller would make its own
30+
/// network request, resulting in redundant STS calls.
31+
func testCoalescedRefreshMakesOnlyOneNetworkCall() async throws {
32+
let coalescer = TokenRefreshCoalescer()
33+
let counter = Counter()
34+
35+
// Simulate multiple concurrent refresh requests
36+
async let result1 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
37+
await counter.increment()
38+
39+
// Simulate network delay
40+
try await Task.sleep(nanoseconds: 100_000_000) // 0.1 seconds
41+
42+
return ("new_token", true)
43+
}
44+
45+
async let result2 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
46+
await counter.increment()
47+
48+
try await Task.sleep(nanoseconds: 100_000_000)
49+
return ("new_token", true)
50+
}
51+
52+
async let result3 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
53+
await counter.increment()
54+
55+
try await Task.sleep(nanoseconds: 100_000_000)
56+
return ("new_token", true)
57+
}
58+
59+
// Wait for all three to complete
60+
let (token1, updated1) = try await result1
61+
let (token2, updated2) = try await result2
62+
let (token3, updated3) = try await result3
63+
64+
// All three should get the same token
65+
XCTAssertEqual(token1, "new_token")
66+
XCTAssertEqual(token2, "new_token")
67+
XCTAssertEqual(token3, "new_token")
68+
69+
XCTAssertTrue(updated1)
70+
XCTAssertTrue(updated2)
71+
XCTAssertTrue(updated3)
72+
73+
// CRITICAL: Only ONE network call should have been made
74+
// (Previously, without coalescing, this would be 3)
75+
let callCount = await counter.value()
76+
XCTAssertEqual(callCount, 1, "Expected only 1 network call, but got \(callCount)")
77+
}
78+
79+
/// Tests that when the token changes, a new refresh is started instead of
80+
/// coalescing with the old one.
81+
func testNewRefreshStartsWhenTokenChanges() async throws {
82+
let coalescer = TokenRefreshCoalescer()
83+
let counter = Counter()
84+
85+
// First refresh for token_v1
86+
async let result1 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
87+
await counter.increment()
88+
89+
try await Task.sleep(nanoseconds: 50_000_000)
90+
return ("new_token_1", true)
91+
}
92+
93+
// Wait a bit, then start a refresh for a different token (token_v2)
94+
// This should NOT coalesce with the first one
95+
try await Task.sleep(nanoseconds: 10_000_000)
96+
97+
async let result2 = try coalescer.coalescedRefresh(currentToken: "token_v2") {
98+
await counter.increment()
99+
100+
try await Task.sleep(nanoseconds: 50_000_000)
101+
return ("new_token_2", true)
102+
}
103+
104+
let token1 = try await result1.0
105+
let token2 = try await result2.0
106+
107+
// Should get different tokens
108+
XCTAssertEqual(token1, "new_token_1")
109+
XCTAssertEqual(token2, "new_token_2")
110+
111+
// Should have made TWO network calls (one for each token)
112+
let callsAfterTwoTokens = await counter.value()
113+
XCTAssertEqual(callsAfterTwoTokens, 2)
114+
}
115+
116+
/// Tests that if a refresh fails, the next call will start a fresh attempt
117+
/// instead of waiting for the failed one.
118+
func testFailedRefreshAllowsRetry() async throws {
119+
let coalescer = TokenRefreshCoalescer()
120+
let counter = Counter()
121+
122+
// First call will fail (run it to completion)
123+
do {
124+
_ = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
125+
await counter.increment()
126+
throw NSError(domain: "TestError", code: -1, userInfo: nil)
127+
}
128+
XCTFail("Expected error")
129+
} catch {
130+
// Expected failure
131+
}
132+
133+
// Second call after the failure should start a fresh attempt and succeed
134+
let (token2, updated2) = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
135+
await counter.increment()
136+
return ("recovered_token", true)
137+
}
138+
139+
XCTAssertEqual(token2, "recovered_token")
140+
XCTAssertTrue(updated2)
141+
142+
// Should have made TWO network calls (first failed, second succeeded)
143+
let secondResult = await counter.value()
144+
XCTAssertEqual(secondResult, 2)
145+
}
146+
147+
/// Stress test: Many concurrent calls for the same token
148+
func testManyCurrentCallsWithSameToken() async throws {
149+
let coalescer = TokenRefreshCoalescer()
150+
let counter = Counter()
151+
152+
let numCalls = 50
153+
var tasks: [Task<(String?, Bool), Error>] = []
154+
155+
// Launch 50 concurrent refresh tasks
156+
for _ in 0 ..< numCalls {
157+
let task = Task {
158+
try await coalescer.coalescedRefresh(currentToken: "token_stress") {
159+
await counter.increment()
160+
161+
try await Task.sleep(nanoseconds: 100_000_000)
162+
return ("stress_token", true)
163+
}
164+
}
165+
tasks.append(task)
166+
}
167+
168+
// Wait for all to complete
169+
var successCount = 0
170+
for task in tasks {
171+
let (token, updated) = try await task.value
172+
XCTAssertEqual(token, "stress_token")
173+
XCTAssertTrue(updated)
174+
successCount += 1
175+
}
176+
177+
XCTAssertEqual(successCount, numCalls)
178+
179+
// All 50 concurrent calls should result in ONLY 1 network call
180+
let stressCallCount = await counter.value()
181+
XCTAssertEqual(
182+
stressCallCount,
183+
1,
184+
"Expected 1 network call for 50 concurrent requests, but got \(stressCallCount)"
185+
)
186+
}
187+
188+
/// Tests that concurrent calls with forceRefresh:false still use the cache
189+
/// when tokens are valid.
190+
func testCachingStillWorksWithCoalescer() async throws {
191+
let coalescer = TokenRefreshCoalescer()
192+
let counter = Counter()
193+
194+
// First call triggers a refresh
195+
let result1 = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
196+
await counter.increment()
197+
198+
return ("refreshed_token", true)
199+
}
200+
201+
XCTAssertEqual(result1.0, "refreshed_token")
202+
let resultAfterRefresh = await counter.value()
203+
XCTAssertEqual(resultAfterRefresh, 1)
204+
205+
// This test documents that caching logic happens BEFORE coalescer is called,
206+
// so this scenario doesn't test the coalescer directly, but verifies the
207+
// integration is correct.
208+
}
209+
}

0 commit comments

Comments
 (0)