Skip to content

Commit 9951ec1

Browse files
authored
Merge pull request #419 from tomato42/faster-prf
Faster handshake
2 parents e8da6cf + a005589 commit 9951ec1

File tree

10 files changed

+208
-139
lines changed

10 files changed

+208
-139
lines changed

tlslite/bufferedsocket.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, socket):
2525
self.socket = socket
2626
self._write_queue = deque()
2727
self.buffer_writes = False
28+
self._read_buffer = bytearray()
2829

2930
def send(self, data):
3031
"""Send data to the socket"""
@@ -51,7 +52,11 @@ def flush(self):
5152

5253
def recv(self, bufsize):
5354
"""Receive data from socket (socket emulation)"""
54-
return self.socket.recv(bufsize)
55+
if not self._read_buffer:
56+
self._read_buffer += self.socket.recv(max(4096, bufsize))
57+
ret = self._read_buffer[:bufsize]
58+
del self._read_buffer[:bufsize]
59+
return ret
5560

5661
def getsockname(self):
5762
"""Return the socket's own address (socket emulation)."""

tlslite/defragmenter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def size_handler(data):
8989
else:
9090
parser = Parser(data)
9191
# skip the header
92-
parser.getFixBytes(size_offset)
92+
parser.skip_bytes(size_offset)
9393

9494
payload_length = parser.get(size_of_size)
9595
if parser.getRemainingLength() < payload_length:
@@ -110,14 +110,15 @@ def add_data(self, msg_type, data):
110110
def get_message(self):
111111
"""Extract the highest priority complete message from buffer"""
112112
for msg_type in self.priorities:
113-
length = self.decoders[msg_type](self.buffers[msg_type])
113+
buf = self.buffers[msg_type]
114+
length = self.decoders[msg_type](buf)
114115
if length is None:
115116
continue
116117

117118
# extract message
118-
data = self.buffers[msg_type][:length]
119+
data = buf[:length]
119120
# remove it from buffer
120-
self.buffers[msg_type] = self.buffers[msg_type][length:]
121+
del buf[:length]
121122
return (msg_type, data)
122123
return None
123124

tlslite/mathtls.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -676,19 +676,27 @@ def paramStrength(param):
676676
return 256 # NIST SP 800-57
677677

678678

679-
def P_hash(macFunc, secret, seed, length):
680-
bytes = bytearray(length)
679+
def P_hash(mac_name, secret, seed, length):
680+
"""Internal method for calculation the PRF in TLS."""
681+
ret = bytearray(length)
682+
seed = compatHMAC(seed)
681683
A = seed
682684
index = 0
683-
while 1:
684-
A = macFunc(secret, A)
685-
output = macFunc(secret, A + seed)
686-
for c in output:
687-
if index >= length:
688-
return bytes
689-
bytes[index] = c
690-
index += 1
691-
return bytes
685+
mac = hmac.HMAC(compatHMAC(secret), digestmod=mac_name)
686+
while index < length:
687+
a_fun = mac.copy()
688+
a_fun.update(A)
689+
A = a_fun.digest()
690+
out_fun = mac.copy()
691+
out_fun.update(A)
692+
out_fun.update(seed)
693+
output = out_fun.digest()
694+
695+
how_many = min(length - index, len(output))
696+
ret[index:index+how_many] = output[:how_many]
697+
index += how_many
698+
return ret
699+
692700

693701
def PRF(secret, label, seed, length):
694702
#Split the secret into left and right halves
@@ -697,8 +705,8 @@ def PRF(secret, label, seed, length):
697705
S2 = secret[ int(math.floor(len(secret)/2.0)) : ]
698706

699707
#Run the left half through P_MD5 and the right half through P_SHA1
700-
p_md5 = P_hash(HMAC_MD5, S1, label + seed, length)
701-
p_sha1 = P_hash(HMAC_SHA1, S2, label + seed, length)
708+
p_md5 = P_hash("md5", S1, label + seed, length)
709+
p_sha1 = P_hash("sha1", S2, label + seed, length)
702710

703711
#XOR the output values and return the result
704712
for x in range(length):
@@ -707,11 +715,11 @@ def PRF(secret, label, seed, length):
707715

708716
def PRF_1_2(secret, label, seed, length):
709717
"""Pseudo Random Function for TLS1.2 ciphers that use SHA256"""
710-
return P_hash(HMAC_SHA256, secret, label + seed, length)
718+
return P_hash("sha256", secret, label + seed, length)
711719

712720
def PRF_1_2_SHA384(secret, label, seed, length):
713721
"""Pseudo Random Function for TLS1.2 ciphers that use SHA384"""
714-
return P_hash(HMAC_SHA384, secret, label + seed, length)
722+
return P_hash("sha384", secret, label + seed, length)
715723

716724
def PRF_SSL(secret, seed, length):
717725
bytes = bytearray(length)

