Skip to content

Commit 715aa06

Browse files
committed
ML-DSA DigestSigner cleanup
1 parent b648424 commit 715aa06

File tree

3 files changed

+95
-60
lines changed

3 files changed

+95
-60
lines changed

ml-dsa/src/crypto.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
1818
}
1919

2020
impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
21-
pub fn pre_digest(digest: Shake) -> Self {
22-
Self::Absorbing(digest)
21+
pub fn updatable(&mut self) -> &mut Shake {
22+
match self {
23+
Self::Absorbing(sponge) => sponge,
24+
Self::Squeezing(_) => unreachable!(),
25+
}
2326
}
2427

2528
pub fn absorb(mut self, input: &[u8]) -> Self {

ml-dsa/src/lib.rs

Lines changed: 89 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ use hybrid_array::{
5151
U75, U80, U88, Unsigned,
5252
},
5353
};
54-
use signature::digest::Update;
5554
use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
5655

5756
#[cfg(feature = "rand_core")]
58-
use signature::RandomizedDigestSigner;
57+
use signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner};
5958

6059
#[cfg(feature = "rand_core")]
6160
use rand_core::{CryptoRng, TryCryptoRng};
@@ -171,17 +170,46 @@ where
171170
const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
172171
}
173172

174-
// This method takes a slice of slices so that we can accommodate the varying calculations (direct
175-
// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
176-
// having to allocate memory for components.
177-
fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
178-
let mut h = H::default().absorb(tr);
173+
struct MuBuilder(H);
179174

180-
for m in Mp.iter().copied().flatten() {
181-
h = h.absorb(m);
175+
impl MuBuilder {
176+
fn new(tr: &[u8], ctx: &[u8]) -> Self {
177+
let mut h = H::default();
178+
h = h.absorb(tr);
179+
h = h.absorb(&[0]);
180+
h = h.absorb(&[Truncate::truncate(ctx.len())]);
181+
h = h.absorb(ctx);
182+
183+
Self(h)
184+
}
185+
186+
fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
187+
let mut h = H::default().absorb(tr);
188+
189+
for m in Mp {
190+
h = h.absorb(m);
191+
}
192+
193+
h.squeeze_new()
182194
}
183195

184-
h.squeeze_new()
196+
fn message(mut self, M: &[&[u8]]) -> B64 {
197+
for m in M {
198+
self.0 = self.0.absorb(m);
199+
}
200+
201+
self.0.squeeze_new()
202+
}
203+
204+
fn finish(mut self) -> B64 {
205+
self.0.squeeze_new()
206+
}
207+
}
208+
209+
impl AsMut<Shake256> for MuBuilder {
210+
fn as_mut(&mut self) -> &mut Shake256 {
211+
self.0.updatable()
212+
}
185213
}
186214

187215
/// An ML-DSA key pair
@@ -388,18 +416,7 @@ impl<P: MlDsaParams> SigningKey<P> {
388416
where
389417
P: MlDsaParams,
390418
{
391-
self.raw_sign_internal(&[Mp], rnd)
392-
}
393-
394-
fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
395-
where
396-
P: MlDsaParams,
397-
{
398-
// Compute the message representative
399-
// XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
400-
// the concatenated M'.
401-
// XXX(RLB) Should the API represent this as an input?
402-
let mu = message_representative(&self.tr, Mp);
419+
let mu = MuBuilder::internal(&self.tr, Mp);
403420
self.raw_sign_mu(&mu, rnd)
404421
}
405422

@@ -469,6 +486,15 @@ impl<P: MlDsaParams> SigningKey<P> {
469486
M: &[u8],
470487
ctx: &[u8],
471488
rng: &mut R,
489+
) -> Result<Signature<P>, Error> {
490+
self.raw_sign_randomized(&[M], ctx, rng)
491+
}
492+
493+
fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
494+
&self,
495+
Mp: &[&[u8]],
496+
ctx: &[u8],
497+
rng: &mut R,
472498
) -> Result<Signature<P>, Error> {
473499
if ctx.len() > 255 {
474500
return Err(Error::new());
@@ -477,8 +503,8 @@ impl<P: MlDsaParams> SigningKey<P> {
477503
let mut rnd = B32::default();
478504
rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
479505

480-
let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
481-
Ok(self.sign_internal(Mp, &rnd))
506+
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
507+
Ok(self.raw_sign_mu(&mu, &rnd))
482508
}
483509

484510
/// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
@@ -517,14 +543,13 @@ impl<P: MlDsaParams> SigningKey<P> {
517543
self.raw_sign_mu(mu, &rnd)
518544
}
519545

520-
fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
546+
fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
521547
if ctx.len() > 255 {
522548
return Err(Error::new());
523549
}
524550

525-
let rnd = B32::default();
526-
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
527-
Ok(self.raw_sign_internal(Mp, &rnd))
551+
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
552+
Ok(self.sign_mu_deterministic(&mu))
528553
}
529554

530555
/// Encode the key in a fixed-size byte array.
@@ -608,9 +633,9 @@ impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
608633
&self,
609634
f: F,
610635
) -> Result<Signature<P>, Error> {
611-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
612-
f(&mut digest)?;
613-
let mu = H::pre_digest(digest).squeeze_new();
636+
let mut mu = MuBuilder::new(&self.tr, &[]);
637+
f(mu.as_mut())?;
638+
let mu = mu.finish();
614639

615640
Ok(self.sign_mu_deterministic(&mu))
616641
}
@@ -640,13 +665,27 @@ impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
640665
/// context string. If you would like to include a context string, use the
641666
/// [`SigningKey::sign_randomized`] method.
642667
#[cfg(feature = "rand_core")]
643-
impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
668+
impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
644669
fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
645670
&self,
646671
rng: &mut R,
647672
msg: &[u8],
648673
) -> Result<Signature<P>, Error> {
649-
self.sign_randomized(msg, &[], rng)
674+
self.try_multipart_sign_with_rng(rng, &[msg])
675+
}
676+
}
677+
678+
/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
679+
/// context string. If you would like to include a context string, use the
680+
/// [`SigningKey::sign_randomized`] method.
681+
#[cfg(feature = "rand_core")]
682+
impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
683+
fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
684+
&self,
685+
rng: &mut R,
686+
msg: &[&[u8]],
687+
) -> Result<Signature<P>, Error> {
688+
self.raw_sign_randomized(msg, &[], rng)
650689
}
651690
}
652691

