Introduce strict mode

This commit introduces strict mode, which is on by default.  Currently
strict mode only enforces client-to-server frame masking.  However,
in the future, it might enforce other parts of the RFC as well.

Closes #164
This commit is contained in:
Solly Ross 2015-05-12 12:56:36 -04:00
parent 461c52ed84
commit 2544dd3aaf
2 changed files with 95 additions and 3 deletions

View File

@ -343,3 +343,88 @@ class WebSocketServerTestCase(unittest.TestCase):
socket.TCP_KEEPIDLE), keepidle)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
res = websocket.WebSocketRequestHandler.decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x1)
self.assertEqual(res['masked'], True)
self.assertEqual(res['length'], 5)
self.assertEqual(res['payload'], b'Hello')
self.assertEqual(res['left'], 0)
def test_decode_hybi_binary(self):
buf = b'\x82\x04\x01\x02\x03\x04'
res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], 4)
self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
self.assertEqual(res['left'], 0)
def test_decode_hybi_extended_16bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7e\x01\x04' + data
res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], 260)
self.assertEqual(res['payload'], data)
self.assertEqual(res['left'], 0)
def test_decode_hybi_extended_64bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
self.assertEqual(res['length'], 260)
self.assertEqual(res['payload'], data)
self.assertEqual(res['left'], 0)
def test_decode_hybi_multi(self):
buf1 = b'\x01\x03\x48\x65\x6c'
buf2 = b'\x80\x02\x6c\x6f'
res1 = websocket.WebSocketRequestHandler.decode_hybi(buf1, strict=False)
self.assertEqual(res1['fin'], 0)
self.assertEqual(res1['opcode'], 0x1)
self.assertEqual(res1['length'], 3)
self.assertEqual(res1['payload'], b'Hel')
self.assertEqual(res1['left'], 0)
res2 = websocket.WebSocketRequestHandler.decode_hybi(buf2, strict=False)
self.assertEqual(res2['fin'], 1)
self.assertEqual(res2['opcode'], 0x0)
self.assertEqual(res2['length'], 2)
self.assertEqual(res2['payload'], b'lo')
self.assertEqual(res2['left'], 0)
def test_encode_hybi_basic(self):
res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1)
expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0)
self.assertEqual(res, expected)
def test_strict_mode_refuses_unmasked_client_frames(self):
buf = b'\x81\x05\x48\x65\x6c\x6c\x6f'
self.assertRaises(websocket.WebSocketRequestHandler.CClose,
websocket.WebSocketRequestHandler.decode_hybi,
buf)
def test_no_strict_mode_accepts_unmasked_client_frames(self):
buf = b'\x81\x05\x48\x65\x6c\x6c\x6f'
res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x1)
self.assertEqual(res['masked'], False)
self.assertEqual(res['length'], 5)
self.assertEqual(res['payload'], b'Hello')

View File

@ -105,6 +105,7 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
self.file_only = getattr(server, "file_only", False)
self.traffic = getattr(server, "traffic", False)
self.auto_pong = getattr(server, "auto_pong", False)
self.strict_mode = getattr(server, "strict_mode", True)
self.logger = getattr(server, "logger", None)
if self.logger is None:
@ -177,7 +178,7 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
return header + buf, len(header), 0
@staticmethod
def decode_hybi(buf, base64=False, logger=None):
def decode_hybi(buf, base64=False, logger=None, strict=True):
""" Decode HyBi style WebSocket packets.
Returns:
{'fin' : 0_or_1,
@ -243,6 +244,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
f['length'])
else:
logger.debug("Unmasked frame: %s" % repr(buf))
if strict:
raise WebSocketRequestHandler.CClose(1002, "The client sent an unmasked frame.")
f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len]
if base64 and f['opcode'] in [1, 2]:
@ -351,7 +356,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
while buf:
frame = self.decode_hybi(buf, base64=self.base64,
logger=self.logger)
logger=self.logger,
strict=self.strict_mode)
#self.msg("Received buf: %s, frame: %s", repr(buf), frame)
if frame['payload'] == None:
@ -591,7 +597,7 @@ class WebSocketServer(object):
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, auto_pong=False):
tcp_keepintvl=None, auto_pong=False, strict_mode=True):
# settings
self.RequestHandlerClass = RequestHandlerClass
@ -606,6 +612,7 @@ class WebSocketServer(object):
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.strict_mode = strict_mode
self.launch_time = time.time()
self.ws_connection = False