tlslite/utils/asn1parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def getChildCount(self):
8585
while True:
8686
if p.getRemainingLength() == 0:
8787
break
88-
p.get(1) # skip Type
88+
p.skip_bytes(1) # skip Type
8989
length = self._getASN1Length(p)
90-
p.getFixBytes(length) # skip value
90+
p.skip_bytes(length) # skip value
9191
count += 1
9292
return count
9393

@@ -104,9 +104,9 @@ def getChildBytes(self, which):
104104
p = Parser(self.value)
105105
for _ in range(which+1):
106106
markIndex = p.index
107-
p.get(1) #skip Type
107+
p.skip_bytes(1) # skip Type
108108
length = self._getASN1Length(p)
109-
p.getFixBytes(length)
109+
p.skip_bytes(length)
110110
return p.bytes[markIndex : p.index]
111111

112112
@staticmethod

tlslite/utils/codec.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
import struct
1010
from struct import pack
11+
from .compat import bytes_to_int
1112

1213

1314
class DecodeError(SyntaxError):
@@ -305,14 +306,8 @@ def get(self, length):
305306
306307
:rtype: int
307308
"""
308-
if self.index + length > len(self.bytes):
309-
raise DecodeError("Read past end of buffer")
310-
x = 0
311-
for _ in range(length):
312-
x <<= 8
313-
x |= self.bytes[self.index]
314-
self.index += 1
315-
return x
309+
ret = self.getFixBytes(length)
310+
return bytes_to_int(ret, 'big')
316311

317312
def getFixBytes(self, lengthBytes):
318313
"""
@@ -323,11 +318,18 @@ def getFixBytes(self, lengthBytes):
323318
324319
:rtype: bytearray
325320
"""
326-
if self.index + lengthBytes > len(self.bytes):
321+
end = self.index + lengthBytes
322+
if end > len(self.bytes):
327323
raise DecodeError("Read past end of buffer")
328-
bytes = self.bytes[self.index : self.index+lengthBytes]
324+
ret = self.bytes[self.index : end]
329325
self.index += lengthBytes
330-
return bytes
326+
return ret
327+
328+
def skip_bytes(self, length):
329+
"""Move the internal pointer ahead length bytes."""
330+
if self.index + length > len(self.bytes):
331+
raise DecodeError("Read past end of buffer")
332+
self.index += length
331333

332334
def getVarBytes(self, lengthLength):
333335
"""

tlslite/utils/compat.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@
1616
if sys.version_info >= (3,0):
1717

1818
def compat26Str(x): return x
19-
20-
# Python 3 requires bytes instead of bytearrays for HMAC
21-
19+
20+
# Python 3.3 requires bytes instead of bytearrays for HMAC
2221
# So, python 2.6 requires strings, python 3 requires 'bytes',
23-
# and python 2.7 can handle bytearrays...
24-
def compatHMAC(x): return bytes(x)
22+
# and python 2.7 and 3.5 can handle bytearrays...
23+
# pylint: disable=invalid-name
24+
# we need to keep compatHMAC and `x` for API compatibility
25+
if sys.version_info < (3, 4):
26+
def compatHMAC(x):
27+
"""Convert bytes-like input to format acceptable for HMAC."""
28+
return bytes(x)
29+
else:
30+
def compatHMAC(x):
31+
"""Convert bytes-like input to format acceptable for HMAC."""
32+
return x
33+
# pylint: enable=invalid-name
2534

2635
def compatAscii2Bytes(val):
2736
"""Convert ASCII string to bytes."""
@@ -80,6 +89,25 @@ def remove_whitespace(text):
8089
"""Removes all whitespace from passed in string"""
8190
return re.sub(r"\s+", "", text, flags=re.UNICODE)
8291

92+
# pylint: disable=invalid-name
93+
# pylint is stupid here and deson't notice it's a function, not
94+
# constant
95+
bytes_to_int = int.from_bytes
96+
# pylint: enable=invalid-name
97+
98+
def bit_length(val):
99+
"""Return number of bits necessary to represent an integer."""
100+
return val.bit_length()
101+
102+
def int_to_bytes(val, length=None, byteorder="big"):
103+
"""Return number converted to bytes"""
104+
if length is None:
105+
length = byte_length(val)
106+
# for gmpy we need to convert back to native int
107+
if type(val) != int:
108+
val = int(val)
109+
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
110+
83111
else:
84112
# Python 2.6 requires strings instead of bytearrays in a couple places,
85113
# so we define this function so it does the conversion if needed.
@@ -92,13 +120,23 @@ def compat26Str(x): return str(x)
92120
def remove_whitespace(text):
93121
"""Removes all whitespace from passed in string"""
94122
return re.sub(r"\s+", "", text)
123+
124+
def bit_length(val):
125+
"""Return number of bits necessary to represent an integer."""
126+
if val == 0:
127+
return 0
128+
return len(bin(val))-2
95129
else:
96130
def compat26Str(x): return x
97131

98132
def remove_whitespace(text):
99133
"""Removes all whitespace from passed in string"""
100134
return re.sub(r"\s+", "", text, flags=re.UNICODE)
101135

