diff --git a/openstack/resource.py b/openstack/resource.py index 630ffb738..ea58ee375 100644 --- a/openstack/resource.py +++ b/openstack/resource.py @@ -231,9 +231,32 @@ class QueryParameters(object): self._mapping = {"limit": "limit", "marker": "marker"} self._mapping.update(dict({name: name for name in names}, **mappings)) + def _validate(self, query, base_path=None): + """Check that supplied query keys match known query mappings + + :param dict query: Collection of key-value pairs where each key is the + client-side parameter name or server side name. + :param base_path: Formatted python string of the base url path for + the resource. + """ + expected_params = list(self._mapping.keys()) + expected_params += self._mapping.values() + + if base_path: + expected_params += utils.get_string_format_keys(base_path) + + invalid_keys = set(query.keys()) - set(expected_params) + if invalid_keys: + raise exceptions.InvalidResourceQuery( + message="Invalid query params: %s" % ",".join(invalid_keys), + extra_data=invalid_keys) + def _transpose(self, query): """Transpose the keys in query based on the mapping + If a query is supplied with its server side name, we will still use + it, but take preference to the client-side name when both are supplied. + :param dict query: Collection of key-value pairs where each key is the client-side parameter name to be transposed to its server side name. @@ -242,6 +265,8 @@ class QueryParameters(object): for key, value in self._mapping.items(): if key in query: result[value] = query[key] + elif value in query: + result[value] = query[value] return result @@ -855,15 +880,7 @@ class Resource(object): raise exceptions.MethodNotSupported(cls, "list") session = cls._get_session(session) - expected_params = utils.get_string_format_keys(cls.base_path) - expected_params += cls._query_mapping._mapping.keys() - - invalid_keys = set(params.keys()) - set(expected_params) - if invalid_keys: - raise exceptions.InvalidResourceQuery( - message="Invalid query params: %s" % ",".join(invalid_keys), - extra_data=invalid_keys) - + cls._query_mapping._validate(params, base_path=cls.base_path) query_params = cls._query_mapping._transpose(params) uri = cls.base_path % params diff --git a/openstack/tests/unit/test_resource.py b/openstack/tests/unit/test_resource.py index 4322f3c4e..1d0fab5bb 100644 --- a/openstack/tests/unit/test_resource.py +++ b/openstack/tests/unit/test_resource.py @@ -1285,7 +1285,9 @@ class TestResourceActions(base.TestCase): # Look at the `params` argument to each of the get calls that # were made. - self.session.get.call_args_list[0][1]["params"] = {qp_name: qp} + self.assertEqual( + self.session.get.call_args_list[0][1]["params"], + {qp_name: qp}) self.assertEqual(self.session.get.call_args_list[0][0][0], Test.base_path % {"something": uri_param}) @@ -1314,6 +1316,81 @@ class TestResourceActions(base.TestCase): except exceptions.InvalidResourceQuery as err: self.assertEqual(str(err), 'Invalid query params: something_wrong') + def test_values_as_list_params(self): + id = 1 + qp = "query param!" + qp_name = "query-param" + uri_param = "uri param!" + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.links = {} + mock_response.json.return_value = {"resources": [{"id": id}]} + + mock_empty = mock.Mock() + mock_empty.status_code = 200 + mock_empty.links = {} + mock_empty.json.return_value = {"resources": []} + + self.session.get.side_effect = [mock_response, mock_empty] + + class Test(self.test_class): + _query_mapping = resource.QueryParameters(query_param=qp_name) + base_path = "/%(something)s/blah" + something = resource.URI("something") + + results = list(Test.list(self.session, paginated=True, + something=uri_param, **{qp_name: qp})) + + self.assertEqual(1, len(results)) + + # Look at the `params` argument to each of the get calls that + # were made. + self.assertEqual( + self.session.get.call_args_list[0][1]["params"], + {qp_name: qp}) + + self.assertEqual(self.session.get.call_args_list[0][0][0], + Test.base_path % {"something": uri_param}) + + def test_values_as_list_params_precedence(self): + id = 1 + qp = "query param!" + qp2 = "query param!!!!!" + qp_name = "query-param" + uri_param = "uri param!" + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.links = {} + mock_response.json.return_value = {"resources": [{"id": id}]} + + mock_empty = mock.Mock() + mock_empty.status_code = 200 + mock_empty.links = {} + mock_empty.json.return_value = {"resources": []} + + self.session.get.side_effect = [mock_response, mock_empty] + + class Test(self.test_class): + _query_mapping = resource.QueryParameters(query_param=qp_name) + base_path = "/%(something)s/blah" + something = resource.URI("something") + + results = list(Test.list(self.session, paginated=True, query_param=qp2, + something=uri_param, **{qp_name: qp})) + + self.assertEqual(1, len(results)) + + # Look at the `params` argument to each of the get calls that + # were made. + self.assertEqual( + self.session.get.call_args_list[0][1]["params"], + {qp_name: qp2}) + + self.assertEqual(self.session.get.call_args_list[0][0][0], + Test.base_path % {"something": uri_param}) + def test_list_multi_page_response_paginated(self): ids = [1, 2] resp1 = mock.Mock()