Skip to content

Commit 47f1167

Browse files
committed
test(sdk): Handle more cases with ExpectedAccessToken
Signed-off-by: Kévin Commaille <[email protected]>
1 parent e6e2768 commit 47f1167

File tree

5 files changed

+102
-93
lines changed

5 files changed

+102
-93
lines changed

crates/matrix-sdk/src/authentication/oauth/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ impl OAuth {
11731173
if let Some(save_session_callback) = self.client.auth_ctx().save_session_callback.get() {
11741174
// Satisfies the save_session_callback invariant: set_session_tokens has
11751175
// been called just above.
1176+
tracing::debug!("call save_session_callback");
11761177
if let Err(err) = save_session_callback(self.client.clone()) {
11771178
error!("when saving session after refresh: {err}");
11781179
}
@@ -1183,6 +1184,7 @@ impl OAuth {
11831184
lock.save_in_memory_and_db(&tokens_clone).await?;
11841185
}
11851186

1187+
tracing::debug!("broadcast session changed");
11861188
_ = self.client.auth_ctx().session_change_sender.send(SessionChange::TokensRefreshed);
11871189

11881190
Ok(())

crates/matrix-sdk/src/test_utils/mocks/mod.rs

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,7 +1861,7 @@ pub struct MockEndpoint<'a, T> {
18611861

18621862
impl<'a, T> MockEndpoint<'a, T> {
18631863
fn new(server: &'a MockServer, mock: MockBuilder, endpoint: T) -> Self {
1864-
Self { server, mock, endpoint, expected_access_token: ExpectedAccessToken::None }
1864+
Self { server, mock, endpoint, expected_access_token: ExpectedAccessToken::Ignore }
18651865
}
18661866

18671867
/// Expect authentication with the default access token on this endpoint.
@@ -1876,12 +1876,30 @@ impl<'a, T> MockEndpoint<'a, T> {
18761876
self
18771877
}
18781878

1879-
/// Don't expect authentication with an access token on this endpoint.
1879+
/// Expect authentication with any access token on this endpoint, regardless
1880+
/// of its value.
18801881
///
1881-
/// This should be used to override the default behavior of the endpoint,
1882-
/// when the access token is unknown for example.
1883-
pub fn do_not_expect_access_token(mut self) -> Self {
1884-
self.expected_access_token = ExpectedAccessToken::None;
1882+
/// This is useful if we don't want to track the value of the access token.
1883+
pub fn expect_any_access_token(mut self) -> Self {
1884+
self.expected_access_token = ExpectedAccessToken::Any;
1885+
self
1886+
}
1887+
1888+
/// Expect no authentication on this endpoint.
1889+
///
1890+
/// This means that the endpoint will not match if an `AUTHENTICATION`
1891+
/// header is present.
1892+
pub fn expect_missing_access_token(mut self) -> Self {
1893+
self.expected_access_token = ExpectedAccessToken::Missing;
1894+
self
1895+
}
1896+
1897+
/// Ignore the access token on this endpoint.
1898+
///
1899+
/// This should be used to override the default behavior of an endpoint that
1900+
/// requires access tokens.
1901+
pub fn ignore_access_token(mut self) -> Self {
1902+
self.expected_access_token = ExpectedAccessToken::Ignore;
18851903
self
18861904
}
18871905

@@ -1927,10 +1945,7 @@ impl<'a, T> MockEndpoint<'a, T> {
19271945
/// # anyhow::Ok(()) });
19281946
/// ```
19291947
pub fn respond_with<R: Respond + 'static>(self, func: R) -> MatrixMock<'a> {
1930-
let mock = self
1931-
.expected_access_token
1932-
.maybe_match_authorization_header(self.mock)
1933-
.respond_with(func);
1948+
let mock = self.mock.and(self.expected_access_token).respond_with(func);
19341949
MatrixMock { mock, server: self.server }
19351950
}
19361951

@@ -1984,6 +1999,17 @@ impl<'a, T> MockEndpoint<'a, T> {
19841999
})))
19852000
}
19862001

2002+
/// Returns a mocked endpoint that emulates an unknown token error, i.e
2003+
/// responds with a 401 HTTP status code and an `M_UNKNOWN_TOKEN` Matrix
2004+
/// error code.
2005+
pub fn error_unknown_token(self, soft_logout: bool) -> MatrixMock<'a> {
2006+
self.respond_with(ResponseTemplate::new(401).set_body_json(json!({
2007+
"errcode": "M_UNKNOWN_TOKEN",
2008+
"error": "Unrecognized access token",
2009+
"soft_logout": soft_logout,
2010+
})))
2011+
}
2012+
19872013
/// Internal helper to return an `{ event_id }` JSON struct along with a 200
19882014
/// ok response.
19892015
fn ok_with_event_id(self, event_id: OwnedEventId) -> MatrixMock<'a> {
@@ -2036,25 +2062,44 @@ impl<'a, T> MockEndpoint<'a, T> {
20362062

20372063
/// The access token to expect on an endpoint.
20382064
enum ExpectedAccessToken {
2039-
/// We don't expect an access token.
2040-
None,
2065+
/// Ignore any access token or lack thereof.
2066+
Ignore,
20412067

20422068
/// We expect the default access token.
20432069
Default,
20442070

20452071
/// We expect the given access token.
20462072
Custom(&'static str),
2073+
2074+
/// We expect any access token.
2075+
Any,
2076+
2077+
/// We expect that there is no access token.
2078+
Missing,
20472079
}
20482080

20492081
impl ExpectedAccessToken {
2050-
/// Match an `Authorization` header on the given mock if one is expected.
2051-
fn maybe_match_authorization_header(&self, mock: MockBuilder) -> MockBuilder {
2052-
let token = match self {
2053-
Self::None => return mock,
2054-
Self::Default => "1234",
2055-
Self::Custom(token) => token,
2056-
};
2057-
mock.and(header(http::header::AUTHORIZATION, format!("Bearer {token}")))
2082+
/// Get the access token from the given request.
2083+
fn access_token(request: &Request) -> Option<&str> {
2084+
request
2085+
.headers
2086+
.get(&http::header::AUTHORIZATION)?
2087+
.to_str()
2088+
.ok()?
2089+
.strip_prefix("Bearer ")
2090+
.filter(|token| !token.is_empty())
2091+
}
2092+
}
2093+
2094+
impl wiremock::Match for ExpectedAccessToken {
2095+
fn matches(&self, request: &Request) -> bool {
2096+
match self {
2097+
Self::Ignore => true,
2098+
Self::Default => Self::access_token(request) == Some("1234"),
2099+
Self::Custom(token) => Self::access_token(request) == Some(token),
2100+
Self::Any => Self::access_token(request).is_some(),
2101+
Self::Missing => request.headers.get(&http::header::AUTHORIZATION).is_none(),
2102+
}
20582103
}
20592104
}
20602105

@@ -3483,14 +3528,6 @@ impl<'a> MockEndpoint<'a, WhoAmIEndpoint> {
34833528
"device_id": device_id,
34843529
})))
34853530
}
3486-
3487-
/// Returns an error response with an `M_UNKNOWN_TOKEN`.
3488-
pub fn err_unknown_token(self) -> MatrixMock<'a> {
3489-
self.respond_with(ResponseTemplate::new(401).set_body_json(json!({
3490-
"errcode": "M_UNKNOWN_TOKEN",
3491-
"error": "Invalid token"
3492-
})))
3493-
}
34943531
}
34953532

