diff --git a/sahara/api/v10.py b/sahara/api/v10.py index 3bbc4f1935..1804411ac6 100644 --- a/sahara/api/v10.py +++ b/sahara/api/v10.py @@ -41,7 +41,8 @@ rest = u.Rest('v10', __name__) @rest.get('/clusters') @acl.enforce("data-processing:clusters:get_all") @v.check_exists(api.get_cluster, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_clusters) def clusters_list(): result = api.get_clusters(**u.get_request_args().to_dict()) return u.render(res=result, name='clusters') @@ -102,7 +103,8 @@ def clusters_delete(cluster_id): @rest.get('/cluster-templates') @acl.enforce("data-processing:cluster-templates:get_all") @v.check_exists(api.get_cluster_template, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_cluster_templates) def cluster_templates_list(): result = api.get_cluster_templates( **u.get_request_args().to_dict()) @@ -149,7 +151,8 @@ def cluster_templates_delete(cluster_template_id): @rest.get('/node-group-templates') @acl.enforce("data-processing:node-group-templates:get_all") @v.check_exists(api.get_node_group_template, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_node_group_templates) def node_group_templates_list(): result = api.get_node_group_templates( **u.get_request_args().to_dict()) diff --git a/sahara/api/v11.py b/sahara/api/v11.py index d5cce77b7b..e55c05a0b7 100644 --- a/sahara/api/v11.py +++ b/sahara/api/v11.py @@ -41,7 +41,8 @@ rest = u.Rest('v11', __name__) @rest.get('/job-executions') @acl.enforce("data-processing:job-executions:get_all") @v.check_exists(api.get_job_execution, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_job_executions) def job_executions_list(): result = api.job_execution_list( **u.get_request_args().to_dict()) @@ -94,7 +95,8 @@ def job_executions_delete(job_execution_id): @rest.get('/data-sources') @acl.enforce("data-processing:data-sources:get_all") @v.check_exists(api.get_data_source, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_data_sources) def data_sources_list(): result = api.get_data_sources(**u.get_request_args().to_dict()) return u.render(res=result, name='data_sources') @@ -136,7 +138,8 @@ def data_source_update(data_source_id, data): @rest.get('/jobs') @acl.enforce("data-processing:jobs:get_all") @v.check_exists(api.get_job, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_jobs) def job_list(): result = api.get_jobs(**u.get_request_args().to_dict()) @@ -211,7 +214,8 @@ def job_binary_create(data): @rest.get('/job-binaries') @acl.enforce("data-processing:job-binaries:get_all") @v.check_exists(api.get_job_binaries, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_job_binaries) def job_binary_list(): result = api.get_job_binaries(**u.get_request_args().to_dict()) @@ -263,7 +267,8 @@ def job_binary_internal_create(**values): @rest.get('/job-binary-internals') @acl.enforce("data-processing:job-binary-internals:get_all") @v.check_exists(api.get_job_binary_internal, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_job_binary_internals) def job_binary_internal_list(): result = api.get_job_binary_internals(**u.get_request_args().to_dict()) return u.render(res=result, name='binaries') diff --git a/sahara/api/v2/cluster_templates.py b/sahara/api/v2/cluster_templates.py index 988d3da844..694be8b3b0 100644 --- a/sahara/api/v2/cluster_templates.py +++ b/sahara/api/v2/cluster_templates.py @@ -27,7 +27,8 @@ rest = u.RestV2('cluster-templates', __name__) @rest.get('/cluster-templates') @acl.enforce("data-processing:cluster-templates:get_all") @v.check_exists(api.get_cluster_template, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_cluster_templates) def cluster_templates_list(): result = api.get_cluster_templates(**u.get_request_args().to_dict()) return u.render(res=result, name='cluster_templates') diff --git a/sahara/api/v2/data_sources.py b/sahara/api/v2/data_sources.py index 4f470a38de..eb859b237d 100644 --- a/sahara/api/v2/data_sources.py +++ b/sahara/api/v2/data_sources.py @@ -27,7 +27,8 @@ rest = u.RestV2('data-sources', __name__) @rest.get('/data-sources') @acl.enforce("data-processing:data-sources:get_all") @v.check_exists(api.get_data_source, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_data_sources) def data_sources_list(): result = api.get_data_sources(**u.get_request_args().to_dict()) return u.render(res=result, name='data_sources') diff --git a/sahara/api/v2/job_binaries.py b/sahara/api/v2/job_binaries.py index 3891f263d4..d4090cdc01 100644 --- a/sahara/api/v2/job_binaries.py +++ b/sahara/api/v2/job_binaries.py @@ -36,7 +36,8 @@ def job_binary_create(data): @rest.get('/job-binaries') @acl.enforce("data-processing:job-binaries:get_all") @v.check_exists(api.get_job_binary, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_job_binaries) def job_binary_list(): result = api.get_job_binaries(**u.get_request_args().to_dict()) return u.render(res=result, name='binaries') diff --git a/sahara/api/v2/job_executions.py b/sahara/api/v2/job_executions.py index 578e78a7af..b968676519 100644 --- a/sahara/api/v2/job_executions.py +++ b/sahara/api/v2/job_executions.py @@ -28,7 +28,8 @@ rest = u.RestV2('job-executions', __name__) @rest.get('/jobs') @acl.enforce("data-processing:job-executions:get_all") @v.check_exists(api.get_job_execution, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_job_executions) def job_executions_list(): result = api.job_execution_list(**u.get_request_args().to_dict()) return u.render(res=result, name='jobs') diff --git a/sahara/api/v2/jobs.py b/sahara/api/v2/jobs.py index 4cbab1ddc2..df31b1e2c1 100644 --- a/sahara/api/v2/jobs.py +++ b/sahara/api/v2/jobs.py @@ -30,7 +30,8 @@ rest = u.RestV2('jobs', __name__) @rest.get('/job-templates') @acl.enforce("data-processing:jobs:get_all") @v.check_exists(api.get_job, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_jobs) def job_list(): result = api.get_jobs(**u.get_request_args().to_dict()) return u.render(res=result, name='job_templates') diff --git a/sahara/api/v2/node_group_templates.py b/sahara/api/v2/node_group_templates.py index 1765c255bd..344eb77bbc 100644 --- a/sahara/api/v2/node_group_templates.py +++ b/sahara/api/v2/node_group_templates.py @@ -28,7 +28,8 @@ rest = u.RestV2('node-group-templates', __name__) @rest.get('/node-group-templates') @acl.enforce("data-processing:node-group-templates:get_all") @v.check_exists(api.get_node_group_template, 'marker') -@v.validate(None, v.validate_pagination_limit) +@v.validate(None, v.validate_pagination_limit, + v.validate_sorting_node_group_templates) def node_group_templates_list(): result = api.get_node_group_templates(**u.get_request_args().to_dict()) return u.render(res=result, name="node_group_templates") diff --git a/sahara/db/sqlalchemy/api.py b/sahara/db/sqlalchemy/api.py index 817f727232..979d6e4bd8 100644 --- a/sahara/db/sqlalchemy/api.py +++ b/sahara/db/sqlalchemy/api.py @@ -64,7 +64,17 @@ def get_session(**kwargs): return facade.get_session(**kwargs) -def _get_prev_and_next_objects(objects, limit, marker): +def _parse_sorting_args(sort_by): + if sort_by is None: + sort_by = "id" + if sort_by[0] == "-": + return sort_by[1:], "desc" + return sort_by, "asc" + + +def _get_prev_and_next_objects(objects, limit, marker, order=None): + if order == 'desc': + objects.reverse() position = None if limit is None: return None, None @@ -281,8 +291,9 @@ def cluster_get(context, cluster_id): def cluster_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): + limit=None, marker=None, sort_by=None, **kwargs): + sort_by, order = _parse_sorting_args(sort_by) regex_cols = ['name', 'description', 'plugin_name'] query = model_query(m.Cluster, context) @@ -294,10 +305,11 @@ def cluster_get_all(context, regex_search=False, marker = cluster_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), m.Cluster, - limit, ['name'], marker) + limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -492,10 +504,10 @@ def cluster_template_get(context, cluster_template_id): def cluster_template_get_all(context, regex_search=False, - marker=None, limit=None, **kwargs): + marker=None, limit=None, sort_by=None, **kwargs): regex_cols = ['name', 'description', 'plugin_name'] - + sort_by, order = _parse_sorting_args(sort_by) query = model_query(m.ClusterTemplate, context) if regex_search: query, kwargs = regex_filter(query, @@ -506,10 +518,12 @@ def cluster_template_get_all(context, regex_search=False, marker = cluster_template_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), - m.ClusterTemplate, limit, ['name'], marker) + m.ClusterTemplate, + limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -625,9 +639,9 @@ def node_group_template_get(context, node_group_template_id): node_group_template_id) -def node_group_template_get_all(context, regex_search=False, - marker=None, limit=None, **kwargs): - +def node_group_template_get_all(context, regex_search=False, marker=None, + limit=None, sort_by=None, **kwargs): + sort_by, order = _parse_sorting_args(sort_by) regex_cols = ['name', 'description', 'plugin_name'] limit = int(limit) if limit else None query = model_query(m.NodeGroupTemplate, context) @@ -638,11 +652,12 @@ def node_group_template_get_all(context, regex_search=False, marker = node_group_template_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query( query.filter_by(**kwargs), m.NodeGroupTemplate, - limit, ['name'], marker) + limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -775,10 +790,12 @@ def data_source_count(context, **kwargs): def data_source_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): + limit=None, marker=None, sort_by=None, **kwargs): regex_cols = ['name', 'description', 'url'] + sort_by, order = _parse_sorting_args(sort_by) + query = model_query(m.DataSource, context) if regex_search: query, kwargs = regex_filter(query, @@ -788,10 +805,11 @@ def data_source_get_all(context, regex_search=False, marker = data_source_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), m.DataSource, - limit, ['name'], marker) + limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -864,7 +882,7 @@ def job_execution_get(context, job_execution_id): def job_execution_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): + limit=None, marker=None, sort_by=None, **kwargs): """Get all JobExecutions filtered by **kwargs. kwargs key values may be the names of fields in a JobExecution @@ -879,6 +897,8 @@ def job_execution_get_all(context, regex_search=False, 'job.name': 'wordcount'}) """ + sort_by, order = _parse_sorting_args(sort_by) + regex_cols = ['job.name', 'cluster.name'] # Remove the external fields if present, they'll @@ -918,7 +938,10 @@ def job_execution_get_all(context, regex_search=False, m.Job, ['name'], search_opts) query = query.filter_by(**search_opts) - res = query.all() + res = query.order_by(sort_by).all() + + if order == 'desc': + res.reverse() # 'info' is a JsonDictType which is stored as a string. # It would be possible to search for the substring containing @@ -1051,10 +1074,10 @@ def job_get(context, job_id): def job_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): + limit=None, marker=None, sort_by=None, **kwargs): regex_cols = ['name', 'description'] - + sort_by, order = _parse_sorting_args(sort_by) query = model_query(m.Job, context) if regex_search: query, kwargs = regex_filter(query, @@ -1064,10 +1087,11 @@ def job_get_all(context, regex_search=False, marker = job_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), - m.Job, limit, ['name'], marker) + m.Job, limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -1168,7 +1192,9 @@ def _job_binary_get(context, session, job_binary_id): def job_binary_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): + limit=None, marker=None, sort_by=None, **kwargs): + + sort_by, order = _parse_sorting_args(sort_by) regex_cols = ['name', 'description', 'url'] query = model_query(m.JobBinary, context) @@ -1180,10 +1206,12 @@ def job_binary_get_all(context, regex_search=False, marker = job_binary_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), - m.JobBinary, limit, ['name'], marker) + m.JobBinary, + limit, [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) @@ -1293,12 +1321,13 @@ def _job_binary_internal_get(context, session, job_binary_internal_id): return query.filter_by(id=job_binary_internal_id).first() -def job_binary_internal_get_all(context, regex_search=False, - limit=None, marker=None, **kwargs): +def job_binary_internal_get_all(context, regex_search=False, limit=None, + marker=None, sort_by=None, **kwargs): """Returns JobBinaryInternal objects that do not contain a data field The data column uses deferred loading. """ + sort_by, order = _parse_sorting_args(sort_by) regex_cols = ['name'] @@ -1311,11 +1340,12 @@ def job_binary_internal_get_all(context, regex_search=False, marker = job_binary_internal_get(context, marker) prev_marker, next_marker = _get_prev_and_next_objects( - query.filter_by(**kwargs).order_by('name').all(), limit, marker) + query.filter_by(**kwargs).order_by(sort_by).all(), + limit, marker, order=order) result = utils.paginate_query(query.filter_by(**kwargs), m.JobBinaryInternal, limit, - ['name'], marker) + [sort_by], marker, order) return types.Page(result, prev_marker, next_marker) diff --git a/sahara/service/validation.py b/sahara/service/validation.py index 68c9d05e1c..d0e9a10ffe 100644 --- a/sahara/service/validation.py +++ b/sahara/service/validation.py @@ -51,6 +51,91 @@ def validate_pagination_limit(): _("'limit' must be positive integer"), 400) +def get_sorting_field(): + request_args = u.get_request_args() + if 'sort_by' in request_args: + sort_by = request_args['sort_by'] + if sort_by: + sort_by = sort_by[1:] if sort_by[0] == '-' else sort_by + return sort_by + return None + + +def validate_sorting_clusters(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'plugin_name', 'hadoop_version', + 'status']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_cluster_templates(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'plugin_name', 'hadoop_version', + 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_node_group_templates(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'plugin_name', 'hadoop_version', + 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_job_binaries(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_job_binary_internals(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_data_sources(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'type', 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_jobs(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'name', 'type', 'created_at', 'updated_at']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + +def validate_sorting_job_executions(): + field = get_sorting_field() + if field is None: + return + if field not in ['id', 'job_template', 'cluster', 'status']: + raise ex.SaharaException( + _("Unknown field for sorting %s") % field, 400) + + def validate(schema, *validators): def decorator(func): @functools.wraps(func) diff --git a/sahara/tests/unit/db/test_utils.py b/sahara/tests/unit/db/test_utils.py index 2b6400ebcb..545a9c2286 100644 --- a/sahara/tests/unit/db/test_utils.py +++ b/sahara/tests/unit/db/test_utils.py @@ -40,6 +40,10 @@ class TestPaginationUtils(testtools.TestCase): res = api._get_prev_and_next_objects(query, 5, mock.MagicMock(id=4)) self.assertEqual((None, 9), res) + def test_parse_sorting_args(self): + self.assertEqual(("name", "desc"), api._parse_sorting_args("-name")) + self.assertEqual(("name", "asc"), api._parse_sorting_args("name")) + class TestRegex(testtools.TestCase):