Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ml-dsa/src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
}

impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
pub fn pre_digest(digest: Shake) -> Self {
Self::Absorbing(digest)
pub fn updatable(&mut self) -> &mut Shake {
match self {
Self::Absorbing(sponge) => sponge,
Self::Squeezing(_) => unreachable!(),
}
}

pub fn absorb(mut self, input: &[u8]) -> Self {
Expand Down
147 changes: 90 additions & 57 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ use hybrid_array::{
U75, U80, U88, Unsigned,
},
};
use signature::digest::Update;
use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};

#[cfg(feature = "rand_core")]
use signature::RandomizedDigestSigner;
use signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner};

#[cfg(feature = "rand_core")]
use rand_core::{CryptoRng, TryCryptoRng};
Expand Down Expand Up @@ -171,17 +170,46 @@ where
const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
}

// This method takes a slice of slices so that we can accommodate the varying calculations (direct
// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
// having to allocate memory for components.
fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
let mut h = H::default().absorb(tr);
struct MuBuilder(H);

for m in Mp.iter().copied().flatten() {
h = h.absorb(m);
impl MuBuilder {
fn new(tr: &[u8], ctx: &[u8]) -> Self {
let mut h = H::default();
h = h.absorb(tr);
h = h.absorb(&[0]);
h = h.absorb(&[Truncate::truncate(ctx.len())]);
h = h.absorb(ctx);

Self(h)
}

fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
let mut h = H::default().absorb(tr);

for m in Mp {
h = h.absorb(m);
}

h.squeeze_new()
}

fn message(mut self, M: &[&[u8]]) -> B64 {
for m in M {
self.0 = self.0.absorb(m);
}

self.0.squeeze_new()
}

fn finish(mut self) -> B64 {
self.0.squeeze_new()
}
}

h.squeeze_new()
impl AsMut<Shake256> for MuBuilder {
fn as_mut(&mut self) -> &mut Shake256 {
self.0.updatable()
}
}

/// An ML-DSA key pair
Expand Down Expand Up @@ -388,18 +416,7 @@ impl<P: MlDsaParams> SigningKey<P> {
where
P: MlDsaParams,
{
self.raw_sign_internal(&[Mp], rnd)
}

fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
where
P: MlDsaParams,
{
// Compute the message representative
// XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
// the concatenated M'.
// XXX(RLB) Should the API represent this as an input?
let mu = message_representative(&self.tr, Mp);
let mu = MuBuilder::internal(&self.tr, Mp);
self.raw_sign_mu(&mu, rnd)
}

Expand Down Expand Up @@ -469,6 +486,16 @@ impl<P: MlDsaParams> SigningKey<P> {
M: &[u8],
ctx: &[u8],
rng: &mut R,
) -> Result<Signature<P>, Error> {
self.raw_sign_randomized(&[M], ctx, rng)
}

#[cfg(feature = "rand_core")]
fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
&self,
Mp: &[&[u8]],
ctx: &[u8],
rng: &mut R,
) -> Result<Signature<P>, Error> {
if ctx.len() > 255 {
return Err(Error::new());
Expand All @@ -477,8 +504,8 @@ impl<P: MlDsaParams> SigningKey<P> {
let mut rnd = B32::default();
rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;

let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
Ok(self.sign_internal(Mp, &rnd))
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
Ok(self.raw_sign_mu(&mu, &rnd))
}

/// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
Expand Down Expand Up @@ -517,14 +544,13 @@ impl<P: MlDsaParams> SigningKey<P> {
self.raw_sign_mu(mu, &rnd)
}

fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
if ctx.len() > 255 {
return Err(Error::new());
}

let rnd = B32::default();
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
Ok(self.raw_sign_internal(Mp, &rnd))
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
Ok(self.sign_mu_deterministic(&mu))
}

/// Encode the key in a fixed-size byte array.
Expand Down Expand Up @@ -608,9 +634,9 @@ impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
&self,
f: F,
) -> Result<Signature<P>, Error> {
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
f(&mut digest)?;
let mu = H::pre_digest(digest).squeeze_new();
let mut mu = MuBuilder::new(&self.tr, &[]);
f(mu.as_mut())?;
let mu = mu.finish();

Ok(self.sign_mu_deterministic(&mu))
}
Expand Down Expand Up @@ -640,13 +666,27 @@ impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
/// context string. If you would like to include a context string, use the
/// [`SigningKey::sign_randomized`] method.
#[cfg(feature = "rand_core")]
impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
&self,
rng: &mut R,
msg: &[u8],
) -> Result<Signature<P>, Error> {
self.sign_randomized(msg, &[], rng)
self.try_multipart_sign_with_rng(rng, &[msg])
}
}