34963533
/// A prebuilt mock for `POST /keys/upload` request.

crates/matrix-sdk/tests/integration/client.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,25 +1534,11 @@ async fn test_server_version_without_auth() {
15341534

15351535
// If we provide an access token, we encounter a failure, likely because the
15361536
// token has expired.
1537-
Mock::given(method("GET"))
1538-
.and(path_regex(r"^/_matrix/client/versions"))
1539-
.and(header("authorization", "Bearer 1234"))
1540-
.respond_with(ResponseTemplate::new(401))
1541-
.mount(server.server())
1542-
.await;
1537+
server.mock_versions().expect_default_access_token().error_unknown_token(true).mount().await;
15431538

15441539
// If we do not provide an access token, all is fine as the endpoint does not
15451540
// require one.
1546-
Mock::given(method("GET"))
1547-
.and(path_regex(r"^/_matrix/client/versions"))
1548-
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
1549-
"versions": [
1550-
"r0.0.1",
1551-
"v1.1"
1552-
]
1553-
})))
1554-
.mount(server.server())
1555-
.await;
1541+
server.mock_versions().expect_missing_access_token().ok_with_unstable_features().mount().await;
15561542

15571543
let request_config = RequestConfig::new().disable_retry();
15581544
client

crates/matrix-sdk/tests/integration/encryption/shared_history.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async fn test_shared_history_out_of_order() {
140140
let bundle_info = bundle_info.await;
141141
matrix_mock_server
142142
.mock_authed_media_download()
143-
.do_not_expect_access_token()
143+
.expect_any_access_token()
144144
.ok_bytes(bundle)
145145
.mock_once()
146146
.named("media_download")

crates/matrix-sdk/tests/integration/refresh_token.rs

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ use assert_matches::assert_matches;
77
use assert_matches2::assert_let;
88
use matrix_sdk::{
99
HttpError, RefreshTokenError, SessionChange, SessionTokens,
10-
authentication::matrix::MatrixSession,
10+
authentication::{matrix::MatrixSession, oauth::OAuthError},
1111
config::RequestConfig,
1212
executor::spawn,
1313
store::RoomLoadSettings,
1414
test_utils::{
15-
client::mock_session_meta,
15+
client::{
16+
mock_prev_session_tokens_with_refresh, mock_session_meta,
17+
mock_session_tokens_with_refresh, oauth::mock_session,
18+
},
1619
logged_in_client_with_server,
1720
mocks::{LoginResponseTemplate200, MatrixMockServer},
1821
no_retry_test_client_with_server, test_client_builder_with_server,
@@ -30,7 +33,7 @@ use serde_json::json;
3033
use tokio::sync::{broadcast::error::TryRecvError, mpsc};
3134
use wiremock::{
3235
Mock, ResponseTemplate,
33-
matchers::{body_partial_json, header, method, path, path_regex},
36+
matchers::{body_partial_json, header, method, path},
3437
};
3538

3639
fn session() -> MatrixSession {
@@ -563,17 +566,12 @@ async fn test_refresh_token_handled_other_error() {
563566

564567
#[async_test]
565568
async fn test_oauth_refresh_token_handled_success() {
566-
use matrix_sdk::test_utils::{
567-
client::{mock_prev_session_tokens_with_refresh, oauth::mock_session},
568-
mocks::MatrixMockServer,
569-
};
570-
571569
let server = MatrixMockServer::new().await;
572570
// Return an error first so the token is refreshed.
573571
server
574572
.mock_who_am_i()
575573
.expect_access_token("prev-access-token")
576-
.err_unknown_token()
574+
.error_unknown_token(true)
577575
.expect(1)
578576
.named("whoami_unknown_token")
579577
.mount()
@@ -619,20 +617,12 @@ async fn test_oauth_refresh_token_handled_success() {
619617

620618
#[async_test]
621619
async fn test_oauth_refresh_token_handled_failure() {
622-
use matrix_sdk::{
623-
authentication::oauth::OAuthError,
624-
test_utils::{
625-
client::{mock_prev_session_tokens_with_refresh, oauth::mock_session},
626-
mocks::MatrixMockServer,
627-
},
628-
};
629-
630620
let server = MatrixMockServer::new().await;
631621
// Return an error first so the token is refreshed.
632622
server
633623
.mock_who_am_i()
634624
.expect_access_token("prev-access-token")
635-
.err_unknown_token()
625+
.error_unknown_token(false)
636626
.expect(1)
637627
.named("whoami_unknown_token")
638628
.mount()
@@ -688,14 +678,6 @@ async fn test_oauth_refresh_token_handled_failure() {
688678

689679
#[async_test]
690680
async fn test_oauth_handle_refresh_tokens() {
691-
use matrix_sdk::test_utils::{
692-
client::{
693-
mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
694-
oauth::mock_session,
695-
},
696-
mocks::MatrixMockServer,
697-
};
698-
699681
let server = MatrixMockServer::new().await;
700682
let oauth_server = server.oauth();
701683

@@ -767,45 +749,47 @@ async fn test_oauth_handle_refresh_tokens() {
767749

768750
#[async_test]
769751
async fn test_oauth_handle_refresh_tokens_without_versions() {
770-
use matrix_sdk::test_utils::{
771-
client::{
772-
mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
773-
oauth::mock_session,
774-
},
775-
mocks::MatrixMockServer,
776-
};
777-
778752
let server = MatrixMockServer::new().await;
779753
let oauth_server = server.oauth();
780754

781755
// If we provide an access token, we encounter a failure, likely because the
782756
// token has expired.
783-
Mock::given(method("GET"))
784-
.and(path_regex(r"^/_matrix/client/versions"))
785-
.and(header("authorization", "Bearer prev-access-token"))
786-
.respond_with(ResponseTemplate::new(401))
787-
.mount(server.server())
757+
server
758+
.mock_versions()
759+
.expect_access_token("prev-access-token")
760+
.error_unknown_token(true)
761+
.expect(1..)
762+
.named("versions with expired token")
763+
.mount()
788764
.await;
789765

790766
// If we do not provide an access token, all is fine as the endpoint does not
791767
// require one.
792-
Mock::given(method("GET"))
793-
.and(path_regex(r"^/_matrix/client/versions"))
794-
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
795-
"versions": [
796-
"r0.0.1",
797-
"v1.1"
798-
]
799-
})))
768+
server
769+
.mock_versions()
770+
.expect_missing_access_token()
771+
.ok_with_unstable_features()
800772
.expect(1..)
801-
.mount(server.server())
773+
.named("unauthenticated versions")
774+
.mount()
775+
.await;
776+
777+
// If we provide the new access token, all is fine.
778+
server
779+
.mock_versions()
780+
.expect_default_access_token()
781+
.ok_with_unstable_features()
782+
.expect(1)
783+
.named("versions with fresh token")
784+
.mount()
802785
.await;
803786

804787
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
805788

806789
let client = server
807790
.client_builder()
808791
.unlogged()
792+
.no_server_versions()
809793
.on_builder(|builder| builder.handle_refresh_tokens())
810794
.build()
811795
.await;

0 commit comments

Comments
 (0)