diff --git a/cloudbaseinit/conf/default.py b/cloudbaseinit/conf/default.py index 55850d89..e2f416e3 100644 --- a/cloudbaseinit/conf/default.py +++ b/cloudbaseinit/conf/default.py @@ -189,6 +189,9 @@ class GlobalOptions(conf_base.Options): 'cloud_config_plugins', default=[], help='List which contains the name of the cloud config ' 'plugins ordered by priority.'), + cfg.BoolOpt( + 'rdp_set_keepalive', default=True, + help='Sets the RDP KeepAlive policy'), ] self._cli_options = [ diff --git a/cloudbaseinit/metadata/services/base.py b/cloudbaseinit/metadata/services/base.py index 734b8d10..601a6d04 100644 --- a/cloudbaseinit/metadata/services/base.py +++ b/cloudbaseinit/metadata/services/base.py @@ -183,6 +183,13 @@ class BaseMetadataService(object): """ return False + @property + def can_post_rdp_cert_thumbprint(self): + return False + + def post_rdp_cert_thumbprint(self, thumbprint): + pass + class BaseHTTPMetadataService(BaseMetadataService): diff --git a/cloudbaseinit/plugins/windows/rdp.py b/cloudbaseinit/plugins/windows/rdp.py new file mode 100644 index 00000000..a1139ae9 --- /dev/null +++ b/cloudbaseinit/plugins/windows/rdp.py @@ -0,0 +1,50 @@ +# Copyright (c) 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_log import log as oslo_logging + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit.plugins.common import base +from cloudbaseinit.utils.windows import rdp + +CONF = cloudbaseinit_conf.CONF +LOG = oslo_logging.getLogger(__name__) + + +class RDPSettingsPlugin(base.BasePlugin): + + def execute(self, service, shared_data): + LOG.info("Setting RDP KeepAlive: %s", CONF.rdp_set_keepalive) + rdp.set_rdp_keepalive(CONF.rdp_set_keepalive) + return base.PLUGIN_EXECUTION_DONE, False + + def get_os_requirements(self): + return 'win32', (5, 2) + + +class RDPPostCertificateThumbprintPlugin(base.BasePlugin): + + def execute(self, service, shared_data): + if not service.can_post_rdp_cert_thumbprint: + LOG.info("The service does not provide the capability to post " + "the RDP certificate thumbprint") + else: + cert_thumb = rdp.get_rdp_certificate_thumbprint() + LOG.info("Posting the RDP certificate thumbprint: %s", cert_thumb) + service.post_rdp_cert_thumbprint(cert_thumb) + + return base.PLUGIN_EXECUTION_DONE, False + + def get_os_requirements(self): + return 'win32', (5, 2) diff --git a/cloudbaseinit/tests/plugins/windows/test_rdp.py b/cloudbaseinit/tests/plugins/windows/test_rdp.py new file mode 100644 index 00000000..27ff8664 --- /dev/null +++ b/cloudbaseinit/tests/plugins/windows/test_rdp.py @@ -0,0 +1,103 @@ +# Copyright (c) 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import importlib +import unittest + +try: + import unittest.mock as mock +except ImportError: + import mock + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit.plugins.common import base +from cloudbaseinit.tests import testutils + +CONF = cloudbaseinit_conf.CONF +MODPATH = "cloudbaseinit.plugins.windows.rdp" + + +class RDPPluginTest(unittest.TestCase): + + def setUp(self): + self.mock_wmi = mock.MagicMock() + self._moves_mock = mock.MagicMock() + patcher = mock.patch.dict( + "sys.modules", + { + "wmi": self.mock_wmi, + "six.moves": self._moves_mock + } + ) + patcher.start() + self.addCleanup(patcher.stop) + rdp = importlib.import_module( + "cloudbaseinit.plugins.windows.rdp") + self.rdp_settings = rdp.RDPSettingsPlugin() + self.rdp_post = rdp.RDPPostCertificateThumbprintPlugin() + self.snatcher = testutils.LogSnatcher(MODPATH) + + @mock.patch("cloudbaseinit.utils.windows.rdp." + "get_rdp_certificate_thumbprint") + def _test_execute_post(self, mock_get_rdp, mock_service=None, + mock_shared_data=None): + expected_res = (base.PLUGIN_EXECUTION_DONE, False) + expected_logs = [] + mock_get_rdp.return_value = mock.sentinel.cert + with self.snatcher: + res = self.rdp_post.execute(mock_service, mock_shared_data) + if not mock_service.can_post_rdp_cert_thumbprint: + expected_logs.append("The service does not provide the capability" + " to post the RDP certificate thumbprint") + else: + expected_logs.append("Posting the RDP certificate thumbprint: %s" + % mock.sentinel.cert) + mock_get_rdp.assert_called_once_with() + mock_service.post_rdp_cert_thumbprint.assert_called_once_with( + mock.sentinel.cert) + + self.assertEqual(res, expected_res) + self.assertEqual(self.snatcher.output, expected_logs) + + @mock.patch("cloudbaseinit.utils.windows.rdp.set_rdp_keepalive") + def _test_execute_settings(self, mock_set_rdp, mock_service=None, + mock_shared_data=None): + expected_res = (base.PLUGIN_EXECUTION_DONE, False) + expected_logs = ["Setting RDP KeepAlive: %s" % CONF.rdp_set_keepalive] + with self.snatcher: + res = self.rdp_settings.execute(mock_service, mock_shared_data) + self.assertEqual(res, expected_res) + self.assertEqual(self.snatcher.output, expected_logs) + mock_set_rdp.assert_called_once_with(CONF.rdp_set_keepalive) + + def test_execute_set_rdp(self): + mock_service = mock.Mock() + self._test_execute_settings(mock_service=mock_service) + + def test_execute_can_not_post(self): + mock_service = mock.Mock() + mock_service.can_post_rdp_cert_thumbprint = False + self._test_execute_post(mock_service=mock_service) + + def test_execute_can_post(self): + mock_service = mock.Mock() + mock_service.can_post_rdp_cert_thumbprint = True + self._test_execute_post(mock_service=mock_service) + + def test_get_os_requirements(self): + expected_res = ('win32', (5, 2)) + res_settings = self.rdp_settings.get_os_requirements() + res_post = self.rdp_post.get_os_requirements() + for res in (res_settings, res_post): + self.assertEqual(res, expected_res) diff --git a/cloudbaseinit/tests/utils/windows/test_rdp.py b/cloudbaseinit/tests/utils/windows/test_rdp.py new file mode 100644 index 00000000..7be255b3 --- /dev/null +++ b/cloudbaseinit/tests/utils/windows/test_rdp.py @@ -0,0 +1,85 @@ +# Copyright (c) 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import importlib +import unittest + +try: + import unittest.mock as mock +except ImportError: + import mock + +from cloudbaseinit import exception +from cloudbaseinit.tests import testutils + + +MODPATH = "cloudbaseinit.utils.windows.rdp" + + +class RdpTest(unittest.TestCase): + + def setUp(self): + self._wmi_mock = mock.MagicMock() + self._moves_mock = mock.MagicMock() + self._module_patcher = mock.patch.dict( + 'sys.modules', { + 'wmi': self._wmi_mock, + 'six.moves': self._moves_mock}) + self._winreg_mock = self._moves_mock.winreg + self.snatcher = testutils.LogSnatcher(MODPATH) + self._module_patcher.start() + self.rdp = importlib.import_module(MODPATH) + + def tearDown(self): + self._module_patcher.stop() + + def _test_get_rdp_certificate_thumbprint(self, mock_cert=None): + + conn = self._wmi_mock.WMI + mock_win32ts = mock.Mock() + conn.return_value = mock_win32ts + mock_win32ts.Win32_TSGeneralSetting.return_value = mock_cert + if not mock_cert: + self.assertRaises(exception.ItemNotFoundException, + self.rdp.get_rdp_certificate_thumbprint) + else: + res = self.rdp.get_rdp_certificate_thumbprint() + self.assertEqual(res, mock.sentinel.cert) + mock_win32ts.Win32_TSGeneralSetting.assert_called_once_with() + conn.assert_called_once_with(moniker='//./root/cimv2/TerminalServices') + + def test_get_rdp_certificate_thumbprint_no_cert(self): + self._test_get_rdp_certificate_thumbprint() + + def test_get_rdp_certificate_thumbprint(self): + mock_c = mock.MagicMock() + mock_c.SSLCertificateSHA1Hash = mock.sentinel.cert + mock_cert = mock.MagicMock() + mock_cert.__getitem__.return_value = mock_c + self._test_get_rdp_certificate_thumbprint(mock_cert=mock_cert) + + def test_set_rdp_keepalive(self): + enable_value = True + expected_logs = [ + "Setting RDP KeepAliveEnabled: %s" % enable_value, + "Setting RDP keepAliveInterval (minutes): %s" % 1] + with self.snatcher: + self.rdp.set_rdp_keepalive(enable_value) + self.assertEqual(self.snatcher.output, expected_logs) + self._winreg_mock.OpenKey.assert_called_once_with( + self._winreg_mock.HKEY_LOCAL_MACHINE, + 'SOFTWARE\\Policies\\Microsoft\\' + 'Windows NT\\Terminal Services', + 0, self._winreg_mock.KEY_ALL_ACCESS) + self.assertEqual(self._winreg_mock.SetValueEx.call_count, 2) diff --git a/cloudbaseinit/utils/windows/rdp.py b/cloudbaseinit/utils/windows/rdp.py new file mode 100644 index 00000000..ff2069f6 --- /dev/null +++ b/cloudbaseinit/utils/windows/rdp.py @@ -0,0 +1,43 @@ +# Copyright (c) 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import wmi + +from oslo_log import log as oslo_logging +from six.moves import winreg + +from cloudbaseinit import exception + +LOG = oslo_logging.getLogger(__name__) + + +def get_rdp_certificate_thumbprint(): + conn = wmi.WMI(moniker='//./root/cimv2/TerminalServices') + tsSettings = conn.Win32_TSGeneralSetting() + if not tsSettings: + raise exception.ItemNotFoundException("No RDP certificate found") + return tsSettings[0].SSLCertificateSHA1Hash + + +def set_rdp_keepalive(enable, interval=1): + with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, + 'SOFTWARE\\Policies\\Microsoft\\' + 'Windows NT\\Terminal Services', + 0, winreg.KEY_ALL_ACCESS) as key: + LOG.debug("Setting RDP KeepAliveEnabled: %s", enable) + winreg.SetValueEx( + key, 'KeepAliveEnable', 0, winreg.REG_DWORD, 1 if enable else 0) + LOG.debug("Setting RDP keepAliveInterval (minutes): %s", interval) + winreg.SetValueEx( + key, 'keepAliveInterval', 0, winreg.REG_DWORD, interval)