Add Hypermedia Pagination Links

This changes all list response bodies to be of the structure defined in
the specification. For example, it changes responses from /v1/hosts to
look like:

    {
      "hosts": [
        {"id": 1, ...},
        {"id": 2, ...}
      ],
      "links": [
        {
          "rel": "first",
          "href": "http://.../v1/hosts"
        },
        {
          "rel": "prev",
          "href": "http://.../v1/hosts?limit=...&marker=...&..."
        },
        {
          "rel": "self",
          "href": "http://.../v1/hosts?limit=...&marker=...&..."
        },
        {
          "rel": "next",
          "href": "http://.../v1/hosts?limit=...&marker=...&..."
        }
      ]
    }

Closes-bug: #1658737
Change-Id: I11c5c6053e5f0873ee53df5b7dcb9b2ed25a0b64
This commit is contained in:
Ian Cordasco 2017-01-26 13:19:02 -06:00
parent b42b39dff1
commit 656535624e
19 changed files with 599 additions and 293 deletions

View File

@ -1,5 +1,6 @@
import functools
import inspect
import urllib.parse as urllib
import decorator
@ -70,3 +71,32 @@ def limit_from(filters, minimum=10, default=30, maximum=100):
# isn't too small, then it must be too big. In that case, let's just
# return the maximum.
return maximum
def links_from(link_params):
"""Generate the list of hypermedia link relations from their parameters.
This uses the request thread-local to determine the endpoint and generate
URLs from that.
:param dict link_params:
A dictionary mapping the relation name to the query parameters.
:returns:
List of dictionaries to represent hypermedia link relations.
:rtype:
list
"""
links = []
relations = ["first", "prev", "self", "next"]
base_url = flask.request.base_url
for relation in relations:
query_params = link_params.get(relation)
if not query_params:
continue
link_rel = {
"rel": relation,
"href": base_url + "?" + urllib.urlencode(query_params),
}
links.append(link_rel)
return links

View File

