diff --git a/oslo_messaging/_drivers/impl_rabbit.py b/oslo_messaging/_drivers/impl_rabbit.py index 60461d728..65032c236 100644 --- a/oslo_messaging/_drivers/impl_rabbit.py +++ b/oslo_messaging/_drivers/impl_rabbit.py @@ -912,7 +912,7 @@ class Connection(object): def _heartbeat_start(self): if self._heartbeat_supported_and_enabled(): - self._heartbeat_exit_event = threading.Event() + self._heartbeat_exit_event = _utils.Event() self._heartbeat_thread = threading.Thread( target=self._heartbeat_thread_job) self._heartbeat_thread.daemon = True diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index e0025e688..919a44bda 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -15,6 +15,11 @@ import threading +from oslo_utils import importutils + +_eventlet = importutils.try_import('eventlet') +_patcher = importutils.try_import('eventlet.patcher') + def version_is_compatible(imp_version, version): """Determine whether versions are compatible. @@ -74,3 +79,49 @@ class DummyLock(object): def __exit__(self, type, value, traceback): self.release() + + +class _Event(object): + """A class that provides consistent eventlet/threading Event API. + + This wraps the eventlet.event.Event class to have the same API as + the standard threading.Event object. + """ + def __init__(self, *args, **kwargs): + self.clear() + + def clear(self): + self._set = False + self._event = _eventlet.event.Event() + + def is_set(self): + return self._set + + isSet = is_set + + def set(self): + self._set = True + self._event.send(True) + + def wait(self, timeout=None): + with _eventlet.timeout.Timeout(timeout, False): + self._event.wait() + return self.is_set() + + +def _is_monkey_patched(module): + """Determines safely is eventlet patching for module enabled or not + :param module: String, module name + :return Bool, True if module is patched, False otherwise + """ + + if _patcher is None: + return False + return _patcher.is_monkey_patched(module) + + +def Event(): + if _is_monkey_patched("thread"): + return _Event() + else: + return threading.Event() diff --git a/oslo_messaging/tests/test_utils.py b/oslo_messaging/tests/test_utils.py index 908c25fbf..256a69439 100644 --- a/oslo_messaging/tests/test_utils.py +++ b/oslo_messaging/tests/test_utils.py @@ -13,9 +13,13 @@ # License for the specific language governing permissions and limitations # under the License. +import threading + from oslo_messaging._drivers import common from oslo_messaging import _utils as utils from oslo_messaging.tests import utils as test_utils + +import six from six.moves import mock @@ -97,3 +101,27 @@ class TimerTestCase(test_utils.BaseTestCase): remaining = t.check_return(callback, 1, a='b') self.assertEqual(0, remaining) callback.assert_called_once_with(1, a='b') + + +class EventCompatTestCase(test_utils.BaseTestCase): + @mock.patch('oslo_messaging._utils._Event.clear') + def test_event_api_compat(self, mock_clear): + with mock.patch('oslo_messaging._utils._is_monkey_patched', + return_value=True): + e_event = utils.Event() + self.assertIsInstance(e_event, utils._Event) + + with mock.patch('oslo_messaging._utils._is_monkey_patched', + return_value=False): + t_event = utils.Event() + if six.PY3: + t_event_cls = threading.Event + else: + t_event_cls = threading._Event + self.assertIsInstance(t_event, t_event_cls) + + public_methods = [m for m in dir(t_event) if not m.startswith("_") and + callable(getattr(t_event, m))] + + for method in public_methods: + self.assertTrue(hasattr(e_event, method))