diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index 605ffa1f..290c2e5a 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -912,19 +912,21 @@ class WindowsUtils(base.BaseOSUtils): return reboot_required @staticmethod - def _fix_network_adapter_dhcp(interface_name, enable_dhcp, address_family): - interface_id = WindowsUtils._get_network_adapter(interface_name).GUID - tcpip_key = "Tcpip6" if address_family == AF_INET6 else "Tcpip" + def _fix_network_adapter_dhcp(interface_name, + enable_dhcp, + address_family): + enable_dhcp_value = 1 if enable_dhcp else 0 - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, - "SYSTEM\\CurrentControlSet\\services\\%(tcpip_key)s\\" - "Parameters\\Interfaces\\%(interface_id)s" % - {"tcpip_key": tcpip_key, "interface_id": interface_id}, - 0, winreg.KEY_SET_VALUE) as key: - winreg.SetValueEx( - key, 'EnableDHCP', 0, winreg.REG_DWORD, - 1 if enable_dhcp else 0) + conn = wmi.WMI(moniker='//./root/standardcimv2') + net_interface = conn.MSFT_NetIPInterface( + InterfaceAlias=interface_name, AddressFamily=address_family) + if not len(net_interface): + raise exception.ItemNotFoundException( + 'Network interface with name "%s" not found' % + interface_name) + net_interface = net_interface[0] + net_interface.Dhcp = enable_dhcp_value + net_interface.put() @staticmethod def _set_interface_dns(interface_name, dnsnameservers): diff --git a/cloudbaseinit/tests/osutils/test_windows.py b/cloudbaseinit/tests/osutils/test_windows.py index f7a99847..ce3a0cf7 100644 --- a/cloudbaseinit/tests/osutils/test_windows.py +++ b/cloudbaseinit/tests/osutils/test_windows.py @@ -2000,6 +2000,40 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): mock.sentinel.mac_address, mock.sentinel.dhcp_server)], response) + def test_fix_network_adapter_dhcp(self): + self._test_fix_network_adapter_dhcp(True) + + def test_fix_network_adapter_dhcp_no_network_adapter(self): + self._test_fix_network_adapter_dhcp(False) + + def _test_fix_network_adapter_dhcp(self, no_net_interface_found): + mock_interface_name = "eth12" + mock_enable_dhcp = True + mock_address_family = self.windows_utils.AF_INET + + conn = self._wmi_mock.WMI.return_value + existing_net_interface = mock.Mock() + existing_net_interface.Dhcp = 0 + + if not no_net_interface_found: + conn.MSFT_NetIPInterface.return_value = [existing_net_interface] + + if no_net_interface_found: + with self.assertRaises(exception.ItemNotFoundException): + self._winutils._fix_network_adapter_dhcp( + mock_interface_name, mock_enable_dhcp, + mock_address_family) + else: + self._winutils._fix_network_adapter_dhcp( + mock_interface_name, mock_enable_dhcp, + mock_address_family) + + conn.MSFT_NetIPInterface.assert_called_once_with( + InterfaceAlias=mock_interface_name, + AddressFamily=mock_address_family) + self.assertEqual(existing_net_interface.Dhcp, 1) + existing_net_interface.put.assert_called_once() + @mock.patch('cloudbaseinit.osutils.windows.WindowsUtils' '.check_sysnative_dir_exists') @mock.patch('cloudbaseinit.osutils.windows.WindowsUtils'