Skip to content

Commit 4c55628

Browse files
committed
feat(sdk): Keep track of whether the access token is expired
This will allow to handle automatically whether to send an access token or not on endpoints that don't require it in contexts were can't refresh it. We also don't cache calls to GET /versions that were not authenticated, because they might lack some features compared to an authenticated request. Signed-off-by: Kévin Commaille <[email protected]>
1 parent 47f1167 commit 4c55628

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

crates/matrix-sdk/src/authentication/mod.rs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ impl fmt::Debug for SessionTokens {
4848
}
4949
}
5050

51+
/// The tokens for a user session and their state.
52+
pub(crate) struct SessionTokensState {
53+
/// The inner tokens.
54+
inner: SessionTokens,
55+
56+
/// Whether the access token is expired.
57+
///
58+
/// We keep track of this information here, rather than dropping the access
59+
/// token, because we still want to make most requests with the expired
60+
/// access token to try to refresh it, or wait for it to be refreshed. If we
61+
/// make a request without the access token we will get the wrong error.
62+
access_token_expired: bool,
63+
}
64+
5165
pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
5266

5367
#[cfg(not(target_family = "wasm"))]
@@ -83,8 +97,8 @@ pub(crate) struct AuthCtx {
8397
/// Authentication data to keep in memory.
8498
pub(crate) auth_data: OnceCell<AuthData>,
8599

86-
/// The current session tokens.
87-
tokens: OnceCell<Mutex<SessionTokens>>,
100+
/// The current session tokens and their state.
101+
tokens: OnceCell<Mutex<SessionTokensState>>,
88102

89103
/// A callback called whenever we need an absolute source of truth for the
90104
/// current session tokens.
@@ -119,22 +133,47 @@ impl AuthCtx {
119133

120134
/// The current session tokens.
121135
pub(crate) fn session_tokens(&self) -> Option<SessionTokens> {
122-
Some(self.tokens.get()?.lock().clone())
136+
Some(self.tokens.get()?.lock().inner.clone())
123137
}
124138

125139
/// The current access token.
126140
pub(crate) fn access_token(&self) -> Option<String> {
127-
Some(self.tokens.get()?.lock().access_token.clone())
141+
Some(self.tokens.get()?.lock().inner.access_token.clone())
142+
}
143+
144+
/// Whether we have a valid session token.
145+
pub(crate) fn has_valid_access_token(&self) -> bool {
146+
self.tokens.get().is_some_and(|tokens| !tokens.lock().access_token_expired)
128147
}
129148

130149
/// Set the current session tokens.
131150
pub(crate) fn set_session_tokens(&self, session_tokens: SessionTokens) {
151+
let session_tokens = SessionTokensState {
152+
inner: session_tokens,
153+
// We just got the tokens, so we assume that they are not expired.
154+
access_token_expired: false,
155+
};
156+
132157
if let Some(tokens) = self.tokens.get() {
133158
*tokens.lock() = session_tokens;
134159
} else {
135160
let _ = self.tokens.set(Mutex::new(session_tokens));
136161
}
137162
}
163+
164+
/// Set the given access token as expired.
165+
///
166+
/// We take the value of the access token to make sure that we don't mark
167+
/// the wrong access token as expired.
168+
pub(crate) fn set_access_token_expired(&self, access_token: &str) {
169+
if let Some(tokens) = self.tokens.get() {
170+
let mut tokens = tokens.lock();
171+
172+
if tokens.inner.access_token == access_token {
173+
tokens.access_token_expired = true;
174+
}
175+
}
176+
}
138177
}
139178

140179
/// An enum over all the possible authentication APIs.

crates/matrix-sdk/src/client/mod.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,8 @@ impl Client {
19361936
let path_builder_input =
19371937
Request::PathBuilder::get_path_builder_input(self, skip_auth).await?;
19381938

1939-
self.inner
1939+
let result = self
1940+
.inner
19401941
.http_client
19411942
.send(
19421943
request,
@@ -1946,7 +1947,17 @@ impl Client {
19461947
path_builder_input,
19471948
send_progress,
19481949
)
1949-
.await
1950+
.await;
1951+
1952+
if let Err(Some(ErrorKind::UnknownToken { .. })) =
1953+
result.as_ref().map_err(HttpError::client_api_error_kind)
1954+
&& let Some(access_token) = &access_token
1955+
{
1956+
// Mark the access token as expired.
1957+
self.auth_ctx().set_access_token_expired(access_token);
1958+
}
1959+
1960+
result
19501961
}
19511962

19521963
fn broadcast_unknown_token(&self, soft_logout: &bool) {
@@ -2030,9 +2041,9 @@ impl Client {
20302041
unstable_features: server_versions.unstable_features,
20312042
};
20322043

2033-
// Attempt to cache the result in storage.
2034-
{
2035-
if let Err(err) = self
2044+
// Only attempt to cache the result in storage if the request was authenticated.
2045+
if self.auth_ctx().has_valid_access_token()
2046+
&& let Err(err) = self
20362047
.state_store()
20372048
.set_kv_data(
20382049
StateStoreDataKey::SupportedVersions,
@@ -2041,9 +2052,8 @@ impl Client {
20412052
)),
20422053
)
20432054
.await
2044-
{
2045-
warn!("error when caching supported versions: {err}");
2046-
}
2055+
{
2056+
warn!("error when caching supported versions: {err}");
20472057
}
20482058

20492059
Ok(supported_versions)
@@ -2105,7 +2115,10 @@ impl Client {
21052115
supported_versions.versions = [MatrixVersion::V1_0].into();
21062116
}
21072117

2108-
*guarded_supported_versions = CachedValue::Cached(supported_versions.clone());
2118+
// Only cache the result if the request was authenticated.
2119+
if self.auth_ctx().has_valid_access_token() {
2120+
*guarded_supported_versions = CachedValue::Cached(supported_versions.clone());
2121+
}
21092122

21102123
Ok(supported_versions)
21112124
}

0 commit comments

Comments
 (0)