diff --git a/ironic/drivers/modules/snmp.py b/ironic/drivers/modules/snmp.py index 08320b26c5..e09807ab42 100644 --- a/ironic/drivers/modules/snmp.py +++ b/ironic/drivers/modules/snmp.py @@ -368,6 +368,38 @@ def _get_client(snmp_info): snmp_info.get("context_name")) +_memoized = {} + + +def memoize(f): + def memoized(self, node_info): + hashable_node_info = frozenset((key, val) + for key, val in node_info.items() + if key is not 'outlet') + if hashable_node_info not in _memoized: + _memoized[hashable_node_info] = f(self) + return _memoized[hashable_node_info] + return memoized + + +def retry_on_outdated_cache(f): + def wrapper(self): + try: + return f(self) + + except exception.SNMPFailure: + hashable_node_info = ( + frozenset((key, val) + for key, val in self.snmp_info.items() + if key is not 'outlet') + ) + del _memoized[hashable_node_info] + self.driver = self._get_pdu_driver(self.snmp_info) + return f(self) + + return wrapper + + @six.add_metaclass(abc.ABCMeta) class SNMPDriverBase(object): """SNMP power driver base class. @@ -737,7 +769,9 @@ class SNMPDriverAuto(SNMPDriverBase): def __init__(self, *args, **kwargs): super(SNMPDriverAuto, self).__init__(*args, **kwargs) + self.driver = self._get_pdu_driver(*args, **kwargs) + def _get_pdu_driver(self, *args, **kwargs): drivers_map = {} for name, obj in DRIVER_CLASSES.items(): @@ -756,7 +790,7 @@ class SNMPDriverAuto(SNMPDriverBase): LOG.debug("SNMP driver mapping %(system_id)s -> %(name)s", {'system_id': system_id, 'name': obj.__name__}) - system_id = self.client.get(self.SYS_OBJ_OID) + system_id = self._fetch_driver(*args, **kwargs) LOG.debug("SNMP device reports sysObjectID %(system_id)s", {'system_id': system_id}) @@ -770,8 +804,7 @@ class SNMPDriverAuto(SNMPDriverBase): LOG.debug("Chosen SNMP driver %(name)s based on sysObjectID " "prefix %(system_id_prefix)s", {Driver.__name__, system_id_prefix}) - self.driver = Driver(*args, **kwargs) - return + return Driver(*args, **kwargs) except KeyError: system_id_prefix = system_id_prefix[:-1] @@ -780,16 +813,22 @@ class SNMPDriverAuto(SNMPDriverBase): "SNMPDriverAuto: no driver matching %(system_id)s") % {'system_id': system_id}) + @retry_on_outdated_cache def _snmp_power_state(self): current_power_state = self.driver._snmp_power_state() return current_power_state + @retry_on_outdated_cache def _snmp_power_on(self): return self.driver._snmp_power_on() + @retry_on_outdated_cache def _snmp_power_off(self): return self.driver._snmp_power_off() + @memoize + def _fetch_driver(self): + return self.client.get(self.SYS_OBJ_OID) # A dictionary of supported drivers keyed by snmp_driver attribute DRIVER_CLASSES = { diff --git a/ironic/tests/unit/drivers/modules/test_snmp.py b/ironic/tests/unit/drivers/modules/test_snmp.py index 382af6505a..7a1e404e48 100644 --- a/ironic/tests/unit/drivers/modules/test_snmp.py +++ b/ironic/tests/unit/drivers/modules/test_snmp.py @@ -27,6 +27,7 @@ from ironic.common import exception from ironic.common import states from ironic.conductor import task_manager from ironic.drivers.modules import snmp +from ironic.drivers.modules.snmp import SNMPDriverAuto from ironic.tests import base from ironic.tests.unit.db import base as db_base from ironic.tests.unit.db import utils as db_utils @@ -664,6 +665,7 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): def setUp(self): super(SNMPDeviceDriverTestCase, self).setUp() self.config(enabled_power_interfaces=['fake', 'snmp']) + snmp._memoized = {} self.node = obj_utils.get_test_node( self.context, power_interface='snmp', @@ -1341,6 +1343,7 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): mock_client = mock_get_client.return_value mock_client.reset_mock() mock_client.get.return_value = sys_obj_oid + snmp._memoized.clear() self._update_driver_info(snmp_driver="auto") driver = snmp._get_driver(self.node) @@ -1361,7 +1364,8 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): mock_client = mock_get_client.return_value mock_client.reset_mock() mock_client.get.return_value = sys_obj_oid - self._update_driver_info(snmp_driver="auto") + snmp._memoized.clear() + self._update_driver_info(snmp_driver="auto",) driver = snmp._get_driver(self.node) second_node = obj_utils.get_test_node( @@ -1381,7 +1385,8 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): mock_client = mock_get_client.return_value mock_client.reset_mock() mock_client.get.return_value = sys_obj_oid - self._update_driver_info(snmp_driver="auto") + snmp._memoized.clear() + self._update_driver_info(snmp_driver="auto",) driver = snmp._get_driver(self.node) second_node = obj_utils.get_test_node( @@ -1403,6 +1408,7 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): mock_client = mock_get_client.return_value mock_client.reset_mock() mock_client.get.return_value = sys_obj_oid + snmp._memoized.clear() self._update_driver_info(snmp_driver="auto") driver = snmp._get_driver(self.node) @@ -1425,6 +1431,7 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): mock_client = mock_get_client.return_value mock_client.reset_mock() mock_client.get.side_effect = [sys_obj_oid, sys_obj_oid] + snmp._memoized.clear() self._update_driver_info(snmp_driver="auto") driver = snmp._get_driver(self.node) @@ -1576,6 +1583,49 @@ class SNMPDeviceDriverTestCase(db_base.DbTestCase): def test_baytech_mrp27_power_reset(self, mock_get_client): self._test_simple_device_power_reset('baytech_mrp27', mock_get_client) + def test_auto_power_on_cached_driver(self, mock_get_client): + mock_client = mock_get_client.return_value + mock_client.reset_mock() + mock_client.get.return_value = (1, 3, 6, 1, 4, 1, 318, 1, 1, 4) + self._update_driver_info(snmp_driver="auto") + + for i in range(5): + snmp._get_driver(self.node) + + mock_client.get.assert_called_once_with(SNMPDriverAuto.SYS_OBJ_OID) + + @mock.patch.object(snmp.SNMPDriverAPCRackPDU, "_snmp_power_on") + def test_snmp_auto_cache_supports_pdu_replacement( + self, broken_pdu_power_on_mock, mock_get_client): + + broken_pdu_exception = exception.SNMPFailure(operation=1, error=2) + broken_pdu_power_on_mock.side_effect = broken_pdu_exception + + broken_pdu_oid = (1, 3, 6, 1, 4, 1, 318, 1, 1, 12) + hashable_node_info = frozenset( + {('address', '1.2.3.4'), ('port', 161), ('community', 'public'), + ('version', '1'), ('driver', 'auto')}) + snmp._memoized = {hashable_node_info: broken_pdu_oid} + + self._update_driver_info(snmp_driver="auto") + + mock_client = mock_get_client.return_value + mock_client.get.return_value = broken_pdu_oid + + driver = snmp._get_driver(self.node) + + mock_client.reset_mock() + replacement_pdu_oid = (1, 3, 6, 1, 4, 1, 318, 1, 1, 4) + mock_client.get.side_effect = [replacement_pdu_oid, + driver.driver.value_power_on] + + pstate = driver.power_on() + + mock_client.set.assert_called_once_with( + driver.driver.oid, driver.driver.value_power_on) + + self.assertEqual(states.POWER_ON, pstate) + @mock.patch.object(snmp, '_get_driver', autospec=True) class SNMPDriverTestCase(db_base.DbTestCase):