diff --git a/cloudbaseinit/plugins/common/fileexecutils.py b/cloudbaseinit/plugins/common/fileexecutils.py index 2ea605f8..bf8880b0 100644 --- a/cloudbaseinit/plugins/common/fileexecutils.py +++ b/cloudbaseinit/plugins/common/fileexecutils.py @@ -12,35 +12,23 @@ # License for the specific language governing permissions and limitations # under the License. -import os - from oslo_log import log as oslo_logging -from cloudbaseinit.plugins.common import execcmd - +from cloudbaseinit.plugins.common import userdatautils LOG = oslo_logging.getLogger(__name__) -FORMATS = { - "cmd": execcmd.Shell, - "exe": execcmd.Shell, - "sh": execcmd.Bash, - "py": execcmd.Python, - "ps1": execcmd.PowershellSysnative, -} - def exec_file(file_path): ret_val = 0 - ext = os.path.splitext(file_path)[1][1:].lower() - command = FORMATS.get(ext) + command = userdatautils.get_command_from_path(file_path) if not command: - # Unsupported - LOG.warning('Unsupported script file type: %s', ext) + # File format not provided or not recognized + LOG.debug('No valid extension or header found in the ' + 'userdata: %s' % file_path) return ret_val - try: - out, err, ret_val = command(file_path).execute() + out, err, ret_val = command.execute() except Exception as ex: LOG.warning('An error occurred during file execution: \'%s\'', ex) else: diff --git a/cloudbaseinit/plugins/common/userdata.py b/cloudbaseinit/plugins/common/userdata.py index a2f0e81a..23346ae6 100644 --- a/cloudbaseinit/plugins/common/userdata.py +++ b/cloudbaseinit/plugins/common/userdata.py @@ -83,11 +83,29 @@ class UserDataPlugin(base.BasePlugin): LOG.debug('User data content:\n%s', user_data_str) return email.message_from_string(user_data_str).walk() + @staticmethod + def _get_headers(user_data): + """Returns the header of the given user data. + + :param user_data: Represents the content of the user data. + :rtype: A string chunk containing the header or None. + .. note :: In case the content type is not valid, + None will be returned. + """ + content = encoding.get_as_string(user_data) + if content: + return content.split("\n\n")[0] + else: + raise exception.CloudbaseInitException("No header could be found." + "The user data content is " + "either invalid or empty.") + def _process_user_data(self, user_data): plugin_status = base.PLUGIN_EXECUTION_DONE reboot = False - LOG.debug("Processing userdata") - if user_data.startswith(b'Content-Type: multipart'): + headers = self._get_headers(user_data) + if 'Content-Type: multipart' in headers: + LOG.debug("Processing userdata") user_data_plugins = factory.load_plugins() user_handlers = {} diff --git a/cloudbaseinit/plugins/common/userdatautils.py b/cloudbaseinit/plugins/common/userdatautils.py index 847093fb..d0cf2bba 100644 --- a/cloudbaseinit/plugins/common/userdatautils.py +++ b/cloudbaseinit/plugins/common/userdatautils.py @@ -12,8 +12,8 @@ # License for the specific language governing permissions and limitations # under the License. - -import functools +import collections +import os import re from oslo_log import log as oslo_logging @@ -23,37 +23,69 @@ from cloudbaseinit.plugins.common import execcmd LOG = oslo_logging.getLogger(__name__) -# Avoid 80+ length by using a local variable, which -# is deleted afterwards. -_compile = functools.partial(re.compile, flags=re.I) -FORMATS = ( - (_compile(br'^rem\s+cmd\s'), execcmd.Shell), - (_compile(br'^#!\s*/usr/bin/env\s+python\s'), execcmd.Python), - (_compile(br'^#!'), execcmd.Bash), - (_compile(br'^#(ps1|ps1_sysnative)\s'), execcmd.PowershellSysnative), - (_compile(br'^#ps1_x86\s'), execcmd.Powershell), - (_compile(br''), execcmd.EC2Config), -) -del _compile +_Script = collections.namedtuple('Script', ['extension', 'script_type', + 'executor']) +_SCRIPTS = ( + _Script(extension='cmd', executor=execcmd.Shell, + script_type=re.compile(br'^rem\s+cmd\s')), + _Script(script_type=re.compile(br'^#!\s*/usr/bin/env\s+python\s'), + extension='py', executor=execcmd.Python), + _Script(extension='exe', script_type=None, executor=execcmd.Shell), + _Script(extension='sh', script_type=re.compile(br'^#!'), + executor=execcmd.Bash), + _Script(extension='ps1', executor=execcmd.PowershellSysnative, + script_type=re.compile(br'^#(ps1|ps1_sysnative)\s')), + _Script(extension=None, executor=execcmd.Powershell, + script_type=re.compile(br'^#ps1_x86\s')), + _Script(extension=None, executor=execcmd.EC2Config, + script_type=re.compile(br''))) -def _get_command(data): - # Get the command which should process the given data. - for pattern, command_class in FORMATS: - if pattern.search(data): - return command_class.from_data(data) +def _get_command(data, is_path=False): + """Returns a specific command executor if the data type is found. + + :param data: It can be either a file or content of user_data type. + :param is_path: Determines whether :data: is a file path or it + contains the user_data content. + :rtype: An `execcmd` command type or `None`. + .. note :: In case the data doesn't have a valid extension or + header, it will return `None`. + """ + if is_path: + extension = os.path.splitext(data)[1][1:].lower() + for script in _SCRIPTS: + if extension == script.extension: + return script.executor(data) + with open(data, 'rb') as file_handler: + file_handler.seek(0) + user_data = file_handler.read() + else: + user_data = data + + for script in _SCRIPTS: + if script.script_type and script.script_type.search(user_data): + return script.executor.from_data(user_data) + return None + + +def get_command(data): + return _get_command(data) + + +def get_command_from_path(path): + return _get_command(path, is_path=True) def execute_user_data_script(user_data): ret_val = 0 out = err = None - command = _get_command(user_data) + command = get_command(user_data) if not command: LOG.warning('Unsupported user_data format') return ret_val try: - out, err, ret_val = command() + out, err, ret_val = command.execute() except Exception as exc: LOG.warning('An error occurred during user_data execution: \'%s\'', exc) diff --git a/cloudbaseinit/tests/plugins/common/test_fileexecutils.py b/cloudbaseinit/tests/plugins/common/test_fileexecutils.py index ee9cc241..668c539f 100644 --- a/cloudbaseinit/tests/plugins/common/test_fileexecutils.py +++ b/cloudbaseinit/tests/plugins/common/test_fileexecutils.py @@ -19,7 +19,6 @@ try: except ImportError: import mock -from cloudbaseinit.plugins.common import execcmd from cloudbaseinit.plugins.common import fileexecutils from cloudbaseinit.tests import testutils @@ -27,27 +26,25 @@ from cloudbaseinit.tests import testutils @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') class TestFileExecutilsPlugin(unittest.TestCase): - def test_exec_file_no_executor(self, _): - with testutils.LogSnatcher('cloudbaseinit.plugins.common.' - 'fileexecutils') as snatcher: - retval = fileexecutils.exec_file("fake.fake") + @mock.patch('cloudbaseinit.plugins.common.userdatautils.' + 'get_command_from_path') + @mock.patch('cloudbaseinit.plugins.common.userdatautils.' + 'execute_user_data_script') + def test_exec_file_no_executor(self, mock_execute_user_data_script, + mock_get_command, _): + mock_get_command.return_value = None + with testutils.create_tempfile() as temp: + with mock.patch('cloudbaseinit.plugins.common.userdatautils' + '.open', create=True): + with testutils.LogSnatcher('cloudbaseinit.plugins.common.' + 'fileexecutils') as snatcher: + retval = fileexecutils.exec_file(temp) - expected_logging = ['Unsupported script file type: fake'] + expected_logging = ['No valid extension or header found' + ' in the userdata: %s' % temp] self.assertEqual(0, retval) self.assertEqual(expected_logging, snatcher.output) - def test_executors_mapping(self, _): - self.assertEqual(fileexecutils.FORMATS["cmd"], - execcmd.Shell) - self.assertEqual(fileexecutils.FORMATS["exe"], - execcmd.Shell) - self.assertEqual(fileexecutils.FORMATS["sh"], - execcmd.Bash) - self.assertEqual(fileexecutils.FORMATS["py"], - execcmd.Python) - self.assertEqual(fileexecutils.FORMATS["ps1"], - execcmd.PowershellSysnative) - @mock.patch('cloudbaseinit.plugins.common.execcmd.' 'BaseCommand.execute') def test_exec_file_fails(self, mock_execute, _): @@ -67,12 +64,7 @@ class TestFileExecutilsPlugin(unittest.TestCase): @mock.patch('cloudbaseinit.plugins.common.execcmd.' 'BaseCommand.execute') def test_exec_file_(self, mock_execute, _): - mock_execute.return_value = ( - mock.sentinel.out, - mock.sentinel.error, - 0, - ) - + mock_execute.return_value = (mock.sentinel.out, mock.sentinel.error, 0) retval = fileexecutils.exec_file("fake.py") mock_execute.assert_called_once_with() self.assertEqual(0, retval) diff --git a/cloudbaseinit/tests/plugins/common/test_userdata.py b/cloudbaseinit/tests/plugins/common/test_userdata.py index 31a2817e..eb87c003 100644 --- a/cloudbaseinit/tests/plugins/common/test_userdata.py +++ b/cloudbaseinit/tests/plugins/common/test_userdata.py @@ -138,6 +138,13 @@ class UserDataPluginTest(unittest.TestCase): self.assertEqual(response, mock_message_from_string().walk()) self.assertEqual(expected_logging, snatcher.output) + def test_get_header(self): + fake_data = "fake-user-data" + self.assertEqual(fake_data, self._userdata._get_headers(fake_data)) + fake_data = None + with self.assertRaises(exception.CloudbaseInitException): + self._userdata._get_headers(fake_data) + @mock.patch('cloudbaseinit.plugins.common.userdataplugins.factory.' 'load_plugins') @mock.patch('cloudbaseinit.plugins.common.userdata.UserDataPlugin' diff --git a/cloudbaseinit/tests/plugins/common/test_userdatautils.py b/cloudbaseinit/tests/plugins/common/test_userdatautils.py index e4d8101c..c1b5c1ff 100644 --- a/cloudbaseinit/tests/plugins/common/test_userdatautils.py +++ b/cloudbaseinit/tests/plugins/common/test_userdatautils.py @@ -44,12 +44,12 @@ class UserDataUtilsTest(unittest.TestCase): If a command was obtained, then a cleanup will be added in order to remove the underlying target path of the command. """ - command = userdatautils._get_command(data) + command = userdatautils.get_command(data) if command and not isinstance(command, execcmd.CommandExecutor): self.addCleanup(_safe_remove, command._target_path) return command - def test__get_command(self, _): + def test_get_command(self, _): command = self._get_command(b'rem cmd test') self.assertIsInstance(command, execcmd.Shell) @@ -83,10 +83,11 @@ class UserDataUtilsTest(unittest.TestCase): self.assertEqual(expected_logging, snatcher.output) @mock.patch('cloudbaseinit.plugins.common.userdatautils.' - '_get_command') + 'get_command') def test_execute_user_data_script_fails(self, mock_get_command, _): - mock_get_command.return_value.side_effect = ValueError - + mock_command = mock.Mock() + mock_command.execute.side_effect = ValueError + mock_get_command.return_value = mock_command with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdatautils') as snatcher: retval = userdatautils.execute_user_data_script( @@ -94,17 +95,18 @@ class UserDataUtilsTest(unittest.TestCase): expected_logging = [ "An error occurred during user_data execution: ''", - 'User_data script ended with return code: 0' - ] + 'User_data script ended with return code: 0'] self.assertEqual(0, retval) self.assertEqual(expected_logging, snatcher.output) @mock.patch('cloudbaseinit.plugins.common.userdatautils.' - '_get_command') + 'get_command') def test_execute_user_data_script(self, mock_get_command, _): - mock_get_command.return_value.return_value = ( + mock_command = mock.Mock() + mock_command.execute.return_value = ( mock.sentinel.output, mock.sentinel.error, -1 ) + mock_get_command.return_value = mock_command retval = userdatautils.execute_user_data_script( mock.sentinel.user_data) self.assertEqual(-1, retval)