Revise receiver sorting

This finishes the series of patches revising sorting support. Some
useless routines are removed from db layer.

Change-Id: I2d0b6c388bb44295da2ef661f02b5baee0c8a129
This commit is contained in:
tengqm 2016-01-08 07:27:48 -05:00
parent 1fd48ae3d6
commit d784e21e0f
12 changed files with 61 additions and 194 deletions

View File

@ -47,8 +47,6 @@ MIDDLE PRIORITY
API
---
- Revise the API for sorting, based on the following guideline:
https://github.com/openstack/api-wg/blob/master/guidelines/pagination_filter_sort.rst
- According to the guidelines from API WG, we need to support `page_reverse`
as a pagination parameter. https://review.openstack.org/190743
- According to the proposal (https://review.openstack.org/#/c/234994/),

View File

@ -0,0 +1,3 @@
---
features:
- Use 'sort' instead of 'sort_keys' and 'sort_dir' for object sorting.

View File

@ -85,8 +85,7 @@ class ReceiverController(object):
param_whitelist = {
'limit': 'single',
'marker': 'single',
'sort_keys': 'multi',
'sort_dir': 'single',
'sort': 'single',
'global_project': 'single',
}
params = util.get_allowed_params(req.params, param_whitelist)

View File

@ -394,11 +394,11 @@ def receiver_get_by_short_id(context, short_id, project_safe=True):
project_safe=project_safe)
def receiver_get_all(context, limit=None, marker=None, filters=None,
sort_keys=None, sort_dir=None, project_safe=True):
def receiver_get_all(context, limit=None, marker=None, filters=None, sort=None,
project_safe=True):
return IMPL.receiver_get_all(context, limit=limit, marker=marker,
sort_keys=sort_keys, sort_dir=sort_dir,
filters=filters, project_safe=project_safe)
sort=sort, filters=filters,
project_safe=project_safe)
def receiver_delete(context, receiver_id, force=False):

View File

