diff --git a/tests/test_websocket.py b/tests/test_websocket.py index acd7699..9d538db 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -343,3 +343,88 @@ class WebSocketServerTestCase(unittest.TestCase): socket.TCP_KEEPIDLE), keepidle) self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl) + + +class HyBiEncodeDecodeTestCase(unittest.TestCase): + def test_decode_hybi_text(self): + buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' + res = websocket.WebSocketRequestHandler.decode_hybi(buf) + + self.assertEqual(res['fin'], 1) + self.assertEqual(res['opcode'], 0x1) + self.assertEqual(res['masked'], True) + self.assertEqual(res['length'], 5) + self.assertEqual(res['payload'], b'Hello') + self.assertEqual(res['left'], 0) + + def test_decode_hybi_binary(self): + buf = b'\x82\x04\x01\x02\x03\x04' + res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + + self.assertEqual(res['fin'], 1) + self.assertEqual(res['opcode'], 0x2) + self.assertEqual(res['length'], 4) + self.assertEqual(res['payload'], b'\x01\x02\x03\x04') + self.assertEqual(res['left'], 0) + + def test_decode_hybi_extended_16bit_binary(self): + data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 + buf = b'\x82\x7e\x01\x04' + data + res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + + self.assertEqual(res['fin'], 1) + self.assertEqual(res['opcode'], 0x2) + self.assertEqual(res['length'], 260) + self.assertEqual(res['payload'], data) + self.assertEqual(res['left'], 0) + + def test_decode_hybi_extended_64bit_binary(self): + data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 + buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data + res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + + self.assertEqual(res['fin'], 1) + self.assertEqual(res['opcode'], 0x2) + self.assertEqual(res['length'], 260) + self.assertEqual(res['payload'], data) + self.assertEqual(res['left'], 0) + + def test_decode_hybi_multi(self): + buf1 = b'\x01\x03\x48\x65\x6c' + buf2 = b'\x80\x02\x6c\x6f' + + res1 = websocket.WebSocketRequestHandler.decode_hybi(buf1, strict=False) + self.assertEqual(res1['fin'], 0) + self.assertEqual(res1['opcode'], 0x1) + self.assertEqual(res1['length'], 3) + self.assertEqual(res1['payload'], b'Hel') + self.assertEqual(res1['left'], 0) + + res2 = websocket.WebSocketRequestHandler.decode_hybi(buf2, strict=False) + self.assertEqual(res2['fin'], 1) + self.assertEqual(res2['opcode'], 0x0) + self.assertEqual(res2['length'], 2) + self.assertEqual(res2['payload'], b'lo') + self.assertEqual(res2['left'], 0) + + def test_encode_hybi_basic(self): + res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1) + expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0) + + self.assertEqual(res, expected) + + def test_strict_mode_refuses_unmasked_client_frames(self): + buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' + self.assertRaises(websocket.WebSocketRequestHandler.CClose, + websocket.WebSocketRequestHandler.decode_hybi, + buf) + + def test_no_strict_mode_accepts_unmasked_client_frames(self): + buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' + res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + + self.assertEqual(res['fin'], 1) + self.assertEqual(res['opcode'], 0x1) + self.assertEqual(res['masked'], False) + self.assertEqual(res['length'], 5) + self.assertEqual(res['payload'], b'Hello') diff --git a/websockify/websocket.py b/websockify/websocket.py index 727413a..1cbf583 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -105,6 +105,7 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): self.file_only = getattr(server, "file_only", False) self.traffic = getattr(server, "traffic", False) self.auto_pong = getattr(server, "auto_pong", False) + self.strict_mode = getattr(server, "strict_mode", True) self.logger = getattr(server, "logger", None) if self.logger is None: @@ -177,7 +178,7 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): return header + buf, len(header), 0 @staticmethod - def decode_hybi(buf, base64=False, logger=None): + def decode_hybi(buf, base64=False, logger=None, strict=True): """ Decode HyBi style WebSocket packets. Returns: {'fin' : 0_or_1, @@ -243,6 +244,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): f['length']) else: logger.debug("Unmasked frame: %s" % repr(buf)) + + if strict: + raise WebSocketRequestHandler.CClose(1002, "The client sent an unmasked frame.") + f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] if base64 and f['opcode'] in [1, 2]: @@ -351,7 +356,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler): while buf: frame = self.decode_hybi(buf, base64=self.base64, - logger=self.logger) + logger=self.logger, + strict=self.strict_mode) #self.msg("Received buf: %s, frame: %s", repr(buf), frame) if frame['payload'] == None: @@ -591,7 +597,7 @@ class WebSocketServer(object): file_only=False, run_once=False, timeout=0, idle_timeout=0, traffic=False, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, - tcp_keepintvl=None, auto_pong=False): + tcp_keepintvl=None, auto_pong=False, strict_mode=True): # settings self.RequestHandlerClass = RequestHandlerClass @@ -606,6 +612,7 @@ class WebSocketServer(object): self.idle_timeout = idle_timeout self.traffic = traffic self.file_only = file_only + self.strict_mode = strict_mode self.launch_time = time.time() self.ws_connection = False