diff --git a/os_traits/__init__.py b/os_traits/__init__.py index b5b6f4f..bb15a1f 100644 --- a/os_traits/__init__.py +++ b/os_traits/__init__.py @@ -82,14 +82,15 @@ def get_traits(prefix=None): ] -def check_traits(traits): +def check_traits(traits, prefix=None): """Returns a tuple of two trait string sets, the first set contains valid traits, and the second contains others. :param traits: An iterable contains trait strings. + :param prefix: Optional string prefix to filter by. e.g. 'HW_' """ trait_set = set(traits) - valid_trait_set = set(get_traits()) + valid_trait_set = set(get_traits(prefix)) valid_traits = trait_set & valid_trait_set diff --git a/os_traits/tests/test_os_traits.py b/os_traits/tests/test_os_traits.py index 2842fb6..66f1d3b 100644 --- a/os_traits/tests/test_os_traits.py +++ b/os_traits/tests/test_os_traits.py @@ -57,6 +57,18 @@ class TestSymbols(base.TestCase): self.assertEqual((traits, not_traits), ot.check_traits(check_traits)) + def test_check_traits_filter_by_prefix(self): + hw_trait = "HW_CPU_X86_SSE42" + storage_trait = "STORAGE_DISK_SSD" + + check_traits = [hw_trait, storage_trait] + self.assertEqual((set([hw_trait]), set([storage_trait])), + ot.check_traits(check_traits, "HW")) + self.assertEqual((set([storage_trait]), set([hw_trait])), + ot.check_traits(check_traits, "STORAGE")) + self.assertEqual((set(), set([hw_trait, storage_trait])), + ot.check_traits(check_traits, "MISC")) + def test_is_custom(self): self.assertTrue(ot.is_custom('CUSTOM_FOO')) self.assertFalse(ot.is_custom('HW_CPU_X86_SSE42'))