Skip to content

Commit bf1e08b

Browse files
Check for cancellation of Rust task in test
1 parent 0e1ec39 commit bf1e08b

File tree

13 files changed

+261
-28
lines changed

13 files changed

+261
-28
lines changed

java/shared/java/org/signal/libsignal/internal/NativeTesting.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ private NativeTesting() {}
9595
public static native long TESTING_FakeChatServer_Create();
9696
public static native CompletableFuture<Long> TESTING_FakeChatServer_GetNextRemote(long asyncRuntime, long server);
9797
public static native CompletableFuture<Long> TESTING_FakeRegistrationSession_CreateSession(long asyncRuntime, Object createSession, long chat);
98+
public static native long TESTING_FutureCancellationCounter_Create(int initialValue);
99+
public static native CompletableFuture TESTING_FutureCancellationCounter_WaitForCount(long asyncRuntime, long count, int target);
98100
public static native CompletableFuture<Integer> TESTING_FutureFailure(long asyncRuntime, int input);
101+
public static native CompletableFuture TESTING_FutureIncrementOnCancel(long asyncRuntime, long guard);
99102
public static native CompletableFuture<Long> TESTING_FutureProducesOtherPointerType(long asyncRuntime, String input);
100103
public static native CompletableFuture<Long> TESTING_FutureProducesPointerType(long asyncRuntime, int input);
101104
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
@@ -104,7 +107,6 @@ private NativeTesting() {}
104107
public static native String TESTING_JoinStringArray(Object[] array, String joinWith);
105108
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
106109
public static native long TESTING_NonSuspendingBackgroundThreadRuntime_New();
107-
public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime);
108110
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
109111
public static native void TESTING_PanicInBodyAsync(Object input);
110112
public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input);
@@ -139,6 +141,8 @@ private NativeTesting() {}
139141
public static native int TESTING_TestingHandleType_getValue(long handle);
140142
public static native CompletableFuture<Integer> TESTING_TokioAsyncFuture(long asyncRuntime, int input);
141143

144+
public static native void TestingFutureCancellationCounter_Destroy(long handle);
145+
142146
public static native void TestingHandleType_Destroy(long handle);
143147

144148
public static native int test_only_fn_returns_123();

java/shared/test/java/org/signal/libsignal/internal/FutureTest.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,26 @@ public void testFutureFromRustCancel() {
6868
assertTrue(testFuture.isDone());
6969
}
7070

71-
@Test
71+
@Test(timeout = 5000)
7272
@SuppressWarnings("unchecked")
73-
public void testFutureOnlyCompletesByCancellation() {
73+
public void testFutureOnlyCompletesByCancellation() throws Exception {
7474
TokioAsyncContext context = new TokioAsyncContext();
75+
var counter =
76+
new NativeHandleGuard.SimpleOwner(
77+
NativeTesting.TESTING_FutureCancellationCounter_Create(0)) {
78+
@Override
79+
protected void release(long nativeHandle) {
80+
NativeTesting.TestingFutureCancellationCounter_Destroy(nativeHandle);
81+
}
82+
};
7583
org.signal.libsignal.internal.CompletableFuture<Integer> testFuture =
7684
context
7785
.guardedMap(
7886
(nativeContextHandle) ->
79-
NativeTesting.TESTING_OnlyCompletesByCancellation(nativeContextHandle))
87+
counter.guardedMap(
88+
counterHandle ->
89+
NativeTesting.TESTING_FutureIncrementOnCancel(
90+
nativeContextHandle, counterHandle)))
8091
.makeCancelable(context);
8192
assertTrue(testFuture.cancel(true));
8293
ExecutionException e = assertThrows(ExecutionException.class, () -> testFuture.get());
@@ -85,6 +96,16 @@ public void testFutureOnlyCompletesByCancellation() {
8596
e.getCause() instanceof java.util.concurrent.CancellationException);
8697
assertTrue(testFuture.isCancelled());
8798
assertTrue(testFuture.isDone());
99+
100+
// Hangs if the count never gets incremented.
101+
context
102+
.guardedMap(
103+
(nativeContextHandle) ->
104+
counter.guardedMap(
105+
counterHandle ->
106+
NativeTesting.TESTING_FutureCancellationCounter_WaitForCount(
107+
nativeContextHandle, counterHandle, 1)))
108+
.get();
88109
}
89110

90111
@Test

