18
18
from collections import defaultdict
19
19
from functools import partial
20
20
from ipaddress import IPv4Address , IPv6Address , IPv4Network , IPv6Network
21
- from typing import Optional , TYPE_CHECKING , Tuple , Sequence , Set , Dict , Iterable , Any , Mapping
21
+ from typing import (Optional , TYPE_CHECKING , Tuple , Sequence , Set , Dict , Iterable , Any , Mapping ,
22
+ Callable )
22
23
import asyncio
23
24
24
25
import attr
25
26
from aiorpcx import (Event , JSONRPCAutoDetect , JSONRPCConnection ,
26
- ReplyAndDisconnect , Request , RPCError , RPCSession ,
27
+ ReplyAndDisconnect , Request , Notification , RPCError , RPCSession ,
27
28
handler_invocation , serve_rs , serve_ws , sleep ,
28
29
NewlineFramer , TaskTimeout , timeout_after , run_in_thread )
30
+ from aiorpcx .jsonrpc import SingleRequest
29
31
30
32
import electrumx
31
33
import electrumx .lib .util as util
@@ -938,6 +940,8 @@ class SessionBase(RPCSession):
938
940
MAX_CHUNK_SIZE = 2016
939
941
session_counter = itertools .count ()
940
942
log_new = False
943
+ request_handlers : Dict [str , Callable ]
944
+ notification_handlers : Dict [str , Callable ]
941
945
942
946
def __init__ (
943
947
self ,
@@ -1025,14 +1029,13 @@ def sub_count_txoutpoints(self):
1025
1029
def sub_count_total (self ):
1026
1030
return self .sub_count_scripthashes () + self .sub_count_txoutpoints ()
1027
1031
1028
- async def handle_request (self , request ):
1029
- '''Handle an incoming request. ElectrumX doesn't receive
1030
- notifications from client sessions.
1031
- '''
1032
+ async def handle_request (self , request : SingleRequest ):
1033
+ '''Handle an incoming request.'''
1034
+ handler = None
1032
1035
if isinstance (request , Request ):
1033
1036
handler = self .request_handlers .get (request .method )
1034
- else :
1035
- handler = None
1037
+ elif isinstance ( request , Notification ) :
1038
+ handler = self . notification_handlers . get ( request . method )
1036
1039
method = 'invalid method' if handler is None else request .method
1037
1040
1038
1041
# Version negotiation must happen before any other messages.
@@ -1642,12 +1645,25 @@ async def estimatefee(self, number, mode=None):
1642
1645
cache [(number , mode )] = (blockhash , feerate , lock )
1643
1646
return feerate
1644
1647
1645
- async def ping (self ):
1648
+ async def ping (self , pong_len = 0 , data = "" ):
1646
1649
'''Serves as a connection keep-alive mechanism and for the client to
1647
- confirm the server is still responding.
1650
+ confirm the server is still responding. It can also be used to obfuscate
1651
+ traffic patterns.
1648
1652
'''
1649
1653
self .bump_cost (0.1 )
1650
- return None
1654
+ if self .protocol_tuple < (1 , 6 ):
1655
+ return None
1656
+ assert_hex_str (data )
1657
+ pong_len = non_negative_integer (pong_len )
1658
+ if pong_len > self .env .max_send :
1659
+ raise RPCError (BAD_REQUEST , f'pong_len value too high' )
1660
+ pong_data = pong_len * "0"
1661
+ return {"data" : pong_data }
1662
+
1663
+ async def on_ping_notification (self , data = "" ):
1664
+ self .bump_cost (0.1 ) # note: the bandwidth cost for receiving 'data' has already been incurred
1665
+ assert_hex_str (data )
1666
+ # nothing to do
1651
1667
1652
1668
async def server_version (self , client_name = '' , protocol_version = None ):
1653
1669
'''Returns the server version as a string.
@@ -1856,6 +1872,7 @@ def set_request_handlers(self, ptuple):
1856
1872
'server.ping' : self .ping ,
1857
1873
'server.version' : self .server_version ,
1858
1874
}
1875
+ notif_handlers = {}
1859
1876
1860
1877
if ptuple < (1 , 6 ):
1861
1878
handlers ['blockchain.scripthash.get_balance' ] = self .scripthash_get_balance
@@ -1878,8 +1895,10 @@ def set_request_handlers(self, ptuple):
1878
1895
handlers ['blockchain.scriptpubkey.listunspent' ] = self .scriptpubkey_listunspent
1879
1896
handlers ['blockchain.scriptpubkey.subscribe' ] = self .scriptpubkey_subscribe
1880
1897
handlers ['blockchain.scriptpubkey.unsubscribe' ] = self .scriptpubkey_unsubscribe
1898
+ notif_handlers ['server.ping' ] = self .on_ping_notification
1881
1899
1882
1900
self .request_handlers = handlers
1901
+ self .notification_handlers = notif_handlers
1883
1902
1884
1903
1885
1904
class LocalRPC (SessionBase ):
@@ -1893,6 +1912,8 @@ def __init__(self, *args, **kwargs):
1893
1912
self .sv_negotiated .set ()
1894
1913
self .client = 'RPC'
1895
1914
self .connection .max_response_size = 0
1915
+ # note: self.request_handlers are set on the class, in SessionManager.__init__
1916
+ self .notification_handlers = {}
1896
1917
1897
1918
def protocol_version_string (self ):
1898
1919
return 'RPC'
0 commit comments