diff --git a/oslo/messaging/notify/dispatcher.py b/oslo/messaging/notify/dispatcher.py index 549e2829e..de04a8708 100644 --- a/oslo/messaging/notify/dispatcher.py +++ b/oslo/messaging/notify/dispatcher.py @@ -67,25 +67,26 @@ class NotificationDispatcher(object): pool=self.pool) @contextlib.contextmanager - def __call__(self, incoming): + def __call__(self, incoming, executor_callback=None): result_wrapper = [] yield lambda: result_wrapper.append( - self._dispatch_and_handle_error(incoming)) + self._dispatch_and_handle_error(incoming, executor_callback)) if result_wrapper[0] == NotificationResult.HANDLED: incoming.acknowledge() else: incoming.requeue() - def _dispatch_and_handle_error(self, incoming): + def _dispatch_and_handle_error(self, incoming, executor_callback): """Dispatch a notification message to the appropriate endpoint method. :param incoming: the incoming notification message :type ctxt: IncomingMessage """ try: - return self._dispatch(incoming.ctxt, incoming.message) + return self._dispatch(incoming.ctxt, incoming.message, + executor_callback) except Exception: # sys.exc_info() is deleted by LOG.exception(). exc_info = sys.exc_info() @@ -93,7 +94,7 @@ class NotificationDispatcher(object): exc_info=exc_info) return NotificationResult.HANDLED - def _dispatch(self, ctxt, message): + def _dispatch(self, ctxt, message, executor_callback=None): """Dispatch an RPC message to the appropriate endpoint method. :param ctxt: the request context @@ -120,8 +121,12 @@ class NotificationDispatcher(object): for callback in self._callbacks_by_priority.get(priority, []): localcontext.set_local_context(ctxt) try: - ret = callback(ctxt, publisher_id, event_type, payload, - metadata) + if executor_callback: + ret = executor_callback(callback, ctxt, publisher_id, + event_type, payload, metadata) + else: + ret = callback(ctxt, publisher_id, event_type, payload, + metadata) ret = NotificationResult.HANDLED if ret is None else ret if self.allow_requeue and ret == NotificationResult.REQUEUE: return ret diff --git a/oslo/messaging/rpc/dispatcher.py b/oslo/messaging/rpc/dispatcher.py index 4b4eeafc2..3b2941aa7 100644 --- a/oslo/messaging/rpc/dispatcher.py +++ b/oslo/messaging/rpc/dispatcher.py @@ -118,23 +118,28 @@ class RPCDispatcher(object): endpoint_version = target.version or '1.0' return utils.version_is_compatible(endpoint_version, version) - def _do_dispatch(self, endpoint, method, ctxt, args): + def _do_dispatch(self, endpoint, method, ctxt, args, executor_callback): ctxt = self.serializer.deserialize_context(ctxt) new_args = dict() for argname, arg in six.iteritems(args): new_args[argname] = self.serializer.deserialize_entity(ctxt, arg) - result = getattr(endpoint, method)(ctxt, **new_args) + func = getattr(endpoint, method) + if executor_callback: + result = executor_callback(func, ctxt, **new_args) + else: + result = func(ctxt, **new_args) return self.serializer.serialize_entity(ctxt, result) @contextlib.contextmanager - def __call__(self, incoming): + def __call__(self, incoming, executor_callback=None): incoming.acknowledge() - yield lambda: self._dispatch_and_reply(incoming) + yield lambda: self._dispatch_and_reply(incoming, executor_callback) - def _dispatch_and_reply(self, incoming): + def _dispatch_and_reply(self, incoming, executor_callback): try: incoming.reply(self._dispatch(incoming.ctxt, - incoming.message)) + incoming.message, + executor_callback)) except ExpectedException as e: LOG.debug(u'Expected exception during message handling (%s)', e.exc_info[1]) @@ -150,7 +155,7 @@ class RPCDispatcher(object): # exc_info. del exc_info - def _dispatch(self, ctxt, message): + def _dispatch(self, ctxt, message, executor_callback=None): """Dispatch an RPC message to the appropriate endpoint method. :param ctxt: the request context @@ -177,7 +182,8 @@ class RPCDispatcher(object): if hasattr(endpoint, method): localcontext.set_local_context(ctxt) try: - return self._do_dispatch(endpoint, method, ctxt, args) + return self._do_dispatch(endpoint, method, ctxt, args, + executor_callback) finally: localcontext.clear_local_context() diff --git a/tests/notify/test_dispatcher.py b/tests/notify/test_dispatcher.py index dacc6dde5..791794887 100644 --- a/tests/notify/test_dispatcher.py +++ b/tests/notify/test_dispatcher.py @@ -35,7 +35,7 @@ notification_msg = dict( ) -class TestDispatcher(test_utils.BaseTestCase): +class TestDispatcherScenario(test_utils.BaseTestCase): scenarios = [ ('no_endpoints', @@ -137,6 +137,9 @@ class TestDispatcher(test_utils.BaseTestCase): self.assertEqual(0, incoming.acknowledge.call_count) self.assertEqual(1, incoming.requeue.call_count) + +class TestDispatcher(test_utils.BaseTestCase): + @mock.patch('oslo.messaging.notify.dispatcher.LOG') def test_dispatcher_unknown_prio(self, mylog): msg = notification_msg.copy() @@ -147,3 +150,22 @@ class TestDispatcher(test_utils.BaseTestCase): callback() mylog.warning.assert_called_once_with('Unknown priority "%s"', 'what???') + + def test_dispatcher_executor_callback(self): + endpoint = mock.Mock(spec=['warn']) + endpoint_method = endpoint.warn + endpoint_method.return_value = messaging.NotificationResult.HANDLED + + targets = [messaging.Target(topic='notifications')] + dispatcher = notify_dispatcher.NotificationDispatcher( + targets, [endpoint], None, allow_requeue=True) + + msg = notification_msg.copy() + msg['priority'] = 'warn' + + incoming = mock.Mock(ctxt={}, message=msg) + executor_callback = mock.Mock() + with dispatcher(incoming, executor_callback) as callback: + callback() + self.assertTrue(executor_callback.called) + self.assertEqual(executor_callback.call_args[0][0], endpoint_method)