diff --git a/os_refresh_config/os_refresh_config.py b/os_refresh_config/os_refresh_config.py index 4704319..e50776a 100755 --- a/os_refresh_config/os_refresh_config.py +++ b/os_refresh_config/os_refresh_config.py @@ -18,10 +18,13 @@ import argparse import fcntl import logging import os +import signal import subprocess import sys import time +import psutil + OLD_BASE_DIR = '/opt/stack/os-config-refresh' DEFAULT_BASE_DIR = '/usr/libexec/os-refresh-config' @@ -55,6 +58,21 @@ PHASES = ['pre-configure', 'migration'] +def timeout(): + p = psutil.Process() + children = list(p.get_children(recursive=True)) + for child in children: + child.kill() + + +def exit(lock, statuscode=0): + signal.alarm(0) + if lock: + lock.truncate(0) + lock.close() + return statuscode + + def main(argv=sys.argv): parser = argparse.ArgumentParser( description="""Runs through all of the phases to ensure @@ -72,6 +90,10 @@ def main(argv=sys.argv): parser.add_argument('--lockfile', default='/var/run/os-refresh-config.lock', help='Lock file to prevent multiple running copies.') + parser.add_argument('--timeout', + type=int, + help='Seconds until the current run will be ' + 'terminated.') options = parser.parse_args(argv[1:]) if options.print_base: @@ -101,6 +123,15 @@ def main(argv=sys.argv): lock.truncate(0) lock.write("Locked by pid==%d at %s\n" % (os.getpid(), time.localtime())) + def timeout_handler(signum, frame): + log.error('Timeout reached: %ss. Sending SIGKILL to all children' % + options.timeout) + timeout() + + if options.timeout: + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(options.timeout) + for phase in PHASES: phase_dir = os.path.join(BASE_DIR, '%s.d' % phase) log.debug('Checking %s' % phase_dir) @@ -124,13 +155,11 @@ def main(argv=sys.argv): except OSError: pass log.error("Aborting...") - return 1 + return exit(lock, 1) else: log.debug('No dir for phase %s' % phase) - lock.truncate(0) - lock.close() - return 0 + return exit(lock) if __name__ == '__main__': diff --git a/os_refresh_config/tests/test_cmd.py b/os_refresh_config/tests/test_cmd.py index cfc7f04..49038a8 100644 --- a/os_refresh_config/tests/test_cmd.py +++ b/os_refresh_config/tests/test_cmd.py @@ -116,7 +116,6 @@ exit %(returncode)s self._write_script('pre-configure', '20-pre-second', 99) self._write_script('configure', '10-conf-first', 0) returncode, stdout, stderr = self._run_orc() - print(stderr) self.assertEqual('\n'.join([ '10-pre-first starting', '10-pre-first done', @@ -126,6 +125,22 @@ exit %(returncode)s ]), stdout) self.assertEqual(1, returncode) + def test_cmd_with_timeout(self): + self._write_script('pre-configure', '10-pre-first', 0, 5) + self._write_script('pre-configure', '20-pre-second', 0, 5) + self._write_script('configure', '10-conf-first', 0, 5) + + now = time.time() + returncode, stdout, stderr = self._run_orc('--timeout', '2', + '--log-level', 'DEBUG') + # check run time accounts for the 2 seconds timeout + self.assertTrue(time.time() - now >= 2.0) + self.assertEqual('\n'.join([ + '10-pre-first starting', + '', + ]), stdout) + self.assertEqual(1, returncode) + def test_debug(self): returncode, stdout, stderr = self._run_orc('--log-level', 'DEBUG') self.assertEqual('', stdout) diff --git a/requirements.txt b/requirements.txt index d3eb0ec..c44400e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ # process, which may cause wedges in the gate later. pbr>=1.6 # Apache-2.0 dib-utils # Apache-2.0 +psutil>=1.1.1,<2.0.0 # BSD