@ -60,6 +60,10 @@ def get_backend():
return sys.modules[__name__]
def _session(context):
return (context and context.session) or get_session()
def model_query(context, *args):
session = _session(context)
query = session.query(*args)
@ -92,43 +96,6 @@ def _get_sort_params(value, whitelist, default_key=None):
return keys, dirs
def _get_sort_keys(keys, whitelist):
"""Returns an array containing only whitelisted keys
:param keys: an array of strings or a single string
:param whitelist: a list of allowed keys
:returns: filtered list of sort keys
"""
if keys is None:
return None
if isinstance(keys, six.string_types):
keys = [keys]
return [k for k in keys if k in whitelist]
def _paginate_query(context, query, model, limit=None, marker=None,
sort_keys=None, sort_dir=None, default_sort_keys=None):
if not sort_keys:
sort_keys = default_sort_keys or []
if not sort_dir:
sort_dir = 'asc'
# This assures the order of the clusters will always be the same
# even for sort_key values that are not unique in the database
sort_keys = sort_keys + ['id']
model_marker = None
if marker:
model_marker = model_query(context, model).get(marker)
try:
query = utils.paginate_query(query, model, limit, sort_keys,
model_marker, sort_dir)
except utils.InvalidSortKey:
raise exception.InvalidParameter(name='sort_keys', value=sort_keys)
return query
# TODO(Yanyan Hu): Set default value of project_safe to True
def query_by_short_id(context, model, short_id, project_safe=False):
q = model_query(context, model)
@ -161,10 +128,6 @@ def query_by_name(context, model, name, project_safe=False):
raise exception.MultipleChoices(arg=name)
def _session(context):
return (context and context.session) or get_session()
# Clusters
def cluster_create(context, values):
cluster_ref = models.Cluster()
@ -211,11 +174,10 @@ def cluster_get_all(context, limit=None, marker=None, sort=None, filters=None,
project_safe=True, show_nested=False):
query = _query_cluster_get_all(context, project_safe=project_safe,
show_nested=show_nested)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Cluster, filters)
keys, dirs = _get_sort_params(sort, consts.CLUSTER_SORT_KEYS, 'init_at')
query = db_filters.exact_filter(query, models.Cluster, filters)
if marker:
marker = model_query(context, models.Cluster).get(marker)
@ -325,11 +287,10 @@ def node_get_all(context, cluster_id=None, limit=None, marker=None, sort=None,
query = _query_node_get_all(context, project_safe=project_safe,
cluster_id=cluster_id)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Node, filters)
keys, dirs = _get_sort_params(sort, consts.NODE_SORT_KEYS, 'init_at')
query = db_filters.exact_filter(query, models.Node, filters)
if marker:
marker = model_query(context, models.Node).get(marker)
return utils.paginate_query(query, models.Node, limit, keys,
@ -580,11 +541,10 @@ def policy_get_all(context, limit=None, marker=None, sort=None, filters=None,
if project_safe:
query = query.filter_by(project=context.project)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Policy, filters)
keys, dirs = _get_sort_params(sort, consts.POLICY_SORT_KEYS, 'created_at')
query = db_filters.exact_filter(query, models.Policy, filters)
if marker:
marker = model_query(context, models.Policy).get(marker)
return utils.paginate_query(query, models.Policy, limit, keys,
@ -631,13 +591,12 @@ def cluster_policy_get_all(context, cluster_id, filters=None, sort=None):
query = model_query(context, models.ClusterPolicies)
query = query.filter_by(cluster_id=cluster_id)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.ClusterPolicies, filters)
keys, dirs = _get_sort_params(sort, consts.CLUSTER_POLICY_SORT_KEYS,
'priority')
query = db_filters.exact_filter(query, models.ClusterPolicies, filters)
return utils.paginate_query(query, models.ClusterPolicies, None, keys,
sort_dirs=dirs).all()
@ -714,11 +673,10 @@ def profile_get_all(context, limit=None, marker=None, sort=None, filters=None,
if project_safe:
query = query.filter_by(project=context.project)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Profile, filters)
keys, dirs = _get_sort_params(sort, consts.PROFILE_SORT_KEYS, 'created_at')
query = db_filters.exact_filter(query, models.Profile, filters)
if marker:
marker = model_query(context, models.Profile).get(marker)
return utils.paginate_query(query, models.Profile, limit, keys,
@ -837,11 +795,10 @@ def event_get_by_short_id(context, short_id, project_safe=True):
def _event_filter_paginate_query(context, query, filters=None,
limit=None, marker=None, sort=None):
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Event, filters)
keys, dirs = _get_sort_params(sort, consts.EVENT_SORT_KEYS, 'timestamp')
query = db_filters.exact_filter(query, models.Event, filters)
if marker:
marker = model_query(context, models.Event).get(marker)
return utils.paginate_query(query, models.Event, limit, keys,
@ -929,11 +886,10 @@ def action_get_all(context, filters=None, limit=None, marker=None, sort=None,
# if project_safe:
# query = query.filter_by(project=context.project)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Action, filters)
keys, dirs = _get_sort_params(sort, consts.ACTION_SORT_KEYS, 'created_at')
query = db_filters.exact_filter(query, models.Action, filters)
if marker:
marker = model_query(context, models.Action).get(marker)
return utils.paginate_query(query, models.Action, limit, keys,
@ -1195,21 +1151,20 @@ def receiver_get(context, receiver_id, project_safe=True):
return receiver
def receiver_get_all(context, limit=None, marker=None, filters=None,
sort_keys=None, sort_dir=None, project_safe=True):
def receiver_get_all(context, limit=None, marker=None, filters=None, sort=None,
project_safe=True):
query = model_query(context, models.Receiver)
if project_safe:
query = query.filter_by(project=context.project)
if filters is None:
filters = {}
if filters:
query = db_filters.exact_filter(query, models.Receiver, filters)
keys = _get_sort_keys(sort_keys, consts.RECEIVER_SORT_KEYS)
query = db_filters.exact_filter(query, models.Receiver, filters)
return _paginate_query(context, query, models.Receiver,
limit=limit, marker=marker,
sort_keys=keys, sort_dir=sort_dir,
default_sort_keys=['name']).all()
keys, dirs = _get_sort_params(sort, consts.RECEIVER_SORT_KEYS, 'name')
if marker:
marker = model_query(context, models.Receiver).get(marker)
return utils.paginate_query(query, models.Receiver, limit, keys,
marker=marker, sort_dirs=dirs).all()
def receiver_get_by_name(context, name, project_safe=True):

View File

@ -153,14 +153,12 @@ class Receiver(object):
return cls._from_db_record(receiver_obj)
@classmethod
def load_all(cls, context, limit=None, marker=None, sort_keys=None,
sort_dir=None, filters=None, project_safe=True):
def load_all(cls, context, limit=None, marker=None, sort=None,
filters=None, project_safe=True):
"""Retrieve all receivers from database."""
records = db_api.receiver_get_all(context, limit=limit, marker=marker,
sort_keys=sort_keys,
sort_dir=sort_dir,
filters=filters,
sort=sort, filters=filters,
project_safe=project_safe)
for record in records:

View File

@ -1468,8 +1468,8 @@ class EngineService(service.Service):
return receiver
@request_context
def receiver_list(self, context, limit=None, marker=None, sort_keys=None,
sort_dir=None, filters=None, project_safe=True):
def receiver_list(self, context, limit=None, marker=None, sort=None,
filters=None, project_safe=True):
if limit is not None:
limit = utils.parse_int_param('limit', limit)
if project_safe is not None:
@ -1477,9 +1477,7 @@ class EngineService(service.Service):
receivers = receiver_mod.Receiver.load_all(context, limit=limit,
marker=marker,
sort_keys=sort_keys,
sort_dir=sort_dir,
filters=filters,
sort=sort, filters=filters,
project_safe=project_safe)
return [r.to_dict() for r in receivers]

View File

@ -306,13 +306,11 @@ class EngineClient(object):
self.make_msg('action_get', identity=identity))
def receiver_list(self, ctxt, limit=None, marker=None, filters=None,
sort_keys=None, sort_dir=None, project_safe=True):
sort=None, project_safe=True):
return self.call(ctxt,
self.make_msg('receiver_list',
self.make_msg('receiver_list', filters=filters,
limit=limit, marker=marker,
sort_keys=sort_keys, sort_dir=sort_dir,
filters=filters,
project_safe=project_safe))
sort=sort, project_safe=project_safe))
def receiver_create(self, ctxt, name, type_name, cluster_id, action,
actor=None, params=None):

View File

@ -102,8 +102,7 @@ class ReceiverControllerTest(shared.ControllerTest, base.SenlinTestCase):
result = self.controller.index(req)
default_args = {'limit': None, 'marker': None,
'sort_keys': None, 'sort_dir': None,
default_args = {'limit': None, 'marker': None, 'sort': None,
'filters': None, 'project_safe': True}
mock_call.assert_called_with(req.context,
@ -117,8 +116,7 @@ class ReceiverControllerTest(shared.ControllerTest, base.SenlinTestCase):
params = {
'limit': 20,
'marker': 'fake marker',
'sort_keys': 'fake sort keys',
'sort_dir': 'fake sort dir',
'sort': 'fake sorting string',
'project_safe': True,
'filters': None,
'balrog': 'you shall not pass!'
@ -133,11 +131,10 @@ class ReceiverControllerTest(shared.ControllerTest, base.SenlinTestCase):
rpc_call_args, _ = mock_call.call_args
engine_args = rpc_call_args[1][1]
self.assertEqual(6, len(engine_args))
self.assertEqual(5, len(engine_args))
self.assertIn('limit', engine_args)
self.assertIn('marker', engine_args)
self.assertIn('sort_keys', engine_args)
self.assertIn('sort_dir', engine_args)
self.assertIn('sort', engine_args)
self.assertIn('filters', engine_args)
self.assertIn('project_safe', engine_args)
self.assertNotIn('tenant_safe', engine_args)

View File

@ -381,73 +381,6 @@ class DBAPIClusterTest(base.SenlinTestCase):
self.assertRaises(exception.ClusterNotFound,
db_api.cluster_update, self.ctx, UUID2, values)
def test_get_sort_keys_returns_empty_list_if_no_keys(self):
sort_keys = None
valid_keys = ['foo', 'bar']
filtered_keys = db_api._get_sort_keys(sort_keys, valid_keys)
self.assertIsNone(filtered_keys)
def test_get_sort_keys_whitelists_single_key(self):
sort_keys = 'foo'
valid_keys = ['foo', 'bar']
filtered_keys = db_api._get_sort_keys(sort_keys, valid_keys)
self.assertEqual(['foo'], filtered_keys)
def test_get_sort_keys_whitelists_multiple_keys(self):
sort_keys = ['foo', 'bar', 'nope']
valid_keys = ['foo', 'bar']
filtered_keys = db_api._get_sort_keys(sort_keys, valid_keys)
self.assertIn('foo', filtered_keys)
self.assertIn('bar', filtered_keys)
self.assertEqual(2, len(filtered_keys))
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_raises_invalid_sort_key(self, mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
mock_paginate_query.side_effect = db_api.utils.InvalidSortKey()
self.assertRaises(exception.InvalidParameter, db_api._paginate_query,
self.ctx, query, model, sort_keys=['foo'])
@mock.patch.object(db_api.utils, 'paginate_query')
@mock.patch.object(db_api, 'model_query')
def test_paginate_query_gets_model_marker(self, mock_query,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
marker = mock.Mock()
mock_query_object = mock.Mock()
mock_query_object.get.return_value = 'real_marker'
mock_query.return_value = mock_query_object
db_api._paginate_query(self.ctx, query, model, marker=marker)
mock_query_object.get.assert_called_once_with(marker)
args, _ = mock_paginate_query.call_args
self.assertIn('real_marker', args)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_default_sorts_dir_by_desc(self,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
db_api._paginate_query(self.ctx, query, model, sort_dir=None)
args, _ = mock_paginate_query.call_args
self.assertIn('asc', args)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_uses_given_sort_plus_id(self,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
db_api._paginate_query(self.ctx, query, model, sort_keys=['name'])
args, _ = mock_paginate_query.call_args
self.assertIn(['name', 'id'], args)
def test_nested_cluster_get_by_name(self):
cluster1 = shared.create_cluster(self.ctx, self.profile,
name='cluster1')

View File

@ -12,6 +12,7 @@
from oslo_utils import timeutils as tu
from senlin.common import consts
from senlin.common import exception
from senlin.db.sqlalchemy import api as db_api
from senlin.tests.unit.common import base
@ -175,21 +176,14 @@ class DBAPIReceiverTest(base.SenlinTestCase):
self._create_receiver(self.ctx, id=v)
mock_paginate = self.patchobject(db_api.utils, 'paginate_query')
sort_keys = ['name', 'type', 'cluster_id', 'action', 'created_at']
sort_keys = consts.RECEIVER_SORT_KEYS
db_api.receiver_get_all(self.ctx, sort_keys=sort_keys)
db_api.receiver_get_all(self.ctx, sort=','.join(sort_keys))
args = mock_paginate.call_args[0]
used_sort_keys = set(args[3])
expected_keys = set(['id', 'name', 'type', 'cluster_id', 'action',
'created_at'])
self.assertEqual(expected_keys, used_sort_keys)
sort_keys.append('id')
self.assertEqual(set(sort_keys), set(args[3]))
def test_receiver_get_all_sort_keys_wont_change(self):
sort_keys = ['id']
db_api.receiver_get_all(self.ctx, sort_keys=sort_keys)
self.assertEqual(['id'], sort_keys)
def test_receiver_get_all_sort_keys_and_dir(self):
def test_receiver_get_all_sorting(self):
values = [{'id': '001', 'name': 'receiver1'},
{'id': '002', 'name': 'receiver3'},
{'id': '003', 'name': 'receiver2'}]
@ -199,18 +193,14 @@ class DBAPIReceiverTest(base.SenlinTestCase):
for v in values:
self._create_receiver(self.ctx, cluster_id=obj_ids[v['name']], **v)
receivers = db_api.receiver_get_all(self.ctx,
sort_keys=['name', 'cluster_id'],
sort_dir='asc')
receivers = db_api.receiver_get_all(self.ctx, sort='name,cluster_id')
self.assertEqual(3, len(receivers))
# Sorted by name (ascending)
self.assertEqual('001', receivers[0].id)
self.assertEqual('003', receivers[1].id)
self.assertEqual('002', receivers[2].id)
receivers = db_api.receiver_get_all(self.ctx,
sort_keys=['cluster_id', 'name'],
sort_dir='asc')
receivers = db_api.receiver_get_all(self.ctx, sort='cluster_id,name')
self.assertEqual(3, len(receivers))
# Sorted by obj_id (ascending)
self.assertEqual('002', receivers[0].id)
@ -218,15 +208,14 @@ class DBAPIReceiverTest(base.SenlinTestCase):
self.assertEqual('001', receivers[2].id)
receivers = db_api.receiver_get_all(self.ctx,
sort_keys=['cluster_id', 'name'],
sort_dir='desc')
sort='cluster_id:desc,name:desc')
self.assertEqual(3, len(receivers))
# Sorted by obj_id (descending)
self.assertEqual('001', receivers[0].id)
self.assertEqual('003', receivers[1].id)
self.assertEqual('002', receivers[2].id)
def test_receiver_get_all_default_sort_dir(self):
def test_receiver_get_all_sorting_default(self):
values = [{'id': '001', 'name': 'receiver1'},
{'id': '002', 'name': 'receiver2'},
{'id': '003', 'name': 'receiver3'}]
@ -236,7 +225,7 @@ class DBAPIReceiverTest(base.SenlinTestCase):
for v in values:
self._create_receiver(self.ctx, cluster_id=obj_ids[v['name']], **v)
receivers = db_api.receiver_get_all(self.ctx, sort_dir='asc')
receivers = db_api.receiver_get_all(self.ctx)
self.assertEqual(3, len(receivers))
self.assertEqual(values[0]['id'], receivers[0].id)
self.assertEqual(values[1]['id'], receivers[1].id)

View File

@ -432,8 +432,7 @@ class EngineRpcAPITestCase(base.SenlinTestCase):
default_args = {
'limit': mock.ANY,
'marker': mock.ANY,
'sort_keys': mock.ANY,
'sort_dir': mock.ANY,
'sort': mock.ANY,
'filters': mock.ANY,
'project_safe': mock.ANY,
}