diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index bb87c5e81..6da64f95b 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -9,7 +9,8 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; use mas_config::{ - ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig, + ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig, + MatrixConfig, PolicyConfig, }; use mas_storage_pg::PgRepositoryFactory; use tracing::{info, info_span}; @@ -45,8 +46,12 @@ impl Options { PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?; let matrix_config = MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?; + let experimental_config = + ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?; info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; + let policy_factory = + policy_factory_from_config(&config, &matrix_config, &experimental_config) + .await?; if with_dynamic_data { let database_config = diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 52465f077..020d24d0f 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -132,7 +132,9 @@ impl Options { // Load and compile the WASM policies (and fallback to the default embedded one) info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; + let policy_factory = + policy_factory_from_config(&config.policy, &config.matrix, &config.experimental) + .await?; let policy_factory = Arc::new(policy_factory); load_policy_factory_dynamic_data_continuously( diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 4925d9866..a9b9a3132 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -13,7 +13,7 @@ use mas_config::{ PolicyConfig, TemplatesConfig, }; use mas_context::LogContext; -use mas_data_model::{SessionExpirationConfig, SiteConfig}; +use mas_data_model::{SessionExpirationConfig, SessionLimitConfig, SiteConfig}; use mas_email::{MailTransport, Mailer}; use mas_handlers::passwords::PasswordManager; use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection}; @@ -135,6 +135,7 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) { pub async fn policy_factory_from_config( config: &PolicyConfig, matrix_config: &MatrixConfig, + experimental_config: &ExperimentalConfig, ) -> Result { let policy_file = tokio::fs::File::open(&config.wasm_module) .await @@ -147,8 +148,17 @@ pub async fn policy_factory_from_config( email: config.email_entrypoint.clone(), }; - let data = - mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone()); + let session_limit_config = + experimental_config + .session_limit + .as_ref() + .map(|c| SessionLimitConfig { + soft_limit: c.soft_limit, + hard_limit: c.hard_limit, + }); + + let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config) + .with_rest(config.data.clone()); PolicyFactory::load(policy_file, data, entrypoints) .await @@ -225,6 +235,13 @@ pub fn site_config_from_config( session_expiration, login_with_email_allowed: account_config.login_with_email_allowed, plan_management_iframe_uri: experimental_config.plan_management_iframe_uri.clone(), + session_limit: experimental_config + .session_limit + .as_ref() + .map(|c| SessionLimitConfig { + soft_limit: c.soft_limit, + hard_limit: c.hard_limit, + }), }) } diff --git a/crates/config/src/sections/experimental.rs b/crates/config/src/sections/experimental.rs index c6c50e88d..b8f3920b0 100644 --- a/crates/config/src/sections/experimental.rs +++ b/crates/config/src/sections/experimental.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::num::NonZeroU64; + use chrono::Duration; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -81,6 +83,13 @@ pub struct ExperimentalConfig { /// validation. #[serde(skip_serializing_if = "Option::is_none")] pub plan_management_iframe_uri: Option, + + /// Experimental feature to limit the number of application sessions per + /// user. + /// + /// Disabled by default. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_limit: Option, } impl Default for ExperimentalConfig { @@ -90,6 +99,7 @@ impl Default for ExperimentalConfig { compat_token_ttl: default_token_ttl(), inactive_session_expiration: None, plan_management_iframe_uri: None, + session_limit: None, } } } @@ -100,9 +110,17 @@ impl ExperimentalConfig { && is_default_token_ttl(&self.compat_token_ttl) && self.inactive_session_expiration.is_none() && self.plan_management_iframe_uri.is_none() + && self.session_limit.is_none() } } impl ConfigurationSection for ExperimentalConfig { const PATH: Option<&'static str> = Some("experimental"); } + +/// Configuration options for the session limit feature +#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)] +pub struct SessionLimitConfig { + pub soft_limit: NonZeroU64, + pub hard_limit: NonZeroU64, +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 962c8be00..fd5c0e633 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -39,7 +39,9 @@ pub use self::{ DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, policy_data::PolicyData, - site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig}, + site_config::{ + CaptchaConfig, CaptchaService, SessionExpirationConfig, SessionLimitConfig, SiteConfig, + }, tokens::{ AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType, }, diff --git a/crates/data-model/src/site_config.rs b/crates/data-model/src/site_config.rs index 9622203ad..bb92dc3e4 100644 --- a/crates/data-model/src/site_config.rs +++ b/crates/data-model/src/site_config.rs @@ -4,7 +4,10 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::num::NonZeroU64; + use chrono::Duration; +use serde::Serialize; use url::Url; /// Which Captcha service is being used @@ -36,6 +39,12 @@ pub struct SessionExpirationConfig { pub compat_session_inactivity_ttl: Option, } +#[derive(Serialize, Debug, Clone)] +pub struct SessionLimitConfig { + pub soft_limit: NonZeroU64, + pub hard_limit: NonZeroU64, +} + /// Random site configuration we want accessible in various places. #[allow(clippy::struct_excessive_bools)] #[derive(Debug, Clone)] @@ -99,4 +108,7 @@ pub struct SiteConfig { /// The iframe URL to show in the plan tab of the UI pub plan_management_iframe_uri: Option, + + /// Limits on the number of application sessions that each user can have + pub session_limit: Option, } diff --git a/crates/handlers/src/oauth2/authorization/consent.rs b/crates/handlers/src/oauth2/authorization/consent.rs index 968aec08a..2587828b5 100644 --- a/crates/handlers/src/oauth2/authorization/consent.rs +++ b/crates/handlers/src/oauth2/authorization/consent.rs @@ -32,7 +32,7 @@ use super::callback::CallbackDestination; use crate::{ BoundActivityTracker, PreferredLanguage, impl_from_error_for_route, oauth2::generate_id_token, - session::{SessionOrFallback, load_session_or_fallback}, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Debug, Error)] @@ -136,10 +136,13 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: Some(&session.user), client: &client, + session_counts: Some(session_counts), scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, requester: mas_policy::Requester { @@ -235,10 +238,13 @@ pub(crate) async fn post( return Err(RouteError::GrantNotPending(grant.id)); } + let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?; + let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: Some(&browser_session.user), client: &client, + session_counts: Some(session_counts), scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, requester: mas_policy::Requester { diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index 30a35aa17..e1d32870f 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -27,7 +27,7 @@ use ulid::Ulid; use crate::{ BoundActivityTracker, PreferredLanguage, - session::{SessionOrFallback, load_session_or_fallback}, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Deserialize, Debug)] @@ -103,11 +103,14 @@ pub(crate) async fn get( .context("Client not found") .map_err(InternalError::from_anyhow)?; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + // Evaluate the policy let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { grant_type: mas_policy::GrantType::DeviceCode, client: &client, + session_counts: Some(session_counts), scope: &grant.scope, user: Some(&session.user), requester: mas_policy::Requester { @@ -205,11 +208,14 @@ pub(crate) async fn post( .context("Client not found") .map_err(InternalError::from_anyhow)?; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + // Evaluate the policy let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { grant_type: mas_policy::GrantType::DeviceCode, client: &client, + session_counts: Some(session_counts), scope: &grant.scope, user: Some(&session.user), requester: mas_policy::Requester { diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 4a63d8290..99506ac29 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -781,6 +781,7 @@ async fn client_credentials_grant( .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: None, client, + session_counts: None, scope: &scope, grant_type: mas_policy::GrantType::ClientCredentials, requester: mas_policy::Requester { diff --git a/crates/handlers/src/session.rs b/crates/handlers/src/session.rs index cb05510ba..aa3836a26 100644 --- a/crates/handlers/src/session.rs +++ b/crates/handlers/src/session.rs @@ -8,9 +8,13 @@ use axum::response::{Html, IntoResponse as _, Response}; use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt}; -use mas_data_model::{BrowserSession, Clock}; +use mas_data_model::{BrowserSession, Clock, User}; use mas_i18n::DataLocale; -use mas_storage::{BoxRepository, RepositoryError}; +use mas_policy::model::SessionCounts; +use mas_storage::{ + BoxRepository, RepositoryError, compat::CompatSessionFilter, oauth2::OAuth2SessionFilter, + personal::PersonalSessionFilter, +}; use mas_templates::{AccountInactiveContext, TemplateContext, Templates}; use rand::RngCore; use thiserror::Error; @@ -102,3 +106,62 @@ pub async fn load_session_or_fallback( maybe_session: Some(session), }) } + +/// Get a count of sessions for the given user, for the purposes of session +/// limiting. +/// +/// Includes: +/// - OAuth 2 sessions +/// - Compatibility sessions +/// - Personal sessions (unless owned by a different user) +/// +/// # Backstory +/// +/// Originally, we were only intending to count sessions with devices in this +/// result, because those are the entries that are expensive for Synapse and +/// also would not hinder use of deviceless clients (like Element Admin, an +/// admin dashboard). +/// +/// However, to do so, we would need to count only sessions including device +/// scopes. To do this efficiently, we'd need a partial index on sessions +/// including device scopes. +/// +/// It turns out that this can't be done cleanly (as we need to, in Postgres, +/// match scope lists where one of the scopes matches one of 2 known prefixes), +/// at least not without somewhat uncomfortable stored functions. +/// +/// So for simplicity's sake, we now count all sessions. +/// For practical use cases, it's not likely to make a noticeable difference +/// (and maybe it's good that there's an overall limit). +pub(crate) async fn count_user_sessions_for_limiting( + repo: &mut BoxRepository, + user: &User, +) -> Result { + let oauth2 = repo + .oauth2_session() + .count(OAuth2SessionFilter::new().active_only().for_user(user)) + .await? as u64; + + let compat = repo + .compat_session() + .count(CompatSessionFilter::new().active_only().for_user(user)) + .await? as u64; + + // Only include self-owned personal sessions, not administratively-owned ones + let personal = repo + .personal_session() + .count( + PersonalSessionFilter::new() + .active_only() + .for_actor_user(user) + .for_owner_user(user), + ) + .await? as u64; + + Ok(SessionCounts { + total: oauth2 + compat + personal, + oauth2, + compat, + personal, + }) +} diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index f1859f352..cf0466a9c 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -85,7 +85,7 @@ pub(crate) async fn policy_factory( email: "email/violation".to_owned(), }; - let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data); + let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data); let policy_factory = PolicyFactory::load(file, data, entrypoints).await?; let policy_factory = Arc::new(policy_factory); @@ -148,6 +148,7 @@ pub fn test_site_config() -> SiteConfig { session_expiration: None, login_with_email_allowed: true, plan_management_iframe_uri: None, + session_limit: None, } } diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 3a3a23c3f..8a038aea8 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -9,11 +9,12 @@ pub mod model; use std::sync::Arc; use arc_swap::ArcSwap; -use mas_data_model::Ulid; +use mas_data_model::{SessionLimitConfig, Ulid}; use opa_wasm::{ Runtime, wasmtime::{Config, Engine, Module, OptLevel, Store}, }; +use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -87,16 +88,29 @@ impl Entrypoints { #[derive(Debug)] pub struct Data { - server_name: String, + base: BaseData, + // We will merge this in a custom way, so don't emit as part of the base rest: Option, } +#[derive(Serialize, Debug)] +struct BaseData { + server_name: String, + + /// Limits on the number of application sessions that each user can have + session_limit: Option, +} + impl Data { #[must_use] - pub fn new(server_name: String) -> Self { + pub fn new(server_name: String, session_limit: Option) -> Self { Self { - server_name, + base: BaseData { + server_name, + session_limit, + }, + rest: None, } } @@ -108,9 +122,7 @@ impl Data { } fn to_value(&self) -> Result { - let base = serde_json::json!({ - "server_name": self.server_name, - }); + let base = serde_json::to_value(&self.base)?; if let Some(rest) = &self.rest { merge_data(base, rest.clone()) @@ -458,7 +470,7 @@ mod tests { #[tokio::test] async fn test_register() { - let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({ + let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({ "allowed_domains": ["element.io", "*.element.io"], "banned_domains": ["staging.element.io"], })); @@ -528,7 +540,7 @@ mod tests { #[tokio::test] async fn test_dynamic_data() { - let data = Data::new("example.com".to_owned()); + let data = Data::new("example.com".to_owned(), None); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) @@ -597,7 +609,7 @@ mod tests { #[tokio::test] async fn test_big_dynamic_data() { - let data = Data::new("example.com".to_owned()); + let data = Data::new("example.com".to_owned(), None); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 2f54ae8bb..b85170025 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -49,6 +49,9 @@ pub enum Code { /// The email address is banned. EmailBanned, + + /// The user has reached their session limit. + TooManySessions, } impl Code { @@ -66,6 +69,7 @@ impl Code { Self::EmailDomainBanned => "email-domain-banned", Self::EmailNotAllowed => "email-not-allowed", Self::EmailBanned => "email-banned", + Self::TooManySessions => "too-many-sessions", } } } @@ -168,6 +172,10 @@ pub struct AuthorizationGrantInput<'a> { #[schemars(with = "Option>")] pub user: Option<&'a User>, + /// How many sessions the user has. + /// Not populated if it's not a user logging in. + pub session_counts: Option, + #[schemars(with = "std::collections::HashMap")] pub client: &'a Client, @@ -179,6 +187,16 @@ pub struct AuthorizationGrantInput<'a> { pub requester: Requester, } +/// Information about how many sessions the user has +#[derive(Serialize, Debug, JsonSchema)] +pub struct SessionCounts { + pub total: u64, + + pub oauth2: u64, + pub compat: u64, + pub personal: u64, +} + /// Input for the email add policy. #[derive(Serialize, Debug, JsonSchema)] #[serde(rename_all = "snake_case")] diff --git a/docs/config.schema.json b/docs/config.schema.json index 524f02c93..409d49fdf 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -2659,6 +2659,14 @@ "plan_management_iframe_uri": { "description": "Experimental feature to show a plan management tab and iframe. This value is passed through \"as is\" to the client without any validation.", "type": "string" + }, + "session_limit": { + "description": "Experimental feature to limit the number of application sessions per user.\n\nDisabled by default.", + "allOf": [ + { + "$ref": "#/definitions/SessionLimitConfig" + } + ] } } }, @@ -2692,6 +2700,26 @@ "type": "boolean" } } + }, + "SessionLimitConfig": { + "description": "Configuration options for the session limit feature", + "type": "object", + "required": [ + "hard_limit", + "soft_limit" + ], + "properties": { + "soft_limit": { + "type": "integer", + "format": "uint64", + "minimum": 1.0 + }, + "hard_limit": { + "type": "integer", + "format": "uint64", + "minimum": 1.0 + } + } } } } \ No newline at end of file diff --git a/policies/authorization_grant/authorization_grant.rego b/policies/authorization_grant/authorization_grant.rego index 79f737af1..e7d1e68e5 100644 --- a/policies/authorization_grant/authorization_grant.rego +++ b/policies/authorization_grant/authorization_grant.rego @@ -153,3 +153,20 @@ violation contains {"msg": sprintf( )} if { common.requester_banned(input.requester, data.requester) } + +violation contains { + "code": "too-many-sessions", + "msg": "user has too many active sessions", +} if { + # Only apply if session limits are enabled in the config + data.session_limit != null + + # Only apply if it's a user logging in (who therefore has countable sessions) + input.session_counts != null + + # For OAuth 2 login, a violation occurs when the soft limit has already been + # reached or exceeded. + # We use the soft limit because the user will be able to interactively remove + # sessions to return under the limit. + data.session_limit.soft_limit <= input.session_counts.total +} diff --git a/policies/authorization_grant/authorization_grant_test.rego b/policies/authorization_grant/authorization_grant_test.rego index 6634eacb9..e2ca74086 100644 --- a/policies/authorization_grant/authorization_grant_test.rego +++ b/policies/authorization_grant/authorization_grant_test.rego @@ -222,3 +222,35 @@ test_mas_scopes if { with input.grant_type as "authorization_code" with input.scope as "urn:mas:admin" } + +test_session_limiting if { + authorization_grant.allow with input.user as user + with input.session_counts as {"total": 1} + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + authorization_grant.allow with input.user as user + with input.session_counts as {"total": 31} + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not authorization_grant.allow with input.user as user + with input.session_counts as {"total": 32} + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not authorization_grant.allow with input.user as user + with input.session_counts as {"total": 42} + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not authorization_grant.allow with input.user as user + with input.session_counts as {"total": 65} + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + # No limit configured + authorization_grant.allow with input.user as user + with input.session_counts as {"total": 1} + with data.session_limit as null + + # Client credentials grant + authorization_grant.allow with input.user as user + with input.session_counts as null + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} +} diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index f23bf7a73..a5d49e304 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -14,6 +14,14 @@ "type": "object", "additionalProperties": true }, + "session_counts": { + "description": "How many sessions the user has. Not populated if it's not a user logging in.", + "allOf": [ + { + "$ref": "#/definitions/SessionCounts" + } + ] + }, "client": { "type": "object", "additionalProperties": true @@ -29,6 +37,38 @@ } }, "definitions": { + "SessionCounts": { + "description": "Information about how many sessions the user has", + "type": "object", + "required": [ + "compat", + "oauth2", + "personal", + "total" + ], + "properties": { + "total": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "oauth2": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "compat": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + "personal": { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + } + } + }, "GrantType": { "type": "string", "enum": [ diff --git a/translations/en.json b/translations/en.json index 5881171fb..e551e15c2 100644 --- a/translations/en.json +++ b/translations/en.json @@ -499,7 +499,7 @@ "context": "pages/policy_violation.html:19:25-62", "description": "Displayed when an authorization request is denied by the policy" }, - "heading": "The authorization request was denied the policy enforced by this service", + "heading": "The authorization request was denied by the policy enforced by this service", "@heading": { "context": "pages/policy_violation.html:18:27-60", "description": "Displayed when an authorization request is denied by the policy"