oslo.messaging/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py

194 lines
7.1 KiB
Python

# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import abc
import logging
import threading
import futurist
import six
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_version
from oslo_messaging._i18n import _LE
LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq()
def suppress_errors(func):
@six.wraps(func)
def silent_func(self, socket):
try:
return func(self, socket)
except Exception as e:
LOG.error(_LE("Receiving message failed: %r"), e)
# NOTE(gdavoian): drop the left parts of a broken message, since
# they most likely will lead to additional exceptions
if socket.getsockopt(zmq.RCVMORE):
socket.recv_multipart()
return silent_func
@six.add_metaclass(abc.ABCMeta)
class ReceiverBase(object):
"""Base response receiving interface."""
def __init__(self, conf):
self.conf = conf
self._lock = threading.Lock()
self._requests = {}
self._poller = zmq_async.get_poller()
self._receive_response_versions = \
zmq_version.get_method_versions(self, 'receive_response')
self._executor = zmq_async.get_executor(self._run_loop)
self._executor.execute()
def register_socket(self, socket):
"""Register a socket for receiving data."""
self._poller.register(socket, self.receive_response)
def unregister_socket(self, socket):
"""Unregister a socket from receiving data."""
self._poller.unregister(socket)
@abc.abstractmethod
def receive_response(self, socket):
"""Receive a response (ack or reply) and return it."""
def track_request(self, request):
"""Track a request via already registered sockets and return
a pair of ack and reply futures for monitoring all possible
types of responses for the given request.
"""
message_id = request.message_id
futures = self._get_futures(message_id)
if futures is None:
ack_future = reply_future = None
if self.conf.oslo_messaging_zmq.rpc_use_acks:
ack_future = futurist.Future()
if request.msg_type == zmq_names.CALL_TYPE:
reply_future = futurist.Future()
futures = (ack_future, reply_future)
self._set_futures(message_id, futures)
return futures
def untrack_request(self, request):
"""Untrack a request and stop monitoring any responses."""
self._pop_futures(request.message_id)
def stop(self):
self._poller.close()
self._executor.stop()
def _get_futures(self, message_id):
with self._lock:
return self._requests.get(message_id)
def _set_futures(self, message_id, futures):
with self._lock:
self._requests[message_id] = futures
def _pop_futures(self, message_id):
with self._lock:
return self._requests.pop(message_id, None)
def _run_loop(self):
response, socket = \
self._poller.poll(self.conf.oslo_messaging_zmq.rpc_poll_timeout)
if response is None:
return
message_type, message_id = response.msg_type, response.message_id
futures = self._get_futures(message_id)
if futures is not None:
ack_future, reply_future = futures
if message_type == zmq_names.REPLY_TYPE:
reply_future.set_result(response)
else:
ack_future.set_result(response)
LOG.debug("Received %(msg_type)s for %(msg_id)s",
{"msg_type": zmq_names.message_type_str(message_type),
"msg_id": message_id})
def _get_receive_response_version(self, version):
receive_response_version = self._receive_response_versions.get(version)
if receive_response_version is None:
raise zmq_version.UnsupportedMessageVersionError(version)
return receive_response_version
class ReceiverProxy(ReceiverBase):
@suppress_errors
def receive_response(self, socket):
empty = socket.recv()
assert empty == b'', "Empty delimiter expected!"
message_version = socket.recv_string()
assert message_version != b'', "Valid message version expected!"
receive_response_version = \
self._get_receive_response_version(message_version)
return receive_response_version(socket)
def _receive_response_v_1_0(self, socket):
reply_id = socket.recv()
assert reply_id != b'', "Valid reply id expected!"
message_type = int(socket.recv())
assert message_type in zmq_names.RESPONSE_TYPES, "Response expected!"
message_id = socket.recv_string()
assert message_id != '', "Valid message id expected!"
if message_type == zmq_names.REPLY_TYPE:
reply_body, failure = socket.recv_loaded()
reply = zmq_response.Reply(message_id=message_id,
reply_id=reply_id,
reply_body=reply_body,
failure=failure)
return reply
else:
ack = zmq_response.Ack(message_id=message_id,
reply_id=reply_id)
return ack
class ReceiverDirect(ReceiverBase):
@suppress_errors
def receive_response(self, socket):
empty = socket.recv()
assert empty == b'', "Empty delimiter expected!"
message_version = socket.recv_string()
assert message_version != b'', "Valid message version expected!"
receive_response_version = \
self._get_receive_response_version(message_version)
return receive_response_version(socket)
def _receive_response_v_1_0(self, socket):
message_type = int(socket.recv())
assert message_type in zmq_names.RESPONSE_TYPES, "Response expected!"
message_id = socket.recv_string()
assert message_id != '', "Valid message id expected!"
if message_type == zmq_names.REPLY_TYPE:
reply_body, failure = socket.recv_loaded()
reply = zmq_response.Reply(message_id=message_id,
reply_body=reply_body,
failure=failure)
return reply
else:
ack = zmq_response.Ack(message_id=message_id)
return ack