node/Native.d.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,14 +630,16 @@ export function TESTING_FakeChatSentRequest_TakeHttpRequest(request: Wrapper<Fak
630630
export function TESTING_FakeChatServer_Create(): FakeChatServer;
631631
export function TESTING_FakeChatServer_GetNextRemote(asyncRuntime: Wrapper<TokioAsyncContext>, server: Wrapper<FakeChatServer>): CancellablePromise<FakeChatRemoteEnd>;
632632
export function TESTING_FakeRegistrationSession_CreateSession(asyncRuntime: Wrapper<TokioAsyncContext>, createSession: RegistrationCreateSessionRequest, chat: Wrapper<FakeChatServer>): CancellablePromise<RegistrationService>;
633+
export function TESTING_FutureCancellationCounter_Create(initialValue: number): TestingFutureCancellationCounter;
634+
export function TESTING_FutureCancellationCounter_WaitForCount(asyncRuntime: Wrapper<TokioAsyncContext>, count: Wrapper<TestingFutureCancellationCounter>, target: number): CancellablePromise<void>;
633635
export function TESTING_FutureFailure(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: number): CancellablePromise<number>;
636+
export function TESTING_FutureIncrementOnCancel(asyncRuntime: Wrapper<TokioAsyncContext>, _guard: TestingFutureCancellationGuard): CancellablePromise<void>;
634637
export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: string): CancellablePromise<OtherTestingHandleType>;
635638
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): CancellablePromise<TestingHandleType>;
636639
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): CancellablePromise<number>;
637640
export function TESTING_InputStreamReadIntoZeroLengthSlice(capsAlphabetInput: InputStream): Promise<Buffer>;
638641
export function TESTING_JoinStringArray(array: string[], joinWith: string): string;
639642
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
640-
export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper<TokioAsyncContext>): CancellablePromise<void>;
641643
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
642644
export function TESTING_PanicInBodyAsync(_input: null): Promise<void>;
643645
export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: null): CancellablePromise<void>;
@@ -771,6 +773,7 @@ interface SessionRecord { readonly __type: unique symbol; }
771773
interface SgxClientState { readonly __type: unique symbol; }
772774
interface SignalMessage { readonly __type: unique symbol; }
773775
interface SignedPreKeyRecord { readonly __type: unique symbol; }
776+
interface TestingFutureCancellationCounter { readonly __type: unique symbol; }
774777
interface TestingHandleType { readonly __type: unique symbol; }
775778
interface TokioAsyncContext { readonly __type: unique symbol; }
776779
interface UnauthenticatedChatConnection { readonly __type: unique symbol; }

node/ts/test/FutureTest.ts

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ function makeAsyncRuntime(): Native.Wrapper<Native.NonSuspendingBackgroundThread
1818
};
1919
}
2020

