Skip to content

Commit 8aa41de

Browse files
authored
ml-dsa: DigestSigner cleanup (#1073)
1 parent 0e4a329 commit 8aa41de

File tree

3 files changed

+96
-60
lines changed

3 files changed

+96
-60
lines changed

ml-dsa/src/crypto.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
1818
}
1919

2020
impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
21-
pub fn pre_digest(digest: Shake) -> Self {
22-
Self::Absorbing(digest)
21+
pub fn updatable(&mut self) -> &mut Shake {
22+
match self {
23+
Self::Absorbing(sponge) => sponge,
24+
Self::Squeezing(_) => unreachable!(),
25+
}
2326
}
2427

2528
pub fn absorb(mut self, input: &[u8]) -> Self {

ml-dsa/src/lib.rs

Lines changed: 90 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ use hybrid_array::{
5151
U75, U80, U88, Unsigned,
5252
},
5353
};
54-
use signature::digest::Update;
5554
use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
5655

5756
#[cfg(feature = "rand_core")]
58-
use signature::RandomizedDigestSigner;
57+
use signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner};
5958

6059
#[cfg(feature = "rand_core")]
6160
use rand_core::{CryptoRng, TryCryptoRng};
@@ -171,17 +170,46 @@ where
171170
const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
172171
}
173172

174-
// This method takes a slice of slices so that we can accommodate the varying calculations (direct
175-
// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
176-
// having to allocate memory for components.
177-
fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
178-
let mut h = H::default().absorb(tr);
173+
struct MuBuilder(H);
179174

180-
for m in Mp.iter().copied().flatten() {
181-
h = h.absorb(m);
175+
impl MuBuilder {
176+
fn new(tr: &[u8], ctx: &[u8]) -> Self {
177+
let mut h = H::default();
178+
h = h.absorb(tr);
179+
h = h.absorb(&[0]);
180+
h = h.absorb(&[Truncate::truncate(ctx.len())]);
181+
h = h.absorb(ctx);
182+
183+
Self(h)
184+
}
185+
186+
fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
187+
let mut h = H::default().absorb(tr);
188+
189+
for m in Mp {
190+
h = h.absorb(m);
191+
}
192+
193+
h.squeeze_new()
194+
}
195+
196+
fn message(mut self, M: &[&[u8]]) -> B64 {
197+
for m in M {
198+
self.0 = self.0.absorb(m);
199+
}
200+
201+
self.0.squeeze_new()
202+
}
203+
204+
fn finish(mut self) -> B64 {
205+
self.0.squeeze_new()
182206
}
207+
}
183208

184-
h.squeeze_new()
209+
impl AsMut<Shake256> for MuBuilder {
210+
fn as_mut(&mut self) -> &mut Shake256 {
211+
self.0.updatable()
212+
}
185213
}
186214

