Merge pull request #671 from dpkp/disconnects

Improve socket disconnect handling
This commit is contained in:
Dana Powers 2016-04-25 22:33:19 -07:00
commit 5b393ac2b5
3 changed files with 87 additions and 24 deletions

View File

@ -142,6 +142,7 @@ class KafkaClient(object):
# Exponential backoff if bootstrap fails
backoff_ms = self.config['reconnect_backoff_ms'] * 2 ** self._bootstrap_fails
next_at = self._last_bootstrap + backoff_ms / 1000.0
self._refresh_on_disconnects = False
now = time.time()
if next_at > now:
log.debug("Sleeping %0.4f before bootstrapping again", next_at - now)
@ -180,6 +181,7 @@ class KafkaClient(object):
log.error('Unable to bootstrap from %s', hosts)
# Max exponential backoff is 2^12, x4000 (50ms -> 200s)
self._bootstrap_fails = min(self._bootstrap_fails + 1, 12)
self._refresh_on_disconnects = True
def _can_connect(self, node_id):
if node_id not in self._conns:
@ -223,7 +225,7 @@ class KafkaClient(object):
except KeyError:
pass
if self._refresh_on_disconnects:
log.warning("Node %s connect failed -- refreshing metadata", node_id)
log.warning("Node %s connection failed -- refreshing metadata", node_id)
self.cluster.request_update()
def _maybe_connect(self, node_id):

View File

@ -381,9 +381,17 @@ class BrokerConnection(object):
# Not receiving is the state of reading the payload header
if not self._receiving:
try:
# An extremely small, but non-zero, probability that there are
# more than 0 but not yet 4 bytes available to read
self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
bytes_to_read = 4 - self._rbuffer.tell()
data = self._sock.recv(bytes_to_read)
# We expect socket.recv to raise an exception if there is not
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
return None
self._rbuffer.write(data)
except ssl.SSLWantReadError:
return None
except ConnectionError as e:
@ -411,7 +419,17 @@ class BrokerConnection(object):
if self._receiving:
staged_bytes = self._rbuffer.tell()
try:
self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
bytes_to_read = self._next_payload_bytes - staged_bytes
data = self._sock.recv(bytes_to_read)
# We expect socket.recv to raise an exception if there is not
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
return None
self._rbuffer.write(data)
except ssl.SSLWantReadError:
return None
except ConnectionError as e:

View File

@ -2,6 +2,7 @@
from __future__ import absolute_import
from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET
import socket
import time
import pytest
@ -14,7 +15,7 @@ import kafka.common as Errors
@pytest.fixture
def socket(mocker):
def _socket(mocker):
socket = mocker.MagicMock()
socket.connect_ex.return_value = 0
mocker.patch('socket.socket', return_value=socket)
@ -22,9 +23,8 @@ def socket(mocker):
@pytest.fixture
def conn(socket):
from socket import AF_INET
conn = BrokerConnection('localhost', 9092, AF_INET)
def conn(_socket):
conn = BrokerConnection('localhost', 9092, socket.AF_INET)
return conn
@ -38,23 +38,23 @@ def conn(socket):
([EALREADY], ConnectionStates.CONNECTING),
([EISCONN], ConnectionStates.CONNECTED)),
])
def test_connect(socket, conn, states):
def test_connect(_socket, conn, states):
assert conn.state is ConnectionStates.DISCONNECTED
for errno, state in states:
socket.connect_ex.side_effect = errno
_socket.connect_ex.side_effect = errno
conn.connect()
assert conn.state is state
def test_connect_timeout(socket, conn):
def test_connect_timeout(_socket, conn):
assert conn.state is ConnectionStates.DISCONNECTED
# Initial connect returns EINPROGRESS
# immediate inline connect returns EALREADY
# second explicit connect returns EALREADY
# third explicit connect returns EALREADY and times out via last_attempt
socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
_socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
conn.connect()
assert conn.state is ConnectionStates.CONNECTING
conn.connect()
@ -108,7 +108,7 @@ def test_send_max_ifr(conn):
assert isinstance(f.exception, Errors.TooManyInFlightRequests)
def test_send_no_response(socket, conn):
def test_send_no_response(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
@ -116,7 +116,7 @@ def test_send_no_response(socket, conn):
payload_bytes = len(header.encode()) + len(req.encode())
third = payload_bytes // 3
remainder = payload_bytes % 3
socket.send.side_effect = [4, third, third, third, remainder]
_socket.send.side_effect = [4, third, third, third, remainder]
assert len(conn.in_flight_requests) == 0
f = conn.send(req, expect_response=False)
@ -125,7 +125,7 @@ def test_send_no_response(socket, conn):
assert len(conn.in_flight_requests) == 0
def test_send_response(socket, conn):
def test_send_response(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
@ -133,7 +133,7 @@ def test_send_response(socket, conn):
payload_bytes = len(header.encode()) + len(req.encode())
third = payload_bytes // 3
remainder = payload_bytes % 3
socket.send.side_effect = [4, third, third, third, remainder]
_socket.send.side_effect = [4, third, third, third, remainder]
assert len(conn.in_flight_requests) == 0
f = conn.send(req)
@ -141,20 +141,18 @@ def test_send_response(socket, conn):
assert len(conn.in_flight_requests) == 1
def test_send_error(socket, conn):
def test_send_error(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
header = RequestHeader(req, client_id=conn.config['client_id'])
try:
error = ConnectionError
_socket.send.side_effect = ConnectionError
except NameError:
from socket import error
socket.send.side_effect = error
_socket.send.side_effect = socket.error
f = conn.send(req)
assert f.failed() is True
assert isinstance(f.exception, Errors.ConnectionError)
assert socket.close.call_count == 1
assert _socket.close.call_count == 1
assert conn.state is ConnectionStates.DISCONNECTED
@ -167,7 +165,52 @@ def test_can_send_more(conn):
assert conn.can_send_more() is False
def test_recv(socket, conn):
def test_recv_disconnected():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.listen(5)
conn = BrokerConnection('127.0.0.1', port, socket.AF_INET)
timeout = time.time() + 1
while time.time() < timeout:
conn.connect()
if conn.connected():
break
else:
assert False, 'Connection attempt to local socket timed-out ?'
conn.send(MetadataRequest[0]([]))
# Disconnect server socket
sock.close()
# Attempt to receive should mark connection as disconnected
assert conn.connected()
conn.recv()
assert conn.disconnected()
def test_recv_disconnected_too(_socket, conn):
conn.connect()
assert conn.connected()
req = MetadataRequest[0]([])
header = RequestHeader(req, client_id=conn.config['client_id'])
payload_bytes = len(header.encode()) + len(req.encode())
_socket.send.side_effect = [4, payload_bytes]
conn.send(req)
# Empty data on recv means the socket is disconnected
_socket.recv.return_value = b''
# Attempt to receive should mark connection as disconnected
assert conn.connected()
conn.recv()
assert conn.disconnected()
def test_recv(_socket, conn):
pass # TODO