|
3 | 3 | // SPDX-License-Identifier: AGPL-3.0-only |
4 | 4 | // |
5 | 5 |
|
| 6 | +use std::panic::{RefUnwindSafe, UnwindSafe}; |
| 7 | +use std::sync::Arc; |
| 8 | + |
6 | 9 | #[allow(unused_imports)] |
7 | 10 | use libsignal_protocol::SignalProtocolError; |
8 | 11 |
|
@@ -347,3 +350,74 @@ impl<'a> node::ResultTypeInfo<'a> for PanicOnReturn { |
347 | 350 | panic!("deliberate panic"); |
348 | 351 | } |
349 | 352 | } |
| 353 | + |
| 354 | +/// Counter for future cancellations |
| 355 | +pub struct TestingFutureCancellationCounter(pub(crate) Arc<tokio::sync::Semaphore>); |
| 356 | + |
| 357 | +impl UnwindSafe for TestingFutureCancellationCounter {} |
| 358 | +impl RefUnwindSafe for TestingFutureCancellationCounter {} |
| 359 | + |
| 360 | +/// RAII guard that increments a counter on `Drop`. |
| 361 | +/// |
| 362 | +/// This is bridged as a reference to a [`TestingFutureCancellationCounter`]. |
| 363 | +pub struct TestingFutureCancellationGuard { |
| 364 | + increment_on_drop: Arc<tokio::sync::Semaphore>, |
| 365 | +} |
| 366 | + |
| 367 | +impl Drop for TestingFutureCancellationGuard { |
| 368 | + fn drop(&mut self) { |
| 369 | + self.increment_on_drop.add_permits(1); |
| 370 | + } |
| 371 | +} |
| 372 | + |
| 373 | +#[cfg(feature = "ffi")] |
| 374 | +impl ffi::SimpleArgTypeInfo for TestingFutureCancellationGuard { |
| 375 | + type ArgType = <&'static TestingFutureCancellationCounter as ffi::SimpleArgTypeInfo>::ArgType; |
| 376 | + |
| 377 | + fn convert_from(foreign: Self::ArgType) -> ffi::SignalFfiResult<Self> { |
| 378 | + <&TestingFutureCancellationCounter as ffi::SimpleArgTypeInfo>::convert_from(foreign).map( |
| 379 | + |TestingFutureCancellationCounter(counter)| TestingFutureCancellationGuard { |
| 380 | + increment_on_drop: Arc::clone(counter), |
| 381 | + }, |
| 382 | + ) |
| 383 | + } |
| 384 | +} |
| 385 | + |
| 386 | +#[cfg(feature = "jni")] |
| 387 | +impl<'a> jni::SimpleArgTypeInfo<'a> for TestingFutureCancellationGuard { |
| 388 | + type ArgType = <&'a TestingFutureCancellationCounter as jni::SimpleArgTypeInfo<'a>>::ArgType; |
| 389 | + |
| 390 | + fn convert_from( |
| 391 | + env: &mut jni::JNIEnv<'a>, |
| 392 | + foreign: &Self::ArgType, |
| 393 | + ) -> Result<Self, jni::BridgeLayerError> { |
| 394 | + <&TestingFutureCancellationCounter as jni::SimpleArgTypeInfo>::convert_from(env, foreign) |
| 395 | + .map( |
| 396 | + |TestingFutureCancellationCounter(counter)| TestingFutureCancellationGuard { |
| 397 | + increment_on_drop: Arc::clone(counter), |
| 398 | + }, |
| 399 | + ) |
| 400 | + } |
| 401 | +} |
| 402 | + |
| 403 | +#[cfg(feature = "node")] |
| 404 | +impl<'storage> node::AsyncArgTypeInfo<'storage> for TestingFutureCancellationGuard { |
| 405 | + type ArgType = node::JsObject; |
| 406 | + type StoredType = Option<node::DefaultFinalize<Self>>; |
| 407 | + fn save_async_arg( |
| 408 | + cx: &mut neon::prelude::FunctionContext, |
| 409 | + foreign: neon::prelude::Handle<Self::ArgType>, |
| 410 | + ) -> neon::prelude::NeonResult<Self::StoredType> { |
| 411 | + <&TestingFutureCancellationCounter as node::AsyncArgTypeInfo>::save_async_arg(cx, foreign) |
| 412 | + .map(move |handle| { |
| 413 | + Some(node::DefaultFinalize(TestingFutureCancellationGuard { |
| 414 | + increment_on_drop: Arc::clone(&handle.0), |
| 415 | + })) |
| 416 | + }) |
| 417 | + } |
| 418 | + fn load_async_arg(stored: &'storage mut Self::StoredType) -> Self { |
| 419 | + stored.take().unwrap().0 |
| 420 | + } |
| 421 | +} |
| 422 | + |
| 423 | +bridge_as_handle!(TestingFutureCancellationCounter); |
0 commit comments