From 43275bb1fb1564c7c5f20c97656298531f0cc778 Mon Sep 17 00:00:00 2001 From: Charles Neill Date: Wed, 10 Aug 2016 18:30:26 -0500 Subject: [PATCH] Fixes a bug in "excluded tests" When specifying multiple test types to exclude, the "get_tests" method in runner would only honor the first exclusion even if more than one was specified. This patch allows multiple exclusion flags to be honored, and adds unit tests to ensure it functions as expected. Closes-Bug: #1612338 Change-Id: Ibbbc0ae27a7c5f2ce8e0f3e68f9b04aef72b8cdf --- syntribos/runner.py | 18 ++++++++++-------- tests/unit/test_runner.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/syntribos/runner.py b/syntribos/runner.py index 79064fcb..c8b98378 100644 --- a/syntribos/runner.py +++ b/syntribos/runner.py @@ -70,14 +70,16 @@ class Runner(object): cls.load_modules(tests) test_types = test_types or [""] excluded_types = excluded_types or [""] - for k, v in sorted(syntribos.tests.base.test_table.iteritems()): - for e in excluded_types: - if e and e in k: - break - else: - for t in test_types: - if t in k: - yield k, v + items = sorted(syntribos.tests.base.test_table.iteritems()) + included = [] + # Only include tests allowed by value in -t params + for t in test_types: + included += [x for x in items if t in x[0]] + # Exclude any tests that meet the above but are excluded by -e params + for e in excluded_types: + if e: + included = [x for x in included if e not in x[0]] + return (i for i in included) @staticmethod def print_symbol(): diff --git a/tests/unit/test_runner.py b/tests/unit/test_runner.py index a292d9d7..9738f742 100644 --- a/tests/unit/test_runner.py +++ b/tests/unit/test_runner.py @@ -27,10 +27,8 @@ class RunnerUnittest(testtools.TestCase): def _compare_tests(self, expected, loaded): """Compare list of expected test names with those that were loaded.""" - loaded_test_names = [] - for name, test in loaded: - self.assertIn(name, expected) - loaded_test_names.append(name) + # loaded_test_names = [] + loaded_test_names = [x[0] for x in loaded] self.assertEqual(expected, loaded_test_names) def test_get_LDAP_tests(self): @@ -100,3 +98,31 @@ class RunnerUnittest(testtools.TestCase): res1 = self.r.get_log_file_name() res2 = self.r.get_log_file_name() self.assertEqual(res1, res2) + + def test_get_sql_tests_exclude_header(self): + """Check that we get the right SQL tests when "HEADER" is excluded.""" + expected = [ + "SQL_INJECTION_BODY", "SQL_INJECTION_PARAMS", "SQL_INJECTION_URL"] + loaded_tests = self.r.get_tests(["SQL"], ["HEADER"]) + self._compare_tests(expected, loaded_tests) + + def test_get_sql_tests_exclude_header_url(self): + """Check that we get the right SQL tests, excluding HEADER/URL.""" + expected = [ + "SQL_INJECTION_BODY", "SQL_INJECTION_PARAMS"] + loaded_tests = self.r.get_tests(["SQL"], ["HEADER", "URL"]) + self._compare_tests(expected, loaded_tests) + + def test_get_sql_tests_exclude_header_url_body(self): + """Check that we get the right SQL tests, excluding HEADER/URL/BODY.""" + expected = ["SQL_INJECTION_PARAMS"] + loaded_tests = self.r.get_tests(["SQL"], ["HEADER", "URL", "BODY"]) + self._compare_tests(expected, loaded_tests) + + def test_get_rce_sql_tests_exclude_url_body(self): + """Check that we get the right SQL tests, excluding HEADER/URL/BODY.""" + expected = [ + "SQL_INJECTION_HEADERS", "SQL_INJECTION_PARAMS", + "COMMAND_INJECTION_HEADERS", "COMMAND_INJECTION_PARAMS"] + loaded_tests = self.r.get_tests(["SQL", "COMMAND"], ["URL", "BODY"]) + self._compare_tests(expected, loaded_tests)