21+
class CancelCounter {
22+
readonly _nativeHandle: Native.TestingFutureCancellationCounter;
23+
constructor(initialValue: number = 0) {
24+
this._nativeHandle =
25+
Native.TESTING_FutureCancellationCounter_Create(initialValue);
26+
}
27+
28+
public async waitForCount(context: TokioAsyncContext, target: number) {
29+
await Native.TESTING_FutureCancellationCounter_WaitForCount(
30+
context,
31+
this,
32+
target
33+
);
34+
}
35+
}
36+
2137
describe('Async runtime not on the Node executor', () => {
2238
it('handles success', async () => {
2339
const runtime = makeAsyncRuntime();
@@ -42,28 +58,36 @@ describe('TokioAsyncContext', () => {
4258
it('supports cancellation of running future', async () => {
4359
const runtime = new TokioAsyncContext(Native.TokioAsyncContext_new());
4460
const abortController = new AbortController();
61+
const counter = new CancelCounter();
4562
const pending = runtime.makeCancellable(
4663
abortController.signal,
47-
Native.TESTING_OnlyCompletesByCancellation(runtime)
64+
Native.TESTING_FutureIncrementOnCancel(runtime, counter)
4865
);
4966
const timeout = setTimeout(200, 'timed out');
5067
assert.equal('timed out', await Promise.race([pending, timeout]));
5168
abortController.abort();
52-
return expect(pending)
53-
.to.eventually.be.rejectedWith(LibSignalErrorBase)
54-
.and.have.property('code', ErrorCode.Cancelled);
69+
return Promise.all([
70+
expect(pending)
71+
.to.eventually.be.rejectedWith(LibSignalErrorBase)
72+
.and.have.property('code', ErrorCode.Cancelled),
73+
expect(counter.waitForCount(runtime, 1)).to.be.fulfilled,
74+
]);
5575
});
5676

5777
it('supports pre-cancellation of not-yet-running future', async () => {
5878
const runtime = new TokioAsyncContext(Native.TokioAsyncContext_new());
5979
const abortController = new AbortController();
80+
const counter = new CancelCounter();
6081
abortController.abort();
6182
const pending = runtime.makeCancellable(
6283
abortController.signal,
63-
Native.TESTING_OnlyCompletesByCancellation(runtime)
84+
Native.TESTING_FutureIncrementOnCancel(runtime, counter)
6485
);
65-
return expect(pending)
66-
.to.eventually.be.rejectedWith(LibSignalErrorBase)
67-
.and.have.property('code', ErrorCode.Cancelled);
86+
return Promise.all([
87+
expect(pending)
88+
.to.eventually.be.rejectedWith(LibSignalErrorBase)
89+
.and.have.property('code', ErrorCode.Cancelled),
90+
expect(counter.waitForCount(runtime, 1)).to.be.fulfilled,
91+
]);
6892
});
6993
});

rust/bridge/shared/macros/src/jni.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,16 @@ fn bridge_io_body(
158158
// Wrap the actual work to catch any panics.
159159
let __future = jni::catch_unwind(std::panic::AssertUnwindSafe(async {
160160
#(#input_loading)*
161-
let __result = #orig_name(#(#input_names),*).await;
162-
// If the original function can't fail, wrap the result in Ok for uniformity.
163-
// See TransformHelper::ok_if_needed.
164-
Ok(TransformHelper(__result).ok_if_needed()?.0)
161+
::tokio::select! {
162+
__result = #orig_name(#(#input_names),*) => {
163+
// If the original function can't fail, wrap the result in Ok for uniformity.
164+
// See TransformHelper::ok_if_needed.
165+
Ok(TransformHelper(__result).ok_if_needed()?.0)
166+
}
167+
_ = __cancel => {
168+
Err(jni::FutureCancelled.into())
169+
}
170+
}
165171
}));
166172
// Pass the stored inputs to the reporter to drop them while attached to the JVM.
167173

rust/bridge/shared/testing/src/convert.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,30 @@ async fn TESTING_FutureFailure(_input: u8) -> Result<i32, SignalProtocolError> {
7474
Err(SignalProtocolError::InvalidArgument("failure".to_string()))
7575
}
7676

77+
bridge_handle_fns!(TestingFutureCancellationCounter, clone = false);
78+
79+
#[bridge_fn]
80+
fn TESTING_FutureCancellationCounter_Create(initial_value: u8) -> TestingFutureCancellationCounter {
81+
TestingFutureCancellationCounter(tokio::sync::Semaphore::new(initial_value.into()).into())
82+
}
83+
84+
#[bridge_io(TokioAsyncContext)]
85+
async fn TESTING_FutureCancellationCounter_WaitForCount(
86+
count: &TestingFutureCancellationCounter,
87+
target: u8,
88+
) {
89+
let _permits = count
90+
.0
91+
.acquire_many(target.into())
92+
.await
93+
.expect("not closed");
94+
}
95+
96+
#[bridge_io(TokioAsyncContext)]
97+
async fn TESTING_FutureIncrementOnCancel(_guard: TestingFutureCancellationGuard) {
98+
std::future::pending().await
99+
}
100+
77101
#[bridge_io(TokioAsyncContext)]
78102
async fn TESTING_TokioAsyncFuture(input: u8) -> i32 {
79103
i32::from(input) * 3

rust/bridge/shared/testing/src/net.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ async fn TESTING_CdsiLookupResponseConvert() -> LookupResponse {
4949
}
5050
}
5151

52-
#[bridge_io(TokioAsyncContext)]
53-
async fn TESTING_OnlyCompletesByCancellation() {
54-
std::future::pending::<()>().await
55-
}
56-
5752
macro_rules! make_error_testing_enum {
5853
(enum $name:ident for $orig:ident {
5954
$($orig_case:ident => $case:ident,)*

rust/bridge/shared/testing/src/types.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
//
55

6+
use std::panic::{RefUnwindSafe, UnwindSafe};
7+
use std::sync::Arc;
8+
69
#[allow(unused_imports)]
710
use libsignal_protocol::SignalProtocolError;
811

@@ -347,3 +350,74 @@ impl<'a> node::ResultTypeInfo<'a> for PanicOnReturn {
347350
panic!("deliberate panic");
348351
}
349352
}
353+
354+
/// Counter for future cancellations
355+
pub struct TestingFutureCancellationCounter(pub(crate) Arc<tokio::sync::Semaphore>);
356+
357+
impl UnwindSafe for TestingFutureCancellationCounter {}
358+
impl RefUnwindSafe for TestingFutureCancellationCounter {}
359+
360+
/// RAII guard that increments a counter on `Drop`.
361+
///
362+
/// This is bridged as a reference to a [`TestingFutureCancellationCounter`].
363+
pub struct TestingFutureCancellationGuard {
364+
increment_on_drop: Arc<tokio::sync::Semaphore>,
365+
}
366+
367+
impl Drop for TestingFutureCancellationGuard {
368+
fn drop(&mut self) {
369+
self.increment_on_drop.add_permits(1);
370+
}
371+
}
372+
373+
#[cfg(feature = "ffi")]
374+
impl ffi::SimpleArgTypeInfo for TestingFutureCancellationGuard {
375+
type ArgType = <&'static TestingFutureCancellationCounter as ffi::SimpleArgTypeInfo>::ArgType;
376+
377+
fn convert_from(foreign: Self::ArgType) -> ffi::SignalFfiResult<Self> {
378+
<&TestingFutureCancellationCounter as ffi::SimpleArgTypeInfo>::convert_from(foreign).map(
379+
|TestingFutureCancellationCounter(counter)| TestingFutureCancellationGuard {
380+
increment_on_drop: Arc::clone(counter),
381+
},
382+
)
383+
}
384+
}
385+
386+
#[cfg(feature = "jni")]
387+
impl<'a> jni::SimpleArgTypeInfo<'a> for TestingFutureCancellationGuard {
388+
type ArgType = <&'a TestingFutureCancellationCounter as jni::SimpleArgTypeInfo<'a>>::ArgType;
389+
390+
fn convert_from(
391+
env: &mut jni::JNIEnv<'a>,
392+
foreign: &Self::ArgType,
393+
) -> Result<Self, jni::BridgeLayerError> {
394+
<&TestingFutureCancellationCounter as jni::SimpleArgTypeInfo>::convert_from(env, foreign)
395+
.map(
396+
|TestingFutureCancellationCounter(counter)| TestingFutureCancellationGuard {
397+
increment_on_drop: Arc::clone(counter),
398+
},
399+
)
400+
}
401+
}
402+
403+
#[cfg(feature = "node")]
404+
impl<'storage> node::AsyncArgTypeInfo<'storage> for TestingFutureCancellationGuard {
405+
type ArgType = node::JsObject;
406+
type StoredType = Option<node::DefaultFinalize<Self>>;
407+
fn save_async_arg(
408+
cx: &mut neon::prelude::FunctionContext,
409+
foreign: neon::prelude::Handle<Self::ArgType>,
410+
) -> neon::prelude::NeonResult<Self::StoredType> {
411+
<&TestingFutureCancellationCounter as node::AsyncArgTypeInfo>::save_async_arg(cx, foreign)
412+
.map(move |handle| {
413+
Some(node::DefaultFinalize(TestingFutureCancellationGuard {
414+
increment_on_drop: Arc::clone(&handle.0),
415+
}))
416+
})
417+
}
418+
fn load_async_arg(stored: &'storage mut Self::StoredType) -> Self {
419+
stored.take().unwrap().0
420+
}
421+
}
422+
423+
bridge_as_handle!(TestingFutureCancellationCounter);

rust/bridge/shared/types/src/ffi/convert.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,8 @@ macro_rules! ffi_arg_type {
10821082
(Ignored<$typ:ty>) => (*const std::ffi::c_void);
10831083
(AsType<$typ:ident, $bridged:ident>) => (ffi_arg_type!($bridged));
10841084

1085+
(TestingFutureCancellationGuard) => (ffi_arg_type!(&TestingFutureCancellationCounter));
1086+
10851087
// In order to provide a fixed-sized array of the correct length,
10861088
// a serialized type FooBar must have a constant FOO_BAR_LEN that's in scope (and exposed to C).
10871089
(Serialized<$typ:ident>) => (*const [std::ffi::c_uchar; ::paste::paste!([<$typ:snake:upper _LEN>])]);

rust/bridge/shared/types/src/jni/convert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,7 @@ macro_rules! jni_arg_type {
19331933
(CreateSession) => {
19341934
$crate::jni::JObject<'local>
19351935
};
1936+
(TestingFutureCancellationGuard) => { ::jni::sys::jlong };
19361937

19371938
(Ignored<$typ:ty>) => (::jni::objects::JObject<'local>);
19381939
}

0 commit comments

Comments
 (0)