diff --git a/greentulip/socket.py b/greentulip/socket.py index 78a7668..d364042 100644 --- a/greentulip/socket.py +++ b/greentulip/socket.py @@ -13,6 +13,7 @@ from socket import error, SOCK_STREAM from socket import socket as std_socket from . import yield_from +from . import GreenUnixSelectorLoop class socket: @@ -24,6 +25,8 @@ class socket: self._sock = std_socket(*args, **kwargs) self._sock.setblocking(False) self._loop = tulip.get_event_loop() + assert isinstance(self._loop, GreenUnixSelectorLoop), \ + 'GreenUnixSelectorLoop event loop is required' @classmethod def from_socket(cls, sock): @@ -93,12 +96,11 @@ class socket: return self.__class__.from_socket(sock), addr @_copydoc - def makefile(self, *args, **kwargs): - if args: - if args[0] == 'rb': - return ReadFile(self._loop, self._sock) - elif args[0] == 'wb': - return WriteFile(self._loop, self._sock) + def makefile(self, mode, *args, **kwargs): + if mode == 'rb': + return ReadFile(self._loop, self._sock) + elif mode == 'wb': + return WriteFile(self._loop, self._sock) raise NotImplementedError bind = _proxy('bind') @@ -135,6 +137,15 @@ class ReadFile: res = fut.result() self._buf.extend(res) + if size <= len(self._buf): + data = self._buf[:size] + del self._buf[:size] + return data + else: + data = self._buf[:] + del self._buf[:] + return data + def close(self): pass @@ -144,7 +155,6 @@ class WriteFile: def __init__(self, loop, sock): self._loop = loop self._sock = sock - self._buf = bytearray() def write(self, data): fut = self._loop.sock_sendall(self._sock, data) diff --git a/tests/test_socket.py b/tests/test_socket.py index 7c635a5..29761e7 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,12 +2,12 @@ # Copyright (c) 2013 Yury Selivanov # License: Apache 2.0 ## -import greentulip -import greentulip.socket as greensocket - import tulip import unittest +import greentulip +import greentulip.socket as greensocket + class SocketTests(unittest.TestCase): @@ -20,6 +20,11 @@ class SocketTests(unittest.TestCase): self.loop.close() tulip.set_event_loop_policy(None) + def test_socket_wrong_event_loop(self): + loop = tulip.DefaultEventLoopPolicy().new_event_loop() + tulip.set_event_loop(loop) + self.assertRaises(AssertionError, greensocket.socket) + def test_socket_docs(self): self.assertIn('accept connections', greensocket.socket.listen.__doc__) self.assertIn('Receive', greensocket.socket.recv.__doc__) @@ -109,3 +114,74 @@ class SocketTests(unittest.TestCase): greentulip.task(client)(greensocket.socket)) thread.join(1) self.assertEqual(check, 2) + + def test_files_socket_echo(self): + import socket as std_socket + import threading + import time + + check = 0 + ev = threading.Event() + + def server(sock_factory): + socket = sock_factory() + socket.bind(('127.0.0.1', 0)) + + assert socket.fileno() is not None + + nonlocal addr + addr = socket.getsockname() + socket.listen(1) + + ev.set() + + sock, client_addrs = socket.accept() + assert isinstance(sock, sock_factory) + + rfile = sock.makefile('rb') + data = rfile.read(1024) + while not data.endswith(b'\r'): + data += rfile.read(1024) + + wfile = sock.makefile('wb') + wfile.write(data) + + ev.wait() + ev.clear() + + sock.close() + socket.close() + + def client(sock_factory): + ev.wait() + ev.clear() + time.sleep(0.1) + + assert addr + sock = sock_factory() + sock.connect(addr) + + data = b'hello greenlets\r' + sock.sendall(data) + + rep = b'' + while not rep.endswith(b'\r'): + rep += sock.recv(1024) + + self.assertEqual(data, rep) + ev.set() + + nonlocal check + check += 1 + + sock.close() + + addr = None + ev.clear() + thread = threading.Thread(target=client, args=(std_socket.socket,)) + thread.setDaemon(True) + thread.start() + self.loop.run_until_complete( + greentulip.task(server)(greensocket.socket)) + thread.join(1) + self.assertEqual(check, 1)