@@ -109,6 +109,10 @@ where
109109#[ cfg( test) ]
110110mod tests {
111111 use super :: * ;
112+ use jsonrpc_core:: { Call , Value } ;
113+ use std:: future:: Future ;
114+ use std:: pin:: Pin ;
115+ use web3:: RequestId ;
112116
113117 #[ test]
114118 fn test_caip2_parsing ( ) {
@@ -128,4 +132,121 @@ mod tests {
128132 "000000000019d6689c085ae165831e93"
129133 ) ;
130134 }
135+
136+ // Mock transport that returns a predefined chain ID
137+ #[ derive( Debug , Clone ) ]
138+ struct MockTransport {
139+ chain_id_response : String ,
140+ }
141+
142+ impl web3:: Transport for MockTransport {
143+ type Out = Pin < Box < dyn Future < Output = Result < Value , web3:: Error > > > > ;
144+
145+ fn prepare ( & self , method : & str , params : Vec < Value > ) -> ( RequestId , Call ) {
146+ let call = Call :: MethodCall ( jsonrpc_core:: MethodCall {
147+ jsonrpc : Some ( jsonrpc_core:: Version :: V2 ) ,
148+ method : method. to_string ( ) ,
149+ params : jsonrpc_core:: Params :: Array ( params) ,
150+ id : jsonrpc_core:: Id :: Num ( 1 ) ,
151+ } ) ;
152+ ( 1 , call)
153+ }
154+
155+ fn send ( & self , _id : RequestId , request : Call ) -> Self :: Out {
156+ let response = match request {
157+ Call :: MethodCall ( ref call) if call. method == "eth_chainId" => {
158+ Ok ( Value :: String ( self . chain_id_response . clone ( ) ) )
159+ }
160+ Call :: MethodCall ( ref call) => Err ( web3:: Error :: Decoder ( format ! (
161+ "Unexpected method: {}" ,
162+ call. method
163+ ) ) ) ,
164+ _ => Err ( web3:: Error :: Decoder ( "Invalid request type" . to_string ( ) ) ) ,
165+ } ;
166+
167+ Box :: pin ( futures:: future:: ready ( response) )
168+ }
169+ }
170+
171+ #[ tokio:: test]
172+ async fn test_validate_chain_id_success ( ) {
173+ let mock_transport = MockTransport {
174+ chain_id_response : "0x1" . to_string ( ) ,
175+ } ;
176+ let web3 = Web3 :: new ( mock_transport) ;
177+ let chain_id = Caip2ChainId :: from_str ( "eip155:1" ) . unwrap ( ) ;
178+
179+ let result = validate_chain_id ( & web3, & chain_id, "http://test.com" ) . await ;
180+ assert ! ( result. is_ok( ) ) ;
181+ }
182+
183+ #[ tokio:: test]
184+ async fn test_validate_chain_id_mismatch ( ) {
185+ let mock_transport = MockTransport {
186+ chain_id_response : "0x1" . to_string ( ) , // Returns mainnet (1)
187+ } ;
188+ let web3 = Web3 :: new ( mock_transport) ;
189+ let chain_id = Caip2ChainId :: from_str ( "eip155:42161" ) . unwrap ( ) ; // Expects Arbitrum (42161)
190+
191+ let result = validate_chain_id ( & web3, & chain_id, "http://test.com" ) . await ;
192+ assert ! ( result. is_err( ) ) ;
193+ let err_msg = result. unwrap_err ( ) . to_string ( ) ;
194+ assert ! ( err_msg. contains( "Chain ID mismatch" ) ) ;
195+ assert ! ( err_msg. contains( "returned chain ID 1" ) ) ;
196+ assert ! ( err_msg. contains( "expected 42161" ) ) ;
197+ }
198+
199+ #[ tokio:: test]
200+ async fn test_validate_chain_id_hex_variations ( ) {
201+ // Test with different hex formats
202+ let test_cases = vec ! [
203+ ( "0x1" , 1 ) , // 0x1
204+ ( "0x01" , 1 ) , // 0x01
205+ ( "0xa4b1" , 42161 ) , // 0xa4b1 (Arbitrum)
206+ ( "0xaa36a7" , 11155111 ) , // Sepolia
207+ ] ;
208+
209+ for ( hex_response, expected_id) in test_cases {
210+ let mock_transport = MockTransport {
211+ chain_id_response : hex_response. to_string ( ) ,
212+ } ;
213+ let web3 = Web3 :: new ( mock_transport) ;
214+ let chain_id = Caip2ChainId :: from_str ( & format ! ( "eip155:{}" , expected_id) ) . unwrap ( ) ;
215+
216+ let result = validate_chain_id ( & web3, & chain_id, "http://test.com" ) . await ;
217+ assert ! (
218+ result. is_ok( ) ,
219+ "Failed for hex {} expecting {}" ,
220+ hex_response,
221+ expected_id
222+ ) ;
223+ }
224+ }
225+
226+ #[ tokio:: test]
227+ async fn test_validate_chain_id_skips_non_evm ( ) {
228+ // Non-EVM chains should be skipped
229+ let mock_transport = MockTransport {
230+ chain_id_response : "should_not_be_called" . to_string ( ) ,
231+ } ;
232+ let web3 = Web3 :: new ( mock_transport) ;
233+ let chain_id = Caip2ChainId :: from_str ( "bip122:000000000019d6689c085ae165831e93" ) . unwrap ( ) ;
234+
235+ let result = validate_chain_id ( & web3, & chain_id, "http://test.com" ) . await ;
236+ assert ! ( result. is_ok( ) ) ;
237+ }
238+
239+ #[ tokio:: test]
240+ async fn test_validate_chain_id_invalid_hex ( ) {
241+ let mock_transport = MockTransport {
242+ chain_id_response : "invalid_hex" . to_string ( ) ,
243+ } ;
244+ let web3 = Web3 :: new ( mock_transport) ;
245+ let chain_id = Caip2ChainId :: from_str ( "eip155:1" ) . unwrap ( ) ;
246+
247+ let result = validate_chain_id ( & web3, & chain_id, "http://test.com" ) . await ;
248+ assert ! ( result. is_err( ) ) ;
249+ let err_msg = result. unwrap_err ( ) . to_string ( ) ;
250+ assert ! ( err_msg. contains( "Failed to parse chain ID hex" ) ) ;
251+ }
131252}
0 commit comments