Skip to content

Commit d8a779a

Browse files
authored
[bindings] Parity with unofficial bindings (#3374)
1 parent f1c121e commit d8a779a

File tree

9 files changed

+300
-52
lines changed

9 files changed

+300
-52
lines changed

bindings/rust/s2n-tls-tokio/tests/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use tokio::{
1111

1212
mod stream;
1313
pub use stream::*;
14+
mod time;
15+
pub use time::*;
1416

1517
/// NOTE: this certificate and key are used for testing purposes only!
1618
pub static CERT_PEM: &[u8] = include_bytes!(concat!(
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use s2n_tls::callbacks::MonotonicClock;
5+
use std::time::Duration;
6+
use tokio::time::Instant;
7+
8+
/// A monotonic clock that allows the s2n-tls C library time
9+
/// to follow the tokio::time::pause behavior.
10+
pub struct TokioTime(Instant);
11+
12+
impl Default for TokioTime {
13+
fn default() -> Self {
14+
TokioTime(Instant::now())
15+
}
16+
}
17+
18+
impl MonotonicClock for TokioTime {
19+
fn get_time(&self) -> Duration {
20+
self.0.elapsed()
21+
}
22+
}

bindings/rust/s2n-tls-tokio/tests/handshake.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use s2n_tls::{
66
config::Config,
77
connection::{Connection, ModifiedBuilder},
88
enums::{ClientAuthType, Mode, Version},
9-
error::Error,
9+
error::{Error, ErrorType},
1010
pool::ConfigPoolBuilder,
1111
security::DEFAULT_TLS13,
1212
};
@@ -150,10 +150,13 @@ async fn handshake_error() -> Result<(), Box<dyn std::error::Error>> {
150150

151151
#[tokio::test(start_paused = true)]
152152
async fn handshake_error_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
153+
let clock = common::TokioTime::default();
154+
153155
// Config::builder() does not include a trust store.
154156
// The client will reject the server certificate as untrusted.
155157
let mut bad_config = Config::builder();
156158
bad_config.set_security_policy(&DEFAULT_TLS13)?;
159+
bad_config.set_monotonic_clock(clock)?;
157160
let client_config = bad_config.build()?;
158161
let server_config = common::server_config()?.build()?;
159162

@@ -178,6 +181,7 @@ async fn handshake_error_with_blinding() -> Result<(), Box<dyn std::error::Error
178181
.await;
179182
let result = timeout?;
180183
assert!(result.is_err());
184+
assert_eq!(result.unwrap_err().kind(), Some(ErrorType::ProtocolError));
181185

182186
Ok(())
183187
}

bindings/rust/s2n-tls-tokio/tests/shutdown.rs

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,20 @@ async fn shutdown_after_split() -> Result<(), Box<dyn std::error::Error>> {
8686

8787
#[tokio::test(start_paused = true)]
8888
async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
89+
let clock = common::TokioTime::default();
90+
let mut server_config = common::server_config()?;
91+
server_config.set_monotonic_clock(clock)?;
92+
8993
let client = TlsConnector::new(common::client_config()?.build()?);
90-
let server = TlsAcceptor::new(common::server_config()?.build()?);
94+
let server = TlsAcceptor::new(server_config.build()?);
9195

9296
let (server_stream, client_stream) = common::get_streams().await?;
9397
let server_stream = common::TestStream::new(server_stream);
9498
let overrides = server_stream.overrides();
9599
let (mut client, mut server) =
96100
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
97101

98-
// Trigger a blinded error.
102+
// Trigger a blinded error for the server.
99103
overrides.next_read(Some(Box::new(|_, _, buf| {
100104
// Parsing the header is one of the blinded operations
101105
// in s2n_recv, so provide a malformed header.
@@ -117,8 +121,7 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
117121
// Shutdown MUST eventually complete after blinding.
118122
//
119123
// We check for completion, but not for success. At the moment, the
120-
// call to s2n_shutdown will fail. See `shutdown_with_blinding_slow()`
121-
// for verification that s2n_shutdown eventually suceeds.
124+
// call to s2n_shutdown will fail due to issues in the underlying C library.
122125
let (timeout, _) = join!(
123126
time::timeout(common::MAX_BLINDING_SECS, server.shutdown()),
124127
time::timeout(common::MAX_BLINDING_SECS, read_until_shutdown(&mut client)),
@@ -127,47 +130,3 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
127130

128131
Ok(())
129132
}
130-
131-
// Ignore because:
132-
// 1) This test is slow. We can avoid Tokio sleeps with time::pause,
133-
// but I couldn't find a good way to do the same to the system time the underlying C uses.
134-
// 2) This test currently fails due to bugs in the underlying s2n_shutdown method.
135-
#[ignore]
136-
#[tokio::test]
137-
async fn shutdown_with_blinding_slow() -> Result<(), Box<dyn std::error::Error>> {
138-
let client = TlsConnector::new(common::client_config()?.build()?);
139-
let server = TlsAcceptor::new(common::server_config()?.build()?);
140-
141-
let (server_stream, client_stream) = common::get_streams().await?;
142-
let server_stream = common::TestStream::new(server_stream);
143-
let overrides = server_stream.overrides();
144-
let (mut client, mut server) =
145-
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
146-
147-
// Trigger a blinded error.
148-
overrides.next_read(Some(Box::new(|_, _, buf| {
149-
// Parsing the header is one of the blinded operations
150-
// in s2n_recv, so provide a malformed header.
151-
let zeroed_header = [23, 0, 0, 0, 0];
152-
buf.put_slice(&zeroed_header);
153-
Ok(()).into()
154-
})));
155-
let mut received = [0; 1];
156-
let result = server.read_exact(&mut received).await;
157-
assert!(result.is_err());
158-
159-
// Shutdown MUST eventually gracefully complete after blinding
160-
let (timeout, _) = join!(
161-
time::timeout(common::MAX_BLINDING_SECS.mul_f32(1.1), server.shutdown()),
162-
time::timeout(
163-
common::MAX_BLINDING_SECS.mul_f32(1.1),
164-
read_until_shutdown(&mut client)
165-
),
166-
);
167-
168-
// Verify shutdown succeeded
169-
let result = timeout?;
170-
assert!(result.is_ok());
171-
172-
Ok(())
173-
}

bindings/rust/s2n-tls/src/callbacks.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
//! can be used to register the task for wakeup. See [`ClientHelloCallback`] as an example.
2424
2525
use crate::{connection::Connection, enums::CallbackResult, error::Error};
26-
use core::{mem::ManuallyDrop, ptr::NonNull, task::Poll};
26+
use core::{mem::ManuallyDrop, ptr::NonNull, task::Poll, time::Duration};
2727
use s2n_tls_sys::s2n_connection;
2828

2929
const READY_OK: Poll<Result<(), Error>> = Poll::Ready(Ok(()));
@@ -153,3 +153,13 @@ impl AsyncCallback for AsyncClientHelloCallback {
153153
pub trait VerifyHostNameCallback {
154154
fn verify_host_name(&self, host_name: &str) -> bool;
155155
}
156+
157+
/// A trait for the callback used to retrieve the system / wall clock time.
158+
pub trait WallClock {
159+
fn get_time_since_epoch(&self) -> Duration;
160+
}
161+
162+
/// A trait for the callback used to retrieve the monotonic time.
163+
pub trait MonotonicClock {
164+
fn get_time(&self) -> Duration;
165+
}

bindings/rust/s2n-tls/src/config.rs

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use core::{convert::TryInto, ptr::NonNull};
1111
use s2n_tls_sys::*;
1212
use std::{
1313
ffi::{c_void, CString},
14+
path::Path,
1415
sync::atomic::{AtomicUsize, Ordering},
1516
};
1617

@@ -73,7 +74,7 @@ impl Config {
7374
}
7475

7576
/// Retrieve a mutable reference to the [`Context`] stored on the config.
76-
fn context_mut(&mut self) -> &mut Context {
77+
pub(crate) fn context_mut(&mut self) -> &mut Context {
7778
let mut ctx = core::ptr::null_mut();
7879
unsafe {
7980
s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx)
@@ -230,6 +231,12 @@ impl Builder {
230231
Ok(self)
231232
}
232233

234+
pub fn add_dhparams(&mut self, pem: &[u8]) -> Result<&mut Self, Error> {
235+
let cstring = CString::new(pem).map_err(|_| Error::InvalidInput)?;
236+
unsafe { s2n_config_add_dhparams(self.as_mut_ptr(), cstring.as_ptr()).into_result() }?;
237+
Ok(self)
238+
}
239+
233240
pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> {
234241
let certificate = CString::new(certificate).map_err(|_| Error::InvalidInput)?;
235242
let private_key = CString::new(private_key).map_err(|_| Error::InvalidInput)?;
@@ -252,6 +259,41 @@ impl Builder {
252259
Ok(self)
253260
}
254261

262+
pub fn trust_location(
263+
&mut self,
264+
file: Option<&Path>,
265+
dir: Option<&Path>,
266+
) -> Result<&mut Self, Error> {
267+
fn to_cstr(input: Option<&Path>) -> Result<Option<CString>, Error> {
268+
Ok(match input {
269+
Some(input) => {
270+
let string = input.to_str().ok_or(Error::InvalidInput)?;
271+
let cstring = CString::new(string).map_err(|_| Error::InvalidInput)?;
272+
Some(cstring)
273+
}
274+
None => None,
275+
})
276+
}
277+
278+
let file_cstr = to_cstr(file)?;
279+
let file_ptr = file_cstr
280+
.as_ref()
281+
.map(|f| f.as_ptr())
282+
.unwrap_or(core::ptr::null());
283+
284+
let dir_cstr = to_cstr(dir)?;
285+
let dir_ptr = dir_cstr
286+
.as_ref()
287+
.map(|f| f.as_ptr())
288+
.unwrap_or(core::ptr::null());
289+
290+
unsafe {
291+
s2n_config_set_verification_ca_location(self.as_mut_ptr(), file_ptr, dir_ptr)
292+
.into_result()
293+
}?;
294+
Ok(self)
295+
}
296+
255297
pub fn wipe_trust_store(&mut self) -> Result<&mut Self, Error> {
256298
unsafe { s2n_config_wipe_trust_store(self.as_mut_ptr()).into_result()? };
257299
Ok(self)
@@ -267,6 +309,39 @@ impl Builder {
267309
Ok(self)
268310
}
269311

312+
/// Clients will request OCSP stapling from the server.
313+
pub fn enable_ocsp(&mut self) -> Result<&mut Self, Error> {
314+
unsafe {
315+
s2n_config_set_status_request_type(self.as_mut_ptr(), s2n_status_request_type::OCSP)
316+
.into_result()
317+
}?;
318+
Ok(self)
319+
}
320+
321+
/// Sets the OCSP data for the default certificate chain associated with the Config.
322+
///
323+
/// Servers will send the data in response to OCSP stapling requests from clients.
324+
//
325+
// NOTE: this modifies a certificate chain, NOT the Config itself. This is currently safe
326+
// because the certificate chain is set with s2n_config_add_cert_chain_and_key, which
327+
// creates a new certificate chain only accessible by the given config. It will
328+
// NOT be safe when we add support for the newer s2n_config_add_cert_chain_and_key_to_store API,
329+
// which allows certificate chains to be shared across configs.
330+
// In that case, we'll need additional guard rails either in these bindings or in the underlying C.
331+
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
332+
let size: u32 = data.len().try_into().map_err(|_| Error::InvalidInput)?;
333+
unsafe {
334+
s2n_config_set_extension_data(
335+
self.as_mut_ptr(),
336+
s2n_tls_extension_type::OCSP_STAPLING,
337+
data.as_ptr(),
338+
size,
339+
)
340+
.into_result()
341+
}?;
342+
self.enable_ocsp()
343+
}
344+
270345
/// Set a custom callback function which is run during client certificate validation during
271346
/// a mutual TLS handshake.
272347
///
@@ -358,6 +433,80 @@ impl Builder {
358433
Ok(self)
359434
}
360435

436+
/// Set a callback function that will be used to get the system time.
437+
///
438+
/// The wall clock time is the best-guess at the real time, measured since the epoch.
439+
/// Unlike monotonic time, it CAN move backwards.
440+
/// It is used by s2n-tls for timestamps.
441+
pub fn set_wall_clock<T: 'static + WallClock>(
442+
&mut self,
443+
handler: T,
444+
) -> Result<&mut Self, Error> {
445+
unsafe extern "C" fn clock_cb(
446+
context: *mut ::libc::c_void,
447+
time_in_nanos: *mut u64,
448+
) -> libc::c_int {
449+
let context = &mut *(context as *mut Context);
450+
if let Some(handler) = context.wall_clock.as_mut() {
451+
if let Ok(nanos) = handler.get_time_since_epoch().as_nanos().try_into() {
452+
*time_in_nanos = nanos;
453+
return CallbackResult::Success.into();
454+
}
455+
}
456+
CallbackResult::Failure.into()
457+
}
458+
459+
let handler = Box::new(handler);
460+
let context = self.0.context_mut();
461+
context.wall_clock = Some(handler);
462+
unsafe {
463+
s2n_config_set_wall_clock(
464+
self.as_mut_ptr(),
465+
Some(clock_cb),
466+
self.0.context_mut() as *mut _ as *mut c_void,
467+
)
468+
.into_result()?;
469+
}
470+
Ok(self)
471+
}
472+
473+
/// Set a callback function that will be used to get the monotonic time.
474+
///
475+
/// The monotonic time is the time since an arbitrary, unspecified point.
476+
/// Unlike wall clock time, it MUST never move backwards.
477+
/// It is used by s2n-tls for timers.
478+
pub fn set_monotonic_clock<T: 'static + MonotonicClock>(
479+
&mut self,
480+
handler: T,
481+
) -> Result<&mut Self, Error> {
482+
unsafe extern "C" fn clock_cb(
483+
context: *mut ::libc::c_void,
484+
time_in_nanos: *mut u64,
485+
) -> libc::c_int {
486+
let context = &mut *(context as *mut Context);
487+
if let Some(handler) = context.monotonic_clock.as_mut() {
488+
if let Ok(nanos) = handler.get_time().as_nanos().try_into() {
489+
*time_in_nanos = nanos;
490+
return CallbackResult::Success.into();
491+
}
492+
}
493+
CallbackResult::Failure.into()
494+
}
495+
496+
let handler = Box::new(handler);
497+
let context = self.0.context_mut();
498+
context.monotonic_clock = Some(handler);
499+
unsafe {
500+
s2n_config_set_monotonic_clock(
501+
self.as_mut_ptr(),
502+
Some(clock_cb),
503+
self.0.context_mut() as *mut _ as *mut c_void,
504+
)
505+
.into_result()?;
506+
}
507+
Ok(self)
508+
}
509+
361510
pub fn build(self) -> Result<Config, Error> {
362511
Ok(self.0)
363512
}
@@ -379,6 +528,8 @@ pub(crate) struct Context {
379528
refcount: AtomicUsize,
380529
pub(crate) client_hello_callback: Option<Box<dyn ClientHelloCallback>>,
381530
pub(crate) verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
531+
pub(crate) wall_clock: Option<Box<dyn WallClock>>,
532+
pub(crate) monotonic_clock: Option<Box<dyn MonotonicClock>>,
382533
}
383534

384535
impl Default for Context {
@@ -391,6 +542,8 @@ impl Default for Context {
391542
refcount,
392543
client_hello_callback: None,
393544
verify_host_callback: None,
545+
wall_clock: None,
546+
monotonic_clock: None,
394547
}
395548
}
396549
}

0 commit comments

Comments
 (0)