diff --git a/glare/db/sqlalchemy/api.py b/glare/db/sqlalchemy/api.py index 674aed0..93c4488 100644 --- a/glare/db/sqlalchemy/api.py +++ b/glare/db/sqlalchemy/api.py @@ -421,13 +421,7 @@ def _do_query_filters(filters): tags = utils.split_filter_value_for_quotes(value) tag_conds.append([models.ArtifactTag.value.in_(tags)]) elif field_name in BASE_ARTIFACT_PROPERTIES: - if op != 'in': - fn = op_mappings[op] - if field_name == 'version': - value = semver_db.parse(value) - basic_conds.append([fn(getattr(models.Artifact, field_name), - value)]) - else: + if op == 'in': if field_name == 'version': value = [semver_db.parse(val) for val in value] basic_conds.append( @@ -436,6 +430,15 @@ def _do_query_filters(filters): else: basic_conds.append( [getattr(models.Artifact, field_name).in_(value)]) + elif op == 'like': + basic_conds.append( + [getattr(models.Artifact, field_name).like(value)]) + else: + fn = op_mappings[op] + if field_name == 'version': + value = semver_db.parse(value) + basic_conds.append([fn(getattr(models.Artifact, field_name), + value)]) else: conds = [models.ArtifactProperty.name == field_name] if key_name is not None: @@ -446,13 +449,16 @@ def _do_query_filters(filters): conds.extend( [models.ArtifactProperty.key_name.in_(key_name)]) if value is not None: - if op != 'in': + if op == 'in': + conds.extend([getattr(models.ArtifactProperty, + field_type + '_value').in_(value)]) + elif op == 'like': + conds.extend( + [models.ArtifactProperty.string_value.like(value)]) + else: fn = op_mappings[op] conds.extend([fn(getattr(models.ArtifactProperty, field_type + '_value'), value)]) - else: - conds.extend([getattr(models.ArtifactProperty, - field_type + '_value').in_(value)]) prop_conds.append(conds) diff --git a/glare/objects/meta/wrappers.py b/glare/objects/meta/wrappers.py index 59463a2..0069d1a 100644 --- a/glare/objects/meta/wrappers.py +++ b/glare/objects/meta/wrappers.py @@ -22,7 +22,8 @@ from glare.objects.meta import validators as val_lib FILTERS = ( FILTER_EQ, FILTER_NEQ, FILTER_IN, FILTER_GT, FILTER_GTE, FILTER_LT, - FILTER_LTE) = ('eq', 'neq', 'in', 'gt', 'gte', 'lt', 'lte') + FILTER_LTE, FILTER_LIKE) = ('eq', 'neq', 'in', 'gt', 'gte', 'lt', 'lte', + 'like') DEFAULT_MAX_BLOB_SIZE = 10485760 # 10 Megabytes DEFAULT_MAX_FOLDER_SIZE = 2673868800 # 2550 Megabytes @@ -83,15 +84,17 @@ class Field(object): self.sortable = sortable try: - default_ops = self.get_allowed_filter_ops(self.element_type) + default_ops = self.get_default_filter_ops(self.element_type) + allowed_ops = self.get_allowed_filter_ops(self.element_type) except AttributeError: - default_ops = self.get_allowed_filter_ops(field_class) + default_ops = self.get_default_filter_ops(field_class) + allowed_ops = self.get_allowed_filter_ops(self.field_class) if filter_ops is None: self.filter_ops = default_ops else: for op in filter_ops: - if op not in default_ops: + if op not in allowed_ops: raise exc.IncorrectArtifactType( "Incorrect filter operator '%s'. " "Only %s are allowed" % (op, ', '.join(default_ops))) @@ -104,7 +107,7 @@ class Field(object): @staticmethod def get_allowed_filter_ops(field): if field in (fields.StringField, fields.String): - return [FILTER_EQ, FILTER_NEQ, FILTER_IN] + return [FILTER_EQ, FILTER_NEQ, FILTER_IN, FILTER_LIKE] elif field in (fields.IntegerField, fields.Integer, fields.FloatField, fields.Float, glare_fields.VersionField): return FILTERS @@ -116,6 +119,22 @@ class Field(object): elif field is fields.DateTimeField: return [FILTER_LT, FILTER_GT] + @staticmethod + def get_default_filter_ops(field): + if field in (fields.StringField, fields.String): + return [FILTER_EQ, FILTER_NEQ, FILTER_IN] + elif field in (fields.IntegerField, fields.Integer, fields.FloatField, + fields.Float, glare_fields.VersionField): + return [FILTER_EQ, FILTER_NEQ, FILTER_IN, FILTER_GT, FILTER_GTE, + FILTER_LT, FILTER_LTE] + elif field in (fields.FlexibleBooleanField, fields.FlexibleBoolean, + glare_fields.Link, glare_fields.LinkFieldType): + return [FILTER_EQ, FILTER_NEQ] + elif field in (glare_fields.BlobField, glare_fields.BlobFieldType): + return [] + elif field is fields.DateTimeField: + return [FILTER_LT, FILTER_GT] + def get_default_validators(self): default = [] if issubclass(self.field_class, fields.StringField): diff --git a/glare/tests/functional/test_schemas.py b/glare/tests/functional/test_schemas.py index 0db35ff..3eea734 100644 --- a/glare/tests/functional/test_schemas.py +++ b/glare/tests/functional/test_schemas.py @@ -422,7 +422,8 @@ fixtures = { u'required_on_activate': False, u'type': [u'object', u'null']}, - u'str1': {u'filter_ops': [u'eq', + u'str1': {u'filter_ops': [u'like', + u'eq', u'neq', u'in'], u'glareType': u'String', diff --git a/glare/tests/sample_artifact.py b/glare/tests/sample_artifact.py index 33cbd38..28bdb15 100644 --- a/glare/tests/sample_artifact.py +++ b/glare/tests/sample_artifact.py @@ -61,6 +61,10 @@ class SampleArtifact(base_artifact.BaseArtifact): sortable=True, required_on_activate=False), 'str1': Field(fields.StringField, + filter_ops=(wrappers.FILTER_LIKE, + wrappers.FILTER_EQ, + wrappers.FILTER_NEQ, + wrappers.FILTER_IN), sortable=True, required_on_activate=False), 'list_of_str': List(fields.String, diff --git a/glare/tests/unit/api/test_list.py b/glare/tests/unit/api/test_list.py index a5f087a..fda93e7 100644 --- a/glare/tests/unit/api/test_list.py +++ b/glare/tests/unit/api/test_list.py @@ -584,3 +584,30 @@ class TestArtifactList(base.BaseTestArtifactAPI): exc.BadRequest, self.controller.list, self.req, 'sample_artifact', [], sort=[(name, sort_dir)]) + + def test_list_like_filter(self): + val = {'name': '0', 'str1': 'banana'} + art0 = self.controller.create(self.req, 'sample_artifact', val) + val = {'name': '1', 'str1': 'nan'} + art1 = self.controller.create(self.req, 'sample_artifact', val) + val = {'name': '2', 'str1': 'anab'} + self.controller.create(self.req, 'sample_artifact', val) + + filters = [('str1', 'like:%banana%')] + res = self.controller.list(self.req, 'sample_artifact', filters) + self.assertEqual(1, len(res['artifacts'])) + self.assertIn(art0, res['artifacts']) + + filters = [('str1', 'like:%nan%')] + res = self.controller.list(self.req, 'sample_artifact', filters) + self.assertEqual(2, len(res['artifacts'])) + self.assertIn(art0, res['artifacts']) + self.assertIn(art1, res['artifacts']) + + filters = [('str1', 'like:%na%')] + res = self.controller.list(self.req, 'sample_artifact', filters) + self.assertEqual(3, len(res['artifacts'])) + + filters = [('str1', 'like:%haha%')] + res = self.controller.list(self.req, 'sample_artifact', filters) + self.assertEqual(0, len(res['artifacts']))