diff --git a/Cargo.lock b/Cargo.lock index 4605097ac..ba5d9c5f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1294,6 +1294,8 @@ dependencies = [ "bytes", "defguard_common", "defguard_core", + "defguard_enterprise_activity_log_stream", + "defguard_enterprise_license", "defguard_event_logger", "defguard_event_router", "defguard_gateway_manager", @@ -1378,6 +1380,13 @@ dependencies = [ "claims", "defguard_certs", "defguard_common", + "defguard_enterprise_activity_log_stream", + "defguard_enterprise_db", + "defguard_enterprise_directory_sync", + "defguard_enterprise_firewall", + "defguard_enterprise_ldap", + "defguard_enterprise_license", + "defguard_enterprise_snat", "defguard_mail", "defguard_proto", "defguard_static_ip", @@ -1423,7 +1432,6 @@ dependencies = [ "tokio-util", "tonic", "tonic-health", - "tonic-prost-build", "totp-lite", "tower", "tower-http", @@ -1439,6 +1447,133 @@ dependencies = [ "x25519-dalek", ] +[[package]] +name = "defguard_enterprise_activity_log_stream" +version = "0.0.0" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bytes", + "defguard_common", + "defguard_enterprise_db", + "defguard_enterprise_license", + "reqwest", + "sqlx", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "defguard_enterprise_db" +version = "0.0.0" +dependencies = [ + "chrono", + "defguard_common", + "defguard_enterprise_license", + "ipnetwork", + "model_derive", + "rand 0.8.5", + "serde", + "serde_json", + "sha256", + "sqlx", + "struct-patch", + "strum", + "strum_macros", + "thiserror 2.0.18", + "tracing", + "utoipa", +] + +[[package]] +name = "defguard_enterprise_directory_sync" +version = "0.0.0" +dependencies = [ + "chrono", + "defguard_common", + "defguard_enterprise_db", + "defguard_enterprise_ldap", + "defguard_enterprise_license", + "futures", + "ipnetwork", + "jsonwebkey", + "jsonwebtoken", + "parse_link_header", + "paste", + "reqwest", + "secrecy", + "serde", + "serde_json", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tracing", + "trait-variant", +] + +[[package]] +name = "defguard_enterprise_firewall" +version = "0.0.0" +dependencies = [ + "chrono", + "defguard_common", + "defguard_enterprise_db", + "defguard_enterprise_license", + "defguard_proto", + "ipnetwork", + "rand 0.8.5", + "sqlx", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "defguard_enterprise_ldap" +version = "0.0.0" +dependencies = [ + "base64 0.22.1", + "defguard_common", + "defguard_enterprise_license", + "ldap3", + "md4", + "rand 0.8.5", + "sha-1", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "defguard_enterprise_license" +version = "0.0.0" +dependencies = [ + "anyhow", + "base64 0.22.1", + "chrono", + "defguard_common", + "humantime", + "pgp", + "prost", + "reqwest", + "serde", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tonic", + "tonic-prost-build", + "tracing", +] + +[[package]] +name = "defguard_enterprise_snat" +version = "0.0.0" +dependencies = [ + "sqlx", + "thiserror 2.0.18", +] + [[package]] name = "defguard_event_logger" version = "0.0.0" @@ -1447,6 +1582,7 @@ dependencies = [ "chrono", "defguard_common", "defguard_core", + "defguard_enterprise_db", "defguard_session_manager", "serde_json", "sqlx", @@ -1476,6 +1612,7 @@ dependencies = [ "defguard_certs", "defguard_common", "defguard_core", + "defguard_enterprise_firewall", "defguard_grpc_tls", "defguard_proto", "defguard_version", @@ -1566,6 +1703,11 @@ dependencies = [ "defguard_certs", "defguard_common", "defguard_core", + "defguard_enterprise_db", + "defguard_enterprise_directory_sync", + "defguard_enterprise_firewall", + "defguard_enterprise_ldap", + "defguard_enterprise_license", "defguard_grpc_tls", "defguard_mail", "defguard_proto", diff --git a/Cargo.toml b/Cargo.toml index 2315ef722..fd7fc8a27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,8 @@ repository = "https://github.com/DefGuard/defguard" rust-version = "1.87.0" [workspace] -members = ["crates/*", "tools/*"] -default-members = ["crates/*"] +members = ["crates/*", "enterprise/crates/*", "tools/*"] +default-members = ["crates/*", "enterprise/crates/*"] resolver = "2" [workspace.dependencies] @@ -17,6 +17,13 @@ defguard_setup = { path = "./crates/defguard_setup", version = "0.0.0" } defguard_common = { path = "./crates/defguard_common", version = "2.0.0" } defguard_static_ip = { path = "./crates/defguard_static_ip", version = "0.0.0" } defguard_core = { path = "./crates/defguard_core", version = "0.0.0" } +defguard_enterprise_activity_log_stream = { path = "./enterprise/crates/defguard_enterprise_activity_log_stream", version = "0.0.0" } +defguard_enterprise_db = { path = "./enterprise/crates/defguard_enterprise_db", version = "0.0.0" } +defguard_enterprise_directory_sync = { path = "./enterprise/crates/defguard_enterprise_directory_sync", version = "0.0.0" } +defguard_enterprise_firewall = { path = "./enterprise/crates/defguard_enterprise_firewall", version = "0.0.0" } +defguard_enterprise_ldap = { path = "./enterprise/crates/defguard_enterprise_ldap", version = "0.0.0" } +defguard_enterprise_license = { path = "./enterprise/crates/defguard_enterprise_license", version = "0.0.0" } +defguard_enterprise_snat = { path = "./enterprise/crates/defguard_enterprise_snat", version = "0.0.0" } defguard_event_logger = { path = "./crates/defguard_event_logger", version = "0.0.0" } defguard_event_router = { path = "./crates/defguard_event_router", version = "0.0.0" } defguard_gateway_manager = { path = "./crates/defguard_gateway_manager", version = "0.0.0" } diff --git a/crates/defguard/Cargo.toml b/crates/defguard/Cargo.toml index e6fef7e6b..64d08b6a3 100644 --- a/crates/defguard/Cargo.toml +++ b/crates/defguard/Cargo.toml @@ -11,6 +11,8 @@ rust-version.workspace = true # internal crates defguard_common = { workspace = true } defguard_core = { workspace = true } +defguard_enterprise_activity_log_stream = { workspace = true } +defguard_enterprise_license = { workspace = true } defguard_event_router = { workspace = true } defguard_event_logger = { workspace = true } defguard_gateway_manager = { workspace = true } diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index b4287f01d..75a1e1f06 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -17,17 +17,14 @@ use defguard_common::{ use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - enterprise::{ - activity_log_stream::activity_log_stream_manager::run_activity_log_stream_manager, - license::{License, run_periodic_license_check, set_cached_license}, - limits::update_counts, - }, events::{ApiEvent, BidiStreamEvent}, grpc::{GatewayEvent, WorkerState, run_grpc_server}, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, version::IncompatibleComponents, }; +use defguard_enterprise_activity_log_stream::activity_log_stream_manager::run_activity_log_stream_manager; +use defguard_enterprise_license::{License, run_periodic_license_check, set_cached_license, update_counts}; use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; use defguard_gateway_manager::{GatewayManager, GatewayTxSet}; diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 0927fce94..5555ed2a5 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -10,6 +10,13 @@ rust-version.workspace = true [dependencies] # internal crates defguard_common = { workspace = true } +defguard_enterprise_activity_log_stream = { workspace = true } +defguard_enterprise_db = { workspace = true } +defguard_enterprise_directory_sync = { workspace = true } +defguard_enterprise_firewall = { workspace = true } +defguard_enterprise_ldap = { workspace = true } +defguard_enterprise_license = { workspace = true } +defguard_enterprise_snat = { workspace = true } defguard_mail = { workspace = true } defguard_proto = { workspace = true } defguard_web_ui = { workspace = true } @@ -96,6 +103,3 @@ reqwest = { version = "0.12", features = [ ], default-features = false } serde_qs = "1.0" webauthn-authenticator-rs = { version = "0.5", features = ["softpasskey"] } - -[build-dependencies] -tonic-prost-build.workspace = true diff --git a/crates/defguard_core/src/auth/mod.rs b/crates/defguard_core/src/auth/mod.rs index 56631cd9d..faaabf174 100644 --- a/crates/defguard_core/src/auth/mod.rs +++ b/crates/defguard_core/src/auth/mod.rs @@ -23,11 +23,9 @@ use defguard_common::db::{ }; use sqlx::PgPool; -use crate::{ - enterprise::{db::models::api_tokens::ApiToken, is_business_license_active}, - error::WebError, - handlers::SESSION_COOKIE_NAME, -}; +use crate::{error::WebError, handlers::SESSION_COOKIE_NAME}; +use defguard_enterprise_db::models::api_tokens::ApiToken; +use defguard_enterprise_license::is_business_license_active; pub struct SessionExtractor(pub Session); diff --git a/crates/defguard_core/src/db/models/activity_log/metadata.rs b/crates/defguard_core/src/db/models/activity_log/metadata.rs index 3d13c68c9..d21cf9d64 100644 --- a/crates/defguard_core/src/db/models/activity_log/metadata.rs +++ b/crates/defguard_core/src/db/models/activity_log/metadata.rs @@ -1,26 +1,23 @@ use chrono::NaiveDateTime; use defguard_common::db::{ - Id, models::{ - AuthenticationKey, AuthenticationKeyType, Device, MFAMethod, Settings, WebAuthn, - WireguardNetwork, group::Group, oauth2client::OAuth2Client, proxy::Proxy, settings::{LdapSyncStatus, OpenIdUsernameHandling, SmtpEncryption}, user::User, + AuthenticationKey, AuthenticationKeyType, Device, MFAMethod, Settings, WebAuthn, + WireguardNetwork, }, + Id, }; -use crate::{ - db::WebHook, - enterprise::db::models::{ - activity_log_stream::{ActivityLogStream, ActivityLogStreamType}, - api_tokens::ApiToken, - openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProvider}, - snat::UserSnatBinding, - }, - events::ClientMFAMethod, +use crate::{db::WebHook, events::ClientMFAMethod}; +use defguard_enterprise_db::models::{ + activity_log_stream::{ActivityLogStream, ActivityLogStreamType}, + api_tokens::ApiToken, + openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProvider}, + snat::UserSnatBinding, }; #[derive(Serialize)] diff --git a/crates/defguard_core/src/enterprise/activity_log_stream/error.rs b/crates/defguard_core/src/enterprise/activity_log_stream/error.rs deleted file mode 100644 index e4eef48d8..000000000 --- a/crates/defguard_core/src/enterprise/activity_log_stream/error.rs +++ /dev/null @@ -1,11 +0,0 @@ -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum ActivityLogStreamError { - #[error("Deserialization of {0} error: {1}")] - ConfigDeserializeError(String, String), - #[error("Sqlx error: {0}")] - SqlxError(#[from] sqlx::Error), - #[error("Parsing http header value failed")] - HeaderValueParsing(), -} diff --git a/crates/defguard_core/src/enterprise/directory_sync_context.rs b/crates/defguard_core/src/enterprise/directory_sync_context.rs new file mode 100644 index 000000000..787cf8833 --- /dev/null +++ b/crates/defguard_core/src/enterprise/directory_sync_context.rs @@ -0,0 +1,37 @@ +use tokio::sync::broadcast::Sender; + +use defguard_enterprise_directory_sync::{DirectorySyncContext, DirectorySyncError}; + +use crate::{grpc::GatewayEvent, user_management}; + +pub fn build_directory_sync_context(wg_tx: Sender) -> DirectorySyncContext { + let disable_tx = wg_tx.clone(); + let delete_tx = wg_tx.clone(); + let sync_tx = wg_tx.clone(); + DirectorySyncContext { + disable_user: Box::new(move |user, conn| { + let disable_tx = disable_tx.clone(); + Box::pin(async move { + user_management::disable_user(user, conn, &disable_tx) + .await + .map_err(|err| DirectorySyncError::UserUpdateError(err.to_string())) + }) + }), + delete_user_and_cleanup_devices: Box::new(move |user, conn| { + let delete_tx = delete_tx.clone(); + Box::pin(async move { + user_management::delete_user_and_cleanup_devices(user, conn, &delete_tx) + .await + .map_err(|err| DirectorySyncError::UserUpdateError(err.to_string())) + }) + }), + sync_allowed_user_devices: Box::new(move |user, conn| { + let sync_tx = sync_tx.clone(); + Box::pin(async move { + user_management::sync_allowed_user_devices(user, conn, &sync_tx) + .await + .map_err(|err| DirectorySyncError::NetworkUpdateError(err.to_string())) + }) + }), + } +} diff --git a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs index e23c402c2..a800f19d0 100644 --- a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs +++ b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs @@ -4,16 +4,14 @@ use reqwest::Url; use tonic::Status; use crate::{ - enterprise::{ - handlers::openid_login::{extract_state_data, user_from_claims}, - is_business_license_active, - }, + enterprise::handlers::openid_login::{extract_state_data, user_from_claims}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::{ proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, utils::parse_client_ip_agent, }, }; +use defguard_enterprise_license::is_business_license_active; impl ClientMfaServer { #[instrument(skip_all)] diff --git a/crates/defguard_core/src/enterprise/grpc/polling.rs b/crates/defguard_core/src/enterprise/grpc/polling.rs index fef268fcb..577ae730f 100644 --- a/crates/defguard_core/src/enterprise/grpc/polling.rs +++ b/crates/defguard_core/src/enterprise/grpc/polling.rs @@ -6,7 +6,8 @@ use defguard_proto::proxy::{DeviceInfo, InstanceInfoRequest, InstanceInfoRespons use sqlx::PgPool; use tonic::Status; -use crate::{enterprise::is_business_license_active, grpc::utils::build_device_config_response}; +use crate::grpc::utils::build_device_config_response; +use defguard_enterprise_license::is_business_license_active; pub struct PollingServer { pool: PgPool, diff --git a/crates/defguard_core/src/enterprise/handlers/acl.rs b/crates/defguard_core/src/enterprise/handlers/acl.rs index 192e43ddd..17a4bc6c3 100644 --- a/crates/defguard_core/src/enterprise/handlers/acl.rs +++ b/crates/defguard_core/src/enterprise/handlers/acl.rs @@ -7,7 +7,7 @@ use axum::{ http::StatusCode, }; use chrono::NaiveDateTime; -use defguard_common::db::Id; +use defguard_common::db::{Id, NoId}; use serde_json::{Value, json}; use utoipa::ToSchema; @@ -15,10 +15,19 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::acl::{AclAlias, AclRule, AclRuleInfo, Protocol, RuleState}, error::WebError, handlers::{ApiResponse, ApiResult}, }; +use defguard_common::db::models::WireguardNetwork; +use defguard_enterprise_db::models::acl::{ + AclAlias, AclError, AclRule, AclRuleAlias, AclRuleDestinationRange, AclRuleDevice, + AclRuleGroup, AclRuleInfo, AclRuleNetwork, AclRuleUser, AliasKind, AliasState, Protocol, + RuleState, parse_destination_addresses, parse_ports, +}; +use sqlx::error::ErrorKind; +use sqlx::postgres::types::PgRange; +use sqlx::{PgConnection, PgPool, query}; +use std::net::IpAddr; /// API representation of [`AclRule`] used in API responses. /// All relations represented as arrays of IDs. @@ -207,14 +216,24 @@ pub(crate) async fn list_acl_rules( ) -> ApiResult { debug!("User {} listing ACL rules", session.user.username); let mut conn = appstate.pool.acquire().await?; - let rules = AclRule::all(&mut *conn).await?; + let rules: Vec> = sqlx::query_as!( + AclRule, + "SELECT id, parent_id, state \"state: _\", name, allow_all_users, deny_all_users, \ + allow_all_groups, deny_all_groups, allow_all_network_devices, deny_all_network_devices, \ + all_locations, addresses, ports, protocols, enabled, expires, any_address, any_port, \ + any_protocol, use_manual_destination_settings FROM aclrule" + ) + .fetch_all(&mut *conn) + .await?; let mut api_rules = Vec::::with_capacity(rules.len()); for rule in &rules { // TODO: may require optimisation wrt. sql queries - let info = rule.to_info(&mut conn).await.map_err(|err| { - error!("Error retrieving ACL rule {rule:?}: {err}"); - err - })?; + let info = AclRule::::to_info(rule, &mut conn) + .await + .map_err(|err| { + error!("Error retrieving ACL rule {rule:?}: {err}"); + err + })?; api_rules.push(info.into()); } info!("User {} listed ACL rules", session.user.username); @@ -244,12 +263,14 @@ pub(crate) async fn get_acl_rule( let mut conn = appstate.pool.acquire().await?; let (rule, status) = match AclRule::find_by_id(&mut *conn, id).await? { Some(rule) => ( - json!(ApiAclRule::from(rule.to_info(&mut conn).await.map_err( - |err| { - error!("Error retrieving ACL rule {rule:?}: {err}"); - err - } - )?)), + json!(ApiAclRule::from( + AclRule::::to_info(&rule, &mut conn) + .await + .map_err(|err| { + error!("Error retrieving ACL rule {rule:?}: {err}"); + err + })? + )), StatusCode::OK, ), None => (Value::Null, StatusCode::NOT_FOUND), @@ -281,7 +302,7 @@ pub(crate) async fn create_acl_rule( // validate submitted ACL rule data.validate()?; - let rule = AclRule::create_from_api(&appstate.pool, &data) + let rule = create_rule_from_api(&appstate.pool, &data) .await .map_err(|err| { error!("Error creating ACL rule {data:?}: {err}"); @@ -320,7 +341,7 @@ pub(crate) async fn update_acl_rule( // validate submitted ACL rule data.validate()?; - let rule = AclRule::update_from_api(&appstate.pool, id, &data) + let rule = update_rule_from_api(&appstate.pool, id, &data) .await .map_err(|err| { error!("Error updating ACL rule {data:?}: {err}"); @@ -350,7 +371,7 @@ pub(crate) async fn delete_acl_rule( Path(id): Path, ) -> ApiResult { debug!("User {} deleting ACL rule {id}", session.user.username); - AclRule::delete_from_api(&appstate.pool, id) + delete_rule_from_api(&appstate.pool, id) .await .map_err(|err| { error!("Error deleting ACL rule {id}: {err}"); @@ -380,7 +401,7 @@ pub(crate) async fn apply_acl_rules( "User {} applying ACL rules: {:?}", session.user.username, data.rules ); - AclRule::apply_rules(&data.rules, &appstate) + apply_rules_from_api(&appstate.pool, &appstate, &data.rules) .await .map_err(|err| { error!("Error applying ACL rules {data:?}: {err}"); @@ -413,7 +434,7 @@ pub(crate) async fn apply_acl_aliases( "User {} applying ACL aliases: {:?}", session.user.username, data.aliases ); - AclAlias::apply_aliases(&data.aliases, &appstate) + apply_aliases_from_api(&appstate.pool, &data.aliases) .await .map_err(|err| { error!("Error applying ACL aliases {data:?}: {err}"); @@ -425,3 +446,424 @@ pub(crate) async fn apply_acl_aliases( ); Ok(ApiResponse::default()) } + +async fn create_rule_from_api(pool: &PgPool, data: &EditAclRule) -> Result { + let mut transaction = pool.begin().await?; + let (rule, ranges) = build_rule_from_api(data, RuleState::New)?; + let rule: AclRule = rule.save(&mut *transaction).await?; + create_rule_relations(&mut transaction, rule.id, data, &ranges).await?; + transaction.commit().await?; + let mut conn = pool.acquire().await?; + Ok(AclRule::::to_info(&rule, &mut conn).await?.into()) +} + +async fn update_rule_from_api( + pool: &PgPool, + id: Id, + data: &EditAclRule, +) -> Result { + let mut transaction = pool.begin().await?; + let existing: AclRule = AclRule::find_by_id(&mut *transaction, id) + .await? + .ok_or_else(|| { + warn!("Update of nonexistent rule ({id}) failed"); + AclError::RuleNotFoundError(id) + })?; + + if existing.state == RuleState::Deleted { + return Err(AclError::CannotModifyDeletedRuleError(id)); + } + + let target_rule = match existing.state { + RuleState::Applied => { + let result = query!("DELETE FROM aclrule WHERE parent_id = $1", id) + .execute(&mut *transaction) + .await?; + debug!( + "Removed {} old modifications of rule {id}", + result.rows_affected() + ); + + let (mut rule, ranges) = build_rule_from_api(data, RuleState::Modified)?; + rule.parent_id = Some(id); + let rule: AclRule = rule.save(&mut *transaction).await?; + create_rule_relations(&mut transaction, rule.id, data, &ranges).await?; + rule + } + RuleState::New | RuleState::Modified | RuleState::Expired => { + let (rule, ranges) = build_rule_from_api(data, existing.state.clone())?; + let mut rule = rule.with_id(existing.id); + rule.parent_id = existing.parent_id; + rule.save(&mut *transaction).await?; + rule.delete_related_objects(&mut transaction).await?; + create_rule_relations(&mut transaction, rule.id, data, &ranges).await?; + rule + } + RuleState::Deleted => { + return Err(AclError::CannotModifyDeletedRuleError(id)); + } + }; + + transaction.commit().await?; + let mut conn = pool.acquire().await?; + Ok(AclRule::::to_info(&target_rule, &mut conn) + .await? + .into()) +} + +async fn delete_rule_from_api(pool: &PgPool, id: Id) -> Result<(), AclError> { + let mut transaction = pool.begin().await?; + let existing: AclRule = AclRule::find_by_id(&mut *transaction, id) + .await? + .ok_or_else(|| AclError::RuleNotFoundError(id))?; + + match existing.state { + RuleState::New => { + existing.delete_related_objects(&mut transaction).await?; + existing.delete(&mut *transaction).await?; + } + RuleState::Applied => { + let result = query!("DELETE FROM aclrule WHERE parent_id = $1", id) + .execute(&mut *transaction) + .await?; + debug!( + "Removed {} old modifications of rule {id}", + result.rows_affected() + ); + + let mut deleted_rule = existing.clone(); + deleted_rule.state = RuleState::Deleted; + deleted_rule.parent_id = Some(id); + let deleted_rule = deleted_rule.as_noid(); + let deleted_rule = deleted_rule.save(&mut *transaction).await?; + create_rule_relations_from_rule(&mut transaction, deleted_rule.id, &existing).await?; + } + RuleState::Modified | RuleState::Deleted | RuleState::Expired => { + existing.delete_related_objects(&mut transaction).await?; + existing.delete(&mut *transaction).await?; + } + } + + transaction.commit().await?; + Ok(()) +} + +async fn apply_rules_from_api( + pool: &PgPool, + appstate: &AppState, + rule_ids: &[Id], +) -> Result<(), AclError> { + if rule_ids.is_empty() { + return Ok(()); + } + + let mut transaction = pool.begin().await?; + let mut affected_location_ids: Vec = Vec::new(); + + for rule_id in rule_ids { + let rule: AclRule = AclRule::find_by_id(&mut *transaction, *rule_id) + .await? + .ok_or_else(|| AclError::RuleNotFoundError(*rule_id))?; + let location_ids: Vec = if rule.all_locations { + let locations: Vec> = + WireguardNetwork::all(&mut *transaction).await?; + locations.into_iter().map(|location| location.id).collect() + } else { + let locations: Vec> = rule.get_networks(&mut *transaction).await?; + locations.into_iter().map(|location| location.id).collect() + }; + rule.apply(&mut transaction).await?; + affected_location_ids.extend(location_ids); + } + + transaction.commit().await?; + + affected_location_ids.sort_unstable(); + affected_location_ids.dedup(); + for location_id in affected_location_ids { + if let Some(location) = WireguardNetwork::find_by_id(pool, location_id).await? { + let mut conn = pool.acquire().await?; + if let Some(firewall_config) = + defguard_enterprise_firewall::try_get_location_firewall_config(&location, &mut conn) + .await + .map_err(|err| AclError::FirewallError(err.to_string()))? + { + appstate.send_wireguard_event(crate::grpc::GatewayEvent::FirewallConfigChanged( + location.id, + firewall_config, + )); + } + } + } + Ok(()) +} + +async fn apply_aliases_from_api(pool: &PgPool, alias_ids: &[Id]) -> Result<(), AclError> { + if alias_ids.is_empty() { + return Ok(()); + } + + let mut transaction = pool.begin().await?; + for alias_id in alias_ids { + let alias: AclAlias = AclAlias::find_by_id(&mut *transaction, *alias_id) + .await? + .ok_or_else(|| AclError::AliasNotFoundError(*alias_id))?; + if alias.state == AliasState::Applied { + return Err(AclError::AliasAlreadyAppliedError(*alias_id)); + } + alias.apply(&mut transaction).await?; + } + transaction.commit().await?; + Ok(()) +} + +fn build_rule_from_api( + data: &EditAclRule, + state: RuleState, +) -> Result<(AclRule, Vec<(IpAddr, IpAddr)>), AclError> { + let destination = parse_destination_addresses(&data.addresses)?; + validate_destination_ranges(&destination.ranges)?; + let ports = parse_ports(&data.ports)?; + + let rule = AclRule { + id: NoId, + parent_id: None, + state, + name: data.name.clone(), + allow_all_users: data.allow_all_users, + deny_all_users: data.deny_all_users, + allow_all_groups: data.allow_all_groups, + deny_all_groups: data.deny_all_groups, + allow_all_network_devices: data.allow_all_network_devices, + deny_all_network_devices: data.deny_all_network_devices, + all_locations: data.all_locations, + addresses: destination.addrs, + ports: ports + .into_iter() + .map(Into::into) + .collect::>>(), + protocols: data.protocols.clone(), + enabled: data.enabled, + expires: data.expires, + any_address: data.any_address, + any_port: data.any_port, + any_protocol: data.any_protocol, + use_manual_destination_settings: data.use_manual_destination_settings, + }; + + Ok((rule, destination.ranges)) +} + +fn validate_destination_ranges(ranges: &[(IpAddr, IpAddr)]) -> Result<(), AclError> { + for (start, end) in ranges { + if start > end { + return Err(AclError::InvalidIpRangeError(format!("{start}-{end}"))); + } + } + Ok(()) +} + +async fn create_rule_relations( + transaction: &mut PgConnection, + rule_id: Id, + data: &EditAclRule, + ranges: &[(IpAddr, IpAddr)], +) -> Result<(), AclError> { + for location_id in &data.locations { + AclRuleNetwork::new(rule_id, *location_id) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "wireguard_network", *location_id))?; + } + + for user_id in &data.allowed_users { + AclRuleUser::new(rule_id, *user_id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "user", *user_id))?; + } + for user_id in &data.denied_users { + AclRuleUser::new(rule_id, *user_id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "user", *user_id))?; + } + + for group_id in &data.allowed_groups { + AclRuleGroup::new(rule_id, *group_id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "group", *group_id))?; + } + for group_id in &data.denied_groups { + AclRuleGroup::new(rule_id, *group_id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "group", *group_id))?; + } + + for device_id in &data.allowed_network_devices { + AclRuleDevice::new(rule_id, *device_id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "device", *device_id))?; + } + for device_id in &data.denied_network_devices { + AclRuleDevice::new(rule_id, *device_id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "device", *device_id))?; + } + + let mut modified_aliases = Vec::new(); + for alias_id in &data.aliases { + let alias: AclAlias = + AclAlias::find_by_id_and_kind(&mut *transaction, *alias_id, AliasKind::Component) + .await? + .ok_or_else(|| AclError::InvalidRelationError(format!("aclalias({alias_id})")))?; + if alias.state == AliasState::Modified { + modified_aliases.push(*alias_id); + continue; + } + AclRuleAlias::new(rule_id, *alias_id) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "aclalias", *alias_id))?; + } + for alias_id in &data.destinations { + let alias: AclAlias = + AclAlias::find_by_id_and_kind(&mut *transaction, *alias_id, AliasKind::Destination) + .await? + .ok_or_else(|| AclError::InvalidRelationError(format!("aclalias({alias_id})")))?; + if alias.state == AliasState::Modified { + modified_aliases.push(*alias_id); + continue; + } + AclRuleAlias::new(rule_id, *alias_id) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "aclalias", *alias_id))?; + } + + if !modified_aliases.is_empty() { + return Err(AclError::CannotUseModifiedAliasInRuleError( + modified_aliases, + )); + } + + for range in ranges { + AclRuleDestinationRange { + id: NoId, + rule_id, + start: range.0, + end: range.1, + } + .save(&mut *transaction) + .await?; + } + + Ok(()) +} + +async fn create_rule_relations_from_rule( + transaction: &mut PgConnection, + rule_id: Id, + source_rule: &AclRule, +) -> Result<(), AclError> { + if !source_rule.all_locations { + let networks = source_rule.get_networks(&mut *transaction).await?; + for network in networks { + AclRuleNetwork::new(rule_id, network.id) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "wireguard_network", network.id))?; + } + } + + let allowed_users = source_rule.get_users(&mut *transaction, true).await?; + for user in allowed_users { + AclRuleUser::new(rule_id, user.id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "user", user.id))?; + } + let denied_users = source_rule.get_users(&mut *transaction, false).await?; + for user in denied_users { + AclRuleUser::new(rule_id, user.id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "user", user.id))?; + } + + let allowed_groups = source_rule.get_groups(&mut *transaction, true).await?; + for group in allowed_groups { + AclRuleGroup::new(rule_id, group.id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "group", group.id))?; + } + let denied_groups = source_rule.get_groups(&mut *transaction, false).await?; + for group in denied_groups { + AclRuleGroup::new(rule_id, group.id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "group", group.id))?; + } + + let allowed_devices = source_rule + .get_network_devices(&mut *transaction, true) + .await?; + for device in allowed_devices { + AclRuleDevice::new(rule_id, device.id, true) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "device", device.id))?; + } + let denied_devices = source_rule + .get_network_devices(&mut *transaction, false) + .await?; + for device in denied_devices { + AclRuleDevice::new(rule_id, device.id, false) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "device", device.id))?; + } + + let aliases = source_rule.get_aliases(&mut *transaction).await?; + for alias in aliases { + AclRuleAlias::new(rule_id, alias.id) + .save(&mut *transaction) + .await + .map_err(|err| map_relation_error(err, "aclalias", alias.id))?; + } + + let ranges = source_rule + .get_destination_address_ranges(&mut *transaction) + .await?; + for range in ranges { + AclRuleDestinationRange { + id: NoId, + rule_id, + start: range.start, + end: range.end, + } + .save(&mut *transaction) + .await?; + } + + Ok(()) +} + +/// Maps [`sqlx::Error`] to [`AclError`] while checking for [`ErrorKind::ForeignKeyViolation`]. +fn map_relation_error(err: sqlx::Error, class: &str, id: Id) -> AclError { + if let sqlx::Error::Database(dberror) = &err { + if dberror.kind() == ErrorKind::ForeignKeyViolation { + error!( + "Failed to create ACL related object, foreign key violation: {class}({id}): {dberror}" + ); + return AclError::InvalidRelationError(format!("{class}({id})")); + } + } + error!("Failed to create ACL related object: {err}"); + AclError::DbError(err) +} diff --git a/crates/defguard_core/src/enterprise/handlers/acl/alias.rs b/crates/defguard_core/src/enterprise/handlers/acl/alias.rs index f13c38602..49a48e408 100644 --- a/crates/defguard_core/src/enterprise/handlers/acl/alias.rs +++ b/crates/defguard_core/src/enterprise/handlers/acl/alias.rs @@ -6,18 +6,20 @@ use axum::{ use defguard_common::db::{Id, NoId}; use serde_json::{Value, json}; use sqlx::{PgConnection, PgPool, query}; +use sqlx::postgres::types::PgRange; +use std::net::IpAddr; use utoipa::ToSchema; use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::acl::{ - AclAlias, AclAliasDestinationRange, AclAliasInfo, AclError, AliasKind, AliasState, - Protocol, acl_delete_related_objects, parse_destination_addresses, - }, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::acl::{ + AclAlias, AclAliasDestinationRange, AclAliasInfo, AclError, AliasKind, AliasState, Protocol, + acl_delete_related_objects, parse_destination_addresses, parse_ports, +}; /// API representation of [`AclAlias`] used in API requests for modification operations. #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, ToSchema)] @@ -34,11 +36,11 @@ impl EditAclAlias { &self, transaction: &mut PgConnection, alias_id: Id, + ranges: &[(IpAddr, IpAddr)], ) -> Result<(), AclError> { debug!("Creating related objects for ACL alias {self:?}"); // save related destination ranges - let destination = parse_destination_addresses(&self.addresses)?; - for range in destination.ranges { + for range in ranges { let obj = AclAliasDestinationRange { id: NoId, alias_id, @@ -77,16 +79,17 @@ impl ApiAclAlias { ) -> Result { let mut transaction = pool.begin().await?; - let alias = AclAlias::try_from(api_alias)? + let (alias, ranges) = build_component_alias_from_api(api_alias, AliasState::Applied)?; + let alias = alias .save(&mut *transaction) .await?; api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; transaction.commit().await?; - let result = Self::from(alias.to_info(pool).await?); + let result = Self::from(AclAlias::::to_info(&alias, pool).await?); Ok(result) } @@ -107,8 +110,7 @@ impl ApiAclAlias { AclError::AliasNotFoundError(id) })?; - // Convert alias from API to model. - let mut alias = AclAlias::try_from(api_alias)?; + let (mut alias, ranges) = build_component_alias_from_api(api_alias, AliasState::Modified)?; // perform appropriate updates depending on existing alias' state let alias = match existing_alias.state { @@ -125,13 +127,12 @@ impl ApiAclAlias { ); // save as a new alias with appropriate parent_id and state - alias.state = AliasState::Modified; alias.parent_id = Some(id); let alias = alias.save(&mut *transaction).await?; // create related objects api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; alias @@ -149,7 +150,7 @@ impl ApiAclAlias { // recreate related objects acl_delete_related_objects(&mut transaction, alias.id).await?; api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; alias @@ -157,7 +158,7 @@ impl ApiAclAlias { }; transaction.commit().await?; - Ok(alias.to_info(pool).await?.into()) + Ok(AclAlias::::to_info(&alias, pool).await?.into()) } } @@ -192,11 +193,12 @@ pub(crate) async fn list_acl_aliases( session: SessionInfo, ) -> ApiResult { debug!("User {} listing ACL aliases", session.user.username); - let aliases = AclAlias::all_of_kind(&appstate.pool, AliasKind::Component).await?; + let aliases: Vec> = + AclAlias::all_of_kind(&appstate.pool, AliasKind::Component).await?; let mut api_aliases = Vec::::with_capacity(aliases.len()); for alias in &aliases { // TODO: may require optimisation wrt. sql queries - let info = alias.to_info(&appstate.pool).await.map_err(|err| { + let info = AclAlias::::to_info(alias, &appstate.pool).await.map_err(|err| { error!("Error retrieving ACL alias {alias:?}: {err}"); err })?; @@ -230,7 +232,7 @@ pub(crate) async fn get_acl_alias( match AclAlias::find_by_id_and_kind(&appstate.pool, id, AliasKind::Component).await? { Some(alias) => ( json!(ApiAclAlias::from( - alias.to_info(&appstate.pool).await.map_err(|err| { + AclAlias::::to_info(&alias, &appstate.pool).await.map_err(|err| { error!("Error retrieving ACL alias {alias:?}: {err}"); err })? @@ -326,12 +328,72 @@ pub(crate) async fn delete_acl_alias( Path(id): Path, ) -> ApiResult { debug!("User {} deleting ACL alias {id}", session.user.username); - AclAlias::delete_from_api(&appstate.pool, id) - .await - .map_err(|err| { - error!("Error deleting ACL alias {id}: {err}"); - err - })?; + let mut transaction = appstate.pool.begin().await?; + let alias = AclAlias::find_by_id_and_kind(&mut *transaction, id, AliasKind::Component) + .await? + .ok_or_else(|| AclError::AliasNotFoundError(id))?; + + match alias.state { + AliasState::Applied => { + let rules = alias.get_rules(&mut *transaction).await?; + if !rules.is_empty() { + return Err(AclError::AliasUsedByRulesError(id).into()); + } + + let result = query!("DELETE FROM aclalias WHERE parent_id = $1", id) + .execute(&mut *transaction) + .await?; + debug!( + "Removed {} old modifications of alias {id}", + result.rows_affected() + ); + + acl_delete_related_objects(&mut transaction, alias.id).await?; + alias.delete(&mut *transaction).await?; + } + AliasState::Modified => { + acl_delete_related_objects(&mut transaction, alias.id).await?; + alias.delete(&mut *transaction).await?; + } + } + transaction.commit().await?; info!("User {} deleted ACL alias {id}", session.user.username); Ok(ApiResponse::default()) } + +fn build_component_alias_from_api( + api_alias: &EditAclAlias, + state: AliasState, +) -> Result<(AclAlias, Vec<(IpAddr, IpAddr)>), AclError> { + let destination = parse_destination_addresses(&api_alias.addresses)?; + validate_destination_ranges(&destination.ranges)?; + let ports = parse_ports(&api_alias.ports)?; + let any_address = api_alias.addresses.trim().is_empty(); + let any_port = api_alias.ports.trim().is_empty(); + let any_protocol = api_alias.protocols.is_empty(); + + let alias = AclAlias::new( + api_alias.name.clone(), + state, + AliasKind::Component, + destination.addrs, + ports.into_iter().map(Into::into).collect::>>(), + api_alias.protocols.clone(), + any_address, + any_port, + any_protocol, + ); + + Ok((alias, destination.ranges)) +} + +fn validate_destination_ranges(ranges: &[(IpAddr, IpAddr)]) -> Result<(), AclError> { + for (start, end) in ranges { + if start > end { + return Err(AclError::InvalidIpRangeError(format!( + "{start}-{end}" + ))); + } + } + Ok(()) +} diff --git a/crates/defguard_core/src/enterprise/handlers/acl/destination.rs b/crates/defguard_core/src/enterprise/handlers/acl/destination.rs index 7814e8b2c..293742272 100644 --- a/crates/defguard_core/src/enterprise/handlers/acl/destination.rs +++ b/crates/defguard_core/src/enterprise/handlers/acl/destination.rs @@ -6,18 +6,20 @@ use defguard_common::db::{Id, NoId}; use reqwest::StatusCode; use serde_json::{Value, json}; use sqlx::{PgConnection, PgPool, query}; +use sqlx::postgres::types::PgRange; +use std::net::IpAddr; use utoipa::ToSchema; use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::acl::{ - AclAlias, AclAliasDestinationRange, AclAliasInfo, AclError, AliasKind, AliasState, - Protocol, acl_delete_related_objects, parse_destination_addresses, - }, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::acl::{ + AclAlias, AclAliasDestinationRange, AclAliasInfo, AclError, AliasKind, AliasState, Protocol, + acl_delete_related_objects, parse_destination_addresses, parse_ports, +}; /// API representation of [`AclAlias`] used in API requests for modification operations #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, ToSchema)] @@ -37,11 +39,11 @@ impl EditAclDestination { &self, transaction: &mut PgConnection, alias_id: Id, + ranges: &[(IpAddr, IpAddr)], ) -> Result<(), AclError> { debug!("Creating related objects for ACL alias {self:?}"); // save related destination ranges - let destination = parse_destination_addresses(&self.addresses)?; - for range in destination.ranges { + for range in ranges { let obj = AclAliasDestinationRange { id: NoId, alias_id, @@ -83,16 +85,17 @@ impl ApiAclDestination { ) -> Result { let mut transaction = pool.begin().await?; - let alias = AclAlias::try_from(api_alias)? + let (alias, ranges) = build_destination_alias_from_api(api_alias, AliasState::Applied)?; + let alias = alias .save(&mut *transaction) .await?; api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; transaction.commit().await?; - let result = Self::from(alias.to_info(pool).await?); + let result = Self::from(AclAlias::::to_info(&alias, pool).await?); Ok(result) } @@ -113,8 +116,7 @@ impl ApiAclDestination { AclError::AliasNotFoundError(id) })?; - // Convert alias from API to model. - let mut alias = AclAlias::try_from(api_alias)?; + let (mut alias, ranges) = build_destination_alias_from_api(api_alias, AliasState::Modified)?; // perform appropriate updates depending on existing alias' state let alias = match existing_alias.state { @@ -131,13 +133,12 @@ impl ApiAclDestination { ); // save as a new alias with appropriate parent_id and state - alias.state = AliasState::Modified; alias.parent_id = Some(id); let alias = alias.save(&mut *transaction).await?; // create related objects api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; alias @@ -155,7 +156,7 @@ impl ApiAclDestination { // recreate related objects acl_delete_related_objects(&mut transaction, alias.id).await?; api_alias - .create_related_objects(&mut transaction, alias.id) + .create_related_objects(&mut transaction, alias.id, &ranges) .await?; alias @@ -163,7 +164,7 @@ impl ApiAclDestination { }; transaction.commit().await?; - Ok(alias.to_info(pool).await?.into()) + Ok(AclAlias::::to_info(&alias, pool).await?.into()) } } @@ -201,11 +202,12 @@ pub(crate) async fn list_acl_destinations( session: SessionInfo, ) -> ApiResult { debug!("User {} listing ACL destinations", session.user.username); - let aliases = AclAlias::all_of_kind(&appstate.pool, AliasKind::Destination).await?; + let aliases: Vec> = + AclAlias::all_of_kind(&appstate.pool, AliasKind::Destination).await?; let mut api_aliases = Vec::::with_capacity(aliases.len()); for alias in &aliases { // TODO: may require optimisation wrt. sql queries - let info = alias.to_info(&appstate.pool).await.map_err(|err| { + let info = AclAlias::::to_info(alias, &appstate.pool).await.map_err(|err| { error!("Error retrieving ACL destination {alias:?}: {err}"); err })?; @@ -242,7 +244,7 @@ pub(crate) async fn get_acl_destination( match AclAlias::find_by_id_and_kind(&appstate.pool, id, AliasKind::Destination).await? { Some(alias) => ( json!(ApiAclDestination::from( - alias.to_info(&appstate.pool).await.map_err(|err| { + AclAlias::::to_info(&alias, &appstate.pool).await.map_err(|err| { error!("Error retrieving ACL destination {alias:?}: {err}"); err })? @@ -350,15 +352,72 @@ pub(crate) async fn delete_acl_destination( "User {} deleting ACL destination {id}", session.user.username ); - AclAlias::delete_from_api(&appstate.pool, id) - .await - .map_err(|err| { - error!("Error deleting ACL destination {id}: {err}"); - err - })?; + let mut transaction = appstate.pool.begin().await?; + let alias = AclAlias::find_by_id_and_kind(&mut *transaction, id, AliasKind::Destination) + .await? + .ok_or_else(|| AclError::AliasNotFoundError(id))?; + + match alias.state { + AliasState::Applied => { + let rules = alias.get_rules(&mut *transaction).await?; + if !rules.is_empty() { + return Err(AclError::AliasUsedByRulesError(id).into()); + } + + let result = query!("DELETE FROM aclalias WHERE parent_id = $1", id) + .execute(&mut *transaction) + .await?; + debug!( + "Removed {} old modifications of alias {id}", + result.rows_affected() + ); + + acl_delete_related_objects(&mut transaction, alias.id).await?; + alias.delete(&mut *transaction).await?; + } + AliasState::Modified => { + acl_delete_related_objects(&mut transaction, alias.id).await?; + alias.delete(&mut *transaction).await?; + } + } + transaction.commit().await?; info!( "User {} deleted ACL destination {id}", session.user.username ); Ok(ApiResponse::default()) } + +fn build_destination_alias_from_api( + api_alias: &EditAclDestination, + state: AliasState, +) -> Result<(AclAlias, Vec<(IpAddr, IpAddr)>), AclError> { + let destination = parse_destination_addresses(&api_alias.addresses)?; + validate_destination_ranges(&destination.ranges)?; + let ports = parse_ports(&api_alias.ports)?; + + let alias = AclAlias::new( + api_alias.name.clone(), + state, + AliasKind::Destination, + destination.addrs, + ports.into_iter().map(Into::into).collect::>>(), + api_alias.protocols.clone(), + api_alias.any_address, + api_alias.any_port, + api_alias.any_protocol, + ); + + Ok((alias, destination.ranges)) +} + +fn validate_destination_ranges(ranges: &[(IpAddr, IpAddr)]) -> Result<(), AclError> { + for (start, end) in ranges { + if start > end { + return Err(AclError::InvalidIpRangeError(format!( + "{start}-{end}" + ))); + } + } + Ok(()) +} diff --git a/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs b/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs index 96cf90597..9678a7894 100644 --- a/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs +++ b/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs @@ -9,12 +9,12 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::activity_log_stream::{ - ActivityLogStream, ActivityLogStreamConfig, ActivityLogStreamType, - }, events::{ApiEvent, ApiEventType, ApiRequestContext}, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::activity_log_stream::{ + ActivityLogStream, ActivityLogStreamConfig, ActivityLogStreamType, +}; pub async fn get_activity_log_stream( _admin: AdminRole, diff --git a/crates/defguard_core/src/enterprise/handlers/api_tokens.rs b/crates/defguard_core/src/enterprise/handlers/api_tokens.rs index d0842acc0..71616054b 100644 --- a/crates/defguard_core/src/enterprise/handlers/api_tokens.rs +++ b/crates/defguard_core/src/enterprise/handlers/api_tokens.rs @@ -11,11 +11,11 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::api_tokens::{ApiToken, ApiTokenInfo}, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, handlers::{ApiResponse, ApiResult, user_for_admin_or_self}, }; +use defguard_enterprise_db::models::api_tokens::{ApiToken, ApiTokenInfo}; const API_TOKEN_LENGTH: usize = 32; diff --git a/crates/defguard_core/src/enterprise/handlers/enterprise_settings.rs b/crates/defguard_core/src/enterprise/handlers/enterprise_settings.rs index e52811806..3224078fc 100644 --- a/crates/defguard_core/src/enterprise/handlers/enterprise_settings.rs +++ b/crates/defguard_core/src/enterprise/handlers/enterprise_settings.rs @@ -5,9 +5,11 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::db::models::enterprise_settings::{EnterpriseSettings, EnterpriseSettingsPatch}, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::enterprise_settings::{ + EnterpriseSettings, EnterpriseSettingsPatch, +}; pub async fn get_enterprise_settings( session: SessionInfo, diff --git a/crates/defguard_core/src/enterprise/handlers/mod.rs b/crates/defguard_core/src/enterprise/handlers/mod.rs index 0cdbcc44f..5983edcf3 100644 --- a/crates/defguard_core/src/enterprise/handlers/mod.rs +++ b/crates/defguard_core/src/enterprise/handlers/mod.rs @@ -1,8 +1,8 @@ use crate::{ auth::{AdminRole, SessionInfo}, - enterprise::get_counts, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_license::get_counts; pub mod acl; pub mod activity_log_stream; @@ -17,10 +17,8 @@ use axum::{ }; use serde::Serialize; -use super::{ - db::models::enterprise_settings::EnterpriseSettings, is_business_license_active, - license::get_cached_license, -}; +use defguard_enterprise_db::models::enterprise_settings::EnterpriseSettings; +use defguard_enterprise_license::{get_cached_license, is_business_license_active}; use crate::{appstate::AppState, error::WebError}; pub struct LicenseInfo { @@ -68,7 +66,7 @@ pub async fn check_enterprise_info(_admin: AdminRole, _session: SessionInfo) -> let license = get_cached_license(); let license_info = license .as_ref() - .map(|license: &crate::enterprise::license::License| { + .map(|license: &defguard_enterprise_license::License| { let counts = get_counts(); let limits_info = license.limits.map(|limits| LicenseLimitsInfo { locations: LimitInfo { diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index 83cc56cea..770065661 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -36,11 +36,6 @@ pub const SELECT_ACCOUNT_SUPPORTED_PROVIDERS: &[&str] = &["Google"]; use super::LicenseInfo; use crate::{ appstate::AppState, - enterprise::{ - db::models::openid_provider::OpenIdProvider, - directory_sync::sync_user_groups_if_configured, ldap::utils::ldap_update_user_state, - limits::update_counts, - }, error::WebError, handlers::{ ApiResponse, AuthResponse, SESSION_COOKIE_NAME, SIGN_IN_COOKIE_NAME, @@ -48,6 +43,10 @@ use crate::{ user::{MAX_USERNAME_CHARS, check_username}, }, }; +use defguard_enterprise_db::models::openid_provider::OpenIdProvider; +use defguard_enterprise_directory_sync::sync_user_groups_if_configured; +use defguard_enterprise_ldap::utils::ldap_update_user_state; +use defguard_enterprise_license::update_counts; /// Prune the given username from illegal characters in accordance with the following rules: /// @@ -595,9 +594,10 @@ pub(crate) async fn auth_callback( // since he already managed to login through the provider. Currently, there is no other way to // sync the groups for the MFA enabled user logging in through the provider without firing it on // every login attempt, even for standard, non-provider users. - if let Err(err) = - sync_user_groups_if_configured(&user, &appstate.pool, &appstate.wireguard_tx).await - { + let context = crate::enterprise::directory_sync_context::build_directory_sync_context( + appstate.wireguard_tx.clone(), + ); + if let Err(err) = sync_user_groups_if_configured(&user, &appstate.pool, &context).await { error!( "Failed to sync user groups for user {} with the directory while the user was trying \ to login in through an external provider: {err}", diff --git a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs index c80ef757e..e14dfa231 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs @@ -16,13 +16,11 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::{ - db::models::openid_provider::{OpenIdProvider, OpenIdProviderKind}, - directory_sync::test_directory_sync_connection, - }, events::{ApiEvent, ApiEventType, ApiRequestContext}, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::openid_provider::{OpenIdProvider, OpenIdProviderKind}; +use defguard_enterprise_directory_sync::test_directory_sync_connection; #[derive(Deserialize, Serialize, ToSchema)] pub struct AddProviderData { diff --git a/crates/defguard_core/src/enterprise/mod.rs b/crates/defguard_core/src/enterprise/mod.rs index 1c4855ae8..b65119877 100644 --- a/crates/defguard_core/src/enterprise/mod.rs +++ b/crates/defguard_core/src/enterprise/mod.rs @@ -1,123 +1,4 @@ -pub mod activity_log_stream; -pub mod db; -pub mod directory_sync; -pub mod firewall; pub mod grpc; pub mod handlers; -pub mod ldap; -pub mod license; -pub mod limits; pub mod snat; -mod utils; - -use license::{get_cached_license, validate_license}; -use limits::get_counts; - -use crate::enterprise::license::LicenseTier; - -/// Helper function to gate features which require a base license (Team or Business tier) -#[must_use] -pub fn is_business_license_active() -> bool { - is_license_tier_active(LicenseTier::Business) -} - -/// Helper function to gate features which require an Enterprise tier license -#[must_use] -pub fn is_enterprise_license_active() -> bool { - is_license_tier_active(LicenseTier::Enterprise) -} - -/// Shared logic for gating features to specific license tiers -fn is_license_tier_active(tier: LicenseTier) -> bool { - debug!("Checking if features for {tier} license tier should be enabled"); - - // get current object counts - let counts = get_counts(); - - let license = get_cached_license(); - let validation_result = validate_license(license.as_ref(), &counts, tier); - debug!("License validation result: {validation_result:?}"); - validation_result.is_ok() -} - -#[cfg(test)] -mod test { - use chrono::{TimeDelta, Utc}; - - use crate::{ - enterprise::{ - is_business_license_active, is_enterprise_license_active, - license::{License, LicenseTier, set_cached_license}, - limits::{Counts, set_counts}, - }, - grpc::proto::enterprise::license::LicenseLimits, - }; - - #[test] - fn test_feature_gates_no_license() { - set_cached_license(None); - - let counts = Counts::new(1, 1, 1, 1); - set_counts(counts); - - assert!(!is_business_license_active()); - assert!(!is_enterprise_license_active()); - } - - #[test] - fn test_feature_gates_with_license() { - // exceed free limits - let counts = Counts::new(1, 1, 5, 1); - set_counts(counts); - - // set Business license - let users_limit = 15; - let devices_limit = 35; - let locations_limit = 5; - let network_devices_limit = 10; - - let limits = LicenseLimits { - users: users_limit, - devices: devices_limit, - locations: locations_limit, - network_devices: Some(network_devices_limit), - }; - let license = License::new( - "test".to_string(), - true, - Some(Utc::now() + TimeDelta::days(1)), - Some(limits), - None, - LicenseTier::Business, - ); - set_cached_license(Some(license)); - - assert!(is_business_license_active()); - assert!(!is_enterprise_license_active()); - - // set Enterprise license - let users_limit = 15; - let devices_limit = 35; - let locations_limit = 5; - let network_devices_limit = 10; - - let limits = LicenseLimits { - users: users_limit, - devices: devices_limit, - locations: locations_limit, - network_devices: Some(network_devices_limit), - }; - let license = License::new( - "test".to_string(), - true, - Some(Utc::now() + TimeDelta::days(1)), - Some(limits), - None, - LicenseTier::Enterprise, - ); - set_cached_license(Some(license)); - - assert!(is_business_license_active()); - assert!(is_enterprise_license_active()); - } -} +pub mod directory_sync_context; diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs index f0552c393..f7e080a0f 100644 --- a/crates/defguard_core/src/enterprise/snat/handlers.rs +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -15,15 +15,15 @@ use utoipa::ToSchema; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::{ - db::models::snat::UserSnatBinding, firewall::try_get_location_firewall_config, - handlers::LicenseInfo, snat::error::UserSnatBindingError, - }, + enterprise::handlers::LicenseInfo, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, grpc::GatewayEvent, handlers::{ApiResponse, ApiResult}, }; +use defguard_enterprise_db::models::snat::UserSnatBinding; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_snat::UserSnatBindingError; /// List all SNAT bindings for a WireGuard location /// @@ -138,7 +138,15 @@ pub async fn create_snat_binding( let binding = snat_binding .save(&appstate.pool) .await - .map_err(UserSnatBindingError::from)?; + .map_err(|err| match UserSnatBindingError::from(err) { + UserSnatBindingError::BindingNotFound => { + WebError::ObjectNotFound("SNAT binding not found".into()) + } + UserSnatBindingError::BindingAlreadyExists => { + WebError::ObjectAlreadyExists("SNAT binding already exists".into()) + } + UserSnatBindingError::DbError { source } => WebError::DbError(source.to_string()), + })?; // emit event appstate.emit_event(ApiEvent { diff --git a/crates/defguard_core/src/enterprise/snat/mod.rs b/crates/defguard_core/src/enterprise/snat/mod.rs index 00d363536..c3d449565 100644 --- a/crates/defguard_core/src/enterprise/snat/mod.rs +++ b/crates/defguard_core/src/enterprise/snat/mod.rs @@ -1,2 +1 @@ -pub mod error; pub mod handlers; diff --git a/crates/defguard_core/src/error.rs b/crates/defguard_core/src/error.rs index 10d824d49..746d27f74 100644 --- a/crates/defguard_core/src/error.rs +++ b/crates/defguard_core/src/error.rs @@ -1,8 +1,8 @@ use axum::http::StatusCode; use defguard_common::{ db::models::{ - DeviceError, ModelError, WireguardNetworkError, settings::SettingsValidationError, - user::UserError, + settings::SettingsValidationError, user::UserError, DeviceError, ModelError, + WireguardNetworkError, }, types::UrlParseError, }; @@ -13,15 +13,13 @@ use tokio::sync::mpsc::error::SendError; use utoipa::ToSchema; use crate::{ - auth::failed_login::FailedLoginError, - db::models::enrollment::TokenError, - enterprise::{ - activity_log_stream::error::ActivityLogStreamError, db::models::acl::AclError, - firewall::FirewallError, license::LicenseError, - }, - events::ApiEvent, + auth::failed_login::FailedLoginError, db::models::enrollment::TokenError, events::ApiEvent, location_management::LocationManagementError, }; +use defguard_enterprise_activity_log_stream::error::ActivityLogStreamError; +use defguard_enterprise_db::models::acl::AclError; +use defguard_enterprise_firewall::FirewallError; +use defguard_enterprise_license::LicenseError; /// Represents kinds of error that occurred #[derive(Debug, Error, ToSchema)] diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index 4d756ed82..ee442c01f 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -2,20 +2,18 @@ use std::net::IpAddr; use chrono::{NaiveDateTime, Utc}; use defguard_common::db::{ - Id, models::{ - AuthenticationKey, Device, MFAMethod, Settings, User, WebAuthn, WireguardNetwork, - group::Group, oauth2client::OAuth2Client, proxy::Proxy, + group::Group, oauth2client::OAuth2Client, proxy::Proxy, AuthenticationKey, Device, + MFAMethod, Settings, User, WebAuthn, WireguardNetwork, }, + Id, }; use defguard_proto::proxy::MfaMethod; -use crate::{ - db::WebHook, - enterprise::db::models::{ - activity_log_stream::ActivityLogStream, api_tokens::ApiToken, - openid_provider::OpenIdProvider, snat::UserSnatBinding, - }, +use crate::db::WebHook; +use defguard_enterprise_db::models::{ + activity_log_stream::ActivityLogStream, api_tokens::ApiToken, openid_provider::OpenIdProvider, + snat::UserSnatBinding, }; /// Shared context that needs to be added to every API event diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index beb66fd8a..4f5ee85f8 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -26,15 +26,13 @@ use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; use crate::{ auth::failed_login::FailedLoginMap, db::AppEvent, - enterprise::{ - db::models::{ - enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, - openid_provider::OpenIdProvider, - }, - is_business_license_active, is_enterprise_license_active, - }, grpc::{auth::AuthServer, interceptor::JwtInterceptor, worker::WorkerServer}, }; +use defguard_enterprise_db::models::{ + enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, + openid_provider::OpenIdProvider, +}; +use defguard_enterprise_license::{is_business_license_active, is_enterprise_license_active}; mod auth; pub mod client_version; @@ -43,18 +41,11 @@ pub mod proxy; pub mod utils; pub mod worker; -pub mod proto { - pub mod enterprise { - pub mod license { - tonic::include_proto!("enterprise.license"); - } - } -} - use defguard_proto::{ - auth::auth_service_server::AuthServiceServer, enterprise::firewall::FirewallConfig, - gateway::Peer, worker::worker_service_server::WorkerServiceServer, + auth::auth_service_server::AuthServiceServer, gateway::Peer, + worker::worker_service_server::WorkerServiceServer, }; +use defguard_proto::enterprise::firewall::FirewallConfig; use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; // gRPC header for passing auth token from clients diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index a2aab9a60..7910d577a 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -38,10 +38,11 @@ use tokio::{ use tonic::{Code, Status}; use crate::{ - enterprise::{db::models::openid_provider::OpenIdProvider, is_business_license_active}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::{GatewayEvent, utils::parse_client_ip_agent}, }; +use defguard_enterprise_db::models::openid_provider::OpenIdProvider; +use defguard_enterprise_license::is_business_license_active; const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes diff --git a/crates/defguard_core/src/grpc/utils.rs b/crates/defguard_core/src/grpc/utils.rs index a9ac22b5f..b04c20f0b 100644 --- a/crates/defguard_core/src/grpc/utils.rs +++ b/crates/defguard_core/src/grpc/utils.rs @@ -19,11 +19,9 @@ use sqlx::PgPool; use tonic::Status; use super::InstanceInfo; -use crate::{ - enterprise::db::models::{ - enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, - }, - grpc::{client_version::ClientFeature, should_prevent_service_location_usage}, +use crate::grpc::{client_version::ClientFeature, should_prevent_service_location_usage}; +use defguard_enterprise_db::models::{ + enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, }; pub async fn build_device_config_response( diff --git a/crates/defguard_core/src/handlers/app_info.rs b/crates/defguard_core/src/handlers/app_info.rs index d3afa735b..318fdca8d 100644 --- a/crates/defguard_core/src/handlers/app_info.rs +++ b/crates/defguard_core/src/handlers/app_info.rs @@ -5,9 +5,8 @@ use defguard_common::{ }; use super::{ApiResponse, ApiResult}; -use crate::{ - appstate::AppState, auth::SessionInfo, enterprise::db::models::openid_provider::OpenIdProvider, -}; +use crate::{appstate::AppState, auth::SessionInfo}; +use defguard_enterprise_db::models::openid_provider::OpenIdProvider; #[derive(Serialize)] struct LdapInfo { diff --git a/crates/defguard_core/src/handlers/auth.rs b/crates/defguard_core/src/handlers/auth.rs index c9839bb1b..304f4b26d 100644 --- a/crates/defguard_core/src/handlers/auth.rs +++ b/crates/defguard_core/src/handlers/auth.rs @@ -37,7 +37,6 @@ use crate::{ SessionExtractor, SessionInfo, failed_login::{check_failed_logins, log_failed_login_attempt}, }, - enterprise::ldap::utils::login_through_ldap, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, handlers::{ @@ -48,6 +47,7 @@ use crate::{ headers::{USER_AGENT_PARSER, check_new_device_login, get_user_agent_device}, server_config, }; +use defguard_enterprise_ldap::utils::login_through_ldap; /// Common functionality for `authenticate()` and `auth_callback()`. /// Returns either `AuthResponse` or `MFAInfo`. diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index 51fef4ff0..97d2b46ae 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -39,9 +39,9 @@ use tonic::{ use crate::{ auth::{AdminOrSetupRole, SessionInfo}, - enterprise::is_enterprise_license_active, version::{MIN_GATEWAY_VERSION, MIN_PROXY_VERSION}, }; +use defguard_enterprise_license::is_enterprise_license_active; const TOKEN_CLIENT_ID: &str = "Defguard Core"; const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); diff --git a/crates/defguard_core/src/handlers/group.rs b/crates/defguard_core/src/handlers/group.rs index e0027a7ac..be5239ec5 100644 --- a/crates/defguard_core/src/handlers/group.rs +++ b/crates/defguard_core/src/handlers/group.rs @@ -15,19 +15,19 @@ use sqlx::query_as; use utoipa::ToSchema; use super::{ApiResponse, ApiResult, EditGroupInfo, GroupInfo, Username}; +use defguard_enterprise_ldap::hashset; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::ldap::utils::{ - ldap_add_user_to_groups, ldap_add_users_to_groups, ldap_delete_group, ldap_modify_group, - ldap_remove_user_from_groups, ldap_remove_users_from_groups, ldap_update_user_state, - ldap_update_users_state, - }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, - hashset, location_management::sync_all_networks, }; +use defguard_enterprise_ldap::utils::{ + ldap_add_user_to_groups, ldap_add_users_to_groups, ldap_delete_group, ldap_modify_group, + ldap_remove_user_from_groups, ldap_remove_users_from_groups, ldap_update_user_state, + ldap_update_users_state, +}; #[derive(Serialize, ToSchema)] pub(crate) struct Groups { diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index 10a32f447..1b5af25ef 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -20,13 +20,11 @@ use utoipa::ToSchema; use webauthn_rs::prelude::RegisterPublicKeyCredential; use crate::{ - appstate::AppState, - auth::SessionInfo, - db::WebHook, - enterprise::{db::models::acl::AclError, license::LicenseError}, - error::WebError, + appstate::AppState, auth::SessionInfo, db::WebHook, error::WebError, events::ApiRequestContext, }; +use defguard_enterprise_db::models::acl::AclError; +use defguard_enterprise_license::LicenseError; pub(crate) mod activity_log; pub(crate) mod app_info; @@ -186,7 +184,7 @@ impl From for ApiResponse { ), AclError::CannotUseModifiedAliasInRuleError(alias_ids) => ApiResponse::new( json!({"msg": format!("Cannot use modified alias in ACL rule {alias_ids:?}")}), - StatusCode::BAD_REQUEST, + StatusCode::UNPROCESSABLE_ENTITY, ), }, WebError::Http(status) => { diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index 10a3fa94e..0ee1e618a 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -29,11 +29,12 @@ use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, enrollment_management::start_desktop_configuration, - enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, events::{ApiEvent, ApiEventType, ApiRequestContext}, grpc::GatewayEvent, server_config, }; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_license::update_counts; #[derive(Serialize)] struct NetworkDeviceLocation { diff --git a/crates/defguard_core/src/handlers/settings.rs b/crates/defguard_core/src/handlers/settings.rs index 16fdbfb3b..479b6d3a3 100644 --- a/crates/defguard_core/src/handlers/settings.rs +++ b/crates/defguard_core/src/handlers/settings.rs @@ -10,11 +10,13 @@ use defguard_common::db::models::{ use sqlx::PgPool; use struct_patch::Patch; +use defguard_enterprise_ldap::LDAPConnection; +use defguard_enterprise_license::update_cached_license; use super::{ApiResponse, ApiResult}; use crate::{ AppState, auth::{AdminRole, SessionInfo}, - enterprise::{handlers::LicenseInfo, ldap::LDAPConnection, license::update_cached_license}, + enterprise::handlers::LicenseInfo, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, }; diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index bbbb8bfb9..97900d3d3 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -20,6 +20,13 @@ use serde_json::json; use sqlx::{Error as SqlxError, PgPool}; use utoipa::ToSchema; +use defguard_enterprise_db::models::api_tokens::ApiToken; +use defguard_enterprise_ldap::model::{ldap_sync_allowed_for_user, maybe_update_rdn}; +use defguard_enterprise_ldap::utils::{ + ldap_add_user, ldap_add_user_to_groups, ldap_change_password, ldap_delete_user, + ldap_handle_user_modify, ldap_remove_user_from_groups, ldap_update_user_state, +}; +use defguard_enterprise_license::{get_cached_license, get_counts, update_counts}; use super::{ AddUserData, ApiResponse, ApiResult, PasswordChange, PasswordChangeSelf, StartEnrollmentRequest, Username, mail::EMAIL_PASSWORD_RESET_START_SUBJECT, @@ -33,19 +40,7 @@ use crate::{ models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, }, enrollment_management::{start_desktop_configuration, start_user_enrollment}, - enterprise::{ - db::models::api_tokens::ApiToken, - handlers::CanManageDevices, - ldap::{ - model::{ldap_sync_allowed_for_user, maybe_update_rdn}, - utils::{ - ldap_add_user, ldap_add_user_to_groups, ldap_change_password, ldap_delete_user, - ldap_handle_user_modify, ldap_remove_user_from_groups, ldap_update_user_state, - }, - }, - license::get_cached_license, - limits::{get_counts, update_counts}, - }, + enterprise::handlers::CanManageDevices, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, is_valid_phone_number, server_config, diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 2366d0e8c..9bdf0ad19 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -28,14 +28,7 @@ use super::{ApiResponse, ApiResult, WebError, device_for_admin_or_self, user_for use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - enterprise::{ - db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, - firewall::try_get_location_firewall_config, - handlers::CanManageDevices, - is_business_license_active, is_enterprise_license_active, - license::get_cached_license, - limits::{get_counts, update_counts}, - }, + enterprise::handlers::CanManageDevices, events::{ApiEvent, ApiEventType, ApiRequestContext}, grpc::GatewayEvent, location_management::{ @@ -44,6 +37,14 @@ use crate::{ }, wg_config::{ImportedDevice, parse_wireguard_config}, }; +use defguard_enterprise_db::models::{ + enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, +}; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_license::{ + get_cached_license, get_counts, is_business_license_active, is_enterprise_license_active, + update_counts, +}; #[derive(Serialize, ToSchema)] pub(crate) struct GatewayInfo { diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 400be93f6..61cc66055 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -17,11 +17,8 @@ use sqlx::PgConnection; use thiserror::Error; use tokio::sync::broadcast::Sender; -use crate::{ - enterprise::firewall::{FirewallError, try_get_location_firewall_config}, - grpc::{GatewayEvent, send_multiple_wireguard_events}, - wg_config::ImportedDevice, -}; +use crate::{grpc::{GatewayEvent, send_multiple_wireguard_events}, wg_config::ImportedDevice}; +use defguard_enterprise_firewall::{FirewallError, try_get_location_firewall_config}; pub mod allowed_peers; diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs index 1fc5d7cea..ac78311e4 100644 --- a/crates/defguard_core/src/user_management.rs +++ b/crates/defguard_core/src/user_management.rs @@ -8,11 +8,12 @@ use sqlx::PgConnection; use tokio::sync::broadcast::Sender; use crate::{ - enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, error::WebError, grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, location_management::sync_allowed_devices_for_user, }; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_license::update_counts; /// Deletes the user and cleans up his devices from gateways pub async fn delete_user_and_cleanup_devices( diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index cf071fe5d..01133061c 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -12,18 +12,16 @@ use tokio::{ use tracing::Instrument; use crate::{ - enterprise::{ - db::models::acl::{AclRule, RuleState}, - directory_sync::{do_directory_sync, get_directory_sync_interval}, - firewall::try_get_location_firewall_config, - is_business_license_active, - ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, - limits::do_count_update, - }, grpc::GatewayEvent, location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, }; +use defguard_enterprise_db::models::acl::{AclRule, RuleState}; +use defguard_enterprise_directory_sync::{do_directory_sync, get_directory_sync_interval}; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_ldap::{do_ldap_sync, sync::get_ldap_sync_interval}; +use defguard_enterprise_license::{do_count_update, is_business_license_active}; +use crate::enterprise::directory_sync_context::build_directory_sync_context; // Times in seconds const UTILITY_THREAD_MAIN_SLEEP_TIME: u64 = 5; @@ -32,6 +30,7 @@ const UPDATES_CHECK_INTERVAL: u64 = 60 * 60 * 6; const EXPIRED_ACL_RULES_CHECK_INTERVAL: u64 = 60 * 5; const ENTERPRISE_STATUS_CHECK_INTERVAL: u64 = 60 * 5; + #[instrument(skip_all)] pub async fn run_utility_thread( pool: &PgPool, @@ -48,8 +47,9 @@ pub async fn run_utility_thread( let mut enterprise_enabled = is_business_license_active(); let directory_sync_task = || async { + let context = build_directory_sync_context(wireguard_tx.clone()); if let Err(e) = Box::pin( - do_directory_sync(pool, &wireguard_tx).instrument(info_span!("directory_sync_task")), + do_directory_sync(pool, &context).instrument(info_span!("directory_sync_task")), ) .await { @@ -255,15 +255,26 @@ async fn expired_acl_rules_check( ); // find affected locations - let mut affected_locations = HashSet::new(); + let mut affected_location_ids = HashSet::new(); for rule in updated_rules { - let locations = rule.get_networks(pool).await?; - for location in locations { - affected_locations.insert(location); + if rule.all_locations { + let locations = WireguardNetwork::all(pool).await?; + for location in locations { + affected_location_ids.insert(location.id); + } + } else { + let locations = rule.get_networks(pool).await?; + for location in locations { + affected_location_ids.insert(location.id); + } } } - let affected_locations: Vec> = affected_locations.into_iter().collect(); + let affected_locations: Vec> = WireguardNetwork::all(pool) + .await? + .into_iter() + .filter(|location| affected_location_ids.contains(&location.id)) + .collect(); debug!( "{} locations affected by expired ACL rules. Sending gateway firewall update events \ for each location", diff --git a/crates/defguard_core/tests/integration/api/acl.rs b/crates/defguard_core/tests/integration/api/acl.rs index 0f0593f50..7ce38b2a7 100644 --- a/crates/defguard_core/tests/integration/api/acl.rs +++ b/crates/defguard_core/tests/integration/api/acl.rs @@ -11,16 +11,14 @@ use defguard_common::{ }, }; use defguard_core::{ - enterprise::{ - db::models::acl::{AclAlias, AclRule, AliasKind, AliasState, RuleState}, - handlers::acl::{ - ApiAclRule, EditAclRule, - alias::{ApiAclAlias, EditAclAlias}, - }, - license::{get_cached_license, set_cached_license}, + enterprise::handlers::acl::{ + ApiAclRule, EditAclRule, + alias::{ApiAclAlias, EditAclAlias}, }, handlers::Auth, }; +use defguard_enterprise_db::models::acl::{AclAlias, AclRule, AliasKind, AliasState, RuleState}; +use defguard_enterprise_license::{get_cached_license, set_cached_license}; use reqwest::StatusCode; use serde_json::{Value, json}; use sqlx::{ @@ -672,9 +670,9 @@ async fn test_invalid_related_objects(_: PgPoolOptions, options: PgConnectOption let mut rule = make_rule(); rule.aliases = vec![1]; let response = client.post("/api/v1/acl/rule").json(&rule).send().await; - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response = client.put("/api/v1/acl/rule/1").json(&rule).send().await; - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); } #[sqlx::test] diff --git a/crates/defguard_core/tests/integration/api/api_tokens.rs b/crates/defguard_core/tests/integration/api/api_tokens.rs index 81c563251..43d3107ea 100644 --- a/crates/defguard_core/tests/integration/api/api_tokens.rs +++ b/crates/defguard_core/tests/integration/api/api_tokens.rs @@ -4,12 +4,10 @@ use defguard_common::{ types::user_info::UserInfo, }; use defguard_core::{ - enterprise::{ - db::models::api_tokens::{ApiToken, ApiTokenInfo}, - handlers::api_tokens::{AddApiTokenData, RenameRequest}, - }, + enterprise::handlers::api_tokens::{AddApiTokenData, RenameRequest}, handlers::Auth, }; +use defguard_enterprise_db::models::api_tokens::{ApiToken, ApiTokenInfo}; use reqwest::{StatusCode, header::HeaderName}; use serde::Deserialize; use serde_json::json; diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 6d9ad0234..ba2f75342 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -18,11 +18,11 @@ use defguard_core::{ auth::failed_login::FailedLoginMap, build_webapp, db::AppEvent, - enterprise::license::{License, LicenseTier, set_cached_license}, events::ApiEvent, grpc::{GatewayEvent, WorkerState}, handlers::{Auth, user::UserDetails}, }; +use defguard_enterprise_license::{License, LicenseTier, set_cached_license}; use reqwest::{StatusCode, header::HeaderName}; use semver::Version; use serde_json::json; diff --git a/crates/defguard_core/tests/integration/api/enterprise_settings.rs b/crates/defguard_core/tests/integration/api/enterprise_settings.rs index 878526329..c99cc9630 100644 --- a/crates/defguard_core/tests/integration/api/enterprise_settings.rs +++ b/crates/defguard_core/tests/integration/api/enterprise_settings.rs @@ -1,10 +1,8 @@ -use defguard_core::{ - enterprise::{ - db::models::enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, - license::{get_cached_license, set_cached_license}, - }, - handlers::Auth, +use defguard_core::handlers::Auth; +use defguard_enterprise_db::models::enterprise_settings::{ + ClientTrafficPolicy, EnterpriseSettings, }; +use defguard_enterprise_license::{get_cached_license, set_cached_license}; use reqwest::StatusCode; use serde_json::json; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_core/tests/integration/api/openid_login.rs b/crates/defguard_core/tests/integration/api/openid_login.rs index 69453455d..6b7261567 100644 --- a/crates/defguard_core/tests/integration/api/openid_login.rs +++ b/crates/defguard_core/tests/integration/api/openid_login.rs @@ -4,15 +4,13 @@ use defguard_common::db::{ models::{oauth2client::OAuth2Client, settings::OpenIdUsernameHandling}, }; use defguard_core::{ - enterprise::{ - db::models::openid_provider::{ - DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProviderKind, - }, - handlers::openid_providers::AddProviderData, - license::{License, LicenseTier, set_cached_license}, - }, + enterprise::handlers::openid_providers::AddProviderData, handlers::{Auth, openid_clients::NewOpenIDClient}, }; +use defguard_enterprise_db::models::openid_provider::{ + DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProviderKind, +}; +use defguard_enterprise_license::{License, LicenseTier, set_cached_license}; use reqwest::{StatusCode, Url}; use serde::Deserialize; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_core/tests/integration/api/snat.rs b/crates/defguard_core/tests/integration/api/snat.rs index 3207a4365..d2f555eb7 100644 --- a/crates/defguard_core/tests/integration/api/snat.rs +++ b/crates/defguard_core/tests/integration/api/snat.rs @@ -2,13 +2,11 @@ use std::net::IpAddr; use defguard_common::db::Id; use defguard_core::{ - enterprise::{ - db::models::snat::UserSnatBinding, - license::{get_cached_license, set_cached_license}, - snat::handlers::{EditUserSnatBinding, NewUserSnatBinding}, - }, + enterprise::snat::handlers::{EditUserSnatBinding, NewUserSnatBinding}, handlers::Auth, }; +use defguard_enterprise_db::models::snat::UserSnatBinding; +use defguard_enterprise_license::{get_cached_license, set_cached_license}; use reqwest::StatusCode; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index 396a9ea75..011c14a3f 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -13,16 +13,14 @@ use defguard_common::db::{ }, }; use defguard_core::{ - enterprise::{ - db::models::openid_provider::{ - DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProviderKind, - }, - handlers::openid_providers::AddProviderData, - license::{get_cached_license, set_cached_license}, - }, + enterprise::handlers::openid_providers::AddProviderData, grpc::GatewayEvent, handlers::{Auth, GroupInfo, wireguard::WireguardNetworkData}, }; +use defguard_enterprise_db::models::openid_provider::{ + DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProviderKind, +}; +use defguard_enterprise_license::{get_cached_license, set_cached_license}; use ipnetwork::IpNetwork; use matches::assert_matches; use reqwest::StatusCode; diff --git a/crates/defguard_core/tests/integration/common.rs b/crates/defguard_core/tests/integration/common.rs index f043201d2..e10eccaed 100644 --- a/crates/defguard_core/tests/integration/common.rs +++ b/crates/defguard_core/tests/integration/common.rs @@ -5,7 +5,7 @@ use defguard_common::{ settings::{initialize_current_settings, update_current_settings}, }, }; -use defguard_core::enterprise::license::{License, LicenseTier, set_cached_license}; +use defguard_enterprise_license::{License, LicenseTier, set_cached_license}; use secrecy::ExposeSecret; use sqlx::PgPool; diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index b771cfbc4..1eb20a754 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -7,13 +7,13 @@ use defguard_common::{ use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - enterprise::license::{License, LicenseTier, set_cached_license}, events::GrpcEvent, grpc::{ WorkerState, build_grpc_service_router, gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, }, }; +use defguard_enterprise_license::{License, LicenseTier, set_cached_license}; use defguard_mail::Mail; use hyper_util::rt::TokioIo; use sqlx::PgPool; diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index dd13941ca..5fb5585c9 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -15,14 +15,14 @@ use defguard_common::db::{ setup_pool, }; use defguard_core::{ - enterprise::{license::set_cached_license, limits::update_counts}, events::GrpcEvent, grpc::{MIN_GATEWAY_VERSION, gateway::events::GatewayEvent}, }; -use defguard_proto::{ - enterprise::firewall::FirewallPolicy, - gateway::{Configuration, PeerStats, Update, stats_update::Payload, update}, +use defguard_enterprise_license::{set_cached_license, update_counts}; +use defguard_proto::gateway::{ + Configuration, PeerStats, Update, stats_update::Payload, update, }; +use defguard_enterprise_firewall::FirewallPolicy; use semver::Version; use sqlx::{ PgPool, diff --git a/crates/defguard_event_logger/Cargo.toml b/crates/defguard_event_logger/Cargo.toml index 676179437..4a5ba2f2a 100644 --- a/crates/defguard_event_logger/Cargo.toml +++ b/crates/defguard_event_logger/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true # internal crates defguard_common.workspace = true defguard_core.workspace = true +defguard_enterprise_db.workspace = true defguard_session_manager.workspace = true # external dependencies diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 651c63cf0..9ed86b944 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -2,20 +2,20 @@ use std::net::IpAddr; use chrono::NaiveDateTime; use defguard_common::db::{ - Id, models::{ - AuthenticationKey, Device, MFAMethod, Settings, User, WebAuthn, WireguardNetwork, - group::Group, oauth2client::OAuth2Client, proxy::Proxy, + group::Group, oauth2client::OAuth2Client, proxy::Proxy, AuthenticationKey, Device, + MFAMethod, Settings, User, WebAuthn, WireguardNetwork, }, + Id, }; use defguard_core::{ db::WebHook, - enterprise::db::models::{ - activity_log_stream::ActivityLogStream, api_tokens::ApiToken, - openid_provider::OpenIdProvider, snat::UserSnatBinding, - }, events::{ApiRequestContext, BidiRequestContext, ClientMFAMethod, GrpcRequestContext}, }; +use defguard_enterprise_db::models::{ + activity_log_stream::ActivityLogStream, api_tokens::ApiToken, openid_provider::OpenIdProvider, + snat::UserSnatBinding, +}; use defguard_session_manager::events::SessionManagerEventContext; /// Messages that can be sent to the event logger diff --git a/crates/defguard_gateway_manager/Cargo.toml b/crates/defguard_gateway_manager/Cargo.toml index 9afb53828..541e7ec25 100644 --- a/crates/defguard_gateway_manager/Cargo.toml +++ b/crates/defguard_gateway_manager/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true defguard_certs.workspace = true defguard_common.workspace = true defguard_core.workspace = true +defguard_enterprise_firewall.workspace = true defguard_grpc_tls.workspace = true defguard_proto.workspace = true defguard_version.workspace = true diff --git a/crates/defguard_gateway_manager/src/error.rs b/crates/defguard_gateway_manager/src/error.rs index 7fde13348..bfe0299bf 100644 --- a/crates/defguard_gateway_manager/src/error.rs +++ b/crates/defguard_gateway_manager/src/error.rs @@ -1,4 +1,5 @@ -use defguard_core::{enterprise::firewall::FirewallError, events::GrpcEvent}; +use defguard_core::events::GrpcEvent; +use defguard_enterprise_firewall::FirewallError; use thiserror::Error; use tokio::sync::mpsc::error::SendError; use tonic::{Code, Status}; diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 6b93625a4..bb6d4e9d0 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -20,19 +20,19 @@ use defguard_common::{ messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::{ - enterprise::firewall::try_get_location_firewall_config, grpc::GatewayEvent, - handlers::mail::send_gateway_disconnected_email, + grpc::GatewayEvent, handlers::mail::send_gateway_disconnected_email, location_management::allowed_peers::get_location_allowed_peers, }; +use defguard_enterprise_firewall::try_get_location_firewall_config; #[cfg(not(test))] use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ - enterprise::firewall::FirewallConfig, gateway::{ Configuration, CoreResponse, Peer, PeerStats, Update, core_request, core_response, gateway_client, update, }, }; +use defguard_proto::enterprise::firewall::FirewallConfig; use defguard_version::client::ClientVersionInterceptor; #[cfg(not(test))] use hyper_rustls::HttpsConnectorBuilder; diff --git a/crates/defguard_proxy_manager/Cargo.toml b/crates/defguard_proxy_manager/Cargo.toml index 2d2761671..21eb092fb 100644 --- a/crates/defguard_proxy_manager/Cargo.toml +++ b/crates/defguard_proxy_manager/Cargo.toml @@ -11,6 +11,11 @@ rust-version.workspace = true # internal dependencies defguard_common.workspace = true defguard_core.workspace = true +defguard_enterprise_db.workspace = true +defguard_enterprise_directory_sync.workspace = true +defguard_enterprise_firewall.workspace = true +defguard_enterprise_ldap.workspace = true +defguard_enterprise_license.workspace = true defguard_mail.workspace = true defguard_proto.workspace = true defguard_version.workspace = true diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index 06b24aba0..b29a80de4 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -17,14 +17,10 @@ use defguard_core::{ db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, enrollment_management::clear_unused_enrollment_tokens, enterprise::{ - db::models::openid_provider::OpenIdProvider, - directory_sync::sync_user_groups_if_configured, grpc::polling::PollingServer, handlers::openid_login::{ SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, }, - is_business_license_active, - ldap::utils::ldap_update_user_state, }, grpc::{ GatewayEvent, @@ -32,6 +28,11 @@ use defguard_core::{ }, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, }; +use defguard_enterprise_db::models::openid_provider::OpenIdProvider; +use defguard_enterprise_directory_sync::sync_user_groups_if_configured; +use defguard_enterprise_ldap::utils::ldap_update_user_state; +use defguard_enterprise_license::is_business_license_active; +use defguard_core::enterprise::directory_sync_context::build_directory_sync_context; use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::proxy::{ AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, InitialInfo, @@ -700,10 +701,13 @@ impl ProxyHandler { { Ok(mut user) => { clear_unused_enrollment_tokens(&user, &pool).await?; + let context = build_directory_sync_context( + wireguard_tx.clone(), + ); if let Err(err) = sync_user_groups_if_configured( &user, &pool, - &wireguard_tx, + &context, ) .await { diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index bd1bbef18..cbca304db 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -14,12 +14,6 @@ use defguard_common::{ }; use defguard_core::{ db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, - enterprise::{ - db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, - firewall::try_get_location_firewall_config, - ldap::utils::ldap_add_user, - limits::update_counts, - }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ GatewayEvent, InstanceInfo, @@ -33,7 +27,15 @@ use defguard_core::{ headers::get_device_info, is_valid_phone_number, }; +use defguard_enterprise_db::models::{ + enterprise_settings::EnterpriseSettings, + openid_provider::OpenIdProvider, +}; +use defguard_enterprise_firewall::try_get_location_firewall_config; +use defguard_enterprise_ldap::utils::ldap_add_user; +use defguard_enterprise_license::update_counts; use defguard_mail::templates::{TemplateLocation, new_device_added_mail}; +use defguard_proto::enterprise::firewall::FirewallConfig; use defguard_proto::proxy::{ ActivateUserRequest, AdminInfo, CodeMfaSetupFinishRequest, CodeMfaSetupFinishResponse, CodeMfaSetupStartRequest, CodeMfaSetupStartResponse, DeviceConfigResponse, @@ -186,7 +188,7 @@ impl EnrollmentServer { "Retrieving enterprise settings for enrollment of user {}({:?}).", user.username, user.id ); - let enterprise_settings = + let enterprise_settings: EnterpriseSettings = EnterpriseSettings::get(&mut *transaction) .await .map_err(|err| { @@ -201,7 +203,7 @@ impl EnrollmentServer { user.username, user.id ); - let openid_provider = OpenIdProvider::get_current(&self.pool) + let openid_provider: Option> = OpenIdProvider::get_current(&self.pool) .await .map_err(|err| { error!("Failed to get OpenID provider: {err}"); @@ -241,7 +243,7 @@ impl EnrollmentServer { debug!("Admin info {admin_info:?}"); debug!("Creating enrollment start response for user {username}({user_id:?})."); - let enterprise_settings = + let enterprise_settings: EnterpriseSettings = EnterpriseSettings::get(&mut *transaction) .await .map_err(|err| { @@ -495,7 +497,8 @@ impl EnrollmentServer { "Fetching enterprise settings for device creation process for user {}({:?})", user.username, user.id, ); - let enterprise_settings = EnterpriseSettings::get(&self.pool).await.map_err(|err| { + let enterprise_settings: EnterpriseSettings = + EnterpriseSettings::get(&self.pool).await.map_err(|err| { error!( "Failed to fetch enterprise settings for device creation process for user {}({:?}): \ {err}", @@ -731,13 +734,14 @@ impl EnrollmentServer { Status::internal("unexpected error") })? { - if let Some(firewall_config) = + let firewall_config: Option = try_get_location_firewall_config(&location, &mut transaction) .await .map_err(|err| { error!("Failed to get firewall config for location {location}: {err}",); Status::internal("unexpected error") - })? + })?; + if let Some(firewall_config) = firewall_config { debug!( "Sending firewall config update for location {location} affected by adding new device {}, user {}({})", @@ -835,7 +839,7 @@ impl EnrollmentServer { info!("Device {} remote configuration done.", device.name); - let openid_provider = OpenIdProvider::get_current(&self.pool) + let openid_provider: Option> = OpenIdProvider::get_current(&self.pool) .await .map_err(|err| { error!("Failed to get OpenID provider: {err}"); diff --git a/crates/defguard_proxy_manager/src/servers/password_reset.rs b/crates/defguard_proxy_manager/src/servers/password_reset.rs index b6d94f253..52da2e459 100644 --- a/crates/defguard_proxy_manager/src/servers/password_reset.rs +++ b/crates/defguard_proxy_manager/src/servers/password_reset.rs @@ -4,7 +4,6 @@ use defguard_common::{ }; use defguard_core::{ db::models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, - enterprise::ldap::utils::ldap_change_password, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, PasswordResetEvent}, grpc::utils::parse_client_ip_agent, handlers::{ @@ -13,6 +12,7 @@ use defguard_core::{ }, headers::get_device_info, }; +use defguard_enterprise_ldap::utils::ldap_change_password; use defguard_proto::proxy::{ DeviceInfo, PasswordResetInitializeRequest, PasswordResetRequest, PasswordResetStartRequest, PasswordResetStartResponse, diff --git a/enterprise/crates/defguard_enterprise_activity_log_stream/Cargo.toml b/enterprise/crates/defguard_enterprise_activity_log_stream/Cargo.toml new file mode 100644 index 000000000..3e59a97ce --- /dev/null +++ b/enterprise/crates/defguard_enterprise_activity_log_stream/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "defguard_enterprise_activity_log_stream" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +defguard_enterprise_db = { workspace = true } +defguard_enterprise_license = { workspace = true } +anyhow = { workspace = true } +base64 = { workspace = true } +bytes = { workspace = true } +reqwest = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +tracing = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/activity_log_stream_manager.rs similarity index 93% rename from crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs rename to enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/activity_log_stream_manager.rs index e1a2a467b..1eefb0d4d 100644 --- a/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs +++ b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/activity_log_stream_manager.rs @@ -4,14 +4,14 @@ use bytes::Bytes; use sqlx::PgPool; use tokio::{sync::broadcast::Receiver, task::JoinSet, time::interval}; use tokio_util::sync::CancellationToken; -use tracing::debug; +use tracing::{debug, error, info, instrument, warn}; use super::ActivityLogStreamReconfigurationNotification; -use crate::enterprise::{ - activity_log_stream::http_stream::{HttpActivityLogStreamConfig, run_http_stream_task}, - db::models::activity_log_stream::{ActivityLogStream, ActivityLogStreamConfig}, - is_business_license_active, +use crate::activity_log_stream::http_stream::{HttpActivityLogStreamConfig, run_http_stream_task}; +use defguard_enterprise_db::models::activity_log_stream::{ + ActivityLogStream, ActivityLogStreamConfig, }; +use defguard_enterprise_license::is_business_license_active; // check if enterprise features are enabled every minute const ENTERPRISE_CHECK_PERIOD_SECS: u64 = 60; diff --git a/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/error.rs b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/error.rs new file mode 100644 index 000000000..b0f1afd0b --- /dev/null +++ b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/error.rs @@ -0,0 +1 @@ +pub use defguard_enterprise_db::models::activity_log_stream::ActivityLogStreamError; diff --git a/crates/defguard_core/src/enterprise/activity_log_stream/http_stream.rs b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/http_stream.rs similarity index 98% rename from crates/defguard_core/src/enterprise/activity_log_stream/http_stream.rs rename to enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/http_stream.rs index 491d70d9c..4485db186 100644 --- a/crates/defguard_core/src/enterprise/activity_log_stream/http_stream.rs +++ b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/http_stream.rs @@ -8,7 +8,7 @@ use tokio::sync::broadcast::Receiver; use tokio_util::sync::CancellationToken; use tracing::{debug, error}; -use crate::enterprise::db::models::activity_log_stream::{ +use defguard_enterprise_db::models::activity_log_stream::{ LogstashHttpActivityLogStream, VectorHttpActivityLogStream, }; diff --git a/crates/defguard_core/src/enterprise/activity_log_stream/mod.rs b/enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/mod.rs similarity index 100% rename from crates/defguard_core/src/enterprise/activity_log_stream/mod.rs rename to enterprise/crates/defguard_enterprise_activity_log_stream/src/activity_log_stream/mod.rs diff --git a/enterprise/crates/defguard_enterprise_activity_log_stream/src/lib.rs b/enterprise/crates/defguard_enterprise_activity_log_stream/src/lib.rs new file mode 100644 index 000000000..754ada059 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_activity_log_stream/src/lib.rs @@ -0,0 +1,3 @@ +pub mod activity_log_stream; + +pub use activity_log_stream::*; diff --git a/enterprise/crates/defguard_enterprise_db/Cargo.toml b/enterprise/crates/defguard_enterprise_db/Cargo.toml new file mode 100644 index 000000000..ec89922e8 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_db/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "defguard_enterprise_db" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +defguard_enterprise_license = { workspace = true } +chrono = { workspace = true } +ipnetwork = { workspace = true } +model_derive = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sha256 = { workspace = true } +sqlx = { workspace = true } +strum = { workspace = true } +strum_macros = { workspace = true } +struct-patch = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +utoipa = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/db/mod.rs b/enterprise/crates/defguard_enterprise_db/src/db/mod.rs similarity index 100% rename from crates/defguard_core/src/enterprise/db/mod.rs rename to enterprise/crates/defguard_enterprise_db/src/db/mod.rs diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/acl.rs similarity index 63% rename from crates/defguard_core/src/enterprise/db/models/acl.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/acl.rs index c876dff94..89771649f 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/acl.rs @@ -17,24 +17,15 @@ use defguard_common::db::{ }; use ipnetwork::{IpNetwork, IpNetworkError}; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, error::ErrorKind, - postgres::types::PgRange, query, query_as, query_scalar, + Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, postgres::types::PgRange, + query, query_as, query_scalar, }; use thiserror::Error; +use tracing::{debug, error, info, warn}; use utoipa::ToSchema; -use crate::{ - appstate::AppState, - enterprise::{ - firewall::{FirewallError, try_get_location_firewall_config}, - handlers::acl::{ - ApiAclRule, EditAclRule, alias::EditAclAlias, destination::EditAclDestination, - }, - }, - grpc::GatewayEvent, -}; - #[derive(Debug, Error)] pub enum AclError { #[error("InvalidPortsFormat: {0}")] @@ -59,14 +50,14 @@ pub enum AclError { AliasAlreadyAppliedError(Id), #[error("AliasUsedByRulesError: {0}")] AliasUsedByRulesError(Id), - #[error(transparent)] - FirewallError(#[from] FirewallError), #[error("InvalidIpRangeError: {0}")] InvalidIpRangeError(String), #[error("CannotModifyDeletedRuleError: {0}")] CannotModifyDeletedRuleError(Id), #[error("CannotUseModifiedAliasInRuleError: {0:?}")] CannotUseModifiedAliasInRuleError(Vec), + #[error("FirewallError: {0}")] + FirewallError(String), } /// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/in.h @@ -196,7 +187,7 @@ pub struct AclRuleInfo { impl AclRuleInfo { /// Constructs a [`String`] of comma-separated addresses and address ranges. - pub(crate) fn format_destination(&self) -> String { + pub fn format_destination(&self) -> String { // process single addresses let addrs = match &self.addresses { d if d.is_empty() => String::new(), @@ -221,7 +212,7 @@ impl AclRuleInfo { } /// Constructs a [`String`] of comma-separated ports and port ranges. - pub(crate) fn format_ports(&self) -> String { + pub fn format_ports(&self) -> String { self.ports .iter() .map(ToString::to_string) @@ -299,270 +290,16 @@ impl Default for AclRule { } } -impl AclRule { - /// Creates new [`AclRule`] with all related objects based on [`ApiAclRule`] - pub(crate) async fn create_from_api( - pool: &PgPool, - api_rule: &EditAclRule, - ) -> Result { - let mut transaction = pool.begin().await?; - - // save the rule - let rule: AclRule = api_rule.clone().try_into()?; - let rule = rule.save(&mut *transaction).await?; - - // create related objects - rule.create_related_objects(&mut transaction, api_rule) - .await?; - - let result = ApiAclRule::from(rule.to_info(&mut transaction).await?); - - transaction.commit().await?; - - Ok(result) - } - - /// Updates [`AclRule`] with all it's related objects based on [`ApiAclRule`] - /// - /// State handling: - /// - /// - For rules in `RuleState::Applied` state (rules that are currently active): - /// 1. Any existing modifications of this rule are deleted - /// 2. A copy of the rule is created with `RuleState::Modified` state and the original rule as parent - /// - For rules in `RuleState::Deleted` we return an error since those should not be modified - /// - For rules in other states (`New`, `Modified` ), we directly update the existing rule - /// since they haven't been applied. - /// - /// This approach allows us to track changes to applied rules while maintaining their history. - /// - /// Applied state does NOT guarantee that all locations have received the rule - /// and performed appropriate operations, only that the next time configuration - /// is being sent it will include this rule. - pub(crate) async fn update_from_api( - pool: &PgPool, - id: Id, - api_rule: &EditAclRule, - ) -> Result { - debug!("Updating rule ID {id} with {api_rule:?}"); - let mut transaction = pool.begin().await?; - - // find the existing rule - let existing_rule = AclRule::find_by_id(&mut *transaction, id) - .await? - .ok_or_else(|| { - warn!("Update of nonexistent rule ({id}) failed"); - AclError::RuleNotFoundError(id) - })?; - - // convert API rule to model - let mut rule: AclRule = api_rule.clone().try_into()?; - - // perform appropriate updates depending on existing rule's state - let rule = match existing_rule.state { - RuleState::Applied | RuleState::Expired => { - // create new `RuleState::Modified` rule - debug!( - "Rule {id} state is {:?} - creating new `Modified` rule object", - existing_rule.state - ); - // remove old modifications of this rule - let result = query!("DELETE FROM aclrule WHERE parent_id = $1", id) - .execute(&mut *transaction) - .await?; - debug!( - "Removed {} old modifications of rule {id}", - result.rows_affected(), - ); - - // save as a new rule with appropriate parent_id and state - rule.state = RuleState::Modified; - rule.parent_id = Some(id); - let rule = rule.save(&mut *transaction).await?; - - // create related objects - rule.create_related_objects(&mut transaction, api_rule) - .await?; - - rule - } - RuleState::Deleted => { - error!("Cannot update a deleted ACL rule {id}"); - return Err(AclError::CannotModifyDeletedRuleError(id)); - } - RuleState::New | RuleState::Modified => { - debug!( - "Rule {id} is a modification to rule {:?} - updating the modification", - existing_rule.parent_id, - ); - // update the not-yet applied modification itself - let mut rule = rule.with_id(id); - rule.parent_id = existing_rule.parent_id; - rule.state = existing_rule.state; - rule.save(&mut *transaction).await?; - - // recreate related objects - rule.delete_related_objects(&mut transaction).await?; - rule.create_related_objects(&mut transaction, api_rule) - .await?; - - rule - } - }; - - let rule_details = rule.to_info(&mut transaction).await?.into(); - - transaction.commit().await?; - - info!("Successfully updated rule {rule_details:?}"); - Ok(rule_details) - } - - /// Deletes [`AclRule`] with all it's related objects. - /// - /// State handling: - /// - /// - For rules in `RuleState::Applied` state (rules that are currently active): - /// 1. Any existing modifications of this rule are deleted. - /// 2. A copy of the rule is created with `RuleState::Deleted` state and the original rule as - /// parent. - /// - /// This preserves the original rule while tracking the deletion. - /// - /// - For rules in other states (`New`, `Modified` or `Deleted`): - /// 1. All related objects are deleted - /// 2. The rule itself is deleted from the database - /// - /// Since these rules were not yet applied, we can safely remove them. - pub(crate) async fn delete_from_api(pool: &PgPool, id: Id) -> Result<(), AclError> { - debug!("Deleting rule {id}"); - let mut transaction = pool.begin().await?; - - // find the existing rule - let existing_rule = AclRule::find_by_id(&mut *transaction, id) - .await? - .ok_or_else(|| { - warn!("Deletion of nonexistent rule ({id}) failed"); - AclError::RuleNotFoundError(id) - })?; - - // perform appropriate modifications depending on existing rule's state - match existing_rule.state { - RuleState::Applied | RuleState::Expired => { - // create new `RuleState::Deleted` rule - debug!( - "Rule {id} state is {:?} - creating new `Deleted` rule object", - existing_rule.state, - ); - // delete all modifications of this rule - let result = query!("DELETE FROM aclrule WHERE parent_id = $1", id) - .execute(&mut *transaction) - .await?; - debug!( - "Removed {} old modifications of rule {id}", - result.rows_affected(), - ); - - // prefetch related objects for use later - let rule_info = existing_rule.to_info(&mut transaction).await?; - - // save as a new rule with appropriate parent_id and state - let mut rule = existing_rule.as_noid(); - rule.state = RuleState::Deleted; - rule.parent_id = Some(id); - let rule = rule.save(&mut *transaction).await?; - - // inherit related objects from parent rule - rule.create_related_objects(&mut transaction, &rule_info.into()) - .await?; - } - _ => { - // delete the not-yet applied modification itself - debug!( - "Rule {id} is a modification to rule {:?} - updating the modification", - existing_rule.parent_id, - ); - // delete related objects - existing_rule - .delete_related_objects(&mut transaction) - .await?; - - // delete the rule - existing_rule.delete(&mut *transaction).await?; - } - } - - transaction.commit().await?; - info!("Rule {id} succesfully deleted or marked for deletion"); - Ok(()) - } - - /// Applies pending changes for all specified rules - /// - /// # Errors - /// - /// - `AclError::RuleNotFoundError` - pub async fn apply_rules(rules: &[Id], appstate: &AppState) -> Result<(), AclError> { - debug!("Applying {} ACL rules: {rules:?}", rules.len()); - let mut transaction = appstate.pool.begin().await?; - - // prepare variable for collecting affected locations - let mut affected_locations = HashSet::new(); - - for id in rules { - let rule = AclRule::find_by_id(&mut *transaction, *id) - .await? - .ok_or_else(|| AclError::RuleNotFoundError(*id))?; - let locations = rule.get_networks(&mut *transaction).await?; - for location in locations { - affected_locations.insert(location); - } - rule.apply(&mut transaction).await?; - } - info!("Applied {} ACL rules: {rules:?}", rules.len()); - - let affected_locations: Vec> = - affected_locations.into_iter().collect(); - debug!( - "{} locations affected by applied ACL rules. Sending gateway firewall update events \ - for each location", - affected_locations.len() - ); - - for location in affected_locations { - match try_get_location_firewall_config(&location, &mut transaction).await? { - Some(firewall_config) => { - debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( - location.id, - firewall_config, - )); - } - None => { - debug!( - "No firewall config generated for location {location}. Not sending a \ - gateway event" - ); - } - } - } - - transaction.commit().await?; - Ok(()) - } -} - #[derive(Debug, Default)] -pub(crate) struct ParsedDestination { - addrs: Vec, - pub(crate) ranges: Vec<(IpAddr, IpAddr)>, +pub struct ParsedDestination { + pub addrs: Vec, + pub ranges: Vec<(IpAddr, IpAddr)>, } /// Perses a destination string into singular ip addresses or networks and address /// ranges. We should be able to parse a string like this one: /// `10.0.0.1/24, 10.1.1.10-10.1.1.20, 192.168.1.10, 10.1.1.1-10.10.1.1` -pub(crate) fn parse_destination_addresses( - destination: &str, -) -> Result { +pub fn parse_destination_addresses(destination: &str) -> Result { debug!("Parsing destination string: {destination}"); let destination: String = destination.chars().filter(|c| !c.is_whitespace()).collect(); let mut result = ParsedDestination::default(); @@ -619,147 +356,9 @@ pub fn parse_ports(ports: &str) -> Result, AclError> { Ok(result) } -/// Maps [`sqlx::Error`] to [`AclError`] while checking for [`ErrorKind::ForeignKeyViolation`]. -fn map_relation_error(err: SqlxError, class: &str, id: Id) -> AclError { - if let SqlxError::Database(dberror) = &err { - if dberror.kind() == ErrorKind::ForeignKeyViolation { - error!( - "Failed to create ACL related object, foreign key violation: {class}({id}): {dberror}" - ); - return AclError::InvalidRelationError(format!("{class}({id})")); - } - } - error!("Failed to create ACL related object: {err}"); - AclError::DbError(err) -} - impl AclRule { - /// Creates relation objects for given [`AclRule`] based on [`EditAclRule`] object - async fn create_related_objects( - &self, - transaction: &mut PgConnection, - api_rule: &EditAclRule, - ) -> Result<(), AclError> { - let rule_id = self.id; - debug!("Creating related objects for ACL rule {api_rule:?}"); - - // save related locations - debug!("Creating related locations for ACL rule {rule_id}"); - for network_id in &api_rule.locations { - AclRuleNetwork::new(rule_id, *network_id) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "WireguardNetwork", *network_id))?; - } - - // allowed users - debug!("Creating related allowed users for ACL rule {rule_id}"); - for user_id in &api_rule.allowed_users { - AclRuleUser::new(rule_id, *user_id, true) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "User", *user_id))?; - } - - // denied users - debug!("Creating related denied users for ACL rule {rule_id}"); - for user_id in &api_rule.denied_users { - AclRuleUser::new(rule_id, *user_id, false) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "User", *user_id))?; - } - - // allowed groups - debug!("Creating related allowed groups for ACL rule {rule_id}"); - for group_id in &api_rule.allowed_groups { - AclRuleGroup::new(rule_id, *group_id, true) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "Group", *group_id))?; - } - - // denied groups - debug!("Creating related denied groups for ACL rule {rule_id}"); - for group_id in &api_rule.denied_groups { - AclRuleGroup::new(rule_id, *group_id, false) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "Group", *group_id))?; - } - - // save related aliases and destinations - debug!("Creating related aliases and destinations for ACL rule {rule_id}"); - // verify if all aliases have a correct state - // aliases used for tracking modifications (`AliasState::Modified`) cannot be used by ACL - // rules - // FIXME: handle aliases and destinations separately - let all_aliases = [api_rule.aliases.clone(), api_rule.destinations.clone()].concat(); - let invalid_alias_ids: Vec = query_scalar!( - "SELECT id FROM aclalias WHERE id = ANY($1) AND state != 'applied'::aclalias_state", - &all_aliases - ) - .fetch_all(&mut *transaction) - .await?; - if !invalid_alias_ids.is_empty() { - error!( - "Cannot use aliases which have not been applied in an ACL rule. Invalid aliases: \ - {invalid_alias_ids:?}" - ); - return Err(AclError::CannotUseModifiedAliasInRuleError( - invalid_alias_ids, - )); - } - for alias_id in &all_aliases { - AclRuleAlias::new(rule_id, *alias_id) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "AclAlias", *alias_id))?; - } - - // allowed devices - debug!("Creating related allowed devices for ACL rule {rule_id}"); - for device_id in &api_rule.allowed_network_devices { - AclRuleDevice::new(rule_id, *device_id, true) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "Device", *device_id))?; - } - - // denied devices - debug!("Creating related denied devices for ACL rule {rule_id}"); - for device_id in &api_rule.denied_network_devices { - AclRuleDevice::new(rule_id, *device_id, false) - .save(&mut *transaction) - .await - .map_err(|err| map_relation_error(err, "Device", *device_id))?; - } - - // destination - let destination = parse_destination_addresses(&api_rule.addresses)?; - debug!("Creating related destination ranges for ACL rule {rule_id}"); - for range in destination.ranges { - if range.1 <= range.0 { - return Err(AclError::InvalidIpRangeError(format!( - "{}-{}", - range.0, range.1 - ))); - } - let obj = AclRuleDestinationRange { - id: NoId, - rule_id, - start: range.0, - end: range.1, - }; - obj.save(&mut *transaction).await?; - } - - info!("Created related objects for ACL rule {api_rule:?}"); - Ok(()) - } - /// Deletes relation objects for given [`AclRule`] - async fn delete_related_objects( + pub async fn delete_related_objects( &self, transaction: &mut PgConnection, ) -> Result<(), SqlxError> { @@ -827,38 +426,6 @@ impl AclRule { } } -impl TryFrom for AclRule { - type Error = AclError; - - fn try_from(rule: EditAclRule) -> Result { - Ok(Self { - addresses: parse_destination_addresses(&rule.addresses)?.addrs, - ports: parse_ports(&rule.ports)? - .into_iter() - .map(Into::into) - .collect(), - id: NoId, - parent_id: None, - state: RuleState::default(), - name: rule.name, - allow_all_users: rule.allow_all_users, - deny_all_users: rule.deny_all_users, - allow_all_groups: rule.allow_all_groups, - deny_all_groups: rule.deny_all_groups, - allow_all_network_devices: rule.allow_all_network_devices, - deny_all_network_devices: rule.deny_all_network_devices, - all_locations: rule.all_locations, - protocols: rule.protocols, - enabled: rule.enabled, - expires: rule.expires, - any_address: rule.any_address, - any_port: rule.any_port, - any_protocol: rule.any_protocol, - use_manual_destination_settings: true, - }) - } -} - impl AclRule { /// Applies pending state change if necessary. /// @@ -919,7 +486,7 @@ impl AclRule { } /// Returns all [`WireguardNetwork`]s the rule applies to - pub(crate) async fn get_networks<'e, E>( + pub async fn get_networks<'e, E>( &self, executor: E, ) -> Result>, SqlxError> @@ -947,10 +514,7 @@ impl AclRule { } /// Returns all [`AclAlias`]es the rule applies to - pub(crate) async fn get_aliases<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub async fn get_aliases<'e, E>(&self, executor: E) -> Result>, SqlxError> where E: PgExecutor<'e>, { @@ -969,7 +533,7 @@ impl AclRule { } /// Returns **active** [`User`]s that are allowed or denied by the rule - pub(crate) async fn get_users<'e, E>( + pub async fn get_users<'e, E>( &self, executor: E, allowed: bool, @@ -985,10 +549,7 @@ impl AclRule { } /// Returns **active** [`User`]s that are allowed by the rule - pub(crate) async fn get_allowed_users<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub async fn get_allowed_users<'e, E>(&self, executor: E) -> Result>, SqlxError> where E: PgExecutor<'e>, { @@ -1011,10 +572,7 @@ impl AclRule { } /// Returns **active** [`User`]s that are denied by the rule - pub(crate) async fn get_denied_users<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub async fn get_denied_users<'e, E>(&self, executor: E) -> Result>, SqlxError> where E: PgExecutor<'e>, { @@ -1037,7 +595,7 @@ impl AclRule { } /// Returns [`Group`]s that are allowed or denied by the rule - pub(crate) async fn get_groups<'e, E>( + pub async fn get_groups<'e, E>( &self, executor: E, allowed: bool, @@ -1059,7 +617,7 @@ impl AclRule { } /// Returns [`Device`]s that are allowed or denied by the rule - pub(crate) async fn get_network_devices<'e, E>( + pub async fn get_network_devices<'e, E>( &self, executor: E, allowed: bool, @@ -1074,7 +632,7 @@ impl AclRule { } } - pub(crate) async fn get_allowed_network_devices<'e, E>( + pub async fn get_allowed_network_devices<'e, E>( &self, executor: E, ) -> Result>, SqlxError> @@ -1094,7 +652,7 @@ impl AclRule { .await } - pub(crate) async fn get_denied_network_devices<'e, E>( + pub async fn get_denied_network_devices<'e, E>( &self, executor: E, ) -> Result>, SqlxError> @@ -1115,7 +673,7 @@ impl AclRule { } /// Returns all [`AclRuleDestinationRanges`]es the rule applies to - pub(crate) async fn get_destination_address_ranges<'e, E>( + pub async fn get_destination_address_ranges<'e, E>( &self, executor: E, ) -> Result>, SqlxError> @@ -1190,7 +748,7 @@ impl AclRule { impl AclRuleInfo { /// Wrapper function which combines explicitly specified allowed users with members of allowed /// groups to generate a list of all unique allowed users for a given ACL. - pub(crate) async fn get_all_allowed_users( + pub async fn get_all_allowed_users( &self, conn: &mut PgConnection, ) -> Result>, SqlxError> { @@ -1246,7 +804,7 @@ impl AclRuleInfo { /// Wrapper function which combines explicitly specified denied users with members of denied /// groups to generate a list of all unique denied users for a given ACL. - pub(crate) async fn get_all_denied_users( + pub async fn get_all_denied_users( &self, conn: &mut PgConnection, ) -> Result>, SqlxError> { @@ -1303,7 +861,7 @@ impl AclRuleInfo { /// Returns the list of explicitly configured allowed network devices or /// a list of all devices if 'allow_all_network_devices' flag is enabled. - pub(crate) async fn get_all_allowed_devices<'e, E: sqlx::PgExecutor<'e>>( + pub async fn get_all_allowed_devices<'e, E: sqlx::PgExecutor<'e>>( &self, executor: E, location_id: Id, @@ -1335,7 +893,7 @@ impl AclRuleInfo { /// Returns the list of explicitly configured denied network devices or /// a list of all devices if 'deny_all_network_devices' flag is enabled. - pub(crate) async fn get_all_denied_devices<'e, E: sqlx::PgExecutor<'e>>( + pub async fn get_all_denied_devices<'e, E: sqlx::PgExecutor<'e>>( &self, executor: E, location_id: Id, @@ -1369,7 +927,7 @@ impl AclRuleInfo { /// Helper struct combining all database objects related to given [`AclAlias`]. /// All related objects are stored in vectors. #[derive(Clone, Debug, ToSchema)] -pub(crate) struct AclAliasInfo { +pub struct AclAliasInfo { pub id: Id, pub parent_id: Option, pub name: String, @@ -1389,7 +947,7 @@ pub(crate) struct AclAliasInfo { impl AclAliasInfo { /// Constructs a [`String`] of comma-separated addresses and address ranges - pub(crate) fn format_destination(&self) -> String { + pub fn format_destination(&self) -> String { // process single addresses let addrs = match &self.addresses { d if d.is_empty() => String::new(), @@ -1415,7 +973,7 @@ impl AclAliasInfo { } /// Constructs a [`String`] of comma-separated ports and port ranges - pub(crate) fn format_ports(&self) -> String { + pub fn format_ports(&self) -> String { self.ports .iter() .map(ToString::to_string) @@ -1480,6 +1038,7 @@ pub struct AclAlias { impl AclAlias { #[must_use] + #[allow(clippy::too_many_arguments)] pub fn new>( name: S, state: AliasState, @@ -1505,158 +1064,11 @@ impl AclAlias { any_protocol, } } - - /// Deletes [`AclAlias`] with all it's related objects. - /// - /// State handling: - /// - /// - For aliases in `AliasState::Applied` state (aliases that are currently active): - /// 1. Check if the alias is being used by any ACL rules. Return an error if it is - /// 2. Any existing modifications of this alias are deleted - /// 3. Delete the alias itself - /// - /// - For aliases in `Modified` state (tracking modifications of already applied aliases): - /// 1. All related objects are deleted - /// 2. The alias itself is deleted from the database - /// - /// Since these aliases were not yet applied, we can safely remove them. - pub(crate) async fn delete_from_api(pool: &PgPool, id: Id) -> Result<(), AclError> { - debug!("Deleting alias {id}"); - let mut transaction = pool.begin().await?; - - // find the existing alias - let existing_alias = AclAlias::find_by_id(&mut *transaction, id) - .await? - .ok_or_else(|| { - error!("Deletion of nonexistent alias ({id}) failed"); - AclError::AliasNotFoundError(id) - })?; - - // check if any rules are using this alias - let rules = existing_alias.get_rules(&mut *transaction).await?; - if !rules.is_empty() { - error!( - "Deletion of alias ({id}) failed. Alias is currently used by following ACL rules: {rules:?}" - ); - return Err(AclError::AliasUsedByRulesError(id)); - } - - // delete all modifications of this alias if any exist - let result = query!("DELETE FROM aclalias WHERE parent_id = $1", id) - .execute(&mut *transaction) - .await?; - let removed_modifications = result.rows_affected(); - if removed_modifications > 0 { - debug!("Removed {removed_modifications} old modifications of alias {id}"); - } - - // delete related objects - acl_delete_related_objects(&mut transaction, id).await?; - - // delete the alias itself - existing_alias.delete(&mut *transaction).await?; - - transaction.commit().await?; - Ok(()) - } - - /// Applies pending changes for all specified aliases - /// - /// # Errors - /// - /// - `AclError::AliasNotFoundError` - pub(crate) async fn apply_aliases(aliases: &[Id], appstate: &AppState) -> Result<(), AclError> { - debug!("Applying {} ACL aliases: {aliases:?}", aliases.len()); - let mut transaction = appstate.pool.begin().await?; - - // prepare variable for collecting affected rules - // we are unable to use `HashSet` because `PgRange` does not implement `Hash` trait - let mut affected_rules = Vec::new(); - - for id in aliases { - let alias = AclAlias::find_by_id(&mut *transaction, *id) - .await? - .ok_or_else(|| AclError::AliasNotFoundError(*id))?; - // run `apply` before fetching relations, since they'll get updated - alias.clone().apply(&mut transaction).await?; - - // fetch ACL rules which are using this alias - let rules = alias.get_rules(&mut *transaction).await?; - affected_rules.extend(rules); - } - info!("Applied {} ACL aliases: {aliases:?}", aliases.len()); - - // find locations affected by applying selected aliases - let mut affected_locations = HashSet::new(); - let mut unique_rule_ids = HashSet::new(); - for rule in affected_rules { - if unique_rule_ids.insert(rule.id) { - let locations = rule.get_networks(&mut *transaction).await?; - for location in locations { - affected_locations.insert(location); - } - } - } - - let affected_locations = affected_locations.into_iter().collect::>(); - debug!( - "{} locations affected by applied ACL aliases. Sending gateway firewall update events \ - for each location", - affected_locations.len() - ); - - for location in affected_locations { - match try_get_location_firewall_config(&location, &mut transaction).await? { - Some(firewall_config) => { - debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( - location.id, - firewall_config, - )); - } - None => { - debug!( - "No firewall config generated for location {location}. Not sending a \ - gateway event" - ); - } - } - } - - transaction.commit().await?; - Ok(()) - } -} - -impl TryFrom<&EditAclAlias> for AclAlias { - type Error = AclError; - - fn try_from(alias: &EditAclAlias) -> Result { - Ok(Self { - addresses: parse_destination_addresses(&alias.addresses)?.addrs, - ports: parse_ports(&alias.ports)? - .into_iter() - .map(Into::into) - .collect(), - id: NoId, - parent_id: None, - name: alias.name.clone(), - kind: AliasKind::Component, - state: AliasState::Applied, - protocols: alias.protocols.clone(), - any_address: true, - any_port: true, - any_protocol: true, - }) - } } impl AclAlias { /// Fetch [`AclAlias`] of a given kind. - pub(crate) async fn all_of_kind<'e, E>( - executor: E, - kind: AliasKind, - ) -> Result, sqlx::Error> + pub async fn all_of_kind<'e, E>(executor: E, kind: AliasKind) -> Result, sqlx::Error> where E: PgExecutor<'e>, { @@ -1692,31 +1104,8 @@ impl AclAlias { } } -impl TryFrom<&EditAclDestination> for AclAlias { - type Error = AclError; - - fn try_from(alias: &EditAclDestination) -> Result { - Ok(Self { - addresses: parse_destination_addresses(&alias.addresses)?.addrs, - ports: parse_ports(&alias.ports)? - .into_iter() - .map(Into::into) - .collect(), - id: NoId, - parent_id: None, - name: alias.name.clone(), - kind: AliasKind::Destination, - state: AliasState::Applied, - protocols: alias.protocols.clone(), - any_address: alias.any_address, - any_port: alias.any_port, - any_protocol: alias.any_protocol, - }) - } -} - /// Deletes relation objects for a given [`AclAlias`]. -pub(crate) async fn acl_delete_related_objects( +pub async fn acl_delete_related_objects( transaction: &mut PgConnection, alias_id: Id, ) -> Result<(), AclError> { @@ -1739,7 +1128,7 @@ pub(crate) async fn acl_delete_related_objects( impl AclAlias { /// Returns all [`AclAliasDestinationRanges`]es the alias applies to - pub(crate) async fn get_destination_ranges<'e, E>( + pub async fn get_destination_ranges<'e, E>( &self, executor: E, ) -> Result>, SqlxError> @@ -1758,7 +1147,7 @@ impl AclAlias { } /// Returns all [`AclRule`]s which use this alias - pub(crate) async fn get_rules<'e, E>(&self, executor: E) -> Result>, SqlxError> + pub async fn get_rules<'e, E>(&self, executor: E) -> Result>, SqlxError> where E: PgExecutor<'e>, { @@ -1779,7 +1168,7 @@ impl AclAlias { /// Retrieves all related objects from the db and converts [`AclAlias`] /// instance to [`AclAliasInfo`]. - pub(crate) async fn to_info(&self, pool: &PgPool) -> Result { + pub async fn to_info(&self, pool: &PgPool) -> Result { let destination_ranges = self.get_destination_ranges(pool).await?; let rules = self.get_rules(pool).await?; @@ -1853,7 +1242,7 @@ impl AclAlias { } #[derive(Model)] -pub(crate) struct AclRuleNetwork { +pub struct AclRuleNetwork { #[allow(dead_code)] id: I, rule_id: Id, @@ -1862,7 +1251,7 @@ pub(crate) struct AclRuleNetwork { impl AclRuleNetwork { #[must_use] - pub(crate) fn new(rule_id: Id, network_id: Id) -> Self { + pub fn new(rule_id: Id, network_id: Id) -> Self { Self { id: NoId, rule_id, @@ -1872,7 +1261,7 @@ impl AclRuleNetwork { } #[derive(Model)] -pub(crate) struct AclRuleUser { +pub struct AclRuleUser { #[allow(dead_code)] id: I, rule_id: Id, @@ -1882,7 +1271,7 @@ pub(crate) struct AclRuleUser { impl AclRuleUser { #[must_use] - pub(crate) fn new(rule_id: Id, user_id: Id, allow: bool) -> Self { + pub fn new(rule_id: Id, user_id: Id, allow: bool) -> Self { Self { id: NoId, rule_id, @@ -1893,7 +1282,7 @@ impl AclRuleUser { } #[derive(Model)] -pub(crate) struct AclRuleGroup { +pub struct AclRuleGroup { #[allow(dead_code)] id: I, rule_id: Id, @@ -1903,7 +1292,7 @@ pub(crate) struct AclRuleGroup { impl AclRuleGroup { #[must_use] - pub(crate) fn new(rule_id: Id, group_id: Id, allow: bool) -> Self { + pub fn new(rule_id: Id, group_id: Id, allow: bool) -> Self { Self { id: NoId, rule_id, @@ -1914,7 +1303,7 @@ impl AclRuleGroup { } #[derive(Model)] -pub(crate) struct AclRuleAlias { +pub struct AclRuleAlias { #[allow(dead_code)] id: I, rule_id: Id, @@ -1923,7 +1312,7 @@ pub(crate) struct AclRuleAlias { impl AclRuleAlias { #[must_use] - pub(crate) fn new(rule_id: Id, alias_id: Id) -> Self { + pub fn new(rule_id: Id, alias_id: Id) -> Self { Self { id: NoId, rule_id, @@ -1933,7 +1322,7 @@ impl AclRuleAlias { } #[derive(Model)] -pub(crate) struct AclRuleDevice { +pub struct AclRuleDevice { #[allow(dead_code)] id: I, rule_id: Id, @@ -1943,7 +1332,7 @@ pub(crate) struct AclRuleDevice { impl AclRuleDevice { #[must_use] - pub(crate) fn new(rule_id: Id, device_id: Id, allow: bool) -> Self { + pub fn new(rule_id: Id, device_id: Id, allow: bool) -> Self { Self { id: NoId, rule_id, @@ -1992,7 +1381,7 @@ impl From<&AclRuleDestinationRange> for RangeInclusive { } #[derive(Clone, Debug, Deserialize, PartialEq, ToSchema)] -pub(crate) struct AclAliasDestinationRange { +pub struct AclAliasDestinationRange { pub id: I, pub alias_id: Id, #[schema(value_type = String)] diff --git a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/acl/tests.rs similarity index 100% rename from crates/defguard_core/src/enterprise/db/models/acl/tests.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/acl/tests.rs diff --git a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/activity_log_stream.rs similarity index 90% rename from crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/activity_log_stream.rs index 3548de188..95be29bac 100644 --- a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/activity_log_stream.rs @@ -3,11 +3,20 @@ use defguard_common::{ secret::SecretStringWrapper, }; use model_derive::Model; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use sqlx::{Error as SqlxError, FromRow, PgExecutor, Type, query_as}; use strum_macros::{Display, EnumString}; +use thiserror::Error; -use crate::enterprise::activity_log_stream::error::ActivityLogStreamError; +#[derive(Debug, Error)] +pub enum ActivityLogStreamError { + #[error("Deserialization of {0} error: {1}")] + ConfigDeserializeError(String, String), + #[error("Sqlx error: {0}")] + SqlxError(#[from] sqlx::Error), + #[error("Parsing http header value failed")] + HeaderValueParsing(), +} #[derive(Debug, Serialize, Deserialize, Type, EnumString, Display, Clone, PartialEq)] #[sqlx(type_name = "text", rename_all = "snake_case")] diff --git a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/api_tokens.rs similarity index 98% rename from crates/defguard_core/src/enterprise/db/models/api_tokens.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/api_tokens.rs index 7b65a80e9..006a08898 100644 --- a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/api_tokens.rs @@ -1,6 +1,7 @@ use chrono::NaiveDateTime; use defguard_common::db::{Id, NoId}; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{Error as SqlxError, PgExecutor, query_as}; #[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] diff --git a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/enterprise_settings.rs similarity index 94% rename from crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/enterprise_settings.rs index 916417a97..7bbb9aa77 100644 --- a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/enterprise_settings.rs @@ -1,7 +1,8 @@ +use serde::{Deserialize, Serialize}; use sqlx::{PgExecutor, Type, query, query_as}; use struct_patch::Patch; -use crate::enterprise::is_business_license_active; +use defguard_enterprise_license::is_business_license_active; #[derive(Debug, Deserialize, Patch, Serialize)] #[patch(attribute(derive(Deserialize, Serialize)))] @@ -51,7 +52,7 @@ impl EnterpriseSettings { } } - pub(crate) async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/db/models/mod.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/mod.rs similarity index 100% rename from crates/defguard_core/src/enterprise/db/models/mod.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/mod.rs diff --git a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/openid_provider.rs similarity index 98% rename from crates/defguard_core/src/enterprise/db/models/openid_provider.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/openid_provider.rs index 16f15d51f..94886244d 100644 --- a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/openid_provider.rs @@ -2,7 +2,9 @@ use std::fmt; use defguard_common::db::{Id, NoId}; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{Error as SqlxError, PgExecutor, PgPool, Type, query, query_as}; +use tracing::warn; use utoipa::ToSchema; // The behavior when a user is deleted from the directory @@ -136,6 +138,7 @@ pub struct OpenIdProvider { impl OpenIdProvider { #[must_use] + #[allow(clippy::too_many_arguments)] pub fn new>( name: S, base_url: S, @@ -181,7 +184,7 @@ impl OpenIdProvider { } } - pub(crate) async fn upsert(self, pool: &PgPool) -> Result, SqlxError> { + pub async fn upsert(self, pool: &PgPool) -> Result, SqlxError> { if let Some(provider) = OpenIdProvider::::get_current(pool).await? { query!( "UPDATE openidprovider SET name = $1, base_url = $2, kind = $3, client_id = $4, \ diff --git a/crates/defguard_core/src/enterprise/db/models/snat.rs b/enterprise/crates/defguard_enterprise_db/src/db/models/snat.rs similarity index 94% rename from crates/defguard_core/src/enterprise/db/models/snat.rs rename to enterprise/crates/defguard_enterprise_db/src/db/models/snat.rs index c4ae69fc8..24c096be7 100644 --- a/crates/defguard_core/src/enterprise/db/models/snat.rs +++ b/enterprise/crates/defguard_enterprise_db/src/db/models/snat.rs @@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize}; use sqlx::{PgExecutor, query_as}; use utoipa::ToSchema; -use crate::enterprise::snat::error::UserSnatBindingError; #[derive(Clone, Debug, Deserialize, Model, Serialize, ToSchema, PartialEq)] #[table(user_snat_binding)] @@ -36,7 +35,7 @@ impl UserSnatBinding { executor: E, location_id: Id, user_id: Id, - ) -> Result + ) -> Result where E: PgExecutor<'e>, { diff --git a/enterprise/crates/defguard_enterprise_db/src/lib.rs b/enterprise/crates/defguard_enterprise_db/src/lib.rs new file mode 100644 index 000000000..d768c745b --- /dev/null +++ b/enterprise/crates/defguard_enterprise_db/src/lib.rs @@ -0,0 +1,3 @@ +pub mod db; + +pub use db::{models, PgAcquire}; diff --git a/enterprise/crates/defguard_enterprise_directory_sync/Cargo.toml b/enterprise/crates/defguard_enterprise_directory_sync/Cargo.toml new file mode 100644 index 000000000..65a1a8657 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_directory_sync/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "defguard_enterprise_directory_sync" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +defguard_enterprise_db = { workspace = true } +defguard_enterprise_ldap = { workspace = true } +defguard_enterprise_license = { workspace = true } +chrono = { workspace = true } +futures = { workspace = true } +jsonwebtoken = { workspace = true } +jsonwebkey = { workspace = true } +parse_link_header = { workspace = true } +paste = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +trait-variant = { workspace = true } + +[dev-dependencies] +ipnetwork = { workspace = true } +secrecy = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/directory_sync/google.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/google.rs similarity index 99% rename from crates/defguard_core/src/enterprise/directory_sync/google.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/google.rs index 642af28e6..4aa7d643b 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/google.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/google.rs @@ -2,7 +2,9 @@ use std::collections::HashMap; use chrono::{DateTime, TimeDelta, Utc}; use jsonwebtoken::{Algorithm, EncodingKey, Header, encode}; +use serde::{Deserialize, Serialize}; use tokio::time::sleep; +use tracing::{debug, info}; use super::{ DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, diff --git a/crates/defguard_core/src/enterprise/directory_sync/jumpcloud.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/jumpcloud.rs similarity index 99% rename from crates/defguard_core/src/enterprise/directory_sync/jumpcloud.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/jumpcloud.rs index 93d8f7828..92942780b 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/jumpcloud.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/jumpcloud.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use tokio::time::sleep; +use serde::Deserialize; +use tracing::{debug, error, info, warn}; use super::{ DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, diff --git a/crates/defguard_core/src/enterprise/directory_sync/microsoft.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/microsoft.rs similarity index 99% rename from crates/defguard_core/src/enterprise/directory_sync/microsoft.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/microsoft.rs index 7e02645e7..87f932baa 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/microsoft.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/microsoft.rs @@ -1,12 +1,13 @@ use chrono::{TimeDelta, Utc}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use tokio::time::sleep; +use tracing::{debug, info, warn}; use super::{ DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, make_get_request, parse_response, }; -use crate::enterprise::directory_sync::{DirectoryUserDetails, REQUEST_TIMEOUT}; +use super::{DirectoryUserDetails, REQUEST_TIMEOUT}; pub(crate) struct MicrosoftDirectorySync { access_token: Option, diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/mod.rs similarity index 89% rename from crates/defguard_core/src/enterprise/directory_sync/mod.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/mod.rs index 1bb93a409..d16025c29 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/mod.rs @@ -6,33 +6,27 @@ use std::{ use defguard_common::db::{ Id, - models::{Settings, group::Group, user::User}, + models::{Settings, group::Group, settings::OpenIdUsernameHandling, user::User}, }; -use paste::paste; -use reqwest::header::AUTHORIZATION; -use sqlx::{PgConnection, PgPool, error::Error as SqlxError}; -use thiserror::Error; -use tokio::sync::broadcast::Sender; - -use super::{ - db::models::openid_provider::{DirectorySyncTarget, OpenIdProvider}, - ldap::utils::ldap_update_users_state, +use defguard_enterprise_db::models::openid_provider::{ + DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProvider, }; #[cfg(not(test))] -use crate::enterprise::is_business_license_active; -use crate::{ - enterprise::{ - db::models::openid_provider::DirectorySyncUserBehavior, - handlers::openid_login::prune_username, - ldap::{ - model::ldap_sync_allowed_for_user, - utils::{ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups}, - }, +use defguard_enterprise_license::is_business_license_active; +use defguard_enterprise_ldap::{ + model::ldap_sync_allowed_for_user, + utils::{ + ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups, + ldap_update_users_state, }, - grpc::GatewayEvent, - handlers::user::check_username, - user_management::{delete_user_and_cleanup_devices, disable_user, sync_allowed_user_devices}, }; +use futures::future::BoxFuture; +use paste::paste; +use reqwest::header::AUTHORIZATION; +use serde::{Deserialize, Serialize}; +use sqlx::{PgConnection, PgPool, error::Error as SqlxError}; +use thiserror::Error; +use tracing::{debug, error, info}; const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); const REQUEST_PAGINATION_SLOWDOWN: Duration = Duration::from_millis(100); @@ -75,6 +69,36 @@ pub enum DirectorySyncError { MultipleUsersFound(String), } +type DisableUserFn = dyn for<'b> Fn( + &'b mut User, + &'b mut PgConnection, + ) -> BoxFuture<'b, Result<(), DirectorySyncError>> + + Send + + Sync + + 'static; + +type DeleteUserFn = dyn for<'b> Fn( + User, + &'b mut PgConnection, + ) -> BoxFuture<'b, Result<(), DirectorySyncError>> + + Send + + Sync + + 'static; + +type SyncAllowedDevicesFn = dyn for<'b> Fn( + &'b User, + &'b mut PgConnection, + ) -> BoxFuture<'b, Result<(), DirectorySyncError>> + + Send + + Sync + + 'static; + +pub struct DirectorySyncContext { + pub disable_user: Box, + pub delete_user_and_cleanup_devices: Box, + pub sync_allowed_user_devices: Box, +} + impl From for DirectorySyncError { fn from(err: reqwest::Error) -> Self { if err.is_decode() { @@ -91,6 +115,69 @@ impl From for DirectorySyncError { } } +fn check_username(username: &str) -> Result<(), DirectorySyncError> { + const MAX_USERNAME_CHARS: usize = 64; + let length = username.len(); + if !(1..MAX_USERNAME_CHARS).contains(&length) { + return Err(DirectorySyncError::UserCreateError(format!( + "Username ({username}) has incorrect length" + ))); + } + + if let Some(first_char) = username.chars().next() { + if !first_char.is_ascii_alphanumeric() { + return Err(DirectorySyncError::UserCreateError( + "Username must not start with a special character".into(), + )); + } + } + + if !username + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') + { + return Err(DirectorySyncError::UserCreateError( + "Username contains invalid characters".into(), + )); + } + + Ok(()) +} + +#[must_use] +fn prune_username(username: &str, handling: OpenIdUsernameHandling) -> String { + let mut result = username.to_string(); + + result = result + .trim_start_matches(|c: char| !c.is_ascii_alphanumeric()) + .to_string(); + + let is_char_valid = |c: char| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_'; + + match handling { + OpenIdUsernameHandling::RemoveForbidden => { + result.retain(&is_char_valid); + } + OpenIdUsernameHandling::ReplaceForbidden => { + result = result + .chars() + .map(|c| if is_char_valid(c) { c } else { '_' }) + .collect(); + } + OpenIdUsernameHandling::PruneEmailDomain => { + if let Some(at_index) = result.find('@') { + result.truncate(at_index); + } + result = result + .chars() + .map(|c| if is_char_valid(c) { c } else { '_' }) + .collect(); + } + } + + result +} + pub mod google; pub mod jumpcloud; pub mod microsoft; @@ -322,7 +409,7 @@ async fn sync_user_groups( directory_sync: &T, user: &User, pool: &PgPool, - wg_tx: &Sender, + context: &DirectorySyncContext, ) -> Result<(), DirectorySyncError> { info!("Syncing groups of user {} with the directory", user.email); let directory_groups = directory_sync.get_user_groups(&user.email).await?; @@ -369,13 +456,13 @@ async fn sync_user_groups( } } - sync_allowed_user_devices(user, &mut transaction, wg_tx) + (context.sync_allowed_user_devices)(user, &mut transaction) .await .map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( - "Failed to sync allowed devices for user {} during directory synchronization: {err}", - user.email - )) + "Failed to sync allowed devices for user {} during directory synchronization: {err}", + user.email + )) })?; transaction.commit().await?; @@ -390,9 +477,7 @@ async fn sync_user_groups( Ok(()) } -pub(crate) async fn test_directory_sync_connection( - pool: &PgPool, -) -> Result<(), DirectorySyncError> { +pub async fn test_directory_sync_connection(pool: &PgPool) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { debug!("Enterprise is not enabled, skipping testing directory sync connection"); @@ -416,7 +501,7 @@ pub(crate) async fn test_directory_sync_connection( pub async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, - wg_tx: &Sender, + context: &DirectorySyncContext, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -433,7 +518,7 @@ pub async fn sync_user_groups_if_configured( match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; - sync_user_groups(&dir_sync, user, pool, wg_tx).await?; + sync_user_groups(&dir_sync, user, pool, context).await?; } Err(err) => { error!("Failed to build directory sync client: {err}"); @@ -480,7 +565,7 @@ async fn create_and_add_to_group( async fn sync_all_users_groups( directory_sync: &T, pool: &PgPool, - wg_tx: &Sender, + context: &DirectorySyncContext, all_users: Option<&[DirectoryUser]>, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' groups with the directory, this may take a while..."); @@ -577,7 +662,9 @@ async fn sync_all_users_groups( create_and_add_to_group(&user, group, pool).await?; } - sync_allowed_user_devices(&user, &mut transaction, wg_tx).await.map_err(|err| { + (context.sync_allowed_user_devices)(&user, &mut transaction) + .await + .map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( "Failed to sync allowed devices for user {} during directory synchronization: {err}", user.email @@ -616,7 +703,7 @@ fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { async fn sync_all_users_state( pool: &PgPool, - wg_tx: &Sender, + context: &DirectorySyncContext, all_users: &[DirectoryUser], ) -> Result<(), DirectorySyncError> { info!("Syncing all users' state with the directory, this may take a while..."); @@ -649,7 +736,7 @@ async fn sync_all_users_state( &mut transaction, &inactive_directory_users, &mut modified_users, - wg_tx, + context, ) .await?; @@ -769,7 +856,9 @@ async fn sync_all_users_state( the admin behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + (context.disable_user)(&mut user, &mut transaction) + .await + .map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable admin {} during directory synchronization: {err}", user.email @@ -799,7 +888,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + (context.delete_user_and_cleanup_devices)(user, &mut transaction) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -824,7 +913,9 @@ async fn sync_all_users_state( the user behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + (context.disable_user)(&mut user, &mut transaction) + .await + .map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable user {} during directory synchronization: {err}", user.email @@ -846,7 +937,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + (context.delete_user_and_cleanup_devices)(user, &mut transaction) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -883,7 +974,7 @@ async fn sync_inactive_directory_users( transaction: &mut PgConnection, inactive_directory_users: &[&DirectoryUser], modified_users: &mut Vec>, - wg_tx: &Sender, + context: &DirectorySyncContext, ) -> Result<(), DirectorySyncError> { // find all active Defguard users disabled in directory let disabled_users_emails = inactive_directory_users @@ -908,7 +999,7 @@ async fn sync_inactive_directory_users( "Disabling user {} because they are disabled in the directory", user.email ); - disable_user(&mut user, transaction, wg_tx) + (context.disable_user)(&mut user, transaction) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -970,7 +1061,7 @@ const DIRECTORY_SYNC_INTERVAL: u64 = 60 * 10; /// Used to inform the utility thread how often it should perform the directory sync job. /// See [`run_utility_thread`] for more details. -pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { +pub async fn get_directory_sync_interval(pool: &PgPool) -> u64 { if let Ok(Some(provider_settings)) = OpenIdProvider::get_current(pool).await { provider_settings .directory_sync_interval @@ -982,9 +1073,9 @@ pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { } // Performs the directory sync job. This function is called by the utility thread. -pub(crate) async fn do_directory_sync( +pub async fn do_directory_sync( pool: &PgPool, - wireguard_tx: &Sender, + context: &DirectorySyncContext, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -1021,7 +1112,7 @@ pub(crate) async fn do_directory_sync( DirectorySyncTarget::All | DirectorySyncTarget::Users ) { let users = dir_sync.get_all_users().await?; - sync_all_users_state(pool, wireguard_tx, &users).await?; + sync_all_users_state(pool, context, &users).await?; all_users = Some(users); } if matches!( @@ -1042,8 +1133,7 @@ pub(crate) async fn do_directory_sync( } _ => None, // No need to pass all users for other providers, for the time being. }; - sync_all_users_groups(&dir_sync, pool, wireguard_tx, users_to_pass.as_deref()) - .await?; + sync_all_users_groups(&dir_sync, pool, context, users_to_pass.as_deref()).await?; } } Err(err) => { diff --git a/crates/defguard_core/src/enterprise/directory_sync/okta.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/okta.rs similarity index 99% rename from crates/defguard_core/src/enterprise/directory_sync/okta.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/okta.rs index 569f4ee47..fbf49b7eb 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/okta.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/okta.rs @@ -3,13 +3,15 @@ use std::str::FromStr; use chrono::{DateTime, TimeDelta, Utc}; use jsonwebtoken::{Algorithm, EncodingKey, Header, encode}; use parse_link_header::parse_with_rel; +use serde::{Deserialize, Serialize}; use tokio::time::sleep; +use tracing::{debug, info}; use super::{ DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser, REQUEST_PAGINATION_SLOWDOWN, parse_response, }; -use crate::enterprise::directory_sync::make_get_request; +use super::make_get_request; // Okta suggests using the maximum limit of 200 for the number of results per page. // If this is an issue, we would need to add resource pagination. diff --git a/crates/defguard_core/src/enterprise/directory_sync/testprovider.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/testprovider.rs similarity index 90% rename from crates/defguard_core/src/enterprise/directory_sync/testprovider.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/testprovider.rs index b73d5abbe..0005aa506 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/testprovider.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/testprovider.rs @@ -53,7 +53,7 @@ impl DirectorySync for TestProviderDirectorySync { email: "testuser@email.com".into(), active: true, id: Some("testuser-id".into()), - user_details: Some(crate::enterprise::directory_sync::DirectoryUserDetails { + user_details: Some(super::DirectoryUserDetails { last_name: "User".into(), first_name: "Test".into(), phone_number: None, @@ -63,7 +63,7 @@ impl DirectorySync for TestProviderDirectorySync { email: "testuserdisabled@email.com".into(), active: false, id: Some("testuserdisabled-id".into()), - user_details: Some(crate::enterprise::directory_sync::DirectoryUserDetails { + user_details: Some(super::DirectoryUserDetails { last_name: "UserDisabled".into(), first_name: "Test".into(), phone_number: None, @@ -73,7 +73,7 @@ impl DirectorySync for TestProviderDirectorySync { email: "testuser2@email.com".into(), active: true, id: Some("testuser2-id".into()), - user_details: Some(crate::enterprise::directory_sync::DirectoryUserDetails { + user_details: Some(super::DirectoryUserDetails { last_name: "User2".into(), first_name: "Test".into(), phone_number: None, diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/tests.rs similarity index 81% rename from crates/defguard_core/src/enterprise/directory_sync/tests.rs rename to enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/tests.rs index 671a58e3c..87fbe9faa 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/directory_sync/tests.rs @@ -15,11 +15,73 @@ mod test { }; use ipnetwork::IpNetwork; use secrecy::ExposeSecret; + use defguard_common::db::models::device::DeviceInfo; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::broadcast; use super::super::*; - use crate::enterprise::db::models::openid_provider::{DirectorySyncTarget, OpenIdProviderKind}; + use defguard_enterprise_db::models::openid_provider::{ + DirectorySyncTarget, OpenIdProviderKind, + }; + + #[derive(Clone, Debug)] + enum TestGatewayEvent { + DeviceDeleted(DeviceInfo), + DeviceCreated(DeviceInfo), + } + + fn build_test_context( + wg_tx: broadcast::Sender, + ) -> DirectorySyncContext { + let disable_tx = wg_tx.clone(); + let delete_tx = wg_tx.clone(); + let sync_tx = wg_tx.clone(); + DirectorySyncContext { + disable_user: Box::new(move |user: &mut User, conn| { + let disable_tx = disable_tx.clone(); + Box::pin(async move { + user.is_active = false; + user.save(&mut *conn).await?; + user.logout_all_sessions(&mut *conn).await?; + let devices = user.devices(&mut *conn).await?; + for device in devices { + let info = DeviceInfo::from_device(&mut *conn, device) + .await + .map_err(|err| DirectorySyncError::UserUpdateError(err.to_string()))?; + let _ = disable_tx.send(TestGatewayEvent::DeviceDeleted(info)); + } + Ok(()) + }) + }), + delete_user_and_cleanup_devices: Box::new(move |user: User, conn| { + let delete_tx = delete_tx.clone(); + Box::pin(async move { + let devices = user.devices(&mut *conn).await?; + for device in devices { + let info = DeviceInfo::from_device(&mut *conn, device) + .await + .map_err(|err| DirectorySyncError::UserUpdateError(err.to_string()))?; + let _ = delete_tx.send(TestGatewayEvent::DeviceDeleted(info)); + } + user.delete(&mut *conn).await?; + Ok(()) + }) + }), + sync_allowed_user_devices: Box::new(move |user: &User, conn| { + let sync_tx = sync_tx.clone(); + Box::pin(async move { + let devices = user.devices(&mut *conn).await?; + for device in devices { + let info = DeviceInfo::from_device(&mut *conn, device) + .await + .map_err(|err| DirectorySyncError::UserUpdateError(err.to_string()))?; + let _ = sync_tx.send(TestGatewayEvent::DeviceCreated(info)); + } + Ok(()) + }) + }), + } + } async fn get_test_network(pool: &PgPool) -> WireguardNetwork { WireguardNetwork::find_by_name(pool, "test") @@ -140,7 +202,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -161,7 +223,8 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); @@ -180,7 +243,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -202,7 +265,8 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); @@ -211,7 +275,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2.id); } else { panic!("Expected a DeviceDeleted event"); @@ -223,7 +287,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); User::init_admin_user(&pool, config.default_admin_password.expose_secret()) .await .unwrap(); @@ -250,7 +314,8 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); @@ -267,7 +332,7 @@ mod test { // Check that we received a device deleted event for whichever admin was removed let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event { assert!(dev.device.user_id == user1.id || dev.device.user_id == user3.id); } else { panic!("Expected a DeviceDeleted event"); @@ -280,7 +345,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -306,7 +371,8 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); @@ -323,7 +389,7 @@ mod test { // Check for device deletion events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -334,7 +400,7 @@ mod test { } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -351,7 +417,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Disable, @@ -393,20 +459,21 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); // Check for device disconnection events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event1 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event2 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); @@ -434,7 +501,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -466,13 +533,14 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + let context = build_test_context(wg_tx.clone()); + sync_all_users_state(&pool, &context, &all_users) .await .unwrap(); // Check for device disconnection events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -483,7 +551,7 @@ mod test { } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -512,7 +580,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -528,7 +596,8 @@ mod test { make_test_user_and_device("testuser2", &pool).await; make_test_user_and_device("testuserdisabled", &pool).await; let all_users = client.get_all_users().await.unwrap(); - sync_all_users_groups(&client, &pool, &wg_tx, Some(&all_users)) + let context = build_test_context(wg_tx.clone()); + sync_all_users_groups(&client, &pool, &context, Some(&all_users)) .await .unwrap(); @@ -569,7 +638,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -583,7 +652,8 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - sync_user_groups_if_configured(&user, &pool, &wg_tx) + let context = build_test_context(wg_tx.clone()); + sync_user_groups_if_configured(&user, &pool, &context) .await .unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); @@ -598,7 +668,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -612,7 +682,8 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); } @@ -623,7 +694,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -649,24 +720,26 @@ mod test { let user2_pre_sync = make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; assert!(user2.is_none()); let mut transaction = pool.begin().await.unwrap(); - sync_allowed_user_devices(&user, &mut transaction, &wg_tx) + let context = build_test_context(wg_tx.clone()); + (context.sync_allowed_user_devices)(&user, &mut transaction) .await .unwrap(); transaction.commit().await.unwrap(); let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(TestGatewayEvent::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2_pre_sync.id); } else { panic!("Expected a DeviceDeleted event"); } let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceCreated(dev)) = event { + if let Ok(TestGatewayEvent::DeviceCreated(dev)) = event { assert_eq!(dev.device.user_id, user.id); } else { panic!("Expected a DeviceDeleted event"); @@ -679,7 +752,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -694,7 +767,8 @@ mod test { make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; @@ -707,7 +781,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -727,7 +801,8 @@ mod test { assert_eq!(user_groups.len(), 1); assert!(user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); // He should still be an admin as it's the last one assert!(user.is_admin(&pool).await.unwrap()); @@ -736,7 +811,8 @@ mod test { let user2 = make_test_user_and_device("testuser2", &pool).await; user2.add_to_group(&pool, &admin_grp).await.unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); let admins = User::find_admins(&pool).await.unwrap(); // There should be only one admin left @@ -745,7 +821,8 @@ mod test { let defguard_user = make_test_user_and_device("defguard", &pool).await; make_admin(&pool, &defguard_user).await; - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); } #[sqlx::test] @@ -754,7 +831,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -771,7 +848,8 @@ mod test { make_admin(&pool, &defguard_user).await; assert!(defguard_user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); // The user should still be an admin assert!(defguard_user.is_admin(&pool).await.unwrap()); @@ -783,7 +861,8 @@ mod test { .await .unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); let user = User::find_by_username(&pool, "defguard").await.unwrap(); assert!(user.is_none()); } @@ -794,7 +873,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // disable prefetching users make_test_provider( @@ -812,7 +891,8 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); // no users in Defguard after sync let defguard_users = User::all(&pool).await.unwrap(); @@ -828,7 +908,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -846,7 +926,8 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + let context = build_test_context(wg_tx.clone()); + do_directory_sync(&pool, &context).await.unwrap(); // all active directory users were synced let defguard_users = User::all(&pool).await.unwrap(); diff --git a/enterprise/crates/defguard_enterprise_directory_sync/src/lib.rs b/enterprise/crates/defguard_enterprise_directory_sync/src/lib.rs new file mode 100644 index 000000000..31be723a2 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_directory_sync/src/lib.rs @@ -0,0 +1,3 @@ +pub mod directory_sync; + +pub use directory_sync::*; diff --git a/enterprise/crates/defguard_enterprise_firewall/Cargo.toml b/enterprise/crates/defguard_enterprise_firewall/Cargo.toml new file mode 100644 index 000000000..b0222442f --- /dev/null +++ b/enterprise/crates/defguard_enterprise_firewall/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "defguard_enterprise_firewall" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +defguard_enterprise_db = { workspace = true } +defguard_enterprise_license = { workspace = true } +defguard_proto = { workspace = true } +ipnetwork = { workspace = true } +sqlx = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +chrono = { workspace = true } +rand = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/mod.rs similarity index 99% rename from crates/defguard_core/src/enterprise/firewall/mod.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/mod.rs index afe826623..9b5cdc34d 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/mod.rs @@ -14,18 +14,17 @@ use defguard_proto::enterprise::firewall::{ }; use ipnetwork::IpNetwork; use sqlx::{Error as SqlxError, PgConnection, query_as, query_scalar}; +use tracing::{debug, info}; -use super::{ - db::models::acl::{ - AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AclRuleInfo, PortRange, - Protocol, +use super::utils::merge_ranges; +use defguard_enterprise_db::models::{ + acl::{ + AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AclRuleInfo, + PortRange, Protocol, }, - utils::merge_ranges, -}; -use crate::enterprise::{ - db::models::{acl::AclAlias, snat::UserSnatBinding}, - is_business_license_active, + snat::UserSnatBinding, }; +use defguard_enterprise_license::is_business_license_active; #[derive(Debug, thiserror::Error)] pub enum FirewallError { @@ -205,6 +204,7 @@ async fn get_source_ips( } /// Generates firewall rules for destination manually specified in ACL rule. +#[allow(clippy::too_many_arguments)] async fn get_manual_destination_rules( conn: &mut PgConnection, rule_id: Id, diff --git a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/all_locations.rs similarity index 97% rename from crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/all_locations.rs index fecc62ee7..4718e25b4 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/all_locations.rs @@ -5,13 +5,11 @@ use ipnetwork::IpNetwork; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use crate::enterprise::{ - db::models::acl::{AclRule, AclRuleNetwork, RuleState}, - firewall::{ - tests::{create_test_users_and_devices, set_test_license_business}, - try_get_location_firewall_config, - }, +use crate::firewall::{ + tests::{create_test_users_and_devices, set_test_license_business}, + try_get_location_firewall_config, }; +use defguard_enterprise_db::models::acl::{AclRule, AclRuleNetwork, RuleState}; #[sqlx::test] async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/destination.rs similarity index 72% rename from crates/defguard_core/src/enterprise/firewall/tests/destination.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/destination.rs index fd8e62bab..6920db8d0 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/destination.rs @@ -1,10 +1,10 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use defguard_proto::enterprise::firewall::{IpAddress, IpRange, ip_address::Address}; +use defguard_proto::enterprise::firewall::{ip_address::Address, IpAddress, IpRange}; -use crate::enterprise::{ - db::models::acl::AclRuleDestinationRange, firewall::process_destination_addrs, -}; +use crate::firewall::process_destination_addrs; +use crate::firewall::tests::default_destination_range_with_values; +use defguard_common::db::Id; #[test] fn test_process_destination_addrs_v4() { @@ -17,16 +17,16 @@ fn test_process_destination_addrs_v4() { ]; let destination_ranges = [ - AclRuleDestinationRange { - start: IpAddr::V4(Ipv4Addr::new(10, 0, 3, 255)), - end: IpAddr::V4(Ipv4Addr::new(10, 0, 4, 0)), - ..Default::default() - }, - AclRuleDestinationRange { - start: IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), // Should be filtered out - end: IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 100)), - ..Default::default() - }, + default_destination_range_with_values( + Id::default(), + IpAddr::V4(Ipv4Addr::new(10, 0, 3, 255)), + IpAddr::V4(Ipv4Addr::new(10, 0, 4, 0)), + ), + default_destination_range_with_values( + Id::default(), + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 100)), + ), ]; let destination_addrs = process_destination_addrs(&destination_ips, &destination_ranges); @@ -72,16 +72,16 @@ fn test_process_destination_addrs_v6() { ]; let destination_ranges = vec![ - AclRuleDestinationRange { - start: IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 4, 0, 0, 0, 0, 1)), - end: IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 4, 0, 0, 0, 0, 3)), - ..Default::default() - }, - AclRuleDestinationRange { - start: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), // Should be filtered out - end: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), - ..Default::default() - }, + default_destination_range_with_values( + Id::default(), + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 4, 0, 0, 0, 0, 1)), + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 4, 0, 0, 0, 0, 3)), + ), + default_destination_range_with_values( + Id::default(), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + ), ]; let destination_addrs = process_destination_addrs(&destination_ips, &destination_ranges); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/disabled_rules.rs similarity index 96% rename from crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/disabled_rules.rs index b4d6ddfee..36d6ead83 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/disabled_rules.rs @@ -5,13 +5,11 @@ use ipnetwork::IpNetwork; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use crate::enterprise::{ - db::models::acl::{AclRule, AclRuleNetwork, RuleState}, - firewall::{ - tests::{create_test_users_and_devices, set_test_license_business}, - try_get_location_firewall_config, - }, +use crate::firewall::{ + tests::{create_test_users_and_devices, set_test_license_business}, + try_get_location_firewall_config, }; +use defguard_enterprise_db::models::acl::{AclRule, AclRuleNetwork, RuleState}; #[sqlx::test] async fn test_disabled_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/expired_rules.rs similarity index 97% rename from crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/expired_rules.rs index aa4c21a7d..f7321b73a 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/expired_rules.rs @@ -5,10 +5,8 @@ use defguard_common::db::{NoId, models::WireguardNetwork, setup_pool}; use ipnetwork::IpNetwork; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use crate::enterprise::{ - db::models::acl::{AclRule, AclRuleNetwork, RuleState}, - firewall::{tests::set_test_license_business, try_get_location_firewall_config}, -}; +use crate::firewall::{tests::set_test_license_business, try_get_location_firewall_config}; +use defguard_enterprise_db::models::acl::{AclRule, AclRuleNetwork, RuleState}; #[sqlx::test] async fn test_expired_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/gh1868.rs similarity index 98% rename from crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/gh1868.rs index c13ad6644..b3e881e62 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/gh1868.rs @@ -13,10 +13,8 @@ use sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, }; -use crate::enterprise::{ - db::models::acl::{AclRule, RuleState}, - firewall::{tests::set_test_license_business, try_get_location_firewall_config}, -}; +use crate::firewall::{tests::set_test_license_business, try_get_location_firewall_config}; +use defguard_enterprise_db::models::acl::{AclRule, RuleState}; async fn setup_user_and_device( rng: &mut ThreadRng, diff --git a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/ip_address_handling.rs similarity index 98% rename from crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/ip_address_handling.rs index 9b4650ad1..966ec2246 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/ip_address_handling.rs @@ -1,17 +1,15 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use defguard_proto::enterprise::firewall::{ - IpAddress, IpRange, Port, PortRange as PortRangeProto, ip_address::Address, - port::Port as PortInner, + ip_address::Address, port::Port as PortInner, IpAddress, IpRange, Port, + PortRange as PortRangeProto, }; use ipnetwork::Ipv6Network; -use crate::enterprise::{ - db::models::acl::PortRange, - firewall::{ - find_largest_subnet_in_range, get_last_ip_in_v6_subnet, merge_addrs, merge_port_ranges, - }, +use crate::firewall::{ + find_largest_subnet_in_range, get_last_ip_in_v6_subnet, merge_addrs, merge_port_ranges, }; +use defguard_enterprise_db::models::acl::PortRange; #[test] fn test_merge_v4_addrs() { diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/mod.rs similarity index 99% rename from crates/defguard_core/src/enterprise/firewall/tests/mod.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/mod.rs index 40e6d233b..a3e4ae0bb 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/mod.rs @@ -20,14 +20,12 @@ use sqlx::{ query, }; -use crate::enterprise::{ - db::models::acl::{ - AclAlias, AclRule, AclRuleAlias, AclRuleDestinationRange, AclRuleDevice, AclRuleGroup, - AclRuleInfo, AclRuleNetwork, AclRuleUser, AliasKind, PortRange, RuleState, - }, - firewall::try_get_location_firewall_config, - license::{License, LicenseTier, set_cached_license}, +use crate::firewall::try_get_location_firewall_config; +use defguard_enterprise_db::models::acl::{ + AclAlias, AclRule, AclRuleAlias, AclRuleDestinationRange, AclRuleDevice, AclRuleGroup, + AclRuleInfo, AclRuleNetwork, AclRuleUser, AliasKind, PortRange, RuleState, }; +use defguard_enterprise_license::{License, LicenseTier, set_cached_license}; mod all_locations; mod destination; @@ -38,14 +36,16 @@ mod ip_address_handling; mod source; mod unapplied_rules; -impl Default for AclRuleDestinationRange { - fn default() -> Self { - Self { - id: Id::default(), - rule_id: Id::default(), - start: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - end: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - } +fn default_destination_range_with_values( + rule_id: Id, + start: IpAddr, + end: IpAddr, +) -> AclRuleDestinationRange { + AclRuleDestinationRange { + id: Id::default(), + rule_id, + start, + end, } } @@ -139,6 +139,7 @@ async fn create_test_users_and_devices( } } +#[allow(clippy::too_many_arguments)] async fn create_acl_rule( pool: &PgPool, rule: AclRule, diff --git a/crates/defguard_core/src/enterprise/firewall/tests/source.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/source.rs similarity index 97% rename from crates/defguard_core/src/enterprise/firewall/tests/source.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/source.rs index 063e18b37..473567b82 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/source.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/source.rs @@ -1,9 +1,9 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use defguard_proto::enterprise::firewall::{IpAddress, IpVersion, ip_address::Address}; +use defguard_proto::enterprise::firewall::{ip_address::Address, IpAddress, IpVersion}; use rand::thread_rng; -use crate::enterprise::firewall::{ +use crate::firewall::{ get_source_addrs, get_source_network_devices, get_source_users, tests::{random_network_device_with_id, random_user_with_id}, }; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/unapplied_rules.rs similarity index 96% rename from crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs rename to enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/unapplied_rules.rs index 10bf3393d..54f2e218d 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs +++ b/enterprise/crates/defguard_enterprise_firewall/src/firewall/tests/unapplied_rules.rs @@ -5,13 +5,11 @@ use ipnetwork::IpNetwork; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use crate::enterprise::{ - db::models::acl::{AclRule, AclRuleNetwork, RuleState}, - firewall::{ - tests::{create_test_users_and_devices, set_test_license_business}, - try_get_location_firewall_config, - }, +use crate::firewall::{ + tests::{create_test_users_and_devices, set_test_license_business}, + try_get_location_firewall_config, }; +use defguard_enterprise_db::models::acl::{AclRule, AclRuleNetwork, RuleState}; #[sqlx::test] async fn test_unapplied_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/enterprise/crates/defguard_enterprise_firewall/src/lib.rs b/enterprise/crates/defguard_enterprise_firewall/src/lib.rs new file mode 100644 index 000000000..5b67e7868 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_firewall/src/lib.rs @@ -0,0 +1,4 @@ +mod utils; +pub mod firewall; + +pub use firewall::*; diff --git a/crates/defguard_core/src/enterprise/utils.rs b/enterprise/crates/defguard_enterprise_firewall/src/utils.rs similarity index 100% rename from crates/defguard_core/src/enterprise/utils.rs rename to enterprise/crates/defguard_enterprise_firewall/src/utils.rs diff --git a/enterprise/crates/defguard_enterprise_ldap/Cargo.toml b/enterprise/crates/defguard_enterprise_ldap/Cargo.toml new file mode 100644 index 000000000..c7d157e35 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_ldap/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "defguard_enterprise_ldap" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +defguard_enterprise_license = { workspace = true } +base64 = { workspace = true } +ldap3 = { workspace = true } +md4 = { workspace = true } +rand = { workspace = true } +sha1 = { version = "0.10", package = "sha-1" } +sqlx = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/ldap/client.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/client.rs similarity index 99% rename from crates/defguard_core/src/enterprise/ldap/client.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/client.rs index 35567bfd6..b46213279 100644 --- a/crates/defguard_core/src/enterprise/ldap/client.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/client.rs @@ -11,10 +11,11 @@ use ldap3::{ }; use super::error::LdapError; -use crate::enterprise::ldap::model::extract_rdn_value; +use crate::ldap::model::extract_rdn_value; +use tracing::{debug, info, warn}; impl super::LDAPConnection { - pub(crate) async fn create() -> Result { + pub async fn create() -> Result { let settings = Settings::get_current_settings(); let config = super::LDAPConfig::try_from(settings.clone())?; let url = settings.ldap_url.ok_or(LdapError::MissingSettings( diff --git a/crates/defguard_core/src/enterprise/ldap/error.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/error.rs similarity index 100% rename from crates/defguard_core/src/enterprise/ldap/error.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/error.rs diff --git a/crates/defguard_core/src/enterprise/ldap/hash.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/hash.rs similarity index 100% rename from crates/defguard_core/src/enterprise/ldap/hash.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/hash.rs diff --git a/crates/defguard_core/src/enterprise/ldap/mod.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/mod.rs similarity index 98% rename from crates/defguard_core/src/enterprise/ldap/mod.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/mod.rs index fde3c1285..577921939 100644 --- a/crates/defguard_core/src/enterprise/ldap/mod.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/mod.rs @@ -15,16 +15,14 @@ use model::UserObjectClass; use rand::Rng; use sqlx::PgPool; use sync::{get_ldap_sync_status, is_ldap_desynced, set_ldap_sync_status}; +use tracing::{debug, info, warn}; use self::error::LdapError; -use crate::enterprise::{ - is_business_license_active, - ldap::model::{ - extract_dn_path, ldap_sync_allowed_for_user, user_as_ldap_attrs, user_as_ldap_mod, - user_from_searchentry, - }, - limits::update_counts, +use crate::ldap::model::{ + extract_dn_path, ldap_sync_allowed_for_user, user_as_ldap_attrs, user_as_ldap_mod, + user_from_searchentry, }; +use defguard_enterprise_license::{is_business_license_active, update_counts}; #[cfg(not(test))] pub mod client; @@ -38,13 +36,13 @@ pub mod utils; #[cfg(test)] fn set_test_license_business() { - use crate::enterprise::license::set_cached_license; + use defguard_enterprise_license::set_cached_license; - let license = crate::enterprise::license::License { + let license = defguard_enterprise_license::License { customer_id: "0c4dcb5400544d47ad8617fcdf2704cb".into(), limits: None, subscription: false, - tier: crate::enterprise::license::LicenseTier::Enterprise, + tier: defguard_enterprise_license::LicenseTier::Enterprise, valid_until: None, version_date_limit: None, }; @@ -55,7 +53,7 @@ fn set_test_license_business() { /// /// This function may trigger either full and incremental sync based on the current sync status. /// Sets LDAP sync status to OutOfSync if any errors occur during the process. -pub(crate) async fn do_ldap_sync(pool: &PgPool) -> Result<(), LdapError> { +pub async fn do_ldap_sync(pool: &PgPool) -> Result<(), LdapError> { debug!("Starting LDAP sync, if enabled"); let mut settings = Settings::get_current_settings(); diff --git a/crates/defguard_core/src/enterprise/ldap/model.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/model.rs similarity index 93% rename from crates/defguard_core/src/enterprise/ldap/model.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/model.rs index a99d5709e..6ac5d0d91 100644 --- a/crates/defguard_core/src/enterprise/ldap/model.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/model.rs @@ -8,7 +8,10 @@ use ldap3::{Mod, SearchEntry}; use sqlx::{Error as SqlxError, PgExecutor}; use super::{LDAPConfig, error::LdapError}; -use crate::{handlers::user::check_username, hashset}; +use crate::hashset; +use tracing::{debug, warn}; + +const MAX_USERNAME_CHARS: usize = 64; pub(crate) enum UserObjectClass { SambaSamAccount, @@ -17,6 +20,28 @@ pub(crate) enum UserObjectClass { User, } +fn check_username(username: &str) -> Result<(), LdapError> { + let length = username.len(); + if !(1..MAX_USERNAME_CHARS).contains(&length) { + return Err(LdapError::InvalidUsername(username.to_string())); + } + + if let Some(first_char) = username.chars().next() { + if !first_char.is_ascii_alphanumeric() { + return Err(LdapError::InvalidUsername(username.to_string())); + } + } + + if !username + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') + { + return Err(LdapError::InvalidUsername(username.to_string())); + } + + Ok(()) +} + impl<'a> From<&'a UserObjectClass> for &'static str { fn from(obj_class: &'a UserObjectClass) -> &'static str { match obj_class { @@ -244,7 +269,7 @@ pub fn user_as_ldap_attrs<'a, I>( } /// Updates the LDAP RDN value of the user in Defguard, if Defguard uses the usernames as RDN. -pub(crate) fn maybe_update_rdn(user: &mut User) { +pub fn maybe_update_rdn(user: &mut User) { debug!("Updating RDN for user {} in Defguard", user.username); let settings = Settings::get_current_settings(); if settings.ldap_using_username_as_rdn() { @@ -259,7 +284,7 @@ pub(crate) fn maybe_update_rdn(user: &mut User) { /// - he is in a group that is allowed to be synced or no such groups are configured /// - he is active (not disabled) /// - he is enrolled -pub(crate) async fn ldap_sync_allowed_for_user<'e, E>( +pub async fn ldap_sync_allowed_for_user<'e, E>( user: &User, executor: E, ) -> Result diff --git a/crates/defguard_core/src/enterprise/ldap/sync.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/sync.rs similarity index 99% rename from crates/defguard_core/src/enterprise/ldap/sync.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/sync.rs index 45416ec14..5bf3602a4 100644 --- a/crates/defguard_core/src/enterprise/ldap/sync.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/sync.rs @@ -65,8 +65,9 @@ use defguard_common::db::{ use sqlx::{PgConnection, PgPool}; use super::{LDAPConfig, error::LdapError}; +use tracing::{debug, trace, warn}; use crate::{ - enterprise::ldap::model::{ + ldap::model::{ get_users_without_ldap_path, ldap_sync_allowed_for_user, update_from_ldap_user, user_from_searchentry, }, diff --git a/crates/defguard_core/src/enterprise/ldap/test_client.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/test_client.rs similarity index 99% rename from crates/defguard_core/src/enterprise/ldap/test_client.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/test_client.rs index 907a6bd4d..6721d8ba3 100644 --- a/crates/defguard_core/src/enterprise/ldap/test_client.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/test_client.rs @@ -8,7 +8,7 @@ use defguard_common::db::models::{User, group::Group}; use ldap3::{Mod, SearchEntry}; use super::error::LdapError; -use crate::enterprise::ldap::model::{extract_rdn_value, user_as_ldap_attrs}; +use crate::ldap::model::{extract_rdn_value, user_as_ldap_attrs}; /// Extract attribute value from LDAP filter /// diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/tests.rs similarity index 99% rename from crates/defguard_core/src/enterprise/ldap/tests.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/tests.rs index d8b01b4c2..c23f0ae86 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/tests.rs @@ -5,7 +5,7 @@ use ldap3::SearchEntry; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; -use crate::enterprise::ldap::{ +use crate::ldap::{ model::{extract_rdn_value, get_users_without_ldap_path, user_from_searchentry}, sync::{ Authority, compute_group_sync_changes, compute_user_sync_changes, diff --git a/crates/defguard_core/src/enterprise/ldap/utils.rs b/enterprise/crates/defguard_enterprise_ldap/src/ldap/utils.rs similarity index 93% rename from crates/defguard_core/src/enterprise/ldap/utils.rs rename to enterprise/crates/defguard_enterprise_ldap/src/ldap/utils.rs index 2ae32db26..431253ef7 100644 --- a/crates/defguard_core/src/enterprise/ldap/utils.rs +++ b/enterprise/crates/defguard_enterprise_ldap/src/ldap/utils.rs @@ -10,12 +10,13 @@ use defguard_common::db::{ use sqlx::PgPool; use super::{LDAPConnection, error::LdapError}; -use crate::enterprise::ldap::{model::ldap_sync_allowed_for_user, with_ldap_status}; +use tracing::{debug, info, warn}; +use crate::ldap::{model::ldap_sync_allowed_for_user, with_ldap_status}; /// Retrieves a user from LDAP if they are in the configured LDAP sync groups. /// /// Creates a new user in Defguard if they do not exist and marks them as coming from LDAP. -pub(crate) async fn login_through_ldap( +pub async fn login_through_ldap( pool: &PgPool, username: &str, password: &str, @@ -63,7 +64,7 @@ pub async fn ldap_update_user_state(user: &mut User, pool: &PgPool) { } /// See the [`LDAPConnection::update_users_state`] function for details. -pub(crate) async fn ldap_update_users_state(users: Vec<&mut User>, pool: &PgPool) { +pub async fn ldap_update_users_state(users: Vec<&mut User>, pool: &PgPool) { let _ = Box::pin(with_ldap_status(pool, async { debug!("Updating users state in LDAP"); let mut ldap_connection = LDAPConnection::create().await?; @@ -123,7 +124,7 @@ pub async fn ldap_add_user(user: &mut User, password: Option<&str>, pool: &P /// Warning: This function does not check if the user is allowed to be synced to LDAP. You must do /// that manually before calling this function. For example, by calling the /// [`User::ldap_sync_allowed`] method on the user. -pub(crate) async fn ldap_handle_user_modify( +pub async fn ldap_handle_user_modify( old_username: &str, current_user: &mut User, pool: &PgPool, @@ -157,7 +158,7 @@ pub(crate) async fn ldap_handle_user_modify( /// For example, by calling the [`User::ldap_sync_allowed`] method on the user. // // The mentioned method can't be called here since the user is already dropped from the database -pub(crate) async fn ldap_delete_user(user: &User, pool: &PgPool) { +pub async fn ldap_delete_user(user: &User, pool: &PgPool) { ldap_delete_users(vec![user], pool).await; } @@ -168,7 +169,7 @@ pub(crate) async fn ldap_delete_user(user: &User, pool: &PgPool) { /// For example, by calling the [`User::ldap_sync_allowed`] method on each user. // // The mentioned method can't be called here since the user is already dropped from the database -pub(crate) async fn ldap_delete_users(users: Vec<&User>, pool: &PgPool) { +pub async fn ldap_delete_users(users: Vec<&User>, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Deleting {:?} users from LDAP", users.len()); let mut ldap_connection = LDAPConnection::create().await?; @@ -184,7 +185,7 @@ pub(crate) async fn ldap_delete_users(users: Vec<&User>, pool: &PgPool) { } /// Remove singular user from multiple groups in ldap. -pub(crate) async fn ldap_remove_user_from_groups( +pub async fn ldap_remove_user_from_groups( user: &User, groups: HashSet<&str>, pool: &PgPool, @@ -195,13 +196,13 @@ pub(crate) async fn ldap_remove_user_from_groups( /// Add singular user to multiple groups in LDAP. Convenience wrapper around /// [`ldap_add_users_to_groups`]. -pub(crate) async fn ldap_add_user_to_groups(user: &User, groups: HashSet<&str>, pool: &PgPool) { +pub async fn ldap_add_user_to_groups(user: &User, groups: HashSet<&str>, pool: &PgPool) { let map = HashMap::from([(user, groups)]); ldap_add_users_to_groups(map, pool).await; } /// Bulk add users to groups in ldap. -pub(crate) async fn ldap_add_users_to_groups( +pub async fn ldap_add_users_to_groups( user_groups: HashMap<&User, HashSet<&str>>, pool: &PgPool, ) { @@ -236,7 +237,7 @@ pub(crate) async fn ldap_add_users_to_groups( } /// Bulk remove users from groups in LDAP. -pub(crate) async fn ldap_remove_users_from_groups( +pub async fn ldap_remove_users_from_groups( user_groups: HashMap<&User, HashSet<&str>>, pool: &PgPool, ) { @@ -309,7 +310,7 @@ pub async fn ldap_change_password(user: &mut User, password: &str, pool: &Pg .await; } -pub(crate) async fn ldap_modify_group(groupname: &str, group: &Group, pool: &PgPool) { +pub async fn ldap_modify_group(groupname: &str, group: &Group, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Modifying group {groupname} in LDAP"); let mut ldap_connection = LDAPConnection::create().await?; @@ -318,7 +319,7 @@ pub(crate) async fn ldap_modify_group(groupname: &str, group: &Group, pool: .await; } -pub(crate) async fn ldap_delete_group(groupname: &str, pool: &PgPool) { +pub async fn ldap_delete_group(groupname: &str, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Deleting group {groupname} from LDAP"); let mut ldap_connection = LDAPConnection::create().await?; diff --git a/enterprise/crates/defguard_enterprise_ldap/src/lib.rs b/enterprise/crates/defguard_enterprise_ldap/src/lib.rs new file mode 100644 index 000000000..e6e5680d6 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_ldap/src/lib.rs @@ -0,0 +1,3 @@ +pub mod ldap; + +pub use ldap::*; diff --git a/enterprise/crates/defguard_enterprise_license/Cargo.toml b/enterprise/crates/defguard_enterprise_license/Cargo.toml new file mode 100644 index 000000000..54b9411ad --- /dev/null +++ b/enterprise/crates/defguard_enterprise_license/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "defguard_enterprise_license" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common = { workspace = true } +anyhow = { workspace = true } +base64 = { workspace = true } +chrono = { workspace = true } +humantime = { workspace = true } +pgp = { workspace = true } +prost = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +tonic = { workspace = true } +sqlx = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } + +[build-dependencies] +tonic-prost-build = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/LICENSE.md b/enterprise/crates/defguard_enterprise_license/LICENSE.md similarity index 100% rename from crates/defguard_core/src/enterprise/LICENSE.md rename to enterprise/crates/defguard_enterprise_license/LICENSE.md diff --git a/crates/defguard_core/build.rs b/enterprise/crates/defguard_enterprise_license/build.rs similarity index 60% rename from crates/defguard_core/build.rs rename to enterprise/crates/defguard_enterprise_license/build.rs index d3deeca18..c330d7ae3 100644 --- a/crates/defguard_core/build.rs +++ b/enterprise/crates/defguard_enterprise_license/build.rs @@ -5,10 +5,7 @@ fn main() -> Result<(), Box> { "LicenseLimits", "#[derive(serde::Serialize, serde::Deserialize)]", ) - .compile_protos( - &["src/enterprise/proto/license.proto"], - &["src/enterprise/proto"], - )?; - println!("cargo:rerun-if-changed=src/enterprise/proto"); + .compile_protos(&["proto/license.proto"], &["proto"])?; + println!("cargo:rerun-if-changed=proto"); Ok(()) } diff --git a/crates/defguard_core/src/enterprise/proto/license.proto b/enterprise/crates/defguard_enterprise_license/proto/license.proto similarity index 100% rename from crates/defguard_core/src/enterprise/proto/license.proto rename to enterprise/crates/defguard_enterprise_license/proto/license.proto diff --git a/enterprise/crates/defguard_enterprise_license/src/lib.rs b/enterprise/crates/defguard_enterprise_license/src/lib.rs new file mode 100644 index 000000000..0ec1ac264 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_license/src/lib.rs @@ -0,0 +1,119 @@ +mod license; +mod limits; + +pub mod proto { + pub mod enterprise { + pub mod license { + include!(concat!(env!("OUT_DIR"), "/enterprise.license.rs")); + } + } +} + +pub use license::*; +pub use limits::*; +use tracing::debug; + +/// Helper function to gate features which require a base license (Team or Business tier) +#[must_use] +pub fn is_business_license_active() -> bool { + is_license_tier_active(LicenseTier::Business) +} + +/// Helper function to gate features which require an Enterprise tier license +#[must_use] +pub fn is_enterprise_license_active() -> bool { + is_license_tier_active(LicenseTier::Enterprise) +} + +/// Shared logic for gating features to specific license tiers +fn is_license_tier_active(tier: LicenseTier) -> bool { + debug!("Checking if features for {tier} license tier should be enabled"); + + // get current object counts + let counts = crate::limits::get_counts(); + + let license = crate::license::get_cached_license(); + let validation_result = crate::license::validate_license(license.as_ref(), &counts, tier); + debug!("License validation result: {validation_result:?}"); + validation_result.is_ok() +} + +#[cfg(test)] +mod test { + use chrono::{TimeDelta, Utc}; + + use crate::{ + is_business_license_active, is_enterprise_license_active, + license::{License, LicenseTier, set_cached_license}, + limits::{Counts, set_counts}, + proto::enterprise::license::LicenseLimits, + }; + + #[test] + fn test_feature_gates_no_license() { + set_cached_license(None); + + let counts = Counts::new(1, 1, 1, 1); + set_counts(counts); + + assert!(!is_business_license_active()); + assert!(!is_enterprise_license_active()); + } + + #[test] + fn test_feature_gates_with_license() { + // exceed free limits + let counts = Counts::new(1, 1, 5, 1); + set_counts(counts); + + // set Business license + let users_limit = 15; + let devices_limit = 35; + let locations_limit = 5; + let network_devices_limit = 10; + + let limits = LicenseLimits { + users: users_limit, + devices: devices_limit, + locations: locations_limit, + network_devices: Some(network_devices_limit), + }; + let license = License::new( + "test".to_string(), + true, + Some(Utc::now() + TimeDelta::days(1)), + Some(limits), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + assert!(is_business_license_active()); + assert!(!is_enterprise_license_active()); + + // set Enterprise license + let users_limit = 15; + let devices_limit = 35; + let locations_limit = 5; + let network_devices_limit = 10; + + let limits = LicenseLimits { + users: users_limit, + devices: devices_limit, + locations: locations_limit, + network_devices: Some(network_devices_limit), + }; + let license = License::new( + "test".to_string(), + true, + Some(Utc::now() + TimeDelta::days(1)), + Some(limits), + None, + LicenseTier::Enterprise, + ); + set_cached_license(Some(license)); + + assert!(is_business_license_active()); + assert!(is_enterprise_license_active()); + } +} diff --git a/crates/defguard_core/src/enterprise/license.rs b/enterprise/crates/defguard_enterprise_license/src/license.rs similarity index 99% rename from crates/defguard_core/src/enterprise/license.rs rename to enterprise/crates/defguard_enterprise_license/src/license.rs index f0895c3a0..f58ce3ceb 100644 --- a/crates/defguard_core/src/enterprise/license.rs +++ b/enterprise/crates/defguard_enterprise_license/src/license.rs @@ -1,6 +1,5 @@ use std::{fmt, time::Duration}; -use anyhow::Result; use base64::prelude::*; use chrono::{DateTime, TimeDelta, Utc}; use defguard_common::{ @@ -16,11 +15,13 @@ use pgp::{ }; use prost::Message; use sqlx::{PgPool, error::Error as SqlxError}; +use tracing::{debug, error, info, instrument, warn}; use thiserror::Error; use tokio::time::sleep; +use serde::{Deserialize, Serialize}; use super::limits::Counts; -use crate::grpc::proto::enterprise::license::{ +use crate::proto::enterprise::license::{ LicenseKey, LicenseLimits, LicenseMetadata, LicenseTier as LicenseTierProto, }; diff --git a/crates/defguard_core/src/enterprise/limits.rs b/enterprise/crates/defguard_enterprise_license/src/limits.rs similarity index 94% rename from crates/defguard_core/src/enterprise/limits.rs rename to enterprise/crates/defguard_enterprise_license/src/limits.rs index c0c47c859..97411e28b 100644 --- a/crates/defguard_core/src/enterprise/limits.rs +++ b/enterprise/crates/defguard_enterprise_license/src/limits.rs @@ -1,5 +1,6 @@ use defguard_common::global_value; use sqlx::{PgPool, error::Error as SqlxError, query}; +use tracing::debug; use super::license::License; #[cfg(test)] @@ -101,23 +102,23 @@ impl Counts { } } - pub(crate) fn user(&self) -> u32 { + pub fn user(&self) -> u32 { self.user } - pub(crate) fn user_device(&self) -> u32 { + pub fn user_device(&self) -> u32 { self.user_device } - pub(crate) fn network_device(&self) -> u32 { + pub fn network_device(&self) -> u32 { self.network_device } - pub(crate) fn location(&self) -> u32 { + pub fn location(&self) -> u32 { self.location } - pub(crate) fn is_over_license_limits(&self, license: &License) -> bool { + pub fn is_over_license_limits(&self, license: &License) -> bool { let limits = &license.limits; match limits { Some(limits) => self.user > limits.users || self.location > limits.locations, @@ -133,8 +134,8 @@ mod test { use super::*; use crate::{ - enterprise::license::{License, LicenseTier, set_cached_license}, - grpc::proto::enterprise::license::LicenseLimits, + license::{License, LicenseTier, set_cached_license}, + proto::enterprise::license::LicenseLimits, }; #[test] diff --git a/crates/defguard_core/src/enterprise/public_key.asc b/enterprise/crates/defguard_enterprise_license/src/public_key.asc similarity index 100% rename from crates/defguard_core/src/enterprise/public_key.asc rename to enterprise/crates/defguard_enterprise_license/src/public_key.asc diff --git a/crates/defguard_core/src/enterprise/test_key.asc b/enterprise/crates/defguard_enterprise_license/src/test_key.asc similarity index 100% rename from crates/defguard_core/src/enterprise/test_key.asc rename to enterprise/crates/defguard_enterprise_license/src/test_key.asc diff --git a/enterprise/crates/defguard_enterprise_snat/Cargo.toml b/enterprise/crates/defguard_enterprise_snat/Cargo.toml new file mode 100644 index 000000000..d31b4286e --- /dev/null +++ b/enterprise/crates/defguard_enterprise_snat/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "defguard_enterprise_snat" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +sqlx = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/defguard_core/src/enterprise/snat/error.rs b/enterprise/crates/defguard_enterprise_snat/src/error.rs similarity index 56% rename from crates/defguard_core/src/enterprise/snat/error.rs rename to enterprise/crates/defguard_enterprise_snat/src/error.rs index 23dec61dc..173dff4bf 100644 --- a/crates/defguard_core/src/enterprise/snat/error.rs +++ b/enterprise/crates/defguard_enterprise_snat/src/error.rs @@ -1,7 +1,5 @@ use thiserror::Error; -use crate::error::WebError; - #[derive(Debug, Error)] pub enum UserSnatBindingError { #[error("Binding not found")] @@ -23,15 +21,3 @@ impl From for UserSnatBindingError { } } } - -impl From for WebError { - fn from(value: UserSnatBindingError) -> Self { - match value { - UserSnatBindingError::BindingNotFound => WebError::ObjectNotFound(value.to_string()), - UserSnatBindingError::BindingAlreadyExists => { - WebError::ObjectAlreadyExists(value.to_string()) - } - UserSnatBindingError::DbError { source } => WebError::DbError(source.to_string()), - } - } -} diff --git a/enterprise/crates/defguard_enterprise_snat/src/lib.rs b/enterprise/crates/defguard_enterprise_snat/src/lib.rs new file mode 100644 index 000000000..5df182655 --- /dev/null +++ b/enterprise/crates/defguard_enterprise_snat/src/lib.rs @@ -0,0 +1,3 @@ +pub mod error; + +pub use error::*;