@ -16,10 +16,12 @@ class Cells(base.Resource):
@base.pagination_context
def get(self, context, request_args, pagination_params):
"""Get all cells, with optional filtering."""
cells_obj = dbapi.cells_get_all(
cells_obj, link_params = dbapi.cells_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(cells_obj), 200, None
links = base.links_from(link_params)
response_body = {'cells': cells_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -17,10 +17,12 @@ class Hosts(base.Resource):
@base.pagination_context
def get(self, context, request_args, pagination_params):
"""Get all hosts for region, with optional filtering."""
hosts_obj = dbapi.hosts_get_all(
hosts_obj, link_params = dbapi.hosts_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(hosts_obj), 200, None
links = base.links_from(link_params)
response_body = {'hosts': hosts_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -18,10 +18,12 @@ class Networks(base.Resource):
@base.pagination_context
def get(self, context, request_args, pagination_params):
"""Get all networks, with optional filtering."""
networks_obj = dbapi.networks_get_all(
networks_obj, link_params = dbapi.networks_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(networks_obj), 200, None
links = base.links_from(link_params)
response_body = {'networks': networks_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):
@ -67,10 +69,12 @@ class NetworkDevices(base.Resource):
@base.pagination_context
def get(self, context, request_args, pagination_params):
"""Get all network devices."""
devices_obj = dbapi.network_devices_get_all(
devices_obj, link_params = dbapi.network_devices_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(devices_obj), 200, None
links = base.links_from(link_params)
response_body = {'network_devices': devices_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):
@ -142,10 +146,12 @@ class NetworkInterfaces(base.Resource):
@base.pagination_context
def get(self, context, request_args, pagination_params):
"""Get all network interfaces."""
interfaces_obj = dbapi.network_interfaces_get_all(
interfaces_obj, link_params = dbapi.network_interfaces_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(interfaces_obj), 200, None
links = base.links_from(link_params)
response_body = {'network_interfaces': interfaces_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -21,22 +21,25 @@ class Regions(base.Resource):
region_id = request_args.get("id")
region_name = request_args.get("name")
if not region_id and not region_name:
if not (region_id or region_name):
# Get all regions for this tenant
regions_obj = dbapi.regions_get_all(
regions_obj, link_params = dbapi.regions_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(regions_obj), 200, None
else:
if region_name:
region_obj = dbapi.regions_get_by_name(context, region_name)
region_obj.data = region_obj.variables
if region_name:
region_obj = dbapi.regions_get_by_name(context, region_name)
region_obj.data = region_obj.variables
return jsonutils.to_primitive([region_obj]), 200, None
if region_id:
region_obj = dbapi.regions_get_by_id(context, region_id)
region_obj.data = region_obj.variables
if region_id:
region_obj = dbapi.regions_get_by_id(context, region_id)
region_obj.data = region_obj.variables
return jsonutils.to_primitive([region_obj]), 200, None
regions_obj = [region_obj]
link_params = {}
links = base.links_from(link_params)
response_body = {'regions': regions_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -20,17 +20,20 @@ class Projects(base.Resource):
if project_id:
project_obj = dbapi.projects_get_by_id(context, project_id)
return jsonutils.to_primitive([project_obj], 200, None)
projects_obj = [project_obj]
link_params = {}
if project_name:
projects_obj = dbapi.projects_get_by_name(
projects_obj, link_params = dbapi.projects_get_by_name(
context, project_name, request_args, pagination_params,
)
else:
projects_obj = dbapi.projects_get_all(
projects_obj, link_params = dbapi.projects_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(projects_obj), 200, None
links = base.links_from(link_params)
response_body = {'projects': projects_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -23,17 +23,20 @@ class Users(base.Resource):
if user_id:
user_obj = dbapi.users_get_by_id(context, user_id)
user_obj.data = user_obj.variables
return jsonutils.to_primitive([user_obj]), 200, None
users_obj = [user_obj]
link_params = {}
if user_name:
users_obj = dbapi.users_get_by_name(
users_obj, link_params = dbapi.users_get_by_name(
context, user_name, request_args, pagination_params,
)
else:
users_obj = dbapi.users_get_all(
users_obj, link_params = dbapi.users_get_all(
context, request_args, pagination_params,
)
return jsonutils.to_primitive(users_obj), 200, None
links = base.links_from(link_params)
response_body = {'users': users_obj, 'links': links}
return jsonutils.to_primitive(response_body), 200, None
@base.http_codes
def post(self, context, request_data):

View File

@ -718,6 +718,38 @@ DefinitionNoParams = {
"additionalProperties": False,
}
DefinitionsPaginationLinks = {
"type": "array",
"items": {
"type": "object",
"properties": {
"rel": {
"type": "string",
"enum": ["first", "prev", "self", "next"],
"description": ("Relation of the associated URL to the current"
" page"),
},
"href": {
"type": "string",
},
},
},
}
def paginated_resource(list_name, schema):
return {
"type": "object",
"additionalProperties": False,
"properties": {
list_name: {
"type": "array",
"items": schema,
},
"links": DefinitionsPaginationLinks,
},
}
validators = {
("ansible_inventory", "GET"): {
"args": {
@ -1461,10 +1493,7 @@ filters = {
("hosts", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionsHost,
"type": "array",
},
"schema": paginated_resource("hosts", DefinitionsHost),
},
400: {
"headers": None,
@ -1550,10 +1579,7 @@ filters = {
("cells", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionsCell,
"type": "array",
},
"schema": paginated_resource("cells", DefinitionsCell),
},
400: {
"headers": None,
@ -1585,10 +1611,7 @@ filters = {
("regions", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionsRegion,
"type": "array",
},
"schema": paginated_resource("regions", DefinitionsRegion),
},
400: {
"headers": None,
@ -1660,10 +1683,7 @@ filters = {
("projects", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionProject,
"type": "array",
},
"schema": paginated_resource("projects", DefinitionProject),
},
400: {
"headers": None,
@ -1699,10 +1719,7 @@ filters = {
("users", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionUser,
"type": "array",
},
"schema": paginated_resource("users", DefinitionUser),
},
400: {
"headers": None,
@ -1810,10 +1827,8 @@ filters = {
("network_devices", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionNetworkDeviceId,
"type": "array",
},
"schema": paginated_resource("network_devices",
DefinitionNetworkDeviceId),
},
400: {
"headers": None,
@ -1957,10 +1972,7 @@ filters = {
("networks", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionNetwork,
"type": "array",
},
"schema": paginated_resource("networks", DefinitionNetwork),
},
400: {
"headers": None,
@ -2050,10 +2062,8 @@ filters = {
("network_interfaces", "GET"): {
200: {
"headers": None,
"schema": {
"items": DefinitionNetworkInterface,
"type": "array",
},
"schema": paginated_resource("network_interfaces",
DefinitionNetworkInterface),
},
400: {
"headers": None,

View File

@ -53,12 +53,16 @@ def get_backend():
def is_admin_context(context):
"""Check if this request had admin project context."""
return (context.is_admin and context.is_admin_project)
if (context.is_admin and context.is_admin_project):
return True
return False
def is_project_admin_context(context):
"""Check if this request has admin context with in the project."""
return context.is_admin
if context.is_admin:
return True
return False
def require_admin_context(f):
@ -76,7 +80,10 @@ def require_project_admin_context(f):
context = args[0]
if is_project_admin_context(context):
return f(*args, **kwargs)
raise exceptions.AdminRequired()
elif is_project_admin_context(args[0]):
return f(*args, **kwargs)
else:
raise exceptions.AdminRequired()
return wrapper
@ -479,13 +486,12 @@ def projects_get_all(context, filters, pagination_params):
@require_admin_context
def projects_get_by_name(context, project_name, filters, pagination_params):
def projects_get_by_name(context, project_name):
"""Get all projects that match the given name."""
query = model_query(context, models.Project)
query = query.filter(models.Project.name.like(project_name))
try:
return _paginate(context, query, models.Project, session, filters,
pagination_params)
return query.all()
except sa_exc.NoResultFound:
raise exceptions.NotFound()
except Exception as err:
@ -548,8 +554,12 @@ def users_get_all(context, filters, pagination_params):
@require_project_admin_context
def users_get_by_name(context, user_name, filters, pagination_params):
"""Get all users that match the given username."""
query = model_query(context, models.User,
project_only=is_admin_context(context))
session = get_session()
if is_admin_context(context):
query = model_query(context, models.User, session=session)
else:
query = model_query(context, models.User, project_only=True,
session=session)
query = query.filter_by(username=user_name)
return _paginate(context, query, models.User, session, filters,
@ -819,36 +829,116 @@ def network_interfaces_delete(context, interface_id):
query.delete()
def _marker_from(context, session, model, params, project_only):
if params['marker'] is None:
def _marker_from(context, session, model, marker, project_only):
if marker is None:
return None
try:
query = model_query(context, model, session=session,
project_only=project_only)
return query.filter_by(id=params['marker']).one()
except sa_exc.NoResultFound:
raise exceptions.BadRequest(
message='Marker "{}" does not exist'.format(params['marker'])
query = model_query(context, model, session=session,
project_only=project_only)
return query.filter_by(id=marker).one()
def _get_previous(query, model, current_marker, page_size, filters):
# NOTE(sigmavirus24): To get the previous items based on the existing
# filters, we need only reverse the direction that the user requested.
original_sort_dir = filters['sort_dir']
sort_dir = 'desc'
if original_sort_dir == 'desc':
sort_dir = 'asc'
results = db_utils.paginate_query(
query, model,
limit=page_size,
sort_keys=filters['sort_keys'],
sort_dir=sort_dir,
marker=current_marker,
).all()
if not results:
return None
return results[-1].id
def _link_params_for(query, model, filters, pagination_params,
current_marker, current_results):
links = {}
# We can discern our base parameters for our links
base_parameters = {}
for (key, value) in filters.items():
# This takes care of things like sort_keys which may have multiple
# values
if isinstance(value, list):
value = ','.join(value)
base_parameters[key] = value
base_parameters['limit'] = pagination_params['limit']
generate_links = ('first', 'self')
if current_results:
next_marker = current_results[-1]
# If there are results to return, there may be a next link to follow
generate_links += ('next',)
# We start our links dictionary with some basics
for relation in generate_links:
params = base_parameters.copy()
if relation == 'self':
if pagination_params['marker'] is not None:
params['marker'] = pagination_params['marker']
elif relation == 'next':
params['marker'] = next_marker.id
links[relation] = params
params = base_parameters.copy()
previous_marker = None
if current_marker is not None:
previous_marker = _get_previous(
query, model, current_marker, pagination_params['limit'], filters,
)
if previous_marker is not None:
params['marker'] = previous_marker
links['prev'] = params
return links
def _paginate(context, query, model, session, filters, pagination_params,
project_only=False):
# NOTE(sigmavirus24) Retrieve the instance of the model represented by the
# marker.
try:
return db_utils.paginate_query(
marker = _marker_from(context, session, model,
pagination_params['marker'],
project_only)
except sa_exc.NoResultFound:
raise exceptions.BadRequest(
message='Marker "{}" does not exist'.format(
pagination_params['marker']
)
)
except Exception as err:
raise exceptions.UnknownException(message=err)
filters.setdefault('sort_keys', ['created_at', 'id'])
filters.setdefault('sort_dir', 'asc')
# Retrieve the results based on the marker and the limit
try:
results = db_utils.paginate_query(
query, model,
limit=pagination_params['limit'],
sort_keys=filters.get('sort_keys', ['created_at']),
marker=_marker_from(context, session, model, pagination_params,
project_only),
sort_keys=filters['sort_keys'],
sort_dir=filters['sort_dir'],
marker=marker,
).all()
except sa_exc.NoResultFound:
raise exceptions.NotFound()
except exceptions.Base:
# NOTE(sigmavirus24): Here we need to allow for _marker_from's
# exception to bubble up without being rewrapped as an
# UnknownException
raise
except Exception as err:
raise exceptions.UnknownException(message=err)
try:
links = _link_params_for(
query, model, filters, pagination_params, marker, results,
)
except Exception as err:
raise exceptions.UnknownException(message=err)
return results, links

View File

@ -182,6 +182,15 @@ class TestCase(testtools.TestCase):
def tearDown(self):
super(TestCase, self).tearDown()
def assertSuccessOk(self, response):
self.assertEqual(requests.codes.OK, response.status_code)
def assertSuccessCreated(self, response):
self.assertEqual(requests.codes.CREATED, response.status_code)
def assertNoContent(self, response):
self.assertEqual(requests.codes.NO_CONTENT, response.status_code)
def get(self, url, headers=None, **params):
resp = self.session.get(
url, verify=False, headers=headers, params=params,

View File

@ -69,16 +69,16 @@ class APIV1CellTest(APIV1ResourceWithVariablesTestCase):
self.create_cell('cell-1')
url = self.url + '/v1/cells?region_id={}'.format(self.region['id'])
resp = self.get(url)
cells = resp.json()
cells = resp.json()['cells']
self.assertEqual(1, len(cells))
self.assertEqual({'cell-1'}, {i['name'] for i in cells})
self.assertEqual(['cell-1'], [i['name'] for i in cells])
def test_cell_get_all_with_name_filter(self):
self.create_cell('cell1')
self.create_cell('cell2')
url = self.url + '/v1/cells?name=cell2'
resp = self.get(url)
cells = resp.json()
cells = resp.json()['cells']
self.assertEqual(1, len(cells))
self.assertEqual({'cell2'}, {cell['name'] for cell in cells})
@ -106,7 +106,7 @@ class APIV1CellTest(APIV1ResourceWithVariablesTestCase):
url = self.url + '/v1/cells'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
cells = resp.json()
cells = resp.json()['cells']
self.assertEqual(2, len(cells))
self.assertEqual({'cell-1', 'cell-2'},
{cell['name'] for cell in cells})
@ -117,7 +117,7 @@ class APIV1CellTest(APIV1ResourceWithVariablesTestCase):
resp = self.get(url)
self.assertEqual(200, resp.status_code)
cells = resp.json()
cells = resp.json()['cells']
self.assertEqual(1, len(cells))
self.assertEqual({'cell-2'},
{cell['name'] for cell in cells})

View File

@ -1,21 +1,18 @@
import urllib.parse
from craton.tests.functional import TestCase
from craton.tests.functional.test_variable_calls import \
APIV1ResourceWithVariablesTestCase
class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
resource = 'hosts'
class HostTests(TestCase):
def setUp(self):
super(APIV1HostTest, self).setUp()
super(HostTests, self).setUp()
self.region = self.create_region()
def tearDown(self):
super(APIV1HostTest, self).tearDown()
def create_region(self):
def create_region(self, region_name='region-1'):
url = self.url + '/v1/regions'
payload = {'name': 'region-1'}
payload = {'name': region_name}
region = self.post(url, data=payload)
self.assertEqual(201, region.status_code)
self.assertIn('Location', region.headers)
@ -25,12 +22,16 @@ class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
)
return region.json()
def create_host(self, name, hosttype, ip_address, variables=None):
def create_host(self, name, hosttype, ip_address, region=None,
**variables):
if region is None:
region = self.region
url = self.url + '/v1/hosts'
payload = {'name': name, 'device_type': hosttype,
'ip_address': ip_address,
'region_id': self.region['id']}
if variables is not None:
'region_id': region['id']}
if variables:
payload['variables'] = variables
host = self.post(url, data=payload)
@ -42,9 +43,10 @@ class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
)
return host.json()
def test_create_host(self):
host = self.create_host('host1', 'server', '192.168.1.1')
self.assertEqual('host1', host['name'])
class APIV1HostTest(HostTests, APIV1ResourceWithVariablesTestCase):
resource = 'hosts'
def test_create_host_supports_vars_ops(self):
host = self.create_host('host1', 'server', '192.168.1.1')
@ -52,6 +54,32 @@ class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
self.assert_vars_can_be_set(host['id'])
self.assert_vars_can_be_deleted(host['id'])
def test_host_get_by_vars_filter(self):
vars1 = {"a": "b", "host": "one"}
self.create_host('host1', 'server', '192.168.1.1', **vars1)
vars2 = {"a": "b"}
self.create_host('host2', 'server', '192.168.1.2', **vars2)
url = self.url + '/v1/hosts'
resp = self.get(url, vars='a:b')
self.assertEqual(200, resp.status_code)
hosts = resp.json()['hosts']
self.assertEqual(2, len(hosts))
self.assertEqual({'192.168.1.1', '192.168.1.2'},
{host['ip_address'] for host in hosts})
url = self.url + '/v1/hosts'
resp = self.get(url, vars='host:one')
self.assertEqual(200, resp.status_code)
hosts = resp.json()['hosts']
self.assertEqual(1, len(hosts))
self.assertEqual('192.168.1.1', hosts[0]['ip_address'])
self.assert_vars_get_expected(hosts[0]['id'], vars1)
def test_create_host(self):
host = self.create_host('host1', 'server', '192.168.1.1')
self.assertEqual('host1', host['name'])
def test_create_with_missing_name_fails(self):
url = self.url + '/v1/hosts'
payload = {'device_type': 'server', 'ip_address': '192.168.1.1',
@ -73,51 +101,22 @@ class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
host = self.post(url, data=payload)
self.assertEqual(400, host.status_code)
def test_host_get_all_for_region(self):
self.create_host('host1', 'server', '192.168.1.1')
self.create_host('host2', 'server', '192.168.1.2')
url = self.url + '/v1/hosts?region_id={}'.format(self.region['id'])
resp = self.get(url)
self.assertEqual(2, len(resp.json()))
def test_host_get_by_ip_filter(self):
self.create_host('host1', 'server', '192.168.1.1')
self.create_host('host2', 'server', '192.168.1.2')
url = self.url + '/v1/hosts?ip_address=192.168.1.1'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
hosts = resp.json()
hosts = resp.json()['hosts']
self.assertEqual(1, len(hosts))
self.assertEqual('192.168.1.1', hosts[0]['ip_address'])
def test_host_get_by_vars_filter(self):
vars1 = {"a": "b", "host": "one"}
self.create_host('host1', 'server', '192.168.1.1', variables=vars1)
vars2 = {"a": "b"}
self.create_host('host2', 'server', '192.168.1.2', variables=vars2)
url = self.url + '/v1/hosts?vars=a:b'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
hosts = resp.json()
self.assertEqual(2, len(hosts))
self.assertEqual({'192.168.1.1', '192.168.1.2'},
{host['ip_address'] for host in hosts})
url = self.url + '/v1/hosts?vars=host:one'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
hosts = resp.json()
self.assertEqual(1, len(hosts))
self.assertEqual('192.168.1.1', hosts[0]['ip_address'])
self.assert_vars_get_expected(hosts[0]['id'], vars1)
def test_host_by_missing_filter(self):
self.create_host('host1', 'server', '192.168.1.1')
url = self.url + '/v1/hosts?ip_address=192.168.1.2'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
self.assertEqual(0, len(resp.json()))
self.assertEqual(0, len(resp.json()['hosts']))
def test_host_delete(self):
host = self.create_host('host1', 'server', '192.168.1.1')
@ -129,3 +128,68 @@ class APIV1HostTest(APIV1ResourceWithVariablesTestCase):
self.assertEqual(404, resp.status_code)
self.assertEqual({'status': 404, 'message': 'Not Found'},
resp.json())
class TestPagination(HostTests):
def setUp(self):
super(TestPagination, self).setUp()
self.hosts = [
self.create_host('host{}'.format(i), 'server',
'192.168.1.{}'.format(i + 1))
for i in range(0, 61)
]
def test_get_returns_a_default_list_of_thirty_hosts(self):
url = self.url + '/v1/hosts'
response = self.get(url)
self.assertSuccessOk(response)
hosts = response.json()
self.assertIn('hosts', hosts)
self.assertEqual(30, len(hosts['hosts']))
self.assertListEqual([h['id'] for h in self.hosts[:30]],
[h['id'] for h in hosts['hosts']])
def test_get_returns_correct_next_link(self):
url = self.url + '/v1/hosts'
thirtieth_host = self.hosts[29]
response = self.get(url)
self.assertSuccessOk(response)
hosts = response.json()
self.assertIn('links', hosts)
for link_rel in hosts['links']:
if link_rel['rel'] == 'next':
break
else:
self.fail("No 'next' link was returned in response")
parsed_next = urllib.parse.urlparse(link_rel['href'])
self.assertIn('marker={}'.format(thirtieth_host['id']),
parsed_next.query)
def test_get_returns_correct_prev_link(self):
first_host = self.hosts[0]
thirtieth_host = self.hosts[29]
url = self.url + '/v1/hosts?marker={}'.format(thirtieth_host['id'])
response = self.get(url)
self.assertSuccessOk(response)
hosts = response.json()
self.assertIn('links', hosts)
for link_rel in hosts['links']:
if link_rel['rel'] == 'prev':
break
else:
self.fail("No 'prev' link was returned in response")
parsed_prev = urllib.parse.urlparse(link_rel['href'])
self.assertIn('marker={}'.format(first_host['id']), parsed_prev.query)
def test_get_all_for_region(self):
region = self.create_region('region-2')
self.create_host('host1', 'server', '192.168.1.1', region=region)
self.create_host('host2', 'server', '192.168.1.2', region=region)
url = self.url + '/v1/hosts?region_id={}'.format(region['id'])
resp = self.get(url)
self.assertSuccessOk(resp)
hosts = resp.json()
self.assertEqual(2, len(hosts['hosts']))

View File

@ -1,54 +1,66 @@
from craton.tests.functional.test_variable_calls import \
APIV1ResourceWithVariablesTestCase
import urllib.parse
from craton.tests.functional import TestCase
class APIV1RegionTest(APIV1ResourceWithVariablesTestCase):
"""Test cases for /region calls.
One set of data for the test is generated by fake data generation
script during test module setup.
"""
resource = 'regions'
def setUp(self):
super(APIV1RegionTest, self).setUp()
def create_region(self, name, note=None, variables=None):
class RegionTests(TestCase):
def create_region(self, name, variables=None):
url = self.url + '/v1/regions'
values = {'name': name}
if note is not None:
values['note'] = note
if variables is not None:
if variables:
values['variables'] = variables
resp = self.post(url, data=values)
self.assertSuccessCreated(resp)
self.assertIn('Location', resp.headers)
json = resp.json()
self.assertEqual(
resp.headers['Location'],
"{}/{}".format(url, json['id'])
)
return json
def delete_regions(self, regions):
base_url = self.url + '/v1/regions/{}'
for region in regions:
url = base_url.format(region['id'])
resp = self.delete(url)
self.assertNoContent(resp)
class APIV1RegionTest(RegionTests):
"""Test cases for /region calls.
One set of data for the test is generated by fake data generateion
script during test module setup.
"""
def test_create_region_full_data(self):
# Test with full set of allowed parameters
values = {"name": "region-new",
"note": "This is region-new.",
"variables": {"a": "b"}}
url = self.url + '/v1/regions'
resp = self.post(url, data=values)
self.assertEqual(201, resp.status_code)
self.assertIn('Location', resp.headers)
self.assertEqual(
resp.headers['Location'],
"{}/{}".format(url, resp.json()['id'])
)
return resp.json()
def test_create_region_full_data(self):
# Test with full set of allowed parameters
region = self.create_region(
'region-new', 'This is region-new.', {'a': 'b'})
self.assertEqual('region-new', region['name'])
self.assertEqual({'a': 'b'}, region['variables'])
self.assertEqual(values['name'], resp.json()['name'])
def test_create_region_without_variables(self):
region = self.create_region(
'region-new', 'This is region-two')
self.assertEqual('region-new', region['name'])
self.assertEqual({}, region['variables'])
def test_create_region_supports_vars_ops(self):
region = self.create_region(
'region-new', 'This is region-new.', {'a': 'b'})
self.assert_vars_get_expected(region['id'], {'a': 'b'})
self.assert_vars_can_be_set(region['id'])
self.assert_vars_can_be_deleted(region['id'])
values = {"name": "region-two",
"note": "This is region-two"}
url = self.url + '/v1/regions'
resp = self.post(url, data=values)
self.assertEqual(201, resp.status_code)
self.assertIn('Location', resp.headers)
self.assertEqual(
resp.headers['Location'],
"{}/{}".format(url, resp.json()['id'])
)
self.assertEqual("region-two", resp.json()['name'])
def test_create_region_with_no_name_fails(self):
values = {"note": "This is region one."}
@ -72,10 +84,7 @@ class APIV1RegionTest(APIV1ResourceWithVariablesTestCase):
url = self.url + '/v1/regions'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
regions = resp.json()
self.assertEqual(2, len(regions))
self.assertEqual({'ORD1', 'ORD2'},
{region['name'] for region in regions})
self.assertEqual(2, len(resp.json()))
def test_regions_get_all_with_name_filter(self):
self.create_region("ORD1")
@ -83,8 +92,9 @@ class APIV1RegionTest(APIV1ResourceWithVariablesTestCase):
url = self.url + '/v1/regions?name=ORD1'
resp = self.get(url)
self.assertEqual(200, resp.status_code)
self.assertEqual(1, len(resp.json()))
self.assertEqual('ORD1', resp.json()[0]['name'])
regions = resp.json()['regions']
self.assertEqual(1, len(regions))
self.assertEqual('ORD1', regions[0]['name'])
def test_region_with_non_existing_filters(self):
self.create_region("ORD1")
@ -100,3 +110,57 @@ class APIV1RegionTest(APIV1ResourceWithVariablesTestCase):
region = resp.json()
self.assertEqual(region['name'], 'ORD1')
self.assertEqual(regvars, region['variables'])
class TestPagination(RegionTests):
def setUp(self):
super(TestPagination, self).setUp()
self.regions = [self.create_region('region-{}'.format(i))
for i in range(0, 61)]
self.addCleanup(self.delete_regions, self.regions)
def test_list_first_thirty_regions(self):
url = self.url + '/v1/regions'
response = self.get(url)
self.assertSuccessOk(response)
json = response.json()
self.assertIn('regions', json)
self.assertEqual(30, len(json['regions']))
self.assertListEqual([r['id'] for r in self.regions[:30]],
[r['id'] for r in json['regions']])
def test_get_returns_correct_next_link(self):
url = self.url + '/v1/regions'
thirtieth_region = self.regions[29]
response = self.get(url)
self.assertSuccessOk(response)
json = response.json()
self.assertIn('links', json)
for link_rel in json['links']:
if link_rel['rel'] == 'next':
break
else:
self.fail("No 'next' link was returned in response")
parsed_next = urllib.parse.urlparse(link_rel['href'])
self.assertIn('marker={}'.format(thirtieth_region['id']),
parsed_next.query)
def test_get_returns_correct_prev_link(self):
first_region = self.regions[0]
thirtieth_region = self.regions[29]
url = self.url + '/v1/regions?marker={}'.format(thirtieth_region['id'])
response = self.get(url)
self.assertSuccessOk(response)
json = response.json()
self.assertIn('links', json)
for link_rel in json['links']:
if link_rel['rel'] == 'prev':
break
else:
self.fail("No 'prev' link was returned in response")
parsed_prev = urllib.parse.urlparse(link_rel['href'])
self.assertIn('marker={}'.format(first_region['id']),
parsed_prev.query)

View File

@ -32,20 +32,21 @@ class CellsDBTestCase(base.DBTestCase):
filters = {
"region_id": cell1["region_id"],
}
res = dbapi.cells_get_all(self.context, filters, default_pagination)
res, _ = dbapi.cells_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['name'], 'cell1')
def test_cells_get_all_filter_name(self):
for cell in cells:
dbapi.cells_create(self.context, cell)
setup_res = dbapi.cells_get_all(self.context, {}, default_pagination)
setup_res, _ = dbapi.cells_get_all(self.context, {},
default_pagination)
self.assertGreater(len(setup_res), 2)
filters = {
"name": cell1["name"],
}
res = dbapi.cells_get_all(self.context, filters, default_pagination)
res, _ = dbapi.cells_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 2)
for cell in res:
self.assertEqual(cell['name'], 'cell1')
@ -53,7 +54,8 @@ class CellsDBTestCase(base.DBTestCase):
def test_cells_get_all_filter_id(self):
for cell in cells:
dbapi.cells_create(self.context, cell)
setup_res = dbapi.cells_get_all(self.context, {}, default_pagination)
setup_res, _ = dbapi.cells_get_all(self.context, {},
default_pagination)
self.assertGreater(len(setup_res), 2)
self.assertEqual(
len([cell for cell in setup_res if cell['id'] == 1]), 1
@ -62,7 +64,7 @@ class CellsDBTestCase(base.DBTestCase):
filters = {
"id": 1,
}
res = dbapi.cells_get_all(self.context, filters, default_pagination)
res, _ = dbapi.cells_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['id'], 1)
@ -76,7 +78,7 @@ class CellsDBTestCase(base.DBTestCase):
"vars": "key2:value2",
"region_id": cell1["region_id"],
}
res = dbapi.cells_get_all(self.context, filters, default_pagination)
res, _ = dbapi.cells_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['name'], 'cell1')
@ -88,7 +90,7 @@ class CellsDBTestCase(base.DBTestCase):
)
filters = {}
filters["vars"] = "key2:value5"
res = dbapi.cells_get_all(self.context, filters, default_pagination)
res, _ = dbapi.cells_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 0)
def test_cell_delete(self):

View File

@ -246,7 +246,7 @@ class HostsDBTestCase(base.DBTestCase):
cell_id=cell_id2,
)
all_res = dbapi.hosts_get_all(self.context, {}, default_pagination)
all_res, _ = dbapi.hosts_get_all(self.context, {}, default_pagination)
self.assertEqual(len(all_res), 2)
self.assertEqual(
len([host for host in all_res if host['cell_id'] == cell_id1]), 1
@ -255,7 +255,8 @@ class HostsDBTestCase(base.DBTestCase):
filters = {
"cell_id": cell_id1,
}
res = dbapi.hosts_get_all(self.context, filters, default_pagination)
res, _ = dbapi.hosts_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0].name, 'www.example.xyz')
@ -272,7 +273,8 @@ class HostsDBTestCase(base.DBTestCase):
"region_id": region_id,
"vars": "key2:value2",
}
res = dbapi.hosts_get_all(self.context, filters, default_pagination)
res, _ = dbapi.hosts_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0].name, 'www.example.xyz')
@ -295,12 +297,12 @@ class HostsDBTestCase(base.DBTestCase):
)
filters = {"vars": "key1:example2"}
res = dbapi.hosts_get_all(self.context, filters, default_pagination)
res, _ = dbapi.hosts_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual('www.example2.xyz', res[0].name)
filters = {"vars": "key2:Tom"}
res = dbapi.hosts_get_all(self.context, filters, default_pagination)
res, _ = dbapi.hosts_get_all(self.context, filters, default_pagination)
self.assertEqual(len(res), 2)
def test_hosts_get_all_with_filters_noexist(self):
@ -316,5 +318,6 @@ class HostsDBTestCase(base.DBTestCase):
"region_id": "region_1",
"vars": "key1:value5",
}
res = dbapi.hosts_get_all(self.context, filters, default_pagination)
res, _ = dbapi.hosts_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 0)

View File

@ -73,8 +73,8 @@ class NetworksDBTestCase(base.DBTestCase):
dbapi.networks_create(self.context, network1)
dbapi.networks_create(self.context, network2)
filters = {}
res = dbapi.networks_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.networks_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 2)
def test_networks_get_all_filter_region(self):
@ -83,8 +83,8 @@ class NetworksDBTestCase(base.DBTestCase):
filters = {
'region_id': network1['region_id'],
}
res = dbapi.networks_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.networks_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['name'], 'test network')
@ -96,8 +96,8 @@ class NetworksDBTestCase(base.DBTestCase):
def test_networks_get_by_name_filter_no_exit(self):
dbapi.networks_create(self.context, network1)
filters = {"name": "foo", "region_id": network1['region_id']}
res = dbapi.networks_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.networks_get_all(self.context, filters,
default_pagination)
self.assertEqual(res, [])
def test_network_update(self):
@ -137,8 +137,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
dbapi.network_devices_create(self.context, device1)
dbapi.network_devices_create(self.context, device2)
filters = {}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 2)
def test_network_device_get_all_filter_region(self):
@ -147,8 +147,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'region_id': device1['region_id'],
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['hostname'], 'switch1')
@ -157,8 +157,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
dbapi.network_devices_create(self.context, device2)
name = device1['hostname']
setup_res = dbapi.network_devices_get_all(self.context, {},
default_pagination)
setup_res, _ = dbapi.network_devices_get_all(self.context, {},
default_pagination)
self.assertEqual(len(setup_res), 2)
matches = [dev for dev in setup_res if dev['hostname'] == name]
@ -167,8 +167,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'name': name,
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['hostname'], name)
@ -197,8 +197,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
self.context, dict(cell_id=cell2.id, **device2)
)
setup_res = dbapi.network_devices_get_all(self.context, {},
default_pagination)
setup_res, _ = dbapi.network_devices_get_all(self.context, {},
default_pagination)
self.assertEqual(len(setup_res), 2)
matches = [dev for dev in setup_res if dev['cell_id'] == cell1.id]
@ -207,8 +207,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'cell_id': cell1.id,
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['cell_id'], cell1.id)
@ -217,8 +217,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
dbapi.network_devices_create(self.context, device3)
dev_type = device1['device_type']
setup_res = dbapi.network_devices_get_all(self.context, {},
default_pagination)
setup_res, _ = dbapi.network_devices_get_all(self.context, {},
default_pagination)
self.assertEqual(len(setup_res), 2)
matches = [dev for dev in setup_res if dev['device_type'] == dev_type]
@ -227,8 +227,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'device_type': dev_type,
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['device_type'], dev_type)
@ -236,8 +236,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
dbapi.network_devices_create(self.context, device1)
dbapi.network_devices_create(self.context, device2)
setup_res = dbapi.network_devices_get_all(self.context, {},
default_pagination)
setup_res, _ = dbapi.network_devices_get_all(self.context, {},
default_pagination)
self.assertEqual(len(setup_res), 2)
@ -247,8 +247,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'id': dev_id
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['id'], dev_id)
@ -257,8 +257,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
dbapi.network_devices_create(self.context, device3)
ip = device1['ip_address']
setup_res = dbapi.network_devices_get_all(self.context, {},
default_pagination)
setup_res, _ = dbapi.network_devices_get_all(self.context, {},
default_pagination)
self.assertEqual(len(setup_res), 2)
matches = [dev for dev in setup_res if str(dev['ip_address']) == ip]
@ -267,8 +267,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
filters = {
'ip_address': ip,
}
res = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_devices_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(str(res[0]['ip_address']), ip)
@ -280,8 +280,8 @@ class NetworkDevicesDBTestCase(base.DBTestCase):
def test_network_devices_get_by_filter_no_exit(self):
dbapi.network_devices_create(self.context, device1)
filters = {"hostname": "foo"}
res = dbapi.networks_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.networks_get_all(self.context, filters,
default_pagination)
self.assertEqual(res, [])
def test_network_devices_delete(self):
@ -332,8 +332,8 @@ class NetworkInterfacesDBTestCase(base.DBTestCase):
dbapi.network_interfaces_create(self.context, network_interface1)
dbapi.network_interfaces_create(self.context, network_interface2)
filters = {}
res = dbapi.network_interfaces_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_interfaces_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 2)
self.assertEqual(
str(res[0]['ip_address']), network_interface1['ip_address']
@ -348,8 +348,8 @@ class NetworkInterfacesDBTestCase(base.DBTestCase):
filters = {
"device_id": 1,
}
res = dbapi.network_interfaces_get_all(self.context, filters,
default_pagination)
res, _ = dbapi.network_interfaces_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['name'], 'eth1')

View File

@ -26,7 +26,8 @@ class RegionsDBTestCase(base.DBTestCase):
def test_regions_get_all(self):
dbapi.regions_create(self.context, region1)
filters = {}
res = dbapi.regions_get_all(self.context, filters, default_pagination)
res, _ = dbapi.regions_get_all(self.context, filters,
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['name'], 'region1')
@ -38,7 +39,7 @@ class RegionsDBTestCase(base.DBTestCase):
)
filters = {}
filters["vars"] = "key1:value1"
regions = dbapi.regions_get_all(
regions, _ = dbapi.regions_get_all(
self.context, filters, default_pagination,
)
self.assertEqual(len(regions), 1)
@ -52,7 +53,7 @@ class RegionsDBTestCase(base.DBTestCase):
)
filters = {}
filters["vars"] = "key1:value12"
regions = dbapi.regions_get_all(
regions, _ = dbapi.regions_get_all(
self.context, filters, default_pagination,
)
self.assertEqual(len(regions), 0)

View File

@ -48,7 +48,7 @@ class UsersDBTestCase(base.DBTestCase):
# is not for the same project no user info is given back.
self.make_user(user1)
self.context.tenant = uuid.uuid4().hex
res = dbapi.users_get_all(self.context, {}, default_pagination)
res, _ = dbapi.users_get_all(self.context, {}, default_pagination)
self.assertEqual(len(res), 0)
def test_user_get_no_admin_context_raises(self):
@ -63,8 +63,8 @@ class UsersDBTestCase(base.DBTestCase):
dbapi.users_create(self.context, user1)
dbapi.users_create(self.context, user2)
self.context.tenant = user1['project_id']
res = dbapi.users_get_by_name(self.context, user1['username'], {},
default_pagination)
res, _ = dbapi.users_get_by_name(self.context, user1['username'], {},
default_pagination)
self.assertEqual(len(res), 1)
self.assertEqual(res[0]['username'], user1['username'])

View File

@ -161,35 +161,39 @@ class APIV1CellsTest(APIV1Test):
@mock.patch.object(dbapi, 'cells_get_all')
def test_get_cells_with_name_filters(self, mock_cells):
cell_name = 'cell1'
mock_cells.return_value = fake_resources.CELL_LIST2
mock_cells.return_value = (fake_resources.CELL_LIST2, {})
resp = self.get('v1/cells?name={}'.format(cell_name))
self.assertEqual(len(resp.json), 2)
cells = resp.json['cells']
self.assertEqual(len(cells), 2)
# Ensure we got the right cell
self.assertEqual(resp.json[0]["name"], cell_name)
self.assertEqual(resp.json[1]["name"], cell_name)
self.assertEqual(cells[0]["name"], cell_name)
self.assertEqual(cells[1]["name"], cell_name)
@mock.patch.object(dbapi, 'cells_get_all')
def test_get_cells_with_name_and_region_filters(self, mock_cells):
mock_cells.return_value = [fake_resources.CELL1]
mock_cells.return_value = ([fake_resources.CELL1], {})
resp = self.get('v1/cells?region_id=1&name=cell1')
self.assertEqual(len(resp.json), 1)
self.assertEqual(len(resp.json['cells']), 1)
# Ensure we got the right cell
self.assertEqual(resp.json[0]["name"], fake_resources.CELL1.name)
self.assertEqual(resp.json['cells'][0]["name"],
fake_resources.CELL1.name)
@mock.patch.object(dbapi, 'cells_get_all')
def test_get_cells_with_id_filters(self, mock_cells):
mock_cells.return_value = [fake_resources.CELL1]
mock_cells.return_value = ([fake_resources.CELL1], {})
resp = self.get('v1/cells?region_id=1&id=1')
self.assertEqual(len(resp.json), 1)
cells = resp.json['cells']
self.assertEqual(len(cells), 1)
# Ensure we got the right cell
self.assertEqual(resp.json[0]["name"], fake_resources.CELL1.name)
self.assertEqual(cells[0]["name"], fake_resources.CELL1.name)
@mock.patch.object(dbapi, 'cells_get_all')
def test_get_cells_with_vars_filters(self, mock_cells):
mock_cells.return_value = [fake_resources.CELL1]
mock_cells.return_value = ([fake_resources.CELL1], {})
resp = self.get('v1/cells?region_id=1&vars=somekey:somevalue')
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], fake_resources.CELL1.name)
self.assertEqual(len(resp.json['cells']), 1)
self.assertEqual(resp.json['cells'][0]["name"],
fake_resources.CELL1.name)
@mock.patch.object(dbapi, 'cells_get_all')
def test_get_cell_no_exist_by_name_fails(self, mock_cell):
@ -347,7 +351,7 @@ class APIV1RegionsIDTest(APIV1Test):
class APIV1RegionsTest(APIV1Test):
@mock.patch.object(dbapi, 'regions_get_all')
def test_regions_get_all(self, mock_regions):
mock_regions.return_value = fake_resources.REGIONS_LIST
mock_regions.return_value = (fake_resources.REGIONS_LIST, {})
resp = self.get('v1/regions')
self.assertEqual(len(resp.json), len(fake_resources.REGIONS_LIST))
@ -361,20 +365,23 @@ class APIV1RegionsTest(APIV1Test):
def test_regions_get_by_name_filters(self, mock_regions):
mock_regions.return_value = fake_resources.REGION1
resp = self.get('v1/regions?name=region1')
self.assertEqual(resp.json[0]["name"], fake_resources.REGION1.name)
regions = resp.json['regions']
self.assertEqual(regions[0]["name"], fake_resources.REGION1.name)
@mock.patch.object(dbapi, 'regions_get_by_id')
def test_regions_get_by_id_filters(self, mock_regions):
mock_regions.return_value = fake_resources.REGION1
resp = self.get('v1/regions?id=1')
self.assertEqual(resp.json[0]["name"], fake_resources.REGION1.name)
regions = resp.json['regions']
self.assertEqual(regions[0]["name"], fake_resources.REGION1.name)
@mock.patch.object(dbapi, 'regions_get_all')
def test_regions_get_by_vars_filters(self, mock_regions):
mock_regions.return_value = [fake_resources.REGION1]
mock_regions.return_value = ([fake_resources.REGION1], {})
resp = self.get('v1/regions?vars=somekey:somevalue')
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], fake_resources.REGION1.name)
self.assertEqual(len(resp.json['regions']), 1)
self.assertEqual(resp.json['regions'][0]["name"],
fake_resources.REGION1.name)
@mock.patch.object(dbapi, 'regions_get_by_name')
def test_get_region_no_exist_by_name_fails(self, mock_regions):
@ -583,7 +590,7 @@ class APIV1HostsLabelsTest(APIV1Test):
class APIV1HostsTest(APIV1Test):
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_hosts_by_region_gets_all_hosts(self, fake_hosts):
fake_hosts.return_value = fake_resources.HOSTS_LIST_R1
fake_hosts.return_value = (fake_resources.HOSTS_LIST_R1, {})
resp = self.get('/v1/hosts?region_id=1')
self.assertEqual(len(resp.json), 2)
@ -601,20 +608,20 @@ class APIV1HostsTest(APIV1Test):
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_hosts(self, fake_hosts):
fake_hosts.return_value = fake_resources.HOSTS_LIST_R3
fake_hosts.return_value = (fake_resources.HOSTS_LIST_R3, {})
resp = self.get('/v1/hosts')
self.assertEqual(len(resp.json), 3)
self.assertEqual(len(resp.json['hosts']), 3)
fake_hosts.assert_called_once_with(
mock.ANY, {}, {'limit': 30, 'marker': None},
)
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_host_by_name_filters(self, fake_hosts):
fake_hosts.return_value = fake_resources.HOSTS_LIST_R2
fake_hosts.return_value = (fake_resources.HOSTS_LIST_R2, {})
resp = self.get('/v1/hosts?region_id=1&name=www.example.net')
host_resp = fake_resources.HOSTS_LIST_R2
self.assertEqual(len(resp.json), len(host_resp))
self.assertEqual(resp.json[0]["name"], host_resp[0].name)
self.assertEqual(len(resp.json['hosts']), len(host_resp))
self.assertEqual(resp.json['hosts'][0]["name"], host_resp[0].name)
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_host_by_ip_address_filter(self, fake_hosts):
@ -626,11 +633,11 @@ class APIV1HostsTest(APIV1Test):
path_query = '/v1/hosts?region_id={}&ip_address={}'.format(
region_id, ip_address
)
fake_hosts.return_value = fake_resources.HOSTS_LIST_R2
fake_hosts.return_value = (fake_resources.HOSTS_LIST_R2, {})
resp = self.get(path_query)
host_resp = fake_resources.HOSTS_LIST_R2
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], host_resp[0].name)
self.assertEqual(len(resp.json['hosts']), 1)
self.assertEqual(resp.json['hosts'][0]["name"], host_resp[0].name)
fake_hosts.assert_called_once_with(
mock.ANY, filters, {'limit': 30, 'marker': None},
@ -638,18 +645,19 @@ class APIV1HostsTest(APIV1Test):
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_host_by_vars_filters(self, fake_hosts):
fake_hosts.return_value = [fake_resources.HOST1]
fake_hosts.return_value = ([fake_resources.HOST1], {})
resp = self.get('/v1/hosts?region_id=1&vars=somekey:somevalue')
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], fake_resources.HOST1.name)
self.assertEqual(len(resp.json['hosts']), 1)
self.assertEqual(resp.json['hosts'][0]["name"],
fake_resources.HOST1.name)
@mock.patch.object(dbapi, 'hosts_get_all')
def test_get_host_by_label_filters(self, fake_hosts):
fake_hosts.return_value = fake_resources.HOSTS_LIST_R2
fake_hosts.return_value = (fake_resources.HOSTS_LIST_R2, {})
resp = self.get('/v1/hosts?region_id=1&label=somelabel')
host_resp = fake_resources.HOSTS_LIST_R2
self.assertEqual(len(resp.json), len(host_resp))
self.assertEqual(resp.json[0]["name"], host_resp[0].name)
self.assertEqual(len(resp.json['hosts']), len(host_resp))
self.assertEqual(resp.json['hosts'][0]["name"], host_resp[0].name)
@mock.patch.object(dbapi, 'hosts_create')
def test_create_host_with_valid_data(self, mock_host):
@ -789,12 +797,12 @@ class APIV1ProjectsTest(APIV1Test):
def test_project_get_all(self, mock_projects):
proj1 = fake_resources.PROJECT1
proj2 = fake_resources.PROJECT2
return_value = [proj1, proj2]
return_value = ([proj1, proj2], {})
mock_projects.return_value = return_value
resp = self.get('v1/projects')
self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json), 2)
self.assertEqual(len(resp.json['projects']), 2)
@mock.patch.object(dbapi, 'projects_create')
def test_project_post_invalid_property(self, mock_projects):
@ -841,11 +849,12 @@ class APIV1UsersTest(APIV1Test):
@mock.patch.object(dbapi, 'users_get_all')
def test_users_get_all(self, mock_user):
return_values = [fake_resources.USER1, fake_resources.USER2]
return_values = ([fake_resources.USER1, fake_resources.USER2], {})
mock_user.return_value = return_values
resp = self.get('v1/users')
self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json), 2)
self.assertEqual(len(resp.json['users']), 2)
@mock.patch.object(dbapi, 'users_get_all')
def test_users_get_no_admin_fails(self, mock_user):
@ -869,17 +878,17 @@ class APIV1NetworksTest(APIV1Test):
@mock.patch.object(dbapi, 'networks_get_all')
def test_get_networks_by_filters(self, fake_networks):
fake_networks.return_value = [fake_resources.NETWORK1]
fake_networks.return_value = ([fake_resources.NETWORK1], {})
resp = self.get('/v1/networks?region_id=1&name=PrivateNetwork')
net_resp = fake_resources.NETWORK1
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], net_resp.name)
self.assertEqual(len(resp.json['networks']), 1)
self.assertEqual(resp.json['networks'][0]["name"], net_resp.name)
@mock.patch.object(dbapi, 'networks_get_all')
def test_get_networks(self, fake_networks):
fake_networks.return_value = fake_resources.NETWORKS_LIST2
fake_networks.return_value = (fake_resources.NETWORKS_LIST2, {})
resp = self.get('/v1/networks')
self.assertEqual(len(resp.json), 3)
self.assertEqual(len(resp.json['networks']), 3)
fake_networks.assert_called_once_with(
mock.ANY, {}, {'limit': 30, 'marker': None},
)
@ -1074,11 +1083,11 @@ class APIV1NetworkDevicesTest(APIV1Test):
path_query = '/v1/network-devices?region_id={}&ip_address={}'.format(
region_id, ip_address
)
fake_devices.return_value = fake_resources.NETWORK_DEVICE_LIST1
fake_devices.return_value = (fake_resources.NETWORK_DEVICE_LIST1, {})
resp = self.get(path_query)
device_resp = fake_resources.NETWORK_DEVICE_LIST1
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["ip_address"],
self.assertEqual(len(resp.json['network_devices']), 1)
self.assertEqual(resp.json['network_devices'][0]["ip_address"],
device_resp[0].ip_address)
fake_devices.assert_called_once_with(
@ -1102,12 +1111,13 @@ class APIV1NetworkDevicesTest(APIV1Test):
@mock.patch.object(dbapi, 'network_devices_get_all')
def test_network_devices_get_by_region(self, mock_devices):
mock_devices.return_value = fake_resources.NETWORK_DEVICE_LIST1
mock_devices.return_value = (fake_resources.NETWORK_DEVICE_LIST1, {})
resp = self.get('/v1/network-devices?region_id=1')
self.assertEqual(len(resp.json), 1)
network_devices = resp.json['network_devices']
self.assertEqual(len(network_devices), 1)
self.assertEqual(200, resp.status_code)
self.assertEqual(
resp.json[0]["name"],
network_devices[0]["name"],
fake_resources.NETWORK_DEVICE_LIST1[0].name
)
@ -1236,11 +1246,13 @@ class APIV1NetworkInterfacesTest(APIV1Test):
device_id, ip_address
)
)
fake_interfaces.return_value = fake_resources.NETWORK_INTERFACE_LIST1
fake_interfaces.return_value = (fake_resources.NETWORK_INTERFACE_LIST1,
{})
resp = self.get(path_query)
interface_resp = fake_resources.NETWORK_INTERFACE_LIST1
self.assertEqual(len(resp.json), 1)
self.assertEqual(resp.json[0]["name"], interface_resp[0].name)
self.assertEqual(len(resp.json['network_interfaces']), 1)
self.assertEqual(resp.json['network_interfaces'][0]["name"],
interface_resp[0].name)
fake_interfaces.assert_called_once_with(
mock.ANY, filters, {'limit': 30, 'marker': None},
@ -1248,13 +1260,14 @@ class APIV1NetworkInterfacesTest(APIV1Test):
@mock.patch.object(dbapi, 'network_interfaces_get_all')
def test_get_network_interfaces_by_device_id(self, fake_interfaces):
fake_interfaces.return_value = fake_resources.NETWORK_INTERFACE_LIST1
fake_interfaces.return_value = (fake_resources.NETWORK_INTERFACE_LIST1,
{})
resp = self.get('/v1/network-interfaces?device_id=1')
self.assertEqual(200, resp.status_code)
network_interface_resp = fake_resources.NETWORK_INTERFACE1
self.assertEqual(resp.json[0]["name"], network_interface_resp.name)
netifaces = resp.json['network_interfaces']
self.assertEqual(netifaces[0]["name"], network_interface_resp.name)
self.assertEqual(
resp.json[0]['ip_address'], network_interface_resp.ip_address
netifaces[0]['ip_address'], network_interface_resp.ip_address
)
@mock.patch.object(dbapi, 'network_interfaces_create')
@ -1293,10 +1306,11 @@ class APIV1NetworkInterfacesTest(APIV1Test):
@mock.patch.object(dbapi, 'network_interfaces_get_all')
def test_get_network_interfaces(self, fake_interfaces):
fake_interfaces.return_value = fake_resources.NETWORK_INTERFACE_LIST2
fake_interfaces.return_value = (fake_resources.NETWORK_INTERFACE_LIST2,
{})
resp = self.get('/v1/network-interfaces')
self.assertEqual(200, resp.status_code)
self.assertEqual(len(resp.json), 2)
self.assertEqual(len(resp.json['network_interfaces']), 2)
fake_interfaces.assert_called_once_with(
mock.ANY, {}, {'limit': 30, 'marker': None},
)