diff --git a/cloudbaseinit/conf/azure.py b/cloudbaseinit/conf/azure.py new file mode 100644 index 00000000..f4faa6da --- /dev/null +++ b/cloudbaseinit/conf/azure.py @@ -0,0 +1,43 @@ +# Copyright 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Config options available for the Azure metadata service.""" + +from oslo_config import cfg + +from cloudbaseinit.conf import base as conf_base + + +class AzureOptions(conf_base.Options): + + """Config options available for the Azure metadata service.""" + + def __init__(self, config): + super(AzureOptions, self).__init__(config, group="azure") + self._options = [ + cfg.StrOpt( + "transport_cert_store_name", + default="Windows Azure Environment", + help="Certificate store name for metadata certificates"), + ] + + def register(self): + """Register the current options to the global ConfigOpts object.""" + group = cfg.OptGroup(self.group_name, title='Azure Options') + self._config.register_group(group) + self._config.register_opts(self._options, group=group) + + def list(self): + """Return a list which contains all the available options.""" + return self._options diff --git a/cloudbaseinit/conf/factory.py b/cloudbaseinit/conf/factory.py index d573d1c2..c0800204 100644 --- a/cloudbaseinit/conf/factory.py +++ b/cloudbaseinit/conf/factory.py @@ -21,6 +21,7 @@ _OPT_PATHS = ( 'cloudbaseinit.conf.ec2.EC2Options', 'cloudbaseinit.conf.maas.MAASOptions', 'cloudbaseinit.conf.openstack.OpenStackOptions', + 'cloudbaseinit.conf.azure.AzureOptions', ) diff --git a/cloudbaseinit/metadata/services/azureservice.py b/cloudbaseinit/metadata/services/azureservice.py new file mode 100644 index 00000000..c33584ea --- /dev/null +++ b/cloudbaseinit/metadata/services/azureservice.py @@ -0,0 +1,467 @@ +# Copyright 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import os +import socket +import time +from xml.etree import ElementTree + +from oslo_log import log as oslo_logging +import six +import untangle + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit import constant +from cloudbaseinit import exception +from cloudbaseinit.metadata.services import base +from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.utils import dhcp +from cloudbaseinit.utils import encoding +from cloudbaseinit.utils.windows import x509 + +CONF = cloudbaseinit_conf.CONF +LOG = oslo_logging.getLogger(__name__) + +WIRESERVER_DHCP_OPTION = 245 +WIRE_SERVER_VERSION = '2015-04-05' + +GOAL_STATE_STARTED = "Started" + +HEALTH_STATE_READY = "Ready" +HEALTH_STATE_NOT_READY = "NotReady" +HEALTH_SUBSTATE_PROVISIONING = "Provisioning" +HEALTH_SUBSTATE_PROVISIONING_FAILED = "ProvisioningFailed" + +ROLE_PROPERTY_CERT_THUMB = "CertificateThumbprint" + +OVF_ENV_DRIVE_TAG = "E6DA6616-8EC4-48E0-BE93-58CE6ACE3CFB.tag" +OVF_ENV_FILENAME = "ovf-env.xml" +CUSTOM_DATA_FILENAME = "CustomData.bin" +DATALOSS_WARNING_PATH = '$$\\OEM\\DATALOSS_WARNING_README.txt' + +DEFAULT_KMS_HOST = "kms.core.windows.net" + + +class AzureService(base.BaseHTTPMetadataService): + + def __init__(self): + super(AzureService, self).__init__(base_url=None) + self._enable_retry = True + self._goal_state = None + self._config_set_drive_path = None + self._ovf_env = None + self._headers = {"x-ms-guest-agent-name": "cloudbase-init"} + self._osutils = osutils_factory.get_os_utils() + + def _get_wire_server_endpoint_address(self): + total_time = 300 + poll_time = 5 + retries = total_time / poll_time + + while True: + try: + options = dhcp.get_dhcp_options() + endpoint = (options or {}).get(WIRESERVER_DHCP_OPTION) + if not endpoint: + raise exception.MetadaNotFoundException( + "Cannot find Azure WireServer endpoint address") + return socket.inet_ntoa(endpoint) + except Exception: + if not retries: + raise + time.sleep(poll_time) + retries -= 1 + + def _check_version_header(self): + if "x-ms-version" not in self._headers: + versions = self._get_versions() + if WIRE_SERVER_VERSION not in versions.Versions.Supported.Version: + raise exception.MetadaNotFoundException( + "Unsupported Azure WireServer version: %s" % + WIRE_SERVER_VERSION) + self._headers["x-ms-version"] = WIRE_SERVER_VERSION + + def _get_versions(self): + return self._wire_server_request("?comp=Versions") + + def _wire_server_request(self, path, data_xml=None, headers=None, + parse_xml=True): + if not self._base_url: + raise exception.CloudbaseInitException( + "Azure WireServer base url not set") + + all_headers = self._headers.copy() + if data_xml: + all_headers["Content-Type"] = "text/xml; charset=utf-8" + if headers: + all_headers.update(headers) + + data = self._exec_with_retry( + lambda: super(AzureService, self)._http_request( + path, data_xml, headers=all_headers)) + + if parse_xml: + return untangle.parse(six.StringIO(encoding.get_as_string(data))) + else: + return data + + @staticmethod + def _encode_xml(xml_root): + bio = six.BytesIO() + ElementTree.ElementTree(xml_root).write( + bio, encoding='utf-8', xml_declaration=True) + return bio.getvalue() + + def _get_health_report_xml(self, state, sub_status=None, description=None): + xml_root = ElementTree.Element('Health') + xml_goal_state_incarnation = ElementTree.SubElement( + xml_root, 'GoalStateIncarnation') + xml_goal_state_incarnation.text = str(self._get_incarnation()) + xml_container = ElementTree.SubElement(xml_root, 'Container') + xml_container_id = ElementTree.SubElement(xml_container, 'ContainerId') + xml_container_id.text = self._get_container_id() + xml_role_instance_list = ElementTree.SubElement( + xml_container, 'RoleInstanceList') + xml_role = ElementTree.SubElement(xml_role_instance_list, 'Role') + xml_role_instance_id = ElementTree.SubElement(xml_role, 'InstanceId') + xml_role_instance_id.text = self._get_role_instance_id() + xml_health = ElementTree.SubElement(xml_role, 'Health') + xml_state = ElementTree.SubElement(xml_health, 'State') + xml_state.text = state + + if sub_status: + xml_details = ElementTree.SubElement(xml_health, 'Details') + xml_sub_status = ElementTree.SubElement(xml_details, 'SubStatus') + xml_sub_status.text = sub_status + xml_description = ElementTree.SubElement( + xml_details, 'Description') + xml_description.text = description + + return self._encode_xml(xml_root) + + def _get_role_properties_xml(self, properties): + xml_root = ElementTree.Element('RoleProperties') + xml_container = ElementTree.SubElement(xml_root, 'Container') + xml_container_id = ElementTree.SubElement(xml_container, 'ContainerId') + xml_container_id.text = self._get_container_id() + xml_role_instances = ElementTree.SubElement( + xml_container, 'RoleInstances') + xml_role_instance = ElementTree.SubElement( + xml_role_instances, 'RoleInstance') + xml_role_instance_id = ElementTree.SubElement( + xml_role_instance, 'Id') + xml_role_instance_id.text = self._get_role_instance_id() + xml_role_properties = ElementTree.SubElement( + xml_role_instance, 'Properties') + + for name, value in properties.items(): + ElementTree.SubElement( + xml_role_properties, 'Property', name=name, value=value) + + return self._encode_xml(xml_root) + + def _get_goal_state(self, force_update=False): + if not self._goal_state or force_update: + self._goal_state = self._wire_server_request( + "machine?comp=goalstate").GoalState + + expected_state = self._goal_state.Machine.ExpectedState + if expected_state != GOAL_STATE_STARTED: + raise exception.CloudbaseInitException( + "Invalid machine expected state: %s" % expected_state) + + return self._goal_state + + def _get_incarnation(self): + goal_state = self._get_goal_state() + return goal_state.Incarnation.cdata + + def _get_container_id(self): + goal_state = self._get_goal_state() + return goal_state.Container.ContainerId.cdata + + def _get_role_instance_config(self): + goal_state = self._get_goal_state() + role_instance = goal_state.Container.RoleInstanceList.RoleInstance + return role_instance.Configuration + + def _get_role_instance_id(self): + goal_state = self._get_goal_state() + role_instance = goal_state.Container.RoleInstanceList.RoleInstance + return role_instance.InstanceId.cdata + + def _post_health_status(self, state, sub_status=None, description=None): + health_report_xml = self._get_health_report_xml( + state, sub_status, description) + LOG.debug("Health data: %s", health_report_xml) + self._wire_server_request( + "machine?comp=health", health_report_xml, parse_xml=False) + + def provisioning_started(self): + self._post_health_status( + HEALTH_STATE_NOT_READY, HEALTH_SUBSTATE_PROVISIONING, + "Cloudbase-Init is preparing your computer for first use...") + + def provisioning_completed(self): + self._post_health_status(HEALTH_STATE_READY) + + def provisioning_failed(self): + self._post_health_status( + HEALTH_STATE_NOT_READY, HEALTH_SUBSTATE_PROVISIONING_FAILED, + "Provisioning failed") + + def _post_role_properties(self, properties): + role_properties_xml = self._get_role_properties_xml(properties) + LOG.debug("Role properties data: %s", role_properties_xml) + self._wire_server_request( + "machine?comp=roleProperties", role_properties_xml, + parse_xml=False) + + @property + def can_post_rdp_cert_thumbprint(self): + return True + + def post_rdp_cert_thumbprint(self, thumbprint): + properties = {ROLE_PROPERTY_CERT_THUMB: thumbprint} + self._post_role_properties(properties) + + def _get_hosting_environment(self): + config = self._get_role_instance_config() + return self._wire_server_request(config.HostingEnvironmentConfig.cdata) + + def _get_shared_config(self): + config = self._get_role_instance_config() + return self._wire_server_request(config.SharedConfig.cdata) + + def _get_extensions_config(self): + config = self._get_role_instance_config() + return self._wire_server_request(config.ExtensionsConfig.cdata) + + def _get_full_config(self): + config = self._get_role_instance_config() + return self._wire_server_request(config.FullConfig.cdata) + + @contextlib.contextmanager + def _create_transport_cert(self, cert_mgr): + x509_thumbprint, x509_cert = cert_mgr.create_self_signed_cert( + "CN=Cloudbase-Init AzureService Transport", machine_keyset=True, + store_name=CONF.azure.transport_cert_store_name) + + try: + yield (x509_thumbprint, x509_cert) + finally: + cert_mgr.delete_certificate_from_store( + x509_thumbprint, machine_keyset=True, + store_name=CONF.azure.transport_cert_store_name) + + def _get_encoded_cert(self, cert_url, transport_cert): + cert_config = self._wire_server_request( + cert_url, headers={"x-ms-guest-agent-public-x509-cert": + transport_cert.replace("\r\n", "")}) + + cert_data = cert_config.CertificateFile.Data.cdata + cert_format = cert_config.CertificateFile.Format.cdata + return cert_data, cert_format + + def get_server_certs(self): + def _get_store_location(store_location): + if store_location == u"System": + return constant.CERT_LOCATION_LOCAL_MACHINE + else: + return store_location + + certs_info = [] + config = self._get_role_instance_config() + if not hasattr(config, 'Certificates'): + return certs_info + + cert_mgr = x509.CryptoAPICertManager() + with self._create_transport_cert(cert_mgr) as ( + transport_cert_thumbprint, transport_cert): + + cert_url = config.Certificates.cdata + cert_data, cert_format = self._get_encoded_cert( + cert_url, transport_cert) + pfx_data = cert_mgr.decode_pkcs7_base64_blob( + cert_data, transport_cert_thumbprint, machine_keyset=True, + store_name=CONF.azure.transport_cert_store_name) + + host_env = self._get_hosting_environment() + host_env_config = host_env.HostingEnvironmentConfig + for cert in host_env_config.StoredCertificates.StoredCertificate: + certs_info.append({ + "store_name": cert["storeName"], + "store_location": _get_store_location( + cert["configurationLevel"]), + "certificate_id": cert["certificateId"], + "name": cert["name"], + "pfx_data": pfx_data, + }) + return certs_info + + def get_instance_id(self): + return self._get_role_instance_id() + + def _get_config_set_drive_path(self): + if not self._config_set_drive_path: + base_paths = self._osutils.get_logical_drives() + for base_path in base_paths: + tag_path = os.path.join(base_path, OVF_ENV_DRIVE_TAG) + if os.path.exists(tag_path): + self._config_set_drive_path = base_path + + if not self._config_set_drive_path: + raise exception.ItemNotFoundException( + "No drive containing file %s could be found" % + OVF_ENV_DRIVE_TAG) + return self._config_set_drive_path + + def _get_ovf_env_path(self): + base_path = self._get_config_set_drive_path() + ovf_env_path = os.path.join(base_path, OVF_ENV_FILENAME) + + if not os.path.exists(ovf_env_path): + raise exception.ItemNotFoundException( + "ovf-env path does not exist: %s" % ovf_env_path) + + LOG.debug("ovs-env path: %s", ovf_env_path) + return ovf_env_path + + def _get_ovf_env(self): + if not self._ovf_env: + ovf_env_path = self._get_ovf_env_path() + self._ovf_env = untangle.parse(ovf_env_path) + return self._ovf_env + + def get_admin_username(self): + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + return win_prov_conf_set.AdminUsername.cdata + + def get_admin_password(self): + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + return win_prov_conf_set.AdminPassword.cdata + + def get_host_name(self): + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + return win_prov_conf_set.ComputerName.cdata + + def get_enable_automatic_updates(self): + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + if hasattr(win_prov_conf_set, "EnableAutomaticUpdates"): + auto_updates = win_prov_conf_set.EnableAutomaticUpdates.cdata + return auto_updates.lower() == "true" + return False + + def get_winrm_listeners_configuration(self): + listeners_config = [] + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + if hasattr(win_prov_conf_set, "WinRM"): + for listener in win_prov_conf_set.WinRM.Listeners.Listener: + protocol = listener.Protocol.cdata + config = {"protocol": protocol} + if hasattr(listener, "CertificateThumbprint"): + cert_thumbprint = listener.CertificateThumbprint.cdata + config["certificate_thumbprint"] = cert_thumbprint + listeners_config.append(config) + return listeners_config + + def get_vm_agent_package_provisioning_data(self): + ovf_env = self._get_ovf_env() + plat_sett_section = ovf_env.Environment.wa_PlatformSettingsSection + plat_sett = plat_sett_section.PlatformSettings + prov_ga = False + ga_package_name = None + if hasattr(plat_sett, "ProvisionGuestAgent"): + prov_ga = plat_sett.ProvisionGuestAgent.cdata.lower() == "true" + if hasattr(plat_sett, "GuestAgentPackageName"): + ga_package_name = plat_sett.GuestAgentPackageName.cdata + return {"provision": prov_ga, + "package_name": ga_package_name} + + def get_kms_host(self): + ovf_env = self._get_ovf_env() + plat_sett_section = ovf_env.Environment.wa_PlatformSettingsSection + host = None + if hasattr(plat_sett_section.PlatformSettings, "KmsServerHostname"): + host = plat_sett_section.PlatformSettings.KmsServerHostname.cdata + return host or DEFAULT_KMS_HOST + + def get_use_avma_licensing(self): + ovf_env = self._get_ovf_env() + plat_sett_section = ovf_env.Environment.wa_PlatformSettingsSection + if hasattr(plat_sett_section.PlatformSettings, "UseAVMA"): + use_avma = plat_sett_section.PlatformSettings.UseAVMA.cdata + return use_avma.lower() == "true" + return False + + def _check_ovf_env_custom_data(self): + # If the custom data file is missing, ensure the configuration matches + ovf_env = self._get_ovf_env() + prov_section = ovf_env.Environment.wa_ProvisioningSection + win_prov_conf_set = prov_section.WindowsProvisioningConfigurationSet + if hasattr(win_prov_conf_set, "CustomData"): + return True + + def get_user_data(self): + try: + return self.get_content(CUSTOM_DATA_FILENAME) + except base.NotExistingMetadataException: + if self._check_ovf_env_custom_data(): + raise exception.ItemNotFoundException( + "Custom data configuration exists, but the custom data " + "file is not present") + raise + + def get_decoded_user_data(self): + # Don't decode to retain compability + return self.get_user_data() + + def get_content(self, name): + base_path = self._get_config_set_drive_path() + content_path = os.path.join(base_path, name) + if not os.path.exists(content_path): + raise base.NotExistingMetadataException() + with open(content_path, 'rb') as f: + return f.read() + + def get_ephemeral_disk_data_loss_warning(self): + return self.get_content(DATALOSS_WARNING_PATH) + + def load(self): + try: + wire_server_endpoint = self._get_wire_server_endpoint_address() + self._base_url = "http://%s" % wire_server_endpoint + except Exception: + LOG.debug("Azure WireServer endpoint not found") + return False + + try: + super(AzureService, self).load() + self._check_version_header() + self._get_ovf_env() + return True + except Exception as ex: + LOG.exception(ex) + return False diff --git a/cloudbaseinit/tests/metadata/services/test_azureservice.py b/cloudbaseinit/tests/metadata/services/test_azureservice.py new file mode 100644 index 00000000..57865196 --- /dev/null +++ b/cloudbaseinit/tests/metadata/services/test_azureservice.py @@ -0,0 +1,788 @@ +# Copyright 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import importlib +import os +import unittest +try: + import unittest.mock as mock +except ImportError: + import mock + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit import exception +from cloudbaseinit.tests import testutils +from cloudbaseinit.utils import encoding + +CONF = cloudbaseinit_conf.CONF +MODPATH = "cloudbaseinit.metadata.services.azureservice.AzureService" + + +class AzureServiceTest(unittest.TestCase): + + @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') + def setUp(self, mock_osutils): + self._mock_osutils = mock_osutils + self._mock_untangle = mock.MagicMock() + self._mock_ctypes = mock.MagicMock() + self._mock_wintypes = mock.MagicMock() + self._moves_mock = mock.MagicMock() + + self._module_patcher = mock.patch.dict( + 'sys.modules', + {'untangle': self._mock_untangle, + 'ctypes': self._mock_ctypes, + 'ctypes.wintypes': self._mock_wintypes, + 'six.moves': self._moves_mock + }) + self._module_patcher.start() + self._azureservice_module = importlib.import_module( + 'cloudbaseinit.metadata.services.azureservice') + + self._azureservice = self._azureservice_module.AzureService() + self._logsnatcher = testutils.LogSnatcher( + 'cloudbaseinit.metadata.services.azureservice') + + def tearDown(self): + self._module_patcher.stop() + + @mock.patch('time.sleep') + @mock.patch('socket.inet_ntoa') + @mock.patch('cloudbaseinit.utils.dhcp.get_dhcp_options') + def _test_get_wire_server_endpoint_address(self, mock_dhcp, + mock_inet_ntoa, + mock_time_sleep, + dhcp_option=None): + mock_dhcp.return_value = dhcp_option + if not dhcp_option: + self.assertRaises(exception.MetadaNotFoundException, + (self._azureservice. + _get_wire_server_endpoint_address)) + else: + mock_inet_ntoa.return_value = mock.sentinel.endpoint + res = self._azureservice._get_wire_server_endpoint_address() + self.assertEqual(res, mock.sentinel.endpoint) + + def test_get_wire_server_endpoint_address_no_endpoint(self): + self._test_get_wire_server_endpoint_address() + + def test_get_wire_server_endpoint_address(self): + dhcp_option = { + self._azureservice_module.WIRESERVER_DHCP_OPTION: + 'mock.sentinel.endpoint'} + self._test_get_wire_server_endpoint_address(dhcp_option=dhcp_option) + + @mock.patch('cloudbaseinit.metadata.services.base.' + 'BaseHTTPMetadataService._http_request') + def _test_wire_server_request(self, + mock_http_request, mock_base_url=None, + path=None, data_xml=None, headers=None, + parse_xml=True): + self._azureservice._base_url = mock_base_url + if not mock_base_url: + self.assertRaises(exception.CloudbaseInitException, + self._azureservice._wire_server_request, path) + return + if headers and data_xml: + expected_headers = self._azureservice._headers.copy() + expected_headers["Content-Type"] = "text/xml; charset=utf-8" + expected_headers.update(headers) + self._azureservice._wire_server_request(path, data_xml, headers, + parse_xml) + mock_http_request.assert_called_once_with(path, data_xml, + headers=expected_headers) + return + mock_http_request.return_value = str(mock.sentinel.data) + res = self._azureservice._wire_server_request(path, data_xml, + headers, parse_xml) + self.assertEqual(mock_http_request.call_count, 1) + + if parse_xml: + self.assertEqual(self._mock_untangle.parse.call_count, 1) + self.assertEqual(res, self._mock_untangle.parse.return_value) + else: + self.assertEqual(res, str(mock.sentinel.data)) + + def test_wire_server_request_url_not_set(self): + self._test_wire_server_request() + + def test_wire_server_request_url_set_no_parse(self): + mock_base_url = "fake-url" + self._test_wire_server_request(mock_base_url=mock_base_url, + parse_xml=False) + + def test_wire_server_request_url_set_with_headers(self): + mock_base_url = "fake-url" + self._test_wire_server_request(mock_base_url=mock_base_url, + parse_xml=False, + headers={"fake-header": "fake-value"}, + data_xml="fake-data") + + def test_wire_server_request_parse_xml(self): + mock_base_url = "fake-url" + self._test_wire_server_request(mock_base_url=mock_base_url) + + def test_encode_xml(self): + fake_root_xml = self._azureservice_module.ElementTree.Element( + "faketag") + expected_encoded_xml = ("" + "\n").encode() + self.assertEqual(self._azureservice._encode_xml(fake_root_xml), + expected_encoded_xml) + + @mock.patch(MODPATH + "._get_role_instance_id") + @mock.patch(MODPATH + "._get_container_id") + @mock.patch(MODPATH + "._get_incarnation") + def test__get_health_report_xml(self, mock_get_incarnation, + mock_get_container_id, + mock_get_role_instance_id): + mock_state = 'FakeState' + mock_substatus = 'FakeStatus' + mock_description = 'FakeDescription' + mock_get_incarnation.return_value = "fake" + mock_get_container_id.return_value = "fakeid" + mock_get_role_instance_id.return_value = "fakeroleid" + res = self._azureservice._get_health_report_xml(mock_state, + mock_substatus, + mock_description) + + expected_result = "\n" \ + "{}" \ + "{}" \ + "{}" \ + "{}
{}" \ + "{}
" \ + "
" + self.assertEqual(encoding.get_as_string(res), + expected_result.format( + mock_get_incarnation.return_value, + mock_get_container_id.return_value, + mock_get_role_instance_id.return_value, + mock_state, + mock_substatus, + mock_description)) + + @mock.patch(MODPATH + "._wire_server_request") + def _test_get_goal_state(self, mock_wire_server_request, + goal_state=True, invalid_state=False): + mock_goalstate = mock.Mock() + mock_goalstate.GoalState = mock.Mock() + mock_goalstate.GoalState.Machine = mock.Mock() + mock_wire_server_request.return_value = mock_goalstate + if goal_state: + self._azureservice._goal_state = mock_goalstate + else: + self._azureservice._goal_state = False + if invalid_state: + mock_goalstate.GoalState.Machine.ExpectedState = \ + not self._azureservice_module.GOAL_STATE_STARTED + self.assertRaises(exception.CloudbaseInitException, + self._azureservice._get_goal_state) + else: + if not goal_state: + mock_goalstate.GoalState.Machine.ExpectedState = \ + self._azureservice_module.GOAL_STATE_STARTED + else: + self._azureservice._goal_state.Machine.ExpectedState = \ + self._azureservice_module.GOAL_STATE_STARTED + res = self._azureservice._get_goal_state() + self.assertEqual(res, mock_goalstate.GoalState) + + if not goal_state: + mock_wire_server_request.assert_called_once_with( + "machine?comp=goalstate") + + def test_get_goal_state_exception(self): + self._test_get_goal_state(invalid_state=True) + + def test_get_goal_state(self): + self._test_get_goal_state(goal_state=False) + + @mock.patch(MODPATH + "._get_goal_state") + def test__get_incarnation(self, mock_get_goal_state): + mock_goal_state = mock.Mock() + mock_get_goal_state.return_value = mock_goal_state + mock_goal_state.Incarnation.cdata = mock.sentinel.cdata + + res = self._azureservice._get_incarnation() + mock_get_goal_state.assert_called_once_with() + self.assertEqual(res, mock.sentinel.cdata) + + @mock.patch(MODPATH + "._get_goal_state") + def test__get_container_id(self, mock_get_goal_state): + mock_goal_state = mock.Mock() + mock_get_goal_state.return_value = mock_goal_state + mock_goal_state.Container.ContainerId.cdata = mock.sentinel.cdata + + res = self._azureservice._get_container_id() + mock_get_goal_state.assert_called_once_with() + self.assertEqual(res, mock.sentinel.cdata) + + @mock.patch(MODPATH + "._get_goal_state") + def test__get_role_instance_config(self, mock_get_goal_state): + mock_goal_state = mock.Mock() + mock_role = mock.Mock() + mock_get_goal_state.return_value = mock_goal_state + mock_goal_state.Container.RoleInstanceList.RoleInstance = mock_role + mock_role.Configuration = mock.sentinel.config_role + + res = self._azureservice._get_role_instance_config() + mock_get_goal_state.assert_called_once_with() + self.assertEqual(res, mock.sentinel.config_role) + + @mock.patch(MODPATH + "._get_goal_state") + def test__get_role_instance_id(self, mock_get_goal_state): + mock_goal_state = mock.Mock() + mock_role = mock.Mock() + mock_get_goal_state.return_value = mock_goal_state + mock_goal_state.Container.RoleInstanceList.RoleInstance = mock_role + mock_role.InstanceId.cdata = mock.sentinel.config_role + + res = self._azureservice._get_role_instance_id() + mock_get_goal_state.assert_called_once_with() + self.assertEqual(res, mock.sentinel.config_role) + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_health_report_xml") + def test__post_health_status(self, mock_get_health_report_xml, + mock_wire_server_request): + mock_get_health_report_xml.return_value = mock.sentinel.report_xml + mock_state = mock.sentinel.state + expected_logging = ["Health data: %s" % mock.sentinel.report_xml] + with self._logsnatcher: + self._azureservice._post_health_status(state=mock_state) + self.assertEqual(self._logsnatcher.output, expected_logging) + mock_get_health_report_xml.assert_called_once_with(mock_state, + None, None) + mock_wire_server_request.assert_called_once_with( + "machine?comp=health", mock.sentinel.report_xml, parse_xml=False) + + @mock.patch(MODPATH + "._post_health_status") + def test_provisioning_started(self, mock_post_health_status): + self._azureservice.provisioning_started() + mock_post_health_status.assert_called_once_with( + self._azureservice_module.HEALTH_STATE_NOT_READY, + self._azureservice_module.HEALTH_SUBSTATE_PROVISIONING, + "Cloudbase-Init is preparing your computer for first use...") + + @mock.patch(MODPATH + "._post_health_status") + def test_provisioning_completed(self, mock_post_health_status): + self._azureservice.provisioning_completed() + mock_post_health_status.assert_called_once_with( + self._azureservice_module.HEALTH_STATE_READY) + + @mock.patch(MODPATH + "._post_health_status") + def test_provisioning_failed(self, mock_post_health_status): + self._azureservice.provisioning_failed() + mock_post_health_status.assert_called_once_with( + self._azureservice_module.HEALTH_STATE_NOT_READY, + self._azureservice_module.HEALTH_SUBSTATE_PROVISIONING_FAILED, + "Provisioning failed") + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_role_properties_xml") + def test__post_role_properties(self, mock_get_role_properties_xml, + mock_wire_server_request): + mock_properties = mock.sentinel.properties + mock_get_role_properties_xml.return_value = mock_properties + expected_logging = ["Role properties data: %s" % mock_properties] + with self._logsnatcher: + self._azureservice._post_role_properties(mock_properties) + self.assertEqual(self._logsnatcher.output, expected_logging) + mock_get_role_properties_xml.assert_called_once_with(mock_properties) + mock_wire_server_request.assert_called_once_with( + "machine?comp=roleProperties", mock_properties, parse_xml=False) + + def test_can_post_rdp_cert_thumbprint(self): + self.assertTrue(self._azureservice.can_post_rdp_cert_thumbprint) + + @mock.patch(MODPATH + "._post_role_properties") + def test_post_rdp_cert_thumbprint(self, mock_post_role_properties): + mock_thumbprint = mock.sentinel.thumbprint + self._azureservice.post_rdp_cert_thumbprint(mock_thumbprint) + expected_props = { + self._azureservice_module.ROLE_PROPERTY_CERT_THUMB: + mock_thumbprint} + mock_post_role_properties.assert_called_once_with(expected_props) + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_role_instance_config") + def test__get_hosting_environment(self, mock_get_role_instance_config, + mock_wire_server_request): + mock_config = mock.Mock() + mock_get_role_instance_config.return_value = mock_config + mock_config.HostingEnvironmentConfig.cdata = mock.sentinel.data + + self._azureservice._get_hosting_environment() + mock_get_role_instance_config.assert_called_once_with() + mock_wire_server_request.assert_called_once_with(mock.sentinel.data) + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_role_instance_config") + def test__get_shared_config(self, mock_get_role_instance_config, + mock_wire_server_request): + mock_config = mock.Mock() + mock_get_role_instance_config.return_value = mock_config + mock_config.SharedConfig.cdata = mock.sentinel.data + + self._azureservice._get_shared_config() + mock_get_role_instance_config.assert_called_once_with() + mock_wire_server_request.assert_called_once_with(mock.sentinel.data) + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_role_instance_config") + def test__get_extensions_config(self, mock_get_role_instance_config, + mock_wire_server_request): + mock_config = mock.Mock() + mock_get_role_instance_config.return_value = mock_config + mock_config.ExtensionsConfig.cdata = mock.sentinel.data + + self._azureservice._get_extensions_config() + mock_get_role_instance_config.assert_called_once_with() + mock_wire_server_request.assert_called_once_with(mock.sentinel.data) + + @mock.patch(MODPATH + "._wire_server_request") + @mock.patch(MODPATH + "._get_role_instance_config") + def test__get_full_config(self, mock_get_role_instance_config, + mock_wire_server_request): + mock_config = mock.Mock() + mock_get_role_instance_config.return_value = mock_config + mock_config.FullConfig.cdata = mock.sentinel.data + + self._azureservice._get_full_config() + mock_get_role_instance_config.assert_called_once_with() + mock_wire_server_request.assert_called_once_with(mock.sentinel.data) + + def test__create_transport_cert(self): + mock_cert_mgr = mock.Mock() + expected_certs = (mock.sentinel.thumbprint, mock.sentinel.cert) + + mock_cert_mgr.create_self_signed_cert.return_value = ( + mock.sentinel.thumbprint, mock.sentinel.cert) + with testutils.ConfPatcher(key='transport_cert_store_name', + value='fake_name', group="azure"): + with self._azureservice._create_transport_cert(mock_cert_mgr) as r: + self.assertEqual(r, expected_certs) + (mock_cert_mgr.create_self_signed_cert. + assert_called_once_with("CN=Cloudbase-Init AzureService Transport", + machine_keyset=True, + store_name="fake_name")) + (mock_cert_mgr.delete_certificate_from_store. + assert_called_once_with(mock.sentinel.thumbprint, + machine_keyset=True, + store_name="fake_name")) + + @mock.patch(MODPATH + "._wire_server_request") + def test__get_encoded_cert(self, mock_wire_server_request): + mock_cert_config = mock.Mock() + mock_transport_cert = mock.Mock() + mock_cert_url = mock.sentinel.cert_url + + mock_transport_cert.replace.return_value = mock.sentinel.transport_cert + mock_wire_server_request.return_value = mock_cert_config + mock_cert_config.CertificateFile.Data.cdata = mock.sentinel.cert_data + mock_cert_config.CertificateFile.Format.cdata = mock.sentinel.cert_fmt + + expected_headers = { + "x-ms-guest-agent-public-x509-cert": mock.sentinel.transport_cert} + expected_result = (mock.sentinel.cert_data, mock.sentinel.cert_fmt) + res = self._azureservice._get_encoded_cert(mock_cert_url, + mock_transport_cert) + (mock_wire_server_request. + assert_called_once_with(mock_cert_url, headers=expected_headers)) + self.assertEqual(res, expected_result) + + @mock.patch(MODPATH + "._get_versions") + def _test__check_version_header(self, mock_get_versions, version): + mock_version = mock.Mock() + mock_get_versions.return_value = mock_version + mock_version.Versions.Supported.Version = [version] + if self._azureservice_module.WIRE_SERVER_VERSION is not version: + self.assertRaises(exception.MetadaNotFoundException, + self._azureservice._check_version_header) + else: + self._azureservice._check_version_header() + self.assertEqual(self._azureservice._headers["x-ms-version"], + version) + + def test_check_version_header_unsupported_version(self): + version = "fake-version" + self._test__check_version_header(version=version) + + def test_check_version_header_supported(self): + version = self._azureservice_module.WIRE_SERVER_VERSION + self._test__check_version_header(version=version) + + @mock.patch(MODPATH + "._wire_server_request") + def test__get_versions(self, mock_server_request): + mock_server_request.return_value = mock.sentinel.version + res = self._azureservice._get_versions() + mock_server_request.assert_called_once_with("?comp=Versions") + self.assertEqual(res, mock.sentinel.version) + + @mock.patch(MODPATH + "._get_role_instance_id") + def test_get_instance_id(self, mock_get_role_instance_id): + mock_get_role_instance_id.return_value = mock.sentinel.id + self.assertEqual(self._azureservice.get_instance_id(), + mock.sentinel.id) + + @mock.patch("os.path.exists") + @mock.patch(MODPATH + "._get_config_set_drive_path") + def _test__get_ovf_env_path(self, mock_get_drives, mock_path_exists, + path_exists=True): + mock_get_drives.return_value = 'fake path' + mock_path_exists.return_value = path_exists + self._azureservice._osutils.get_logical_drives = mock_get_drives + if not path_exists: + self.assertRaises(exception.ItemNotFoundException, + self._azureservice._get_ovf_env_path) + else: + res = self._azureservice._get_ovf_env_path() + ovf_env_path = os.path.join( + "fake path", self._azureservice_module.OVF_ENV_FILENAME) + self.assertEqual(res, ovf_env_path) + mock_path_exists.assert_called_once_with(ovf_env_path) + mock_get_drives.assert_called_once_with() + + def test_get_ovf_env_path_exists(self): + self._test__get_ovf_env_path() + + def test_get_ovf_env_path_not_exists(self): + self._test__get_ovf_env_path(path_exists=False) + + @mock.patch(MODPATH + "._get_ovf_env_path") + def test_get_ovf_env(self, mock_get_ovf_env_path): + fake_xml = '' + mock_get_ovf_env_path.return_value = fake_xml + res = self._azureservice._get_ovf_env() + self.assertIsNotNone(res) + mock_get_ovf_env_path.assert_called_once_with() + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_admin_username(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + mock_win_prov.AdminUsername.cdata = mock.sentinel.cdata + res = self._azureservice.get_admin_username() + mock_get_ovf_env.assert_called_once_with() + self.assertEqual(res, mock.sentinel.cdata) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_admin_password(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + mock_win_prov.AdminPassword.cdata = mock.sentinel.cdata + res = self._azureservice.get_admin_password() + mock_get_ovf_env.assert_called_once_with() + self.assertEqual(res, mock.sentinel.cdata) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_host_name(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + mock_win_prov.ComputerName.cdata = mock.sentinel.cdata + res = self._azureservice.get_host_name() + mock_get_ovf_env.assert_called_once_with() + self.assertEqual(res, mock.sentinel.cdata) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_enable_automatic_updates(self, mock_get_ovf_env, + enable_updates=True): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + if not enable_updates: + mock_win_prov = mock.MagicMock(spec="") + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + res = self._azureservice.get_enable_automatic_updates() + mock_get_ovf_env.assert_called_once_with() + self.assertFalse(res) + + def test_get_enable_automatic_updates_no_updates(self): + self.test_get_enable_automatic_updates(enable_updates=False) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_winrm_listeners_configuration(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + mock_listener = mock.Mock() + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + mock_win_prov.WinRM.Listeners.Listener = [mock_listener] + mock_listener.Protocol.cdata = mock.sentinel.fake_protocol + (mock_listener.CertificateThumbprint. + cdata) = mock.sentinel.fake_thumbprint + + expected_result = [ + { + 'certificate_thumbprint': mock.sentinel.fake_thumbprint, + 'protocol': mock.sentinel.fake_protocol, + }] + res = self._azureservice.get_winrm_listeners_configuration() + mock_get_ovf_env.assert_called_once_with() + self.assertEqual(res, expected_result) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_vm_agent_package_provisioning_data(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_package_name = mock.sentinel.package_name + mock_get_ovf_env.return_value = mock_ovf_env + (mock_ovf_env.Environment.wa_PlatformSettingsSection. + PlatformSettings.GuestAgentPackageName.cdata) = mock_package_name + res = self._azureservice.get_vm_agent_package_provisioning_data() + expected_provisioning_data = { + 'provision': False, 'package_name': mock_package_name} + self.assertEqual(res, expected_provisioning_data) + + @mock.patch(MODPATH + "._get_ovf_env") + def test_get_kms_host(self, mock_get_ovf_env): + mock_ovf_env = mock.Mock() + mock_get_ovf_env.return_value = mock_ovf_env + self.assertTrue(self._azureservice.get_kms_host()) + + @mock.patch(MODPATH + "._get_ovf_env") + def _test_get_use_avma_licensing(self, mock_get_ovf_env, use_avma): + mock_ovf_env = mock.Mock() + mock_ovf_env.Environment = mock.Mock() + plat_sett_section = mock.Mock() + if not use_avma: + plat_sett_section.PlatformSettings = mock.MagicMock(spec="") + mock_ovf_env.Environment.wa_PlatformSettingsSection = plat_sett_section + mock_get_ovf_env.return_value = mock_ovf_env + + self.assertFalse(self._azureservice.get_use_avma_licensing()) + mock_get_ovf_env.assert_called_once_with() + + def test_get_use_avma_licensing(self): + self._test_get_use_avma_licensing(use_avma=True) + + def test_get_use_avma_licensing_no_use_avma(self): + self._test_get_use_avma_licensing(use_avma=False) + + @mock.patch(MODPATH + "._get_ovf_env") + @mock.patch(MODPATH + "._check_version_header") + @mock.patch(MODPATH + "._get_wire_server_endpoint_address") + def _test_load(self, mock_get_endpoint_address, + mock_check_version_header, mock_get_ovf_env, + endpoint_side_effect=None, load_side_effect=None): + if endpoint_side_effect: + mock_get_endpoint_address.side_effect = endpoint_side_effect + expected_logging = ["Azure WireServer endpoint not found"] + with self._logsnatcher: + res = self._azureservice.load() + self.assertFalse(res) + self.assertEqual(self._logsnatcher.output, expected_logging) + mock_get_endpoint_address.assert_called_once_with() + return + + mock_endpoint = mock.sentinel.endpoint + mock_get_endpoint_address.return_value = mock_endpoint + if load_side_effect: + mock_check_version_header.side_effect = load_side_effect + res = self._azureservice.load() + self.assertFalse(res) + return + else: + res = self._azureservice.load() + self.assertTrue(res) + self.assertIn(str(mock_endpoint), self._azureservice._base_url) + mock_check_version_header.assert_called_once_with() + mock_get_ovf_env.assert_called_once_with() + return + + def test_load_no_endpoint(self): + self._test_load(endpoint_side_effect=Exception) + + def test_load_exception(self): + exc = Exception("Fake exception") + self._test_load(load_side_effect=exc) + + def test_load(self): + self._test_load() + + @mock.patch('os.path.exists') + def _test_get_config_set_drive_path(self, mock_path_exists, + path_exists=True): + self._azureservice._set_config_drive_path = None + mock_osutils = mock.Mock() + mock_osutils.get_logical_drives.return_value = ['fake path'] * 3 + self._azureservice._osutils = mock_osutils + mock_path_exists.side_effect = [False] * 2 + [path_exists] + if path_exists: + result = self._azureservice._get_config_set_drive_path() + self.assertEqual(result, 'fake path') + else: + self.assertRaises(exception.ItemNotFoundException, + self._azureservice._get_config_set_drive_path) + self.assertEqual(mock_path_exists.call_count, 3) + + def test_get_config_set_drive_path(self): + self._test_get_config_set_drive_path() + + def test_get_config_set_drive_path_not_exists(self): + self._test_get_config_set_drive_path(path_exists=False) + + @mock.patch(MODPATH + '._get_ovf_env') + def test_check_ovf_env_custom_data(self, mock_get_ovf_env, + custom_data=True): + mock_ovf_env = mock.Mock() + mock_prov_section = mock.Mock() + mock_win_prov = mock.Mock() + mock_win_prov.PlatformSettings = mock.Mock() + if not custom_data: + mock_win_prov.PlatformSettings = mock.MagicMock(spec="") + mock_get_ovf_env.return_value = mock_ovf_env + mock_ovf_env.Environment.wa_ProvisioningSection = mock_prov_section + mock_prov_section.WindowsProvisioningConfigurationSet = mock_win_prov + res = self._azureservice._check_ovf_env_custom_data() + mock_get_ovf_env.assert_called_once_with() + self.assertTrue(res) + + @mock.patch(MODPATH + '._check_ovf_env_custom_data') + def test_get_user_data_ItemNotFound(self, mock_check_custom_data): + mock_check_custom_data.return_value = True + self._azureservice._config_set_drive_path = "fake path" + self.assertRaises(exception.ItemNotFoundException, + self._azureservice.get_user_data) + + @mock.patch(MODPATH + '._check_ovf_env_custom_data') + def test_get_user_data_NoMetadataException(self, mock_check_custom_data): + mock_check_custom_data.return_value = False + self._azureservice._config_set_drive_path = "fake path" + self.assertRaises( + self._azureservice_module.base.NotExistingMetadataException, + self._azureservice.get_user_data) + + @mock.patch(MODPATH + '._get_role_instance_config') + def test_get_server_certs_no_certs(self, mock_get_instance_config): + mock_get_instance_config.return_value = mock.MagicMock(spec="") + res = self._azureservice.get_server_certs() + self.assertEqual(res, []) + + @mock.patch(MODPATH + '._get_hosting_environment') + @mock.patch(MODPATH + '._get_encoded_cert') + @mock.patch(MODPATH + '._get_role_instance_config') + @mock.patch(MODPATH + '._create_transport_cert') + @mock.patch('cloudbaseinit.utils.windows.x509.CryptoAPICertManager') + def test_get_server_certs(self, mock_cert_manager, mock_create_cert, + mock_get_config, mock_get_encoded_cert, + mock_get_hosting_env): + cert_model = { + "storeName": mock.sentinel.storeName, + "configurationLevel": mock.sentinel.configurationLevel, + "certificateId": mock.sentinel.certificateId, + "name": mock.sentinel.name + } + mock_cert_mgr = mock.Mock() + mock_cert_mgr.decode_pkcs7_base64_blob.return_value = \ + mock.sentinel.pfx_data + mock_cert_manager.return_value = mock_cert_mgr + mock_create_cert.return_value.__enter__.return_value = \ + (mock.sentinel.thumbprint, mock.sentinel.cert) + mock_get_encoded_cert.return_value = \ + (mock.sentinel.cert_data, mock.sentinel.cert_format) + mock_host_env = mock.Mock() + mock_host_env_config = mock_host_env.HostingEnvironmentConfig + mock_host_env_config.StoredCertificates.StoredCertificate = \ + [cert_model] + mock_get_hosting_env.return_value = mock_host_env + + res = self._azureservice.get_server_certs() + expected_result = [{ + "store_name": mock.sentinel.storeName, + "store_location": mock.sentinel.configurationLevel, + "certificate_id": mock.sentinel.certificateId, + "name": mock.sentinel.name, + "pfx_data": mock.sentinel.pfx_data, + }] + self.assertEqual(res, expected_result) + self.assertEqual(mock_create_cert.call_count, 1) + self.assertEqual(mock_get_encoded_cert.call_count, 1) + self.assertEqual(mock_cert_mgr.decode_pkcs7_base64_blob.call_count, 1) + mock_cert_manager.assert_called_once_with() + mock_get_config.assert_called_once_with() + mock_get_hosting_env.assert_called_once_with() + + @mock.patch(MODPATH + '.get_user_data') + def test_get_decoded_user_data(self, mock_get_user_data): + mock_get_user_data.return_value = mock.sentinel.user_data + res = self._azureservice.get_decoded_user_data() + self.assertEqual(res, mock.sentinel.user_data) + + @mock.patch(MODPATH + '.get_content') + def test_get_ephemeral_disk_data_loss_warning(self, mock_get_content): + mock_get_content.return_value = mock.sentinel.content + res = self._azureservice.get_ephemeral_disk_data_loss_warning() + self.assertEqual(res, mock.sentinel.content) + mock_get_content.assert_called_once_with( + self._azureservice_module.DATALOSS_WARNING_PATH) + + @mock.patch(MODPATH + "._get_role_instance_id") + @mock.patch(MODPATH + "._get_container_id") + def _test_get_role_properties_xml(self, mock_get_container_id, + mock_get_role_instance_id, + properties): + mock_get_container_id.return_value = "fake container id" + mock_get_role_instance_id.return_value = "fake instance id" + + res = self._azureservice._get_role_properties_xml(properties) + expected_properties = "" + property_template = '' + result_template = ("\n" + "" + "{container_id}" + "{instance_id}" + "{properties}" + "") + if properties: + expected_properties = "" + for name, value in properties.items(): + expected_properties += property_template.format( + property_name=name, + value=value) + expected_properties += "" + else: + expected_properties = "" + expected_result = result_template.format( + container_id=mock_get_container_id.return_value, + instance_id=mock_get_role_instance_id.return_value, + properties=expected_properties) + self.assertEqual(encoding.get_as_string(res), expected_result) + + def test_get_role_properties_xml_no_properties(self): + self._test_get_role_properties_xml(properties={}) + + def test_get_role_properties_xml(self): + properties = { + "fake property 1": "fake value 1", + "fake property 2": "fake value 2" + } + self._test_get_role_properties_xml(properties=properties) diff --git a/cloudbaseinit/tests/utils/test_dhcp.py b/cloudbaseinit/tests/utils/test_dhcp.py index 6c271d99..5a911309 100644 --- a/cloudbaseinit/tests/utils/test_dhcp.py +++ b/cloudbaseinit/tests/utils/test_dhcp.py @@ -50,12 +50,12 @@ class DHCPUtilsTests(unittest.TestCase): data += b'\x00' * 128 data += dhcp._DHCP_COOKIE data += b'\x35\x01\x01' - data += b'\x3c' + struct.pack('b', len('fake id')) + 'fake id'.encode( + data += b'\x3c' + struct.pack('B', len('fake id')) + 'fake id'.encode( 'ascii') data += b'\x3d\x07\x01' data += fake_mac_address_b - data += b'\x37' + struct.pack('b', len([100])) - data += struct.pack('b', 100) + data += b'\x37' + struct.pack('B', len([100])) + data += struct.pack('B', 100) data += dhcp._OPTION_END response = dhcp._get_dhcp_request_data( @@ -154,7 +154,6 @@ class DHCPUtilsTests(unittest.TestCase): socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) mock_socket().bind.assert_called_once_with(('', 68)) mock_socket().settimeout.assert_called_once_with(5) - mock_socket().connect.assert_called_once_with(('fake host', 67)) mock_socket().getsockname.assert_called_once_with() mock_get_mac_address_by_local_ip.assert_called_once_with( 'fake local ip') @@ -162,7 +161,8 @@ class DHCPUtilsTests(unittest.TestCase): 'fake mac', ['fake option'], 'cloudbase-init') - mock_socket().send.assert_called_once_with('fake data') + mock_socket().sendto.assert_called_once_with( + 'fake data', ('fake host', 67)) mock_socket().recv.assert_called_once_with(1024) mock_parse_dhcp_reply.assert_called_once_with(mock_socket().recv(), 'fake int') diff --git a/cloudbaseinit/tests/utils/test_network.py b/cloudbaseinit/tests/utils/test_network.py index 68ebd941..d685dca8 100644 --- a/cloudbaseinit/tests/utils/test_network.py +++ b/cloudbaseinit/tests/utils/test_network.py @@ -88,3 +88,11 @@ class NetworkUtilsTest(unittest.TestCase): } for v6, v4 in netmask_map.items(): self.assertEqual(v4, network.netmask6_to_4_truncate(v6)) + + @mock.patch('socket.socket') + def test_get_local_ip(self, mock_socket): + mock_socket.return_value = mock.Mock() + mock_socket().getsockname.return_value = ["fake name"] + res = network.get_local_ip("fake address") + self.assertEqual(res, "fake name") + mock_socket().connect.assert_called_with(("fake address", 8000)) diff --git a/cloudbaseinit/utils/dhcp.py b/cloudbaseinit/utils/dhcp.py index 6cc5a8e1..b6d9032d 100644 --- a/cloudbaseinit/utils/dhcp.py +++ b/cloudbaseinit/utils/dhcp.py @@ -21,6 +21,7 @@ import time from oslo_log import log as oslo_logging +from cloudbaseinit.utils import network _DHCP_COOKIE = b'\x63\x82\x53\x63' _OPTION_END = b'\xff' @@ -56,20 +57,20 @@ def _get_dhcp_request_data(id_req, mac_address, requested_options, if vendor_id: vendor_id_b = vendor_id.encode('ascii') - data += b'\x3c' + struct.pack('b', len(vendor_id_b)) + vendor_id_b + data += b'\x3c' + struct.pack('B', len(vendor_id_b)) + vendor_id_b data += b'\x3d\x07\x01' + mac_address_b - data += b'\x37' + struct.pack('b', len(requested_options)) + data += b'\x37' + struct.pack('B', len(requested_options)) for option in requested_options: - data += struct.pack('b', option) + data += struct.pack('B', option) data += _OPTION_END return data def _parse_dhcp_reply(data, id_req): - message_type = struct.unpack('b', data[0:1])[0] + message_type = struct.unpack('B', data[0:1])[0] if message_type != 2: return False, {} @@ -86,8 +87,8 @@ def _parse_dhcp_reply(data, id_req): i = 240 data_len = len(data) while i < data_len and data[i:i + 1] != _OPTION_END: - id_option = struct.unpack('b', data[i:i + 1])[0] - option_data_len = struct.unpack('b', data[i + 1:i + 2])[0] + id_option = struct.unpack('B', data[i:i + 1])[0] + option_data_len = struct.unpack('B', data[i + 1:i + 2])[0] i += 2 options[id_option] = data[i: i + option_data_len] i += option_data_len @@ -120,7 +121,7 @@ def _bind_dhcp_client_socket(s, max_bind_attempts, bind_retry_interval): time.sleep(bind_retry_interval) -def get_dhcp_options(dhcp_host, requested_options=[], timeout=5.0, +def get_dhcp_options(dhcp_host=None, requested_options=[], timeout=5.0, vendor_id='cloudbase-init', max_bind_attempts=10, bind_retry_interval=3): id_req = random.randint(0, 2 ** 32 - 1) @@ -128,18 +129,21 @@ def get_dhcp_options(dhcp_host, requested_options=[], timeout=5.0, s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if not dhcp_host: + s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + try: _bind_dhcp_client_socket(s, max_bind_attempts, bind_retry_interval) s.settimeout(timeout) - s.connect((dhcp_host, 67)) - local_ip_addr = s.getsockname()[0] + local_ip_addr = network.get_local_ip(dhcp_host) mac_address = _get_mac_address_by_local_ip(local_ip_addr) data = _get_dhcp_request_data(id_req, mac_address, requested_options, vendor_id) - s.send(data) + + s.sendto(data, (dhcp_host or "", 67)) start = datetime.datetime.now() now = start diff --git a/cloudbaseinit/utils/network.py b/cloudbaseinit/utils/network.py index 08889b32..81fca7b4 100644 --- a/cloudbaseinit/utils/network.py +++ b/cloudbaseinit/utils/network.py @@ -29,6 +29,12 @@ LOG = oslo_logging.getLogger(__name__) MAX_URL_CHECK_RETRIES = 3 +def get_local_ip(address=None): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect((address or "", 8000)) + return s.getsockname()[0] + + def check_url(url, retries_count=MAX_URL_CHECK_RETRIES): for i in range(0, MAX_URL_CHECK_RETRIES): try: diff --git a/requirements.txt b/requirements.txt index 0a5c7f6d..d340bf08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ oauthlib netifaces PyYAML requests +untangle==1.1.1 pywin32;sys_platform=="win32" comtypes;sys_platform=="win32" wmi;sys_platform=="win32"