Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions crates/cli/src/commands/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig,
MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
Expand Down Expand Up @@ -45,8 +46,12 @@ impl Options {
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config =
ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
let policy_factory =
policy_factory_from_config(&config, &matrix_config, &experimental_config)
.await?;

if with_dynamic_data {
let database_config =
Expand Down
4 changes: 3 additions & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ impl Options {

// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory =
policy_factory_from_config(&config.policy, &config.matrix, &config.experimental)
.await?;
let policy_factory = Arc::new(policy_factory);

load_policy_factory_dynamic_data_continuously(
Expand Down
23 changes: 20 additions & 3 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use mas_config::{
PolicyConfig, TemplatesConfig,
};
use mas_context::LogContext;
use mas_data_model::{SessionExpirationConfig, SiteConfig};
use mas_data_model::{SessionExpirationConfig, SessionLimitConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
Expand Down Expand Up @@ -135,6 +135,7 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
Expand All @@ -147,8 +148,17 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};

let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
let session_limit_config =
experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
});

let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config)
.with_rest(config.data.clone());

PolicyFactory::load(policy_file, data, entrypoints)
.await
Expand Down Expand Up @@ -225,6 +235,13 @@ pub fn site_config_from_config(
session_expiration,
login_with_email_allowed: account_config.login_with_email_allowed,
plan_management_iframe_uri: experimental_config.plan_management_iframe_uri.clone(),
session_limit: experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
}),
})
}

Expand Down
17 changes: 17 additions & 0 deletions crates/config/src/sections/experimental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.

use std::num::NonZeroU64;

use chrono::Duration;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -81,6 +83,13 @@ pub struct ExperimentalConfig {
/// validation.
#[serde(skip_serializing_if = "Option::is_none")]
pub plan_management_iframe_uri: Option<String>,

/// Experimental feature to limit the number of application sessions per
/// user.
///
/// Disabled by default.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_limit: Option<SessionLimitConfig>,
}

impl Default for ExperimentalConfig {
Expand All @@ -90,6 +99,7 @@ impl Default for ExperimentalConfig {
compat_token_ttl: default_token_ttl(),
inactive_session_expiration: None,
plan_management_iframe_uri: None,
session_limit: None,
}
}
}
Expand All @@ -106,3 +116,10 @@ impl ExperimentalConfig {
impl ConfigurationSection for ExperimentalConfig {
const PATH: Option<&'static str> = Some("experimental");
}

/// Configuration options for the inactive session expiration feature
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct SessionLimitConfig {
pub soft_limit: NonZeroU64,
pub hard_limit: NonZeroU64,
}
4 changes: 3 additions & 1 deletion crates/data-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ pub use self::{
DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
policy_data::PolicyData,
site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig},
site_config::{
CaptchaConfig, CaptchaService, SessionExpirationConfig, SessionLimitConfig, SiteConfig,
},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},
Expand Down
12 changes: 12 additions & 0 deletions crates/data-model/src/site_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.

use std::num::NonZeroU64;

use chrono::Duration;
use serde::Serialize;
use url::Url;

/// Which Captcha service is being used
Expand Down Expand Up @@ -36,6 +39,12 @@ pub struct SessionExpirationConfig {
pub compat_session_inactivity_ttl: Option<Duration>,
}

#[derive(Serialize, Debug, Clone)]
pub struct SessionLimitConfig {
pub soft_limit: NonZeroU64,
pub hard_limit: NonZeroU64,
}

/// Random site configuration we want accessible in various places.
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -99,4 +108,7 @@ pub struct SiteConfig {

/// The iframe URL to show in the plan tab of the UI
pub plan_management_iframe_uri: Option<String>,

/// Limits on the number of application sessions that each user can have
pub session_limit: Option<SessionLimitConfig>,
}
8 changes: 7 additions & 1 deletion crates/handlers/src/oauth2/authorization/consent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use super::callback::CallbackDestination;
use crate::{
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
oauth2::generate_id_token,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};