/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
/// context string. If you would like to include a context string, use the
/// [`SigningKey::sign_randomized`] method.
#[cfg(feature = "rand_core")]
impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
&self,
rng: &mut R,
msg: &[&[u8]],
) -> Result<Signature<P>, Error> {
self.raw_sign_randomized(msg, &[], rng)
}
}

Expand All @@ -663,9 +703,9 @@ impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningK
rng: &mut R,
f: F,
) -> Result<Signature<P>, Error> {
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
f(&mut digest)?;
let mu = H::pre_digest(digest).squeeze_new();
let mut mu = MuBuilder::new(&self.tr, &[]);
f(mu.as_mut())?;
let mu = mu.finish();

self.sign_mu_randomized(&mu, rng)
}
Expand Down Expand Up @@ -736,19 +776,11 @@ impl<P: MlDsaParams> VerifyingKey<P> {
/// include the domain separator that distinguishes between the normal and pre-hashed cases,
/// and it does not separate the context string from the rest of the message.
// Algorithm 8 ML-DSA.Verify_internal
pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
where
P: MlDsaParams,
{
self.raw_verify_internal(&[Mp], sigma)
}

fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
where
P: MlDsaParams,
{
// Compute the message representative
let mu = message_representative(&self.tr, Mp);
let mu = MuBuilder::internal(&self.tr, &[M]);
self.raw_verify_mu(&mu, sigma)
}

Expand Down Expand Up @@ -793,8 +825,8 @@ impl<P: MlDsaParams> VerifyingKey<P> {
return false;
}

let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
self.raw_verify_internal(Mp, sigma)
let mu = MuBuilder::new(&self.tr, ctx).message(M);
self.verify_mu(&mu, sigma)
}

fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
Expand Down Expand Up @@ -837,9 +869,9 @@ impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P>
f: F,
signature: &Signature<P>,
) -> Result<(), Error> {
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
f(&mut digest)?;
let mu = H::pre_digest(digest).squeeze_new();
let mut mu = MuBuilder::new(&self.tr, &[]);
f(mu.as_mut())?;
let mu = mu.finish();

self.raw_verify_mu(&mu, signature)
.then_some(())
Expand Down Expand Up @@ -1060,6 +1092,7 @@ where
mod test {
use super::*;
use crate::param::*;
use signature::digest::Update;

#[test]
fn output_sizes() {
Expand Down Expand Up @@ -1142,7 +1175,7 @@ mod test {
let rnd = Array([0u8; 32]);
let sig = sk.sign_internal(&[M], &rnd);

assert!(vk.verify_internal(&[M], &sig));
assert!(vk.verify_internal(M, &sig));
}

#[test]
Expand Down Expand Up @@ -1179,7 +1212,7 @@ mod test {
let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();

assert_eq!(sig_dec, sig);
assert!(vk.verify_internal(&[M], &sig_dec));
assert!(vk.verify_internal(M, &sig_dec));
}
}

Expand All @@ -1202,7 +1235,7 @@ mod test {

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let mu = MuBuilder::internal(&sk.tr, &[M]);
let sig = sk.raw_sign_mu(&mu, &rnd);

assert!(vk.raw_verify_mu(&mu, &sig));
Expand All @@ -1224,10 +1257,10 @@ mod test {

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let mu = MuBuilder::internal(&sk.tr, &[M]);
let sig = sk.raw_sign_mu(&mu, &rnd);

assert!(vk.verify_internal(&[M], &sig));
assert!(vk.verify_internal(M, &sig));
}
sign_mu_verify_internal::<MlDsa44>();
sign_mu_verify_internal::<MlDsa65>();
Expand All @@ -1246,7 +1279,7 @@ mod test {

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let mu = MuBuilder::internal(&sk.tr, &[M]);
let sig = sk.sign_internal(&[M], &rnd);

assert!(vk.raw_verify_mu(&mu, &sig));
Expand Down
2 changes: 1 addition & 1 deletion ml-dsa/tests/sig-ver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn verify<P: MlDsaParams>(tg: &acvp::TestGroup, tc: &acvp::TestCase) {

// Verify the signature if it successfully decoded
let test_passed = sig
.map(|sig| vk.verify_internal(&[&tc.message], &sig))
.map(|sig| vk.verify_internal(&tc.message, &sig))
.unwrap_or_default();
assert_eq!(test_passed, tc.test_passed);
}
Expand Down