@@ -51,11 +51,10 @@ use hybrid_array::{
5151 U75 , U80 , U88 , Unsigned ,
5252 } ,
5353} ;
54- use signature:: digest:: Update ;
5554use signature:: { DigestSigner , DigestVerifier , MultipartSigner , MultipartVerifier , Signer } ;
5655
5756#[ cfg( feature = "rand_core" ) ]
58- use signature:: RandomizedDigestSigner ;
57+ use signature:: { RandomizedDigestSigner , RandomizedMultipartSigner , RandomizedSigner } ;
5958
6059#[ cfg( feature = "rand_core" ) ]
6160use rand_core:: { CryptoRng , TryCryptoRng } ;
@@ -171,17 +170,46 @@ where
171170 const ALGORITHM_IDENTIFIER : AlgorithmIdentifierRef < ' static > = P :: ALGORITHM_IDENTIFIER ;
172171}
173172
174- // This method takes a slice of slices so that we can accommodate the varying calculations (direct
175- // for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
176- // having to allocate memory for components.
177- fn message_representative ( tr : & [ u8 ] , Mp : & [ & [ & [ u8 ] ] ] ) -> B64 {
178- let mut h = H :: default ( ) . absorb ( tr) ;
173+ struct MuBuilder ( H ) ;
179174
180- for m in Mp . iter ( ) . copied ( ) . flatten ( ) {
181- h = h. absorb ( m) ;
175+ impl MuBuilder {
176+ fn new ( tr : & [ u8 ] , ctx : & [ u8 ] ) -> Self {
177+ let mut h = H :: default ( ) ;
178+ h = h. absorb ( tr) ;
179+ h = h. absorb ( & [ 0 ] ) ;
180+ h = h. absorb ( & [ Truncate :: truncate ( ctx. len ( ) ) ] ) ;
181+ h = h. absorb ( ctx) ;
182+
183+ Self ( h)
184+ }
185+
186+ fn internal ( tr : & [ u8 ] , Mp : & [ & [ u8 ] ] ) -> B64 {
187+ let mut h = H :: default ( ) . absorb ( tr) ;
188+
189+ for m in Mp {
190+ h = h. absorb ( m) ;
191+ }
192+
193+ h. squeeze_new ( )
182194 }
183195
184- h. squeeze_new ( )
196+ fn message ( mut self , M : & [ & [ u8 ] ] ) -> B64 {
197+ for m in M {
198+ self . 0 = self . 0 . absorb ( m) ;
199+ }
200+
201+ self . 0 . squeeze_new ( )
202+ }
203+
204+ fn finish ( mut self ) -> B64 {
205+ self . 0 . squeeze_new ( )
206+ }
207+ }
208+
209+ impl AsMut < Shake256 > for MuBuilder {
210+ fn as_mut ( & mut self ) -> & mut Shake256 {
211+ self . 0 . updatable ( )
212+ }
185213}
186214
187215/// An ML-DSA key pair
@@ -388,18 +416,7 @@ impl<P: MlDsaParams> SigningKey<P> {
388416 where
389417 P : MlDsaParams ,
390418 {
391- self . raw_sign_internal ( & [ Mp ] , rnd)
392- }
393-
394- fn raw_sign_internal ( & self , Mp : & [ & [ & [ u8 ] ] ] , rnd : & B32 ) -> Signature < P >
395- where
396- P : MlDsaParams ,
397- {
398- // Compute the message representative
399- // XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
400- // the concatenated M'.
401- // XXX(RLB) Should the API represent this as an input?
402- let mu = message_representative ( & self . tr , Mp ) ;
419+ let mu = MuBuilder :: internal ( & self . tr , Mp ) ;
403420 self . raw_sign_mu ( & mu, rnd)
404421 }
405422
@@ -469,6 +486,15 @@ impl<P: MlDsaParams> SigningKey<P> {
469486 M : & [ u8 ] ,
470487 ctx : & [ u8 ] ,
471488 rng : & mut R ,
489+ ) -> Result < Signature < P > , Error > {
490+ self . raw_sign_randomized ( & [ M ] , ctx, rng)
491+ }
492+
493+ fn raw_sign_randomized < R : TryCryptoRng + ?Sized > (
494+ & self ,
495+ Mp : & [ & [ u8 ] ] ,
496+ ctx : & [ u8 ] ,
497+ rng : & mut R ,
472498 ) -> Result < Signature < P > , Error > {
473499 if ctx. len ( ) > 255 {
474500 return Err ( Error :: new ( ) ) ;
@@ -477,8 +503,8 @@ impl<P: MlDsaParams> SigningKey<P> {
477503 let mut rnd = B32 :: default ( ) ;
478504 rng. try_fill_bytes ( & mut rnd) . map_err ( |_| Error :: new ( ) ) ?;
479505
480- let Mp = & [ & [ 0 ] , & [ Truncate :: truncate ( ctx . len ( ) ) ] , ctx, M ] ;
481- Ok ( self . sign_internal ( Mp , & rnd) )
506+ let mu = MuBuilder :: new ( & self . tr , ctx) . message ( Mp ) ;
507+ Ok ( self . raw_sign_mu ( & mu , & rnd) )
482508 }
483509
484510 /// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
@@ -517,14 +543,13 @@ impl<P: MlDsaParams> SigningKey<P> {
517543 self . raw_sign_mu ( mu, & rnd)
518544 }
519545
520- fn raw_sign_deterministic ( & self , M : & [ & [ u8 ] ] , ctx : & [ u8 ] ) -> Result < Signature < P > , Error > {
546+ fn raw_sign_deterministic ( & self , Mp : & [ & [ u8 ] ] , ctx : & [ u8 ] ) -> Result < Signature < P > , Error > {
521547 if ctx. len ( ) > 255 {
522548 return Err ( Error :: new ( ) ) ;
523549 }
524550
525- let rnd = B32 :: default ( ) ;
526- let Mp = & [ & [ & [ 0 ] , & [ Truncate :: truncate ( ctx. len ( ) ) ] , ctx] , M ] ;
527- Ok ( self . raw_sign_internal ( Mp , & rnd) )
551+ let mu = MuBuilder :: new ( & self . tr , ctx) . message ( Mp ) ;
552+ Ok ( self . sign_mu_deterministic ( & mu) )
528553 }
529554
530555 /// Encode the key in a fixed-size byte array.
@@ -608,9 +633,9 @@ impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
608633 & self ,
609634 f : F ,
610635 ) -> Result < Signature < P > , Error > {
611- let mut digest = Shake256 :: default ( ) . chain ( self . tr ) . chain ( [ 0 , 0 ] ) ;
612- f ( & mut digest ) ?;
613- let mu = H :: pre_digest ( digest ) . squeeze_new ( ) ;
636+ let mut mu = MuBuilder :: new ( & self . tr , & [ ] ) ;
637+ f ( mu . as_mut ( ) ) ?;
638+ let mu = mu . finish ( ) ;
614639
615640 Ok ( self . sign_mu_deterministic ( & mu) )
616641 }
@@ -640,13 +665,27 @@ impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
640665/// context string. If you would like to include a context string, use the
641666/// [`SigningKey::sign_randomized`] method.
642667#[ cfg( feature = "rand_core" ) ]
643- impl < P : MlDsaParams > signature :: RandomizedSigner < Signature < P > > for SigningKey < P > {
668+ impl < P : MlDsaParams > RandomizedSigner < Signature < P > > for SigningKey < P > {
644669 fn try_sign_with_rng < R : TryCryptoRng + ?Sized > (
645670 & self ,
646671 rng : & mut R ,
647672 msg : & [ u8 ] ,
648673 ) -> Result < Signature < P > , Error > {
649- self . sign_randomized ( msg, & [ ] , rng)
674+ self . try_multipart_sign_with_rng ( rng, & [ msg] )
675+ }
676+ }
677+
678+ /// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
679+ /// context string. If you would like to include a context string, use the
680+ /// [`SigningKey::sign_randomized`] method.
681+ #[ cfg( feature = "rand_core" ) ]
682+ impl < P : MlDsaParams > RandomizedMultipartSigner < Signature < P > > for SigningKey < P > {
683+ fn try_multipart_sign_with_rng < R : TryCryptoRng + ?Sized > (
684+ & self ,
685+ rng : & mut R ,
686+ msg : & [ & [ u8 ] ] ,
687+ ) -> Result < Signature < P > , Error > {
688+ self . raw_sign_randomized ( msg, & [ ] , rng)
650689 }
651690}
652691
@@ -663,9 +702,9 @@ impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningK
663702 rng : & mut R ,
664703 f : F ,
665704 ) -> Result < Signature < P > , Error > {
666- let mut digest = Shake256 :: default ( ) . chain ( self . tr ) . chain ( [ 0 , 0 ] ) ;
667- f ( & mut digest ) ?;
668- let mu = H :: pre_digest ( digest ) . squeeze_new ( ) ;
705+ let mut mu = MuBuilder :: new ( & self . tr , & [ ] ) ;
706+ f ( mu . as_mut ( ) ) ?;
707+ let mu = mu . finish ( ) ;
669708
670709 self . sign_mu_randomized ( & mu, rng)
671710 }
@@ -736,19 +775,11 @@ impl<P: MlDsaParams> VerifyingKey<P> {
736775 /// include the domain separator that distinguishes between the normal and pre-hashed cases,
737776 /// and it does not separate the context string from the rest of the message.
738777 // Algorithm 8 ML-DSA.Verify_internal
739- pub fn verify_internal ( & self , Mp : & [ & [ u8 ] ] , sigma : & Signature < P > ) -> bool
740- where
741- P : MlDsaParams ,
742- {
743- self . raw_verify_internal ( & [ Mp ] , sigma)
744- }
745-
746- fn raw_verify_internal ( & self , Mp : & [ & [ & [ u8 ] ] ] , sigma : & Signature < P > ) -> bool
778+ pub fn verify_internal ( & self , M : & [ u8 ] , sigma : & Signature < P > ) -> bool
747779 where
748780 P : MlDsaParams ,
749781 {
750- // Compute the message representative
751- let mu = message_representative ( & self . tr , Mp ) ;
782+ let mu = MuBuilder :: internal ( & self . tr , & [ M ] ) ;
752783 self . raw_verify_mu ( & mu, sigma)
753784 }
754785
@@ -793,8 +824,8 @@ impl<P: MlDsaParams> VerifyingKey<P> {
793824 return false ;
794825 }
795826
796- let Mp = & [ & [ & [ 0 ] , & [ Truncate :: truncate ( ctx . len ( ) ) ] , ctx] , M ] ;
797- self . raw_verify_internal ( Mp , sigma)
827+ let mu = MuBuilder :: new ( & self . tr , ctx) . message ( M ) ;
828+ self . verify_mu ( & mu , sigma)
798829 }
799830
800831 fn encode_internal ( rho : & B32 , t1 : & Vector < P :: K > ) -> EncodedVerifyingKey < P > {
@@ -837,9 +868,9 @@ impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P>
837868 f : F ,
838869 signature : & Signature < P > ,
839870 ) -> Result < ( ) , Error > {
840- let mut digest = Shake256 :: default ( ) . chain ( self . tr ) . chain ( [ 0 , 0 ] ) ;
841- f ( & mut digest ) ?;
842- let mu = H :: pre_digest ( digest ) . squeeze_new ( ) ;
871+ let mut mu = MuBuilder :: new ( & self . tr , & [ ] ) ;
872+ f ( mu . as_mut ( ) ) ?;
873+ let mu = mu . finish ( ) ;
843874
844875 self . raw_verify_mu ( & mu, signature)
845876 . then_some ( ( ) )
@@ -1060,6 +1091,7 @@ where
10601091mod test {
10611092 use super :: * ;
10621093 use crate :: param:: * ;
1094+ use signature:: digest:: Update ;
10631095
10641096 #[ test]
10651097 fn output_sizes ( ) {
@@ -1142,7 +1174,7 @@ mod test {
11421174 let rnd = Array ( [ 0u8 ; 32 ] ) ;
11431175 let sig = sk. sign_internal ( & [ M ] , & rnd) ;
11441176
1145- assert ! ( vk. verify_internal( & [ M ] , & sig) ) ;
1177+ assert ! ( vk. verify_internal( M , & sig) ) ;
11461178 }
11471179
11481180 #[ test]
@@ -1179,7 +1211,7 @@ mod test {
11791211 let sig_dec = Signature :: < P > :: decode ( & sig_enc) . unwrap ( ) ;
11801212
11811213 assert_eq ! ( sig_dec, sig) ;
1182- assert ! ( vk. verify_internal( & [ M ] , & sig_dec) ) ;
1214+ assert ! ( vk. verify_internal( M , & sig_dec) ) ;
11831215 }
11841216 }
11851217
@@ -1202,7 +1234,7 @@ mod test {
12021234
12031235 let M = b"Hello world" ;
12041236 let rnd = Array ( [ 0u8 ; 32 ] ) ;
1205- let mu = message_representative ( & sk. tr , & [ & [ M ] ] ) ;
1237+ let mu = MuBuilder :: internal ( & sk. tr , & [ M ] ) ;
12061238 let sig = sk. raw_sign_mu ( & mu, & rnd) ;
12071239
12081240 assert ! ( vk. raw_verify_mu( & mu, & sig) ) ;
@@ -1224,10 +1256,10 @@ mod test {
12241256
12251257 let M = b"Hello world" ;
12261258 let rnd = Array ( [ 0u8 ; 32 ] ) ;
1227- let mu = message_representative ( & sk. tr , & [ & [ M ] ] ) ;
1259+ let mu = MuBuilder :: internal ( & sk. tr , & [ M ] ) ;
12281260 let sig = sk. raw_sign_mu ( & mu, & rnd) ;
12291261
1230- assert ! ( vk. verify_internal( & [ M ] , & sig) ) ;
1262+ assert ! ( vk. verify_internal( M , & sig) ) ;
12311263 }
12321264 sign_mu_verify_internal :: < MlDsa44 > ( ) ;
12331265 sign_mu_verify_internal :: < MlDsa65 > ( ) ;
@@ -1246,7 +1278,7 @@ mod test {
12461278
12471279 let M = b"Hello world" ;
12481280 let rnd = Array ( [ 0u8 ; 32 ] ) ;
1249- let mu = message_representative ( & sk. tr , & [ & [ M ] ] ) ;
1281+ let mu = MuBuilder :: internal ( & sk. tr , & [ M ] ) ;
12501282 let sig = sk. sign_internal ( & [ M ] , & rnd) ;
12511283
12521284 assert ! ( vk. raw_verify_mu( & mu, & sig) ) ;
0 commit comments