136+
def bit_length(val):
137+
"""Return number of bits necessary to represent an integer."""
138+
return val.bit_length()
139+
102140
def compatAscii2Bytes(val):
103141
"""Convert ASCII string to bytes."""
104142
return val
@@ -147,6 +185,35 @@ def time_stamp():
147185
"""Returns system time as a float"""
148186
return time.clock()
149187

188+
def bytes_to_int(val, byteorder):
189+
"""Convert bytes to an int."""
190+
if not val:
191+
return 0
192+
if byteorder == "big":
193+
return int(b2a_hex(val), 16)
194+
if byteorder == "little":
195+
return int(b2a_hex(val[::-1]), 16)
196+
raise ValueError("Only 'big' and 'little' endian supported")
197+
198+
def int_to_bytes(val, length=None, byteorder="big"):
199+
"""Return number converted to bytes"""
200+
if length is None:
201+
length = byte_length(val)
202+
if byteorder == "big":
203+
return bytearray((val >> i) & 0xff
204+
for i in reversed(range(0, length*8, 8)))
205+
if byteorder == "little":
206+
return bytearray((val >> i) & 0xff
207+
for i in range(0, length*8, 8))
208+
raise ValueError("Only 'big' or 'little' endian supported")
209+
210+
211+
def byte_length(val):
212+
"""Return number of bytes necessary to represent an integer."""
213+
length = bit_length(val)
214+
return (length + 7) // 8
215+
216+
150217
try:
151218
# Fedora and Red Hat Enterprise Linux versions have small curves removed
152219
getattr(ecdsa, 'NIST192p')

tlslite/utils/cryptomath.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import math
1414
import base64
1515
import binascii
16-
import sys
1716

18-
from .compat import compat26Str, compatHMAC, compatLong, b2a_hex
17+
from .compat import compat26Str, compatHMAC, compatLong, \
18+
bytes_to_int, int_to_bytes, bit_length, byte_length
1919
from .codec import Writer
2020

2121
from . import tlshashlib as hashlib
@@ -204,18 +204,8 @@ def bytesToNumber(b, endian="big"):
204204
205205
By default assumes big-endian encoding of the number.
206206
"""
207-
# if string is empty, consider it to be representation of zero
208-
# while it may be a bit unorthodox, it is the inverse of numberToByteArray
209-
# with default parameters
210-
if not b:
211-
return 0
207+
return bytes_to_int(b, endian)
212208

213-
if endian == "big":
214-
return int(b2a_hex(b), 16)
215-
elif endian == "little":
216-
return int(b2a_hex(b[::-1]), 16)
217-
else:
218-
raise ValueError("Only 'big' and 'little' endian supported")
219209

220210
def numberToByteArray(n, howManyBytes=None, endian="big"):
221211
"""
@@ -225,16 +215,14 @@ def numberToByteArray(n, howManyBytes=None, endian="big"):
225215
not be larger. The returned bytearray will contain a big- or little-endian
226216
encoding of the input integer (n). Big endian encoding is used by default.
227217
"""
228-
if howManyBytes == None:
229-
howManyBytes = numBytes(n)
230-
if endian == "big":
231-
return bytearray((n >> i) & 0xff
232-
for i in reversed(range(0, howManyBytes*8, 8)))
233-
elif endian == "little":
234-
return bytearray((n >> i) & 0xff
235-
for i in range(0, howManyBytes*8, 8))
236-
else:
237-
raise ValueError("Only 'big' and 'little' endian supported")
218+
if howManyBytes is not None:
219+
length = byte_length(n)
220+
if howManyBytes < length:
221+
ret = int_to_bytes(n, length, endian)
222+
if endian == "big":
223+
return ret[length-howManyBytes:length]
224+
return ret[:howManyBytes]
225+
return int_to_bytes(n, howManyBytes, endian)
238226

239227

240228
def mpiToNumber(mpi):
@@ -265,23 +253,16 @@ def numberToMPI(n):
265253
# Misc. Utility Functions
266254
# **************************************************************************
267255

268-
def numBits(n):
269-
"""Return number of bits necessary to represent the integer in binary"""
270-
if n==0:
271-
return 0
272-
if sys.version_info < (2, 7):
273-
# bit_length() was introduced in 2.7, and it is an order of magnitude
274-
# faster than the below code
275-
return len(bin(n))-2
276-
else:
277-
return n.bit_length()
278256

279-
def numBytes(n):
280-
"""Return number of bytes necessary to represent the integer in bytes"""
281-
if n==0:
282-
return 0
283-
bits = numBits(n)
284-
return (bits + 7) // 8
257+
# pylint: disable=invalid-name
258+
# pylint recognises them as constants, not function names, also
259+
# we can't change their names without API change
260+
numBits = bit_length
261+
262+
263+
numBytes = byte_length
264+
# pylint: enable=invalid-name
265+
285266

286267
# **************************************************************************
287268
# Big Number Math

0 commit comments

Comments
 (0)