diff --git a/keystone/common/serializer.py b/keystone/common/serializer.py index 91a16be402..e5f21a7364 100644 --- a/keystone/common/serializer.py +++ b/keystone/common/serializer.py @@ -71,7 +71,35 @@ class XmlDeserializer(object): def __call__(self, xml_str): """Returns a dictionary populated by decoding the given xml string.""" dom = etree.fromstring(xml_str.strip(), PARSER) - return self.walk_element(dom, True) + links_json = self._find_and_remove_links_from_root(dom, True) + obj_json = self.walk_element(dom, True) + if links_json: + obj_json['links'] = links_json['links'] + return obj_json + + def _deserialize_links(self, links, links_json): + for link in links: + links_json['links'][link.attrib['rel']] = link.attrib['href'] + + def _find_and_remove_links_from_root(self, dom, namespace): + """Special-case links element + + If "links" is in the elements, convert it and remove it from root + element. "links" will be placed back into the root of the converted + JSON object. + + """ + for element in dom: + decoded_tag = XmlDeserializer._tag_name(element.tag, namespace) + if decoded_tag == 'links': + links_json = {'links': {}} + self._deserialize_links(element, links_json) + dom.remove(element) + # TODO(gyee): are 'next' and 'previous' mandatory? If so, + # setting them to None if they don't exist? + links_json['links'].setdefault('previous') + links_json['links'].setdefault('next') + return links_json @staticmethod def _tag_name(tag, namespace): @@ -133,6 +161,12 @@ class XmlDeserializer(object): else: list_item_tag = decoded_tag[:-1] + # links is a special dict + if decoded_tag == 'links': + links_json = {'links': {}} + self._deserialize_links(element, links_json) + return links_json + for child in [self.walk_element(x) for x in element if not isinstance(x, ENTITY_TYPE)]: if list_item_tag: @@ -153,10 +187,17 @@ class XmlSerializer(object): Optionally, namespace the etree by specifying an ``xmlns``. """ + links = None # FIXME(dolph): skipping links for now for key in d.keys(): if '_links' in key: d.pop(key) + # FIXME(gyee): special-case links in collections + if 'links' == key: + if links: + # we have multiple links + raise Exception('Multiple links found') + links = d.pop(key) assert len(d.keys()) == 1, ('Cannot encode more than one root ' 'element: %s' % d.keys()) @@ -175,9 +216,23 @@ class XmlSerializer(object): self.populate_element(root, d[name]) + # FIXME(gyee): special-case links for now + if links: + self._populate_links(root, links) + # TODO(dolph): you can get a doctype from lxml, using ElementTrees return '%s\n%s' % (DOCTYPE, etree.tostring(root, pretty_print=True)) + def _populate_links(self, element, links_json): + links = etree.Element('links') + for k, v in links_json.iteritems(): + if v: + link = etree.Element('link') + link.set('rel', unicode(k)) + link.set('href', unicode(v)) + links.append(link) + element.append(links) + def _populate_list(self, element, k, v): """Populates an element with a key & list value.""" # spec has a lot of inconsistency here! @@ -219,9 +274,13 @@ class XmlSerializer(object): def _populate_dict(self, element, k, v): """Populates an element with a key & dictionary value.""" - child = etree.Element(k) - self.populate_element(child, v) - element.append(child) + if k == 'links': + # links is a special dict + self._populate_links(element, v) + else: + child = etree.Element(k) + self.populate_element(child, v) + element.append(child) def _populate_bool(self, element, k, v): """Populates an element with a key & boolean value.""" diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 816bad45d1..8ea4901597 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +import copy import re from keystone.common import serializer @@ -176,3 +177,51 @@ class XmlSerializerTestCase(test.TestCase): """ self.assertEqualIgnoreWhitespace(serializer.to_xml(d), xml) + + def test_collection_list(self): + d = { + "links": { + "next": "http://localhost:5000/v3/objects?page=3", + "previous": None, + "self": "http://localhost:5000/v3/objects" + }, + "objects": [{ + "attribute": "value1", + "links": { + "self": "http://localhost:5000/v3/objects/abc123def", + "anotherobj": "http://localhost:5000/v3/anotherobjs/123" + } + }, { + "attribute": "value2", + "links": { + "self": "http://localhost:5000/v3/objects/abc456" + } + }]} + xml = """ + + + + + + + + + + + + + + + + + + + """ + self.assertEqualIgnoreWhitespace( + serializer.to_xml(copy.deepcopy(d)), xml) + self.assertDictEqual(serializer.from_xml(xml), d) diff --git a/tests/test_v3.py b/tests/test_v3.py index 8c6b7d17ed..2facfe370c 100644 --- a/tests/test_v3.py +++ b/tests/test_v3.py @@ -298,7 +298,10 @@ class RestfulTestCase(test_content_types.RestfulTestCase): response, and asserted to be equal. """ - entities = resp.body.get(key) + resp_body = resp.body + if resp.getheader('Content-Type') == 'application/xml': + resp_body = serializer.from_xml(etree.tostring(resp_body)) + entities = resp_body.get(key) self.assertIsNotNone(entities) if expected_length is not None: @@ -308,7 +311,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase): self.assertTrue(len(entities)) # collections should have relational links - self.assertValidListLinks(resp.body.get('links')) + self.assertValidListLinks(resp_body.get('links')) for entity in entities: self.assertIsNotNone(entity) diff --git a/tests/test_v3_catalog.py b/tests/test_v3_catalog.py index 2d161db5ef..67cbd34025 100644 --- a/tests/test_v3_catalog.py +++ b/tests/test_v3_catalog.py @@ -38,6 +38,11 @@ class CatalogTestCase(test_v3.RestfulTestCase): r = self.get('/services') self.assertValidServiceListResponse(r, ref=self.service) + def test_list_services_xml(self): + """GET /services (xml data)""" + r = self.get('/services', content_type='xml') + self.assertValidServiceListResponse(r, ref=self.service) + def test_get_service(self): """GET /services/{service_id}""" r = self.get('/services/%(service_id)s' % { @@ -65,6 +70,11 @@ class CatalogTestCase(test_v3.RestfulTestCase): r = self.get('/endpoints') self.assertValidEndpointListResponse(r, ref=self.endpoint) + def test_list_endpoints_xml(self): + """GET /endpoints (xml data)""" + r = self.get('/endpoints', content_type='xml') + self.assertValidEndpointListResponse(r, ref=self.endpoint) + def test_create_endpoint(self): """POST /endpoints""" ref = self.new_endpoint_ref(service_id=self.service_id) diff --git a/tests/test_v3_identity.py b/tests/test_v3_identity.py index 9ef487c318..162c41e3a0 100644 --- a/tests/test_v3_identity.py +++ b/tests/test_v3_identity.py @@ -55,6 +55,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/domains') self.assertValidDomainListResponse(r, ref=self.domain) + def test_list_domains_xml(self): + """GET /domains (xml data)""" + r = self.get('/domains', content_type='xml') + self.assertValidDomainListResponse(r, ref=self.domain) + def test_get_domain(self): """GET /domains/{domain_id}""" r = self.get('/domains/%(domain_id)s' % { @@ -105,6 +110,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/projects') self.assertValidProjectListResponse(r, ref=self.project) + def test_list_projects_xml(self): + """GET /projects (xml data)""" + r = self.get('/projects', content_type='xml') + self.assertValidProjectListResponse(r, ref=self.project) + def test_create_project(self): """POST /projects""" ref = self.new_project_ref(domain_id=self.domain_id) @@ -151,6 +161,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/users') self.assertValidUserListResponse(r, ref=self.user) + def test_list_users_xml(self): + """GET /users (xml data)""" + r = self.get('/users', content_type='xml') + self.assertValidUserListResponse(r, ref=self.user) + def test_get_user(self): """GET /users/{user_id}""" r = self.get('/users/%(user_id)s' % { @@ -215,6 +230,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/groups') self.assertValidGroupListResponse(r, ref=self.group) + def test_list_groups_xml(self): + """GET /groups (xml data)""" + r = self.get('/groups', content_type='xml') + self.assertValidGroupListResponse(r, ref=self.group) + def test_get_group(self): """GET /groups/{group_id}""" r = self.get('/groups/%(group_id)s' % { @@ -242,6 +262,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/credentials') self.assertValidCredentialListResponse(r, ref=self.credential) + def test_list_credentials_xml(self): + """GET /credentials (xml data)""" + r = self.get('/credentials', content_type='xml') + self.assertValidCredentialListResponse(r, ref=self.credential) + def test_create_credential(self): """POST /credentials""" ref = self.new_credential_ref(user_id=self.user['id']) @@ -290,6 +315,11 @@ class IdentityTestCase(test_v3.RestfulTestCase): r = self.get('/roles') self.assertValidRoleListResponse(r, ref=self.role) + def test_list_roles_xml(self): + """GET /roles (xml data)""" + r = self.get('/roles', content_type='xml') + self.assertValidRoleListResponse(r, ref=self.role) + def test_get_role(self): """GET /roles/{role_id}""" r = self.get('/roles/%(role_id)s' % { diff --git a/tests/test_v3_policy.py b/tests/test_v3_policy.py index 1af68b6e87..e8855ea35e 100644 --- a/tests/test_v3_policy.py +++ b/tests/test_v3_policy.py @@ -30,6 +30,11 @@ class PolicyTestCase(test_v3.RestfulTestCase): r = self.get('/policies') self.assertValidPolicyListResponse(r, ref=self.policy) + def test_list_policies_xml(self): + """GET /policies (xml data)""" + r = self.get('/policies', content_type='xml') + self.assertValidPolicyListResponse(r, ref=self.policy) + def test_get_policy(self): """GET /policies/{policy_id}""" r = self.get(