187215
/// An ML-DSA key pair
@@ -388,18 +416,7 @@ impl<P: MlDsaParams> SigningKey<P> {
388416
where
389417
P: MlDsaParams,
390418
{
391-
self.raw_sign_internal(&[Mp], rnd)
392-
}
393-
394-
fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
395-
where
396-
P: MlDsaParams,
397-
{
398-
// Compute the message representative
399-
// XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
400-
// the concatenated M'.
401-
// XXX(RLB) Should the API represent this as an input?
402-
let mu = message_representative(&self.tr, Mp);
419+
let mu = MuBuilder::internal(&self.tr, Mp);
403420
self.raw_sign_mu(&mu, rnd)
404421
}
405422

@@ -469,6 +486,16 @@ impl<P: MlDsaParams> SigningKey<P> {
469486
M: &[u8],
470487
ctx: &[u8],
471488
rng: &mut R,
489+
) -> Result<Signature<P>, Error> {
490+
self.raw_sign_randomized(&[M], ctx, rng)
491+
}
492+
493+
#[cfg(feature = "rand_core")]
494+
fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
495+
&self,
496+
Mp: &[&[u8]],
497+
ctx: &[u8],
498+
rng: &mut R,
472499
) -> Result<Signature<P>, Error> {
473500
if ctx.len() > 255 {
474501
return Err(Error::new());
@@ -477,8 +504,8 @@ impl<P: MlDsaParams> SigningKey<P> {
477504
let mut rnd = B32::default();
478505
rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
479506

480-
let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
481-
Ok(self.sign_internal(Mp, &rnd))
507+
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
508+
Ok(self.raw_sign_mu(&mu, &rnd))
482509
}
483510

484511
/// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
@@ -517,14 +544,13 @@ impl<P: MlDsaParams> SigningKey<P> {
517544
self.raw_sign_mu(mu, &rnd)
518545
}
519546

520-
fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
547+
fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
521548
if ctx.len() > 255 {
522549
return Err(Error::new());
523550
}
524551

525-
let rnd = B32::default();
526-
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
527-
Ok(self.raw_sign_internal(Mp, &rnd))
552+
let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
553+
Ok(self.sign_mu_deterministic(&mu))
528554
}
529555

530556
/// Encode the key in a fixed-size byte array.
@@ -608,9 +634,9 @@ impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
608634
&self,
609635
f: F,
610636
) -> Result<Signature<P>, Error> {
611-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
612-
f(&mut digest)?;
613-
let mu = H::pre_digest(digest).squeeze_new();
637+
let mut mu = MuBuilder::new(&self.tr, &[]);
638+
f(mu.as_mut())?;
639+
let mu = mu.finish();
614640

615641
Ok(self.sign_mu_deterministic(&mu))
616642
}
@@ -640,13 +666,27 @@ impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
640666
/// context string. If you would like to include a context string, use the
641667
/// [`SigningKey::sign_randomized`] method.
642668
#[cfg(feature = "rand_core")]
643-
impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
669+
impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
644670
fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
645671
&self,
646672
rng: &mut R,
647673
msg: &[u8],
648674
) -> Result<Signature<P>, Error> {
649-
self.sign_randomized(msg, &[], rng)
675+
self.try_multipart_sign_with_rng(rng, &[msg])
676+
}
677+
}
678+
679+
/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
680+
/// context string. If you would like to include a context string, use the
681+
/// [`SigningKey::sign_randomized`] method.
682+
#[cfg(feature = "rand_core")]
683+
impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
684+
fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
685+
&self,
686+
rng: &mut R,
687+
msg: &[&[u8]],
688+
) -> Result<Signature<P>, Error> {
689+
self.raw_sign_randomized(msg, &[], rng)
650690
}
651691
}
652692

@@ -663,9 +703,9 @@ impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningK
663703
rng: &mut R,
664704
f: F,
665705
) -> Result<Signature<P>, Error> {
666-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
667-
f(&mut digest)?;
668-
let mu = H::pre_digest(digest).squeeze_new();
706+
let mut mu = MuBuilder::new(&self.tr, &[]);
707+
f(mu.as_mut())?;
708+
let mu = mu.finish();
669709

670710
self.sign_mu_randomized(&mu, rng)
671711
}
@@ -736,19 +776,11 @@ impl<P: MlDsaParams> VerifyingKey<P> {
736776
/// include the domain separator that distinguishes between the normal and pre-hashed cases,
737777
/// and it does not separate the context string from the rest of the message.
738778
// Algorithm 8 ML-DSA.Verify_internal
739-
pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
740-
where
741-
P: MlDsaParams,
742-
{
743-
self.raw_verify_internal(&[Mp], sigma)
744-
}
745-
746-
fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
779+
pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
747780
where
748781
P: MlDsaParams,
749782
{
750-
// Compute the message representative
751-
let mu = message_representative(&self.tr, Mp);
783+
let mu = MuBuilder::internal(&self.tr, &[M]);
752784
self.raw_verify_mu(&mu, sigma)
753785
}
754786

@@ -793,8 +825,8 @@ impl<P: MlDsaParams> VerifyingKey<P> {
793825
return false;
794826
}
795827

796-
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
797-
self.raw_verify_internal(Mp, sigma)
828+
let mu = MuBuilder::new(&self.tr, ctx).message(M);
829+
self.verify_mu(&mu, sigma)
798830
}
799831

