From 8a8685b62ff3e17e3f3ff4042ac828ae88b0151c Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Sat, 15 Jun 2013 16:13:42 +0100 Subject: [PATCH] Make it possible to call prepare() on call context Found it's quite easy to get into doing e.g. client = RPCClient(...) client = client.prepare(topic='foo') client = client.prepare(server='bar') --- oslo/messaging/rpc/client.py | 68 ++++++++++++++++++++++++------------ tests/test_rpc_client.py | 17 ++++++++- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/oslo/messaging/rpc/client.py b/oslo/messaging/rpc/client.py index 0aae956cf..326da69a0 100644 --- a/oslo/messaging/rpc/client.py +++ b/oslo/messaging/rpc/client.py @@ -112,6 +112,46 @@ class _CallContext(object): wait_for_reply=True, timeout=timeout) return self.serializer.deserialize_entity(ctxt, result) + _marker = object() + + @classmethod + def _prepare(cls, base, + exchange=_marker, topic=_marker, namespace=_marker, + version=_marker, server=_marker, fanout=_marker, + timeout=_marker, check_for_lock=_marker, version_cap=_marker): + """Prepare a method invocation context. See RPCClient.prepare().""" + kwargs = dict( + exchange=exchange, + topic=topic, + namespace=namespace, + version=version, + server=server, + fanout=fanout) + kwargs = dict([(k, v) for k, v in kwargs.items() + if v is not cls._marker]) + target = base.target(**kwargs) + + if timeout is cls._marker: + timeout = base.timeout + if check_for_lock is cls._marker: + check_for_lock = base.check_for_lock + if version_cap is cls._marker: + version_cap = base.version_cap + + return _CallContext(base.transport, target, + base.serializer, + timeout, check_for_lock, + version_cap) + + def prepare(self, exchange=_marker, topic=_marker, namespace=_marker, + version=_marker, server=_marker, fanout=_marker, + timeout=_marker, check_for_lock=_marker, version_cap=_marker): + """Prepare a method invocation context. See RPCClient.prepare().""" + return self._prepare(self, + exchange, topic, namespace, + version, server, fanout, + timeout, check_for_lock, version_cap) + class RPCClient(object): @@ -199,7 +239,7 @@ class RPCClient(object): super(RPCClient, self).__init__() - _marker = object() + _marker = _CallContext._marker def prepare(self, exchange=_marker, topic=_marker, namespace=_marker, version=_marker, server=_marker, fanout=_marker, @@ -232,28 +272,10 @@ class RPCClient(object): :param version_cap: raise a RPCVersionCapError version exceeds this cap :type version_cap: str """ - kwargs = dict( - exchange=exchange, - topic=topic, - namespace=namespace, - version=version, - server=server, - fanout=fanout) - kwargs = dict([(k, v) for k, v in kwargs.items() - if v is not self._marker]) - target = self.target(**kwargs) - - if timeout is self._marker: - timeout = self.timeout - if check_for_lock is self._marker: - check_for_lock = self.check_for_lock - if version_cap is self._marker: - version_cap = self.version_cap - - return _CallContext(self.transport, target, - self.serializer, - timeout, check_for_lock, - version_cap) + return _CallContext._prepare(self, + exchange, topic, namespace, + version, server, fanout, + timeout, check_for_lock, version_cap) def cast(self, ctxt, method, **kwargs): """Invoke a method and return immediately. diff --git a/tests/test_rpc_client.py b/tests/test_rpc_client.py index fb204612f..5a32cf9e2 100644 --- a/tests/test_rpc_client.py +++ b/tests/test_rpc_client.py @@ -75,7 +75,7 @@ class TestCastCall(test_utils.BaseTestCase): class TestCastToTarget(test_utils.BaseTestCase): - scenarios = [ + _base = [ ('all_none', dict(ctor={}, prepare={}, expect={})), ('ctor_exchange', dict(ctor=dict(exchange='testexchange'), @@ -175,6 +175,16 @@ class TestCastToTarget(test_utils.BaseTestCase): expect=dict(fanout=False))), ] + _prepare = [ + ('single_prepare', dict(double_prepare=False)), + ('double_prepare', dict(double_prepare=True)), + ] + + @classmethod + def generate_scenarios(cls): + cls.scenarios = testscenarios.multiply_scenarios(cls._base, + cls._prepare) + def setUp(self): super(TestCastToTarget, self).setUp(conf=cfg.ConfigOpts()) @@ -198,9 +208,14 @@ class TestCastToTarget(test_utils.BaseTestCase): if self.prepare: client = client.prepare(**self.prepare) + if self.double_prepare: + client = client.prepare(**self.prepare) client.cast({}, 'foo') +TestCastToTarget.generate_scenarios() + + _notset = object()