diff --git a/nova/api/openstack/common.py b/nova/api/openstack/common.py index b1e31b0c743e..ccc70cd1fb25 100644 --- a/nova/api/openstack/common.py +++ b/nova/api/openstack/common.py @@ -185,12 +185,19 @@ def limited(items, request, max_limit=FLAGS.osapi_max_limit): return items[offset:range_end] +def get_limit_and_marker(request, max_limit=FLAGS.osapi_max_limit): + """get limited parameter from request""" + params = get_pagination_params(request) + limit = params.get('limit', max_limit) + limit = min(max_limit, limit) + marker = params.get('marker') + + return limit, marker + + def limited_by_marker(items, request, max_limit=FLAGS.osapi_max_limit): """Return a slice of items according to the requested marker and limit.""" - params = get_pagination_params(request) - - limit = params.get('limit', max_limit) - marker = params.get('marker') + limit, marker = get_limit_and_marker(request, max_limit) limit = min(max_limit, limit) start_index = 0 diff --git a/nova/api/openstack/compute/servers.py b/nova/api/openstack/compute/servers.py index 229c3b5aa2b2..9f462c565353 100644 --- a/nova/api/openstack/compute/servers.py +++ b/nova/api/openstack/compute/servers.py @@ -446,16 +446,17 @@ class Controller(wsgi.Controller): else: search_opts['user_id'] = context.user_id + limit, marker = common.get_limit_and_marker(req) instance_list = self.compute_api.get_all(context, - search_opts=search_opts) + search_opts=search_opts, + limit=limit, marker=marker) - limited_list = self._limit_items(instance_list, req) if is_detail: - self._add_instance_faults(context, limited_list) - response = self._view_builder.detail(req, limited_list) + self._add_instance_faults(context, instance_list) + response = self._view_builder.detail(req, instance_list) else: - response = self._view_builder.index(req, limited_list) - req.cache_db_instances(limited_list) + response = self._view_builder.index(req, instance_list) + req.cache_db_instances(instance_list) return response def _get_server(self, context, req, instance_uuid): @@ -1021,9 +1022,6 @@ class Controller(wsgi.Controller): self.compute_api.set_admin_password(context, server, password) return webob.Response(status_int=202) - def _limit_items(self, items, req): - return common.limited_by_marker(items, req) - def _validate_metadata(self, metadata): """Ensure that we can work with the metadata given.""" try: diff --git a/nova/common/sqlalchemyutils.py b/nova/common/sqlalchemyutils.py new file mode 100644 index 000000000000..a186948ac96c --- /dev/null +++ b/nova/common/sqlalchemyutils.py @@ -0,0 +1,128 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack LLC. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Implementation of paginate query.""" + +import sqlalchemy + +from nova import exception +from nova.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +# copy from glance/db/sqlalchemy/api.py +def paginate_query(query, model, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + """Returns a query with sorting / pagination criteria added. + + Pagination works by requiring a unique sort_key, specified by sort_keys. + (If sort_keys is not unique, then we risk looping through values.) + We use the last row in the previous page as the 'marker' for pagination. + So we must return values that follow the passed marker in the order. + With a single-valued sort_key, this would be easy: sort_key > X. + With a compound-values sort_key, (k1, k2, k3) we must do this to repeat + the lexicographical ordering: + (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) + + We also have to cope with different sort_directions. + + Typically, the id of the last row is used as the client-facing pagination + marker, then the actual marker object must be fetched from the db and + passed in to us as marker. + + :param query: the query object to which we should add paging/sorting + :param model: the ORM model class + :param limit: maximum number of items to return + :param sort_keys: array of attributes by which results should be sorted + :param marker: the last item of the previous page; we returns the next + results after this value. + :param sort_dir: direction in which results should be sorted (asc, desc) + :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + + :rtype: sqlalchemy.orm.query.Query + :return: The query with sorting/pagination added. + """ + + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warn(_('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + + try: + sort_key_attr = getattr(model, current_sort_key) + except AttributeError: + raise exception.InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = getattr(marker, sort_key) + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in xrange(0, len(sort_keys)): + crit_attrs = [] + for j in xrange(0, i): + model_attr = getattr(model, sort_keys[j]) + crit_attrs.append((model_attr == marker_values[j])) + + model_attr = getattr(model, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((model_attr < marker_values[i])) + elif sort_dirs[i] == 'asc': + crit_attrs.append((model_attr > marker_values[i])) + else: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.filter(f) + + if limit is not None: + query = query.limit(limit) + + return query diff --git a/nova/compute/api.py b/nova/compute/api.py index 0c47d879a980..eb2f0520fbfd 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -1013,7 +1013,7 @@ class API(base.Base): return inst def get_all(self, context, search_opts=None, sort_key='created_at', - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): """Get all instances filtered by one of the given parameters. If there is no filter and the context is an admin, it will retrieve @@ -1090,7 +1090,9 @@ class API(base.Base): return [] inst_models = self._get_instances_by_filters(context, filters, - sort_key, sort_dir) + sort_key, sort_dir, + limit=limit, + marker=marker) # Convert the models to dictionaries instances = [] @@ -1102,7 +1104,10 @@ class API(base.Base): return instances - def _get_instances_by_filters(self, context, filters, sort_key, sort_dir): + def _get_instances_by_filters(self, context, filters, + sort_key, sort_dir, + limit=None, + marker=None): if 'ip6' in filters or 'ip' in filters: res = self.network_api.get_instance_uuids_by_ip_filter(context, filters) @@ -1111,8 +1116,9 @@ class API(base.Base): uuids = set([r['instance_uuid'] for r in res]) filters['uuid'] = uuids - return self.db.instance_get_all_by_filters(context, filters, sort_key, - sort_dir) + return self.db.instance_get_all_by_filters(context, filters, + sort_key, sort_dir, + limit=limit, marker=marker) @wrap_check_policy @check_instance_state(vm_state=[vm_states.ACTIVE, vm_states.STOPPED]) diff --git a/nova/db/api.py b/nova/db/api.py index f718f00471bd..46d0305ef4a3 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -582,10 +582,11 @@ def instance_get_all(context, columns_to_join=None): def instance_get_all_by_filters(context, filters, sort_key='created_at', - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): """Get all instances that match all filters.""" return IMPL.instance_get_all_by_filters(context, filters, sort_key, - sort_dir) + sort_dir, limit=limit, + marker=marker) def instance_get_active_by_window(context, begin, end=None, project_id=None, diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 797516ac9cb0..341b0d3321a2 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -26,6 +26,7 @@ import functools import warnings from nova import block_device +from nova.common.sqlalchemyutils import paginate_query from nova.compute import vm_states from nova import db from nova.db.sqlalchemy import models @@ -1503,7 +1504,8 @@ def instance_get_all(context, columns_to_join=None): @require_context -def instance_get_all_by_filters(context, filters, sort_key, sort_dir): +def instance_get_all_by_filters(context, filters, sort_key, sort_dir, + limit=None, marker=None): """Return instances that match all filters. Deleted instances will be returned by default, unless there's a filter that says otherwise""" @@ -1557,6 +1559,13 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir): filters, exact_match_filter_names) query_prefix = regex_filter(query_prefix, models.Instance, filters) + + # paginate query + query_prefix = paginate_query(query_prefix, models.Instance, limit, + [sort_key, 'created_at', 'id'], + marker=marker, + sort_dir=sort_dir) + instances = query_prefix.all() return instances diff --git a/nova/exception.py b/nova/exception.py index cd1eabc9dff0..e4d212ca6b48 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -294,6 +294,10 @@ class InvalidGroup(Invalid): message = _("Group not valid. Reason: %(reason)s") +class InvalidSortKey(Invalid): + message = _("Sort key supplied was not valid.") + + class InstanceInvalidState(Invalid): message = _("Instance %(instance_uuid)s in %(attr)s %(state)s. Cannot " "%(method)s while the instance is in this state.") diff --git a/nova/tests/api/openstack/compute/test_servers.py b/nova/tests/api/openstack/compute/test_servers.py index 3edf7a5ab7b2..ae4d30cc18b0 100644 --- a/nova/tests/api/openstack/compute/test_servers.py +++ b/nova/tests/api/openstack/compute/test_servers.py @@ -582,7 +582,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): return [fakes.stub_instance(100, uuid=server_uuid)] self.stubs.Set(nova.compute.API, 'get_all', fake_get_all) @@ -597,7 +598,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('image' in search_opts) self.assertEqual(search_opts['image'], '12345') @@ -613,7 +615,7 @@ class ServersControllerTest(test.TestCase): def test_tenant_id_filter_converts_to_project_id_for_admin(self): def fake_get_all(context, filters=None, sort_key=None, - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): self.assertNotEqual(filters, None) self.assertEqual(filters['project_id'], 'fake') self.assertFalse(filters.get('tenant_id')) @@ -630,7 +632,7 @@ class ServersControllerTest(test.TestCase): def test_admin_restricted_tenant(self): def fake_get_all(context, filters=None, sort_key=None, - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): self.assertNotEqual(filters, None) self.assertEqual(filters['project_id'], 'fake') return [fakes.stub_instance(100)] @@ -646,7 +648,7 @@ class ServersControllerTest(test.TestCase): def test_admin_all_tenants(self): def fake_get_all(context, filters=None, sort_key=None, - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): self.assertNotEqual(filters, None) self.assertTrue('project_id' not in filters) return [fakes.stub_instance(100)] @@ -662,7 +664,7 @@ class ServersControllerTest(test.TestCase): def test_all_tenants(self): def fake_get_all(context, filters=None, sort_key=None, - sort_dir='desc'): + sort_dir='desc', limit=None, marker=None): self.assertNotEqual(filters, None) self.assertEqual(filters['project_id'], 'fake') return [fakes.stub_instance(100)] @@ -679,7 +681,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('flavor' in search_opts) # flavor is an integer ID @@ -698,7 +701,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('vm_state' in search_opts) self.assertEqual(search_opts['vm_state'], vm_states.ACTIVE) @@ -728,7 +732,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertTrue('vm_state' in search_opts) self.assertEqual(search_opts['vm_state'], 'deleted') @@ -747,7 +752,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('name' in search_opts) self.assertEqual(search_opts['name'], 'whee.*') @@ -765,7 +771,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('changes-since' in search_opts) changes_since = datetime.datetime(2011, 1, 24, 17, 8, 1, @@ -796,7 +803,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) # Allowed by user self.assertTrue('name' in search_opts) @@ -824,7 +832,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) # Allowed by user self.assertTrue('name' in search_opts) @@ -852,7 +861,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('ip' in search_opts) self.assertEqual(search_opts['ip'], '10\..*') @@ -874,7 +884,8 @@ class ServersControllerTest(test.TestCase): server_uuid = str(utils.gen_uuid()) def fake_get_all(compute_self, context, search_opts=None, - sort_key=None, sort_dir='desc'): + sort_key=None, sort_dir='desc', + limit=None, marker=None): self.assertNotEqual(search_opts, None) self.assertTrue('ip6' in search_opts) self.assertEqual(search_opts['ip6'], 'ffff.*') diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py index 81d0f2ae1992..5ad2b094a209 100644 --- a/nova/tests/api/openstack/fakes.py +++ b/nova/tests/api/openstack/fakes.py @@ -384,10 +384,26 @@ def fake_instance_get(**kwargs): def fake_instance_get_all_by_filters(num_servers=5, **kwargs): def _return_servers(context, *args, **kwargs): servers_list = [] + marker = None + limit = None + found_marker = False + if "marker" in kwargs: + marker = kwargs["marker"] + if "limit" in kwargs: + limit = kwargs["limit"] + for i in xrange(num_servers): - server = stub_instance(id=i + 1, uuid=get_fake_uuid(i), + uuid = get_fake_uuid(i) + server = stub_instance(id=i + 1, uuid=uuid, **kwargs) servers_list.append(server) + if not marker is None and uuid == marker: + found_marker = True + servers_list = [] + if not marker is None and not found_marker: + raise webob.exc.HTTPBadRequest + if not limit is None: + servers_list = servers_list[:limit] return servers_list return _return_servers @@ -400,7 +416,7 @@ def stub_instance(id, user_id=None, project_id=None, host=None, auto_disk_config=False, display_name=None, include_fake_metadata=True, config_drive=None, power_state=None, nw_cache=None, metadata=None, - security_groups=None): + security_groups=None, limit=None, marker=None): if user_id is None: user_id = 'fake_user' diff --git a/nova/tests/test_db_api.py b/nova/tests/test_db_api.py index 313de2545151..fbc5908dfe4a 100644 --- a/nova/tests/test_db_api.py +++ b/nova/tests/test_db_api.py @@ -113,6 +113,32 @@ class DbApiTestCase(test.TestCase): else: self.assertTrue(result[1].deleted) + def test_instance_get_all_by_filters_paginate(self): + self.flags(sql_connection="notdb://") + test1 = self.create_instances_with_args(display_name='test1') + test2 = self.create_instances_with_args(display_name='test2') + test3 = self.create_instances_with_args(display_name='test3') + + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}, + marker=None) + self.assertEqual(3, len(result)) + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}, + sort_dir="asc", + marker=test1) + self.assertEqual(2, len(result)) + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}, + sort_dir="asc", + marker=test2) + self.assertEqual(1, len(result)) + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}, + sort_dir="asc", + marker=test3) + self.assertEqual(0, len(result)) + def test_migration_get_unconfirmed_by_dest_compute(self): ctxt = context.get_admin_context()