@@ -663,9 +702,9 @@ impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningK
663702
rng: &mut R,
664703
f: F,
665704
) -> Result<Signature<P>, Error> {
666-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
667-
f(&mut digest)?;
668-
let mu = H::pre_digest(digest).squeeze_new();
705+
let mut mu = MuBuilder::new(&self.tr, &[]);
706+
f(mu.as_mut())?;
707+
let mu = mu.finish();
669708

670709
self.sign_mu_randomized(&mu, rng)
671710
}
@@ -736,19 +775,11 @@ impl<P: MlDsaParams> VerifyingKey<P> {
736775
/// include the domain separator that distinguishes between the normal and pre-hashed cases,
737776
/// and it does not separate the context string from the rest of the message.
738777
// Algorithm 8 ML-DSA.Verify_internal
739-
pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
740-
where
741-
P: MlDsaParams,
742-
{
743-
self.raw_verify_internal(&[Mp], sigma)
744-
}
745-
746-
fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
778+
pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
747779
where
748780
P: MlDsaParams,
749781
{
750-
// Compute the message representative
751-
let mu = message_representative(&self.tr, Mp);
782+
let mu = MuBuilder::internal(&self.tr, &[M]);
752783
self.raw_verify_mu(&mu, sigma)
753784
}
754785

@@ -793,8 +824,8 @@ impl<P: MlDsaParams> VerifyingKey<P> {
793824
return false;
794825
}
795826

796-
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
797-
self.raw_verify_internal(Mp, sigma)
827+
let mu = MuBuilder::new(&self.tr, ctx).message(M);
828+
self.verify_mu(&mu, sigma)
798829
}
799830

800831
fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
@@ -837,9 +868,9 @@ impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P>
837868
f: F,
838869
signature: &Signature<P>,
839870
) -> Result<(), Error> {
840-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
841-
f(&mut digest)?;
842-
let mu = H::pre_digest(digest).squeeze_new();
871+
let mut mu = MuBuilder::new(&self.tr, &[]);
872+
f(mu.as_mut())?;
873+
let mu = mu.finish();
843874

844875
self.raw_verify_mu(&mu, signature)
845876
.then_some(())
@@ -1060,6 +1091,7 @@ where
10601091
mod test {
10611092
use super::*;
10621093
use crate::param::*;
1094+
use signature::digest::Update;
10631095

10641096
#[test]
10651097
fn output_sizes() {
@@ -1142,7 +1174,7 @@ mod test {
11421174
let rnd = Array([0u8; 32]);
11431175
let sig = sk.sign_internal(&[M], &rnd);
11441176

1145-
assert!(vk.verify_internal(&[M], &sig));
1177+
assert!(vk.verify_internal(M, &sig));
11461178
}
11471179

11481180
#[test]
@@ -1179,7 +1211,7 @@ mod test {
11791211
let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
11801212

11811213
assert_eq!(sig_dec, sig);
1182-
assert!(vk.verify_internal(&[M], &sig_dec));
1214+
assert!(vk.verify_internal(M, &sig_dec));
11831215
}
11841216
}
11851217

@@ -1202,7 +1234,7 @@ mod test {
12021234

12031235
let M = b"Hello world";
12041236
let rnd = Array([0u8; 32]);
1205-
let mu = message_representative(&sk.tr, &[&[M]]);
1237+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12061238
let sig = sk.raw_sign_mu(&mu, &rnd);
12071239

12081240
assert!(vk.raw_verify_mu(&mu, &sig));
@@ -1224,10 +1256,10 @@ mod test {
12241256

12251257
let M = b"Hello world";
12261258
let rnd = Array([0u8; 32]);
1227-
let mu = message_representative(&sk.tr, &[&[M]]);
1259+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12281260
let sig = sk.raw_sign_mu(&mu, &rnd);
12291261

1230-
assert!(vk.verify_internal(&[M], &sig));
1262+
assert!(vk.verify_internal(M, &sig));
12311263
}
12321264
sign_mu_verify_internal::<MlDsa44>();
12331265
sign_mu_verify_internal::<MlDsa65>();
@@ -1246,7 +1278,7 @@ mod test {
12461278

12471279
let M = b"Hello world";
12481280
let rnd = Array([0u8; 32]);
1249-
let mu = message_representative(&sk.tr, &[&[M]]);
1281+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12501282
let sig = sk.sign_internal(&[M], &rnd);
12511283

12521284
assert!(vk.raw_verify_mu(&mu, &sig));

ml-dsa/tests/sig-ver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn verify<P: MlDsaParams>(tg: &acvp::TestGroup, tc: &acvp::TestCase) {
3535

3636
// Verify the signature if it successfully decoded
3737
let test_passed = sig
38-
.map(|sig| vk.verify_internal(&[&tc.message], &sig))
38+
.map(|sig| vk.verify_internal(&tc.message, &sig))
3939
.unwrap_or_default();
4040
assert_eq!(test_passed, tc.test_passed);
4141
}

0 commit comments

Comments
 (0)