800832
fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
@@ -837,9 +869,9 @@ impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P>
837869
f: F,
838870
signature: &Signature<P>,
839871
) -> Result<(), Error> {
840-
let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
841-
f(&mut digest)?;
842-
let mu = H::pre_digest(digest).squeeze_new();
872+
let mut mu = MuBuilder::new(&self.tr, &[]);
873+
f(mu.as_mut())?;
874+
let mu = mu.finish();
843875

844876
self.raw_verify_mu(&mu, signature)
845877
.then_some(())
@@ -1060,6 +1092,7 @@ where
10601092
mod test {
10611093
use super::*;
10621094
use crate::param::*;
1095+
use signature::digest::Update;
10631096

10641097
#[test]
10651098
fn output_sizes() {
@@ -1142,7 +1175,7 @@ mod test {
11421175
let rnd = Array([0u8; 32]);
11431176
let sig = sk.sign_internal(&[M], &rnd);
11441177

1145-
assert!(vk.verify_internal(&[M], &sig));
1178+
assert!(vk.verify_internal(M, &sig));
11461179
}
11471180

11481181
#[test]
@@ -1179,7 +1212,7 @@ mod test {
11791212
let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
11801213

11811214
assert_eq!(sig_dec, sig);
1182-
assert!(vk.verify_internal(&[M], &sig_dec));
1215+
assert!(vk.verify_internal(M, &sig_dec));
11831216
}
11841217
}
11851218

@@ -1202,7 +1235,7 @@ mod test {
12021235

12031236
let M = b"Hello world";
12041237
let rnd = Array([0u8; 32]);
1205-
let mu = message_representative(&sk.tr, &[&[M]]);
1238+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12061239
let sig = sk.raw_sign_mu(&mu, &rnd);
12071240

12081241
assert!(vk.raw_verify_mu(&mu, &sig));
@@ -1224,10 +1257,10 @@ mod test {
12241257

12251258
let M = b"Hello world";
12261259
let rnd = Array([0u8; 32]);
1227-
let mu = message_representative(&sk.tr, &[&[M]]);
1260+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12281261
let sig = sk.raw_sign_mu(&mu, &rnd);
12291262

1230-
assert!(vk.verify_internal(&[M], &sig));
1263+
assert!(vk.verify_internal(M, &sig));
12311264
}
12321265
sign_mu_verify_internal::<MlDsa44>();
12331266
sign_mu_verify_internal::<MlDsa65>();
@@ -1246,7 +1279,7 @@ mod test {
12461279

12471280
let M = b"Hello world";
12481281
let rnd = Array([0u8; 32]);
1249-
let mu = message_representative(&sk.tr, &[&[M]]);
1282+
let mu = MuBuilder::internal(&sk.tr, &[M]);
12501283
let sig = sk.sign_internal(&[M], &rnd);
12511284

12521285
assert!(vk.raw_verify_mu(&mu, &sig));

ml-dsa/tests/sig-ver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn verify<P: MlDsaParams>(tg: &acvp::TestGroup, tc: &acvp::TestCase) {
3535

3636
// Verify the signature if it successfully decoded
3737
let test_passed = sig
38-
.map(|sig| vk.verify_internal(&[&tc.message], &sig))
38+
.map(|sig| vk.verify_internal(&tc.message, &sig))
3939
.unwrap_or_default();
4040
assert_eq!(test_passed, tc.test_passed);
4141
}

0 commit comments

Comments
 (0)