#[derive(Debug, Error)]
Expand Down Expand Up @@ -136,10 +136,13 @@ pub(crate) async fn get(

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;

let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
Expand Down Expand Up @@ -235,10 +238,13 @@ pub(crate) async fn post(
return Err(RouteError::GrantNotPending(grant.id));
}

let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?;

let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&browser_session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
Expand Down
8 changes: 7 additions & 1 deletion crates/handlers/src/oauth2/device/consent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use ulid::Ulid;

use crate::{
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -103,11 +103,14 @@ pub(crate) async fn get(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;

// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
Expand Down Expand Up @@ -205,11 +208,14 @@ pub(crate) async fn post(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;

// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/oauth2/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ async fn client_credentials_grant(
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: None,
client,
session_counts: None,
scope: &scope,
grant_type: mas_policy::GrantType::ClientCredentials,
requester: mas_policy::Requester {
Expand Down
67 changes: 65 additions & 2 deletions crates/handlers/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

use axum::response::{Html, IntoResponse as _, Response};
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt};
use mas_data_model::{BrowserSession, Clock};
use mas_data_model::{BrowserSession, Clock, User};
use mas_i18n::DataLocale;
use mas_storage::{BoxRepository, RepositoryError};
use mas_policy::model::SessionCounts;
use mas_storage::{
BoxRepository, RepositoryError, compat::CompatSessionFilter, oauth2::OAuth2SessionFilter,
personal::PersonalSessionFilter,
};
use mas_templates::{AccountInactiveContext, TemplateContext, Templates};
use rand::RngCore;
use thiserror::Error;
Expand Down Expand Up @@ -102,3 +106,62 @@ pub async fn load_session_or_fallback(
maybe_session: Some(session),
})
}

/// Get a count of sessions for the given user, for the purposes of session
/// limiting.
///
/// Includes:
/// - OAuth 2 sessions
/// - Compatibility sessions
/// - Personal sessions (unless owned by a different user)
///
/// # Backstory
///
/// Originally, we were only intending to count sessions with devices in this
/// result, because those are the entries that are expensive for Synapse and
/// also would not hinder use of deviceless clients (like Element Admin, an
/// admin dashboard).
///
/// However, to do so, we would need to count only sessions including device
/// scopes. To do this efficiently, we'd need a partial index on sessions
/// including device scopes.
///
/// It turns out that this can't be done cleanly (as we need to, in Postgres,
/// match scope lists where one of the scopes matches one of 2 known prefixes),
/// at least not without somewhat uncomfortable stored functions.
///
/// So for simplicity's sake, we now count all sessions.
/// For practical use cases, it's not likely to make a noticeable difference
/// (and maybe it's good that there's an overall limit).
pub(crate) async fn count_user_sessions_for_limiting(
repo: &mut BoxRepository,
user: &User,
) -> Result<SessionCounts, RepositoryError> {
let oauth2 = repo
.oauth2_session()
.count(OAuth2SessionFilter::new().active_only().for_user(user))
.await? as u64;

let compat = repo
.compat_session()
.count(CompatSessionFilter::new().active_only().for_user(user))
.await? as u64;

// Only include self-owned personal sessions, not administratively-owned ones
let personal = repo
.personal_session()
.count(
PersonalSessionFilter::new()
.active_only()
.for_actor_user(user)
.for_owner_user(user),
)
.await? as u64;

Ok(SessionCounts {
total: oauth2 + compat + personal,
oauth2,
compat,
personal,
})
}
3 changes: 2 additions & 1 deletion crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};

let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data);

let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
Expand Down Expand Up @@ -148,6 +148,7 @@ pub fn test_site_config() -> SiteConfig {
session_expiration: None,
login_with_email_allowed: true,
plan_management_iframe_uri: None,
session_limit: None,
}
}

Expand Down
Loading
Loading