Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions crates/cli/src/commands/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PolicyConfig,
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig,
MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
Expand Down Expand Up @@ -45,8 +46,12 @@ impl Options {
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config =
ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
let policy_factory =
policy_factory_from_config(&config, &matrix_config, &experimental_config)
.await?;

if with_dynamic_data {
let database_config =
Expand Down
4 changes: 3 additions & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ impl Options {

// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory =
policy_factory_from_config(&config.policy, &config.matrix, &config.experimental)
.await?;
let policy_factory = Arc::new(policy_factory);

load_policy_factory_dynamic_data_continuously(
Expand Down
23 changes: 20 additions & 3 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use mas_config::{
PolicyConfig, TemplatesConfig,
};
use mas_context::LogContext;
use mas_data_model::{SessionExpirationConfig, SiteConfig};
use mas_data_model::{SessionExpirationConfig, SessionLimitConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
Expand Down Expand Up @@ -135,6 +135,7 @@ pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
Expand All @@ -147,8 +148,17 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};

let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
let session_limit_config =
experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
});

let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config)
.with_rest(config.data.clone());

PolicyFactory::load(policy_file, data, entrypoints)
.await
Expand Down Expand Up @@ -225,6 +235,13 @@ pub fn site_config_from_config(
session_expiration,
login_with_email_allowed: account_config.login_with_email_allowed,
plan_management_iframe_uri: experimental_config.plan_management_iframe_uri.clone(),
session_limit: experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
}),
})
}

Expand Down
15 changes: 15 additions & 0 deletions crates/config/src/sections/experimental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ pub struct ExperimentalConfig {
/// validation.
#[serde(skip_serializing_if = "Option::is_none")]
pub plan_management_iframe_uri: Option<String>,

/// Experimental feature to limit the number of application sessions per
/// user.
///
/// Disabled by default.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_limit: Option<SessionLimitConfig>,
}

impl Default for ExperimentalConfig {
Expand All @@ -90,6 +97,7 @@ impl Default for ExperimentalConfig {
compat_token_ttl: default_token_ttl(),
inactive_session_expiration: None,
plan_management_iframe_uri: None,
session_limit: None,
}
}
}
Expand All @@ -106,3 +114,10 @@ impl ExperimentalConfig {
impl ConfigurationSection for ExperimentalConfig {
const PATH: Option<&'static str> = Some("experimental");
}

/// Configuration options for the inactive session expiration feature
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct SessionLimitConfig {
pub soft_limit: u64,
pub hard_limit: u64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBD, how does it behave when a limit is set to zero? Would it be worth using NonZeroU64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
4 changes: 3 additions & 1 deletion crates/data-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ pub use self::{
DeviceCodeGrantState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
policy_data::PolicyData,
site_config::{CaptchaConfig, CaptchaService, SessionExpirationConfig, SiteConfig},
site_config::{
CaptchaConfig, CaptchaService, SessionExpirationConfig, SessionLimitConfig, SiteConfig,
},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},
Expand Down
10 changes: 10 additions & 0 deletions crates/data-model/src/site_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Please see LICENSE files in the repository root for full details.

use chrono::Duration;
use serde::Serialize;
use url::Url;

/// Which Captcha service is being used
Expand Down Expand Up @@ -36,6 +37,12 @@ pub struct SessionExpirationConfig {
pub compat_session_inactivity_ttl: Option<Duration>,
}

#[derive(Serialize, Debug, Clone)]
pub struct SessionLimitConfig {
pub soft_limit: u64,
pub hard_limit: u64,
}

/// Random site configuration we want accessible in various places.
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -99,4 +106,7 @@ pub struct SiteConfig {

/// The iframe URL to show in the plan tab of the UI
pub plan_management_iframe_uri: Option<String>,

/// Limits on the number of application sessions that each user can have
pub session_limit: Option<SessionLimitConfig>,
}
12 changes: 11 additions & 1 deletion crates/handlers/src/oauth2/authorization/consent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use super::callback::CallbackDestination;
use crate::{
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
oauth2::generate_id_token,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};

#[derive(Debug, Error)]
Expand Down Expand Up @@ -136,10 +136,15 @@ pub(crate) async fn get(

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
.await
.map_err(|e| RouteError::Internal(e.into()))?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surprised RouteError doesn't implement From that error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, anyhow, that's why. Fine then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be fair, maybe I should have double-thought it and passed through the database error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
Expand Down Expand Up @@ -235,10 +240,15 @@ pub(crate) async fn post(
return Err(RouteError::GrantNotPending(grant.id));
}

let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user)
.await
.map_err(|e| RouteError::Internal(e.into()))?;

let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&browser_session.user),
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
Expand Down
12 changes: 11 additions & 1 deletion crates/handlers/src/oauth2/device/consent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use ulid::Ulid;

use crate::{
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -103,11 +103,16 @@ pub(crate) async fn get(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
.await
.map_err(InternalError::from_anyhow)?;

// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
Expand Down Expand Up @@ -205,11 +210,16 @@ pub(crate) async fn post(
.context("Client not found")
.map_err(InternalError::from_anyhow)?;

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user)
.await
.map_err(InternalError::from_anyhow)?;

// Evaluate the policy
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
grant_type: mas_policy::GrantType::DeviceCode,
client: &client,
session_counts: Some(session_counts),
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/oauth2/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ async fn client_credentials_grant(
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: None,
client,
session_counts: None,
scope: &scope,
grant_type: mas_policy::GrantType::ClientCredentials,
requester: mas_policy::Requester {
Expand Down
67 changes: 65 additions & 2 deletions crates/handlers/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

use axum::response::{Html, IntoResponse as _, Response};
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt};
use mas_data_model::{BrowserSession, Clock};
use mas_data_model::{BrowserSession, Clock, User};
use mas_i18n::DataLocale;
use mas_storage::{BoxRepository, RepositoryError};
use mas_policy::model::SessionCounts;
use mas_storage::{
BoxRepository, RepositoryError, compat::CompatSessionFilter, oauth2::OAuth2SessionFilter,
personal::PersonalSessionFilter,
};
use mas_templates::{AccountInactiveContext, TemplateContext, Templates};
use rand::RngCore;
use thiserror::Error;
Expand Down Expand Up @@ -102,3 +106,62 @@ pub async fn load_session_or_fallback(
maybe_session: Some(session),
})
}

/// Get a count of sessions for the given user, for the purposes of session
/// limiting.
///
/// Includes:
/// - OAuth 2 sessions
/// - Compatibility sessions
/// - Personal sessions (unless owned by a different user)
///
/// # Backstory
///
/// Originally, we were only intending to count sessions with devices in this
/// result, because those are the entries that are expensive for Synapse and
/// also would not hinder use of deviceless clients (like Element Admin, an
/// admin dashboard).
///
/// However, to do so, we would need to count only sessions including device
/// scopes. To do this efficiently, we'd need a partial index on sessions
/// including device scopes.
///
/// It turns out that this can't be done cleanly (as we need to, in Postgres,
/// match scope lists where one of the scopes matches one of 2 known prefixes),
/// at least not without somewhat uncomfortable stored functions.
///
/// So for simplicity's sake, we now count all sessions.
/// For practical use cases, it's not likely to make a noticeable difference
/// (and maybe it's good that there's an overall limit).
pub(crate) async fn count_user_sessions_for_limiting(
repo: &mut BoxRepository,
user: &User,
) -> anyhow::Result<SessionCounts> {
let oauth2 = repo
.oauth2_session()
.count(OAuth2SessionFilter::new().active_only().for_user(user))
.await? as u64;

let compat = repo
.compat_session()
.count(CompatSessionFilter::new().active_only().for_user(user))
.await? as u64;

// Only include self-owned personal sessions, not administratively-owned ones
let personal = repo
.personal_session()
.count(
PersonalSessionFilter::new()
.active_only()
.for_actor_user(user)
.for_owner_user(user),
)
.await? as u64;

Ok(SessionCounts {
total: oauth2 + compat + personal,
oauth2,
compat,
personal,
})
}
3 changes: 2 additions & 1 deletion crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};

let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
let data = mas_policy::Data::new(server_name.to_owned(), None).with_rest(data);

let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
Expand Down Expand Up @@ -148,6 +148,7 @@ pub fn test_site_config() -> SiteConfig {
session_expiration: None,
login_with_email_allowed: true,
plan_management_iframe_uri: None,
session_limit: None,
}
}

Expand Down
Loading
Loading