@@ -116,7 +116,7 @@ def decode_silent_payment_address(address: str, hrp: str = "tsp") -> Tuple[ECPub
116116 return B_scan , B_spend
117117
118118
119- def create_outputs (input_priv_keys : List [Tuple [ECKey , bool ]], outpoints : List [COutPoint ], recipients : List [str ], hrp = "tsp" ) -> List [str ]:
119+ def create_outputs (input_priv_keys : List [Tuple [ECKey , bool ]], outpoints : List [COutPoint ], recipients : List [str ], expected : Dict [ str , any ] = None , hrp = "tsp" ) -> List [str ]:
120120 G = ECKey ().set (1 ).get_pubkey ()
121121 negated_keys = []
122122 for key , is_xonly in input_priv_keys :
@@ -129,10 +129,17 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
129129 if not a_sum .valid :
130130 # Input privkeys sum is zero -> fail
131131 return []
132+ assert ECKey ().set (bytes .fromhex (expected .get ("input_private_key_sum" ))) == a_sum , "a_sum did not match expected input_private_key_sum"
133+
132134 input_hash = get_input_hash (outpoints , a_sum * G )
133135 silent_payment_groups : Dict [ECPubKey , List [ECPubKey ]] = {}
134136 for recipient in recipients :
135- B_scan , B_m = decode_silent_payment_address (recipient , hrp = hrp )
137+ B_scan , B_m = decode_silent_payment_address (recipient ["address" ], hrp = hrp )
138+ # Verify decoded intermediate keys for recipient
139+ expected_B_scan = ECPubKey ().set (bytes .fromhex (recipient ["scan_pub_key" ]))
140+ expected_B_m = ECPubKey ().set (bytes .fromhex (recipient ["spend_pub_key" ]))
141+ assert expected_B_scan == B_scan , "B_scan did not match expected recipient.scan_pub_key"
142+ assert expected_B_m == B_m , "B_m did not match expected recipient.spend_pub_key"
136143 if B_scan in silent_payment_groups :
137144 silent_payment_groups [B_scan ].append (B_m )
138145 else :
@@ -141,6 +148,16 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
141148 outputs = []
142149 for B_scan , B_m_values in silent_payment_groups .items ():
143150 ecdh_shared_secret = input_hash * a_sum * B_scan
151+
152+ expected_shared_secrets = expected .get ("shared_secrets" , {})
153+ # Find the recipient address that corresponds to this B_scan and get its index
154+ for recipient_idx , recipient in enumerate (recipients ):
155+ recipient_B_scan = ECPubKey ().set (bytes .fromhex (recipient ["scan_pub_key" ]))
156+ if recipient_B_scan == B_scan :
157+ expected_shared_secret_hex = expected_shared_secrets [recipient_idx ]
158+ assert ecdh_shared_secret .get_bytes (False ).hex () == expected_shared_secret_hex , f"ecdh_shared_secret did not match expected, recipient { recipient_idx } ({ recipient ['address' ]} ): expected={ expected_shared_secret_hex } "
159+ break
160+
144161 k = 0
145162 for B_m in B_m_values :
146163 t_k = TaggedHash ("BIP0352/SharedSecret" , ecdh_shared_secret .get_bytes (False ) + ser_uint32 (k ))
@@ -151,9 +168,15 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
151168 return list (set (outputs ))
152169
153170
154- def scanning (b_scan : ECKey , B_spend : ECPubKey , A_sum : ECPubKey , input_hash : bytes , outputs_to_check : List [ECPubKey ], labels : Dict [str , str ] = {} ) -> List [Dict [str , str ]]:
171+ def scanning (b_scan : ECKey , B_spend : ECPubKey , A_sum : ECPubKey , input_hash : bytes , outputs_to_check : List [ECPubKey ], labels : Dict [str , str ] = None , expected : Dict [ str , any ] = None ) -> List [Dict [str , str ]]:
155172 G = ECKey ().set (1 ).get_pubkey ()
173+ input_hash_key = ECKey ().set (input_hash )
174+ computed_tweak_point = input_hash_key * A_sum
175+ assert computed_tweak_point .get_bytes (False ).hex () == expected .get ("tweak" ), "tweak did not match expected"
176+
156177 ecdh_shared_secret = input_hash * b_scan * A_sum
178+ assert ecdh_shared_secret .get_bytes (False ).hex () == expected .get ("shared_secret" ), "ecdh_shared_secret did not match expected shared_secret"
179+
157180 k = 0
158181 wallet = []
159182 while True :
@@ -236,11 +259,12 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
236259 is_p2tr (vin .prevout ),
237260 ))
238261 input_pub_keys .append (pubkey )
262+ assert [pk .get_bytes (False ).hex () for pk in input_pub_keys ] == expected .get ("input_pub_keys" ), "input_pub_keys did not match expected"
239263
240264 sending_outputs = []
241265 if (len (input_pub_keys ) > 0 ):
242266 outpoints = [vin .outpoint for vin in vins ]
243- sending_outputs = create_outputs (input_priv_keys , outpoints , given ["recipients" ], hrp = "sp" )
267+ sending_outputs = create_outputs (input_priv_keys , outpoints , given ["recipients" ], expected = expected , hrp = "sp" )
244268
245269 # Note: order doesn't matter for creating/finding the outputs. However, different orderings of the recipient addresses
246270 # will produce different generated outputs if sending to multiple silent payment addresses belonging to the
@@ -303,6 +327,7 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
303327 # Input pubkeys sum is point at infinity -> skip tx
304328 assert expected ["outputs" ] == []
305329 continue
330+ assert A_sum .get_bytes (False ).hex () == expected .get ("input_pub_key_sum" ), "A_sum did not match expected input_pub_key_sum"
306331 input_hash = get_input_hash ([vin .outpoint for vin in vins ], A_sum )
307332 pre_computed_labels = {
308333 (generate_label (b_scan , label ) * G ).get_bytes (False ).hex (): generate_label (b_scan , label ).hex ()
@@ -315,6 +340,7 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
315340 input_hash = input_hash ,
316341 outputs_to_check = outputs_to_check ,
317342 labels = pre_computed_labels ,
343+ expected = expected ,
318344 )
319345
320346 # Check that the private key is correct for the found output public key
0 commit comments