@@ -352,7 +352,7 @@ static size_t secp256k1_strauss_scratch_size(size_t n_points) {
352352 return n_points * point_size ;
353353}
354354
355- static int secp256k1_ecmult_strauss_batch (const secp256k1_callback * error_callback , secp256k1_scratch * scratch , secp256k1_gej * r , const secp256k1_scalar * inp_g_sc , secp256k1_ecmult_multi_callback cb , void * cbdata , size_t n_points , size_t cb_offset ) {
355+ static int secp256k1_ecmult_strauss_batch (const secp256k1_callback * error_callback , secp256k1_scratch * scratch , secp256k1_gej * r , secp256k1_scalar * scratch_sclrs , secp256k1_gej * scratch_pts , const secp256k1_scalar * inp_g_sc , secp256k1_ecmult_multi_callback cb , void * cbdata , size_t n_points , size_t cb_offset ) {
356356 secp256k1_gej * points ;
357357 secp256k1_scalar * scalars ;
358358 struct secp256k1_strauss_state state ;
@@ -367,8 +367,13 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callba
367367 /* We allocate STRAUSS_SCRATCH_OBJECTS objects on the scratch space. If these
368368 * allocations change, make sure to update the STRAUSS_SCRATCH_OBJECTS
369369 * constant and strauss_scratch_size accordingly. */
370- points = (secp256k1_gej * )secp256k1_scratch_alloc (error_callback , scratch , n_points * sizeof (secp256k1_gej ));
371- scalars = (secp256k1_scalar * )secp256k1_scratch_alloc (error_callback , scratch , n_points * sizeof (secp256k1_scalar ));
370+ if (scratch_sclrs == NULL && scratch_pts == NULL ) {
371+ points = (secp256k1_gej * )secp256k1_scratch_alloc (error_callback , scratch , n_points * sizeof (secp256k1_gej ));
372+ scalars = (secp256k1_scalar * )secp256k1_scratch_alloc (error_callback , scratch , n_points * sizeof (secp256k1_scalar ));
373+ } else {
374+ points = scratch_pts ;
375+ scalars = scratch_sclrs ;
376+ }
372377 state .aux = (secp256k1_fe * )secp256k1_scratch_alloc (error_callback , scratch , n_points * ECMULT_TABLE_SIZE (WINDOW_A ) * sizeof (secp256k1_fe ));
373378 state .pre_a = (secp256k1_ge * )secp256k1_scratch_alloc (error_callback , scratch , n_points * ECMULT_TABLE_SIZE (WINDOW_A ) * sizeof (secp256k1_ge ));
374379 state .ps = (struct secp256k1_strauss_point_state * )secp256k1_scratch_alloc (error_callback , scratch , n_points * sizeof (struct secp256k1_strauss_point_state ));
@@ -378,13 +383,15 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callba
378383 return 0 ;
379384 }
380385
381- for (i = 0 ; i < n_points ; i ++ ) {
382- secp256k1_ge point ;
383- if (!cb (& scalars [i ], & point , i + cb_offset , cbdata )) {
384- secp256k1_scratch_apply_checkpoint (error_callback , scratch , scratch_checkpoint );
385- return 0 ;
386+ if (scratch_sclrs == NULL && scratch_pts == NULL ) {
387+ for (i = 0 ; i < n_points ; i ++ ) {
388+ secp256k1_ge point ;
389+ if (!cb (& scalars [i ], & point , i + cb_offset , cbdata )) {
390+ secp256k1_scratch_apply_checkpoint (error_callback , scratch , scratch_checkpoint );
391+ return 0 ;
392+ }
393+ secp256k1_gej_set_ge (& points [i ], & point );
386394 }
387- secp256k1_gej_set_ge (& points [i ], & point );
388395 }
389396 secp256k1_ecmult_strauss_wnaf (& state , r , n_points , points , scalars , inp_g_sc );
390397 secp256k1_scratch_apply_checkpoint (error_callback , scratch , scratch_checkpoint );
@@ -393,7 +400,7 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callba
393400
394401/* Wrapper for secp256k1_ecmult_multi_func interface */
395402static int secp256k1_ecmult_strauss_batch_single (const secp256k1_callback * error_callback , secp256k1_scratch * scratch , secp256k1_gej * r , const secp256k1_scalar * inp_g_sc , secp256k1_ecmult_multi_callback cb , void * cbdata , size_t n ) {
396- return secp256k1_ecmult_strauss_batch (error_callback , scratch , r , inp_g_sc , cb , cbdata , n , 0 );
403+ return secp256k1_ecmult_strauss_batch (error_callback , scratch , r , NULL , NULL , inp_g_sc , cb , cbdata , n , 0 );
397404}
398405
399406static size_t secp256k1_strauss_max_points (const secp256k1_callback * error_callback , secp256k1_scratch * scratch ) {
@@ -838,10 +845,7 @@ static int secp256k1_ecmult_multi_var(const secp256k1_callback* error_callback,
838845 if (n_batch_points >= ECMULT_PIPPENGER_THRESHOLD ) {
839846 f = secp256k1_ecmult_pippenger_batch ;
840847 } else {
841- if (!secp256k1_ecmult_multi_batch_size_helper (& n_batches , & n_batch_points , secp256k1_strauss_max_points (error_callback , scratch ), n )) {
842- return secp256k1_ecmult_multi_simple_var (r , inp_g_sc , cb , cbdata , n );
843- }
844- f = secp256k1_ecmult_strauss_batch ;
848+ return secp256k1_ecmult_multi_simple_var (r , inp_g_sc , cb , cbdata , n );
845849 }
846850 for (i = 0 ; i < n_batches ; i ++ ) {
847851 size_t nbp = n < n_batch_points ? n : n_batch_points ;
0 commit comments