Pass correct context to get_by_compute_node()
A recent change to the PciDevTracker class allowed for the passing of the compute node ID to the __init__() method, where PciDevsList.get_by_compute_node() was called with a 'context' parameter, which wasn't defined, resulting in the imported context module being passed instead. This change requires that the context be passed in to the __init__() for the PciDevTracker class, and that that be used to create the PciDevsList. The existing import of the context module is no longer needed in the pci/manager.py file, so the conflict is no longer a problem. The only place in the code that currently instantiates a PciDevTracker object is in the ResourceTracker, so that has been updated to pass in the context. A unit test to check for context has also been added. Change-Id: Id136eabacb00e4381c03f12d8484fc90a5eb48b1 Closes-Bug: #1408480
This commit is contained in:
parent
5b3bfdbc14
commit
50ee9dd76e
|
@ -345,9 +345,6 @@ class ResourceTracker(object):
|
|||
@utils.synchronized(COMPUTE_RESOURCE_SEMAPHORE)
|
||||
def _update_available_resource(self, context, resources):
|
||||
if 'pci_passthrough_devices' in resources:
|
||||
if not self.pci_tracker:
|
||||
self.pci_tracker = pci_manager.PciDevTracker()
|
||||
|
||||
devs = []
|
||||
for dev in jsonutils.loads(resources.pop(
|
||||
'pci_passthrough_devices')):
|
||||
|
@ -357,6 +354,10 @@ class ResourceTracker(object):
|
|||
if self.pci_filter.device_assignable(dev):
|
||||
devs.append(dev)
|
||||
|
||||
if not self.pci_tracker:
|
||||
n_id = self.compute_node['id'] if self.compute_node else None
|
||||
self.pci_tracker = pci_manager.PciDevTracker(context,
|
||||
node_id=n_id)
|
||||
self.pci_tracker.set_hvdevs(devs)
|
||||
|
||||
# Grab all instances assigned to this node:
|
||||
|
|
|
@ -20,7 +20,6 @@ from oslo_log import log as logging
|
|||
|
||||
from nova.compute import task_states
|
||||
from nova.compute import vm_states
|
||||
from nova import context
|
||||
from nova import exception
|
||||
from nova.i18n import _LW
|
||||
from nova import objects
|
||||
|
@ -43,7 +42,7 @@ class PciDevTracker(object):
|
|||
information is updated to DB when devices information is changed.
|
||||
"""
|
||||
|
||||
def __init__(self, node_id=None):
|
||||
def __init__(self, context, node_id=None):
|
||||
"""Create a pci device tracker.
|
||||
|
||||
If a node_id is passed in, it will fetch pci devices information
|
||||
|
|
|
@ -21,6 +21,7 @@ import mock
|
|||
from oslo_serialization import jsonutils
|
||||
|
||||
from nova.compute import claims
|
||||
from nova import context
|
||||
from nova import db
|
||||
from nova import exception
|
||||
from nova import objects
|
||||
|
@ -43,9 +44,11 @@ class FakeResourceHandler(object):
|
|||
class DummyTracker(object):
|
||||
icalled = False
|
||||
rcalled = False
|
||||
pci_tracker = pci_manager.PciDevTracker()
|
||||
ext_resources_handler = FakeResourceHandler()
|
||||
|
||||
def __init__(self):
|
||||
self.new_pci_tracker()
|
||||
|
||||
def abort_instance_claim(self, *args, **kwargs):
|
||||
self.icalled = True
|
||||
|
||||
|
@ -53,7 +56,8 @@ class DummyTracker(object):
|
|||
self.rcalled = True
|
||||
|
||||
def new_pci_tracker(self):
|
||||
self.pci_tracker = pci_manager.PciDevTracker()
|
||||
ctxt = context.RequestContext('testuser', 'testproject')
|
||||
self.pci_tracker = pci_manager.PciDevTracker(ctxt)
|
||||
|
||||
|
||||
@mock.patch('nova.objects.InstancePCIRequests.get_by_instance_uuid',
|
||||
|
|
|
@ -233,6 +233,10 @@ class BaseTestCase(test.TestCase):
|
|||
'flavor_get', self._fake_flavor_get)
|
||||
|
||||
self.host = 'fakehost'
|
||||
self.compute = self._create_compute_node()
|
||||
self.updated = False
|
||||
self.deleted = False
|
||||
self.update_call_count = 0
|
||||
|
||||
def _create_compute_node(self, values=None):
|
||||
compute = {
|
||||
|
@ -427,6 +431,13 @@ class BaseTestCase(test.TestCase):
|
|||
# only used in the subsequent notification:
|
||||
return (instance, instance)
|
||||
|
||||
def _fake_compute_node_update(self, ctx, compute_node_id, values,
|
||||
prune_stats=False):
|
||||
self.update_call_count += 1
|
||||
self.updated = True
|
||||
self.compute.update(values)
|
||||
return self.compute
|
||||
|
||||
def _driver(self):
|
||||
return FakeVirtDriver()
|
||||
|
||||
|
@ -440,6 +451,7 @@ class BaseTestCase(test.TestCase):
|
|||
driver = self._driver()
|
||||
|
||||
tracker = resource_tracker.ResourceTracker(host, driver, node)
|
||||
tracker.compute_node = self._create_compute_node()
|
||||
tracker.ext_resources_handler = \
|
||||
resources.ResourceHandler(RESOURCE_NAMES, True)
|
||||
return tracker
|
||||
|
@ -512,6 +524,8 @@ class MissingServiceTestCase(BaseTestCase):
|
|||
self.tracker = self._tracker()
|
||||
|
||||
def test_missing_service(self):
|
||||
self.tracker.compute_node = None
|
||||
self.tracker._get_service = mock.Mock(return_value=None)
|
||||
self.tracker.update_available_resource(self.context)
|
||||
self.assertTrue(self.tracker.disabled)
|
||||
|
||||
|
@ -543,6 +557,7 @@ class MissingComputeNodeTestCase(BaseTestCase):
|
|||
raise exception.ComputeHostNotFound(host=host)
|
||||
|
||||
def test_create_compute_node(self):
|
||||
self.tracker.compute_node = None
|
||||
self.tracker.update_available_resource(self.context)
|
||||
self.assertTrue(self.created)
|
||||
|
||||
|
@ -558,10 +573,6 @@ class BaseTrackerTestCase(BaseTestCase):
|
|||
# database models and a compatible compute driver:
|
||||
super(BaseTrackerTestCase, self).setUp()
|
||||
|
||||
self.updated = False
|
||||
self.deleted = False
|
||||
self.update_call_count = 0
|
||||
|
||||
self.tracker = self._tracker()
|
||||
self._migrations = {}
|
||||
|
||||
|
@ -582,11 +593,12 @@ class BaseTrackerTestCase(BaseTestCase):
|
|||
patcher = pci_fakes.fake_pci_whitelist()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
self.stubs.Set(self.tracker.scheduler_client, 'update_resource_stats',
|
||||
self._fake_compute_node_update)
|
||||
self._init_tracker()
|
||||
self.limits = self._limits()
|
||||
|
||||
def _fake_service_get_by_compute_host(self, ctx, host):
|
||||
self.compute = self._create_compute_node()
|
||||
self.service = self._create_service(host, compute=self.compute)
|
||||
return self.service
|
||||
|
||||
|
@ -721,7 +733,8 @@ class SchedulerClientTrackerTestCase(BaseTrackerTestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(SchedulerClientTrackerTestCase, self).setUp()
|
||||
self.tracker.scheduler_client.update_resource_stats = mock.Mock()
|
||||
self.tracker.scheduler_client.update_resource_stats = mock.Mock(
|
||||
side_effect=self._fake_compute_node_update)
|
||||
|
||||
def test_create_resource(self):
|
||||
self.tracker._write_ext_resources = mock.Mock()
|
||||
|
|
|
@ -17,6 +17,7 @@ import copy
|
|||
|
||||
import mock
|
||||
|
||||
import nova
|
||||
from nova.compute import task_states
|
||||
from nova.compute import vm_states
|
||||
from nova import context
|
||||
|
@ -109,6 +110,7 @@ class PciDevTrackerTestCase(test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(PciDevTrackerTestCase, self).setUp()
|
||||
self.fake_context = context.get_admin_context()
|
||||
self.stubs.Set(db, 'pci_device_get_all_by_node',
|
||||
self._fake_get_pci_devices)
|
||||
# The fake_pci_whitelist must be called before creating the fake
|
||||
|
@ -116,7 +118,7 @@ class PciDevTrackerTestCase(test.TestCase):
|
|||
patcher = pci_fakes.fake_pci_whitelist()
|
||||
self.addCleanup(patcher.stop)
|
||||
self._create_fake_instance()
|
||||
self.tracker = manager.PciDevTracker(1)
|
||||
self.tracker = manager.PciDevTracker(self.fake_context, 1)
|
||||
|
||||
def test_pcidev_tracker_create(self):
|
||||
self.assertEqual(len(self.tracker.pci_devs), 3)
|
||||
|
@ -126,9 +128,16 @@ class PciDevTrackerTestCase(test.TestCase):
|
|||
self.assertEqual(len(self.tracker.stats.pools), 3)
|
||||
self.assertEqual(self.tracker.node_id, 1)
|
||||
|
||||
def test_pcidev_tracker_create_no_nodeid(self):
|
||||
self.tracker = manager.PciDevTracker()
|
||||
@mock.patch.object(nova.objects.PciDeviceList, 'get_by_compute_node')
|
||||
def test_pcidev_tracker_create_no_nodeid(self, mock_get_cn):
|
||||
self.tracker = manager.PciDevTracker(self.fake_context)
|
||||
self.assertEqual(len(self.tracker.pci_devs), 0)
|
||||
self.assertFalse(mock_get_cn.called)
|
||||
|
||||
@mock.patch.object(nova.objects.PciDeviceList, 'get_by_compute_node')
|
||||
def test_pcidev_tracker_create_with_nodeid(self, mock_get_cn):
|
||||
self.tracker = manager.PciDevTracker(self.fake_context, node_id=1)
|
||||
mock_get_cn.assert_called_once_with(self.fake_context, 1)
|
||||
|
||||
def test_set_hvdev_new_dev(self):
|
||||
fake_pci_3 = dict(fake_pci, address='0000:00:00.4', vendor_id='v2')
|
||||
|
@ -278,30 +287,28 @@ class PciDevTrackerTestCase(test.TestCase):
|
|||
|
||||
def test_save(self):
|
||||
self.stubs.Set(db, "pci_device_update", self._fake_pci_device_update)
|
||||
ctxt = context.get_admin_context()
|
||||
fake_pci_v3 = dict(fake_pci, address='0000:00:00.2', vendor_id='v3')
|
||||
fake_pci_devs = [copy.deepcopy(fake_pci), copy.deepcopy(fake_pci_2),
|
||||
copy.deepcopy(fake_pci_v3)]
|
||||
self.tracker.set_hvdevs(fake_pci_devs)
|
||||
self.update_called = 0
|
||||
self.tracker.save(ctxt)
|
||||
self.tracker.save(self.fake_context)
|
||||
self.assertEqual(self.update_called, 3)
|
||||
|
||||
def test_save_removed(self):
|
||||
self.stubs.Set(db, "pci_device_update", self._fake_pci_device_update)
|
||||
self.stubs.Set(db, "pci_device_destroy", self._fake_pci_device_destroy)
|
||||
self.destroy_called = 0
|
||||
ctxt = context.get_admin_context()
|
||||
self.assertEqual(len(self.tracker.pci_devs), 3)
|
||||
dev = self.tracker.pci_devs[0]
|
||||
self.update_called = 0
|
||||
device.remove(dev)
|
||||
self.tracker.save(ctxt)
|
||||
self.tracker.save(self.fake_context)
|
||||
self.assertEqual(len(self.tracker.pci_devs), 2)
|
||||
self.assertEqual(self.destroy_called, 1)
|
||||
|
||||
def test_set_compute_node_id(self):
|
||||
self.tracker = manager.PciDevTracker()
|
||||
self.tracker = manager.PciDevTracker(self.fake_context)
|
||||
fake_pci_devs = [copy.deepcopy(fake_pci), copy.deepcopy(fake_pci_1),
|
||||
copy.deepcopy(fake_pci_2)]
|
||||
self.tracker.set_hvdevs(fake_pci_devs)
|
||||
|
@ -379,6 +386,10 @@ class PciDevTrackerTestCase(test.TestCase):
|
|||
|
||||
|
||||
class PciGetInstanceDevs(test.TestCase):
|
||||
def setUp(self):
|
||||
super(PciGetInstanceDevs, self).setUp()
|
||||
self.fake_context = context.get_admin_context()
|
||||
|
||||
def test_get_devs_object(self):
|
||||
def _fake_obj_load_attr(foo, attrname):
|
||||
if attrname == 'pci_devices':
|
||||
|
@ -386,12 +397,12 @@ class PciGetInstanceDevs(test.TestCase):
|
|||
foo.pci_devices = objects.PciDeviceList()
|
||||
|
||||
inst = fakes.stub_instance(id='1')
|
||||
ctxt = context.get_admin_context()
|
||||
self.mox.StubOutWithMock(db, 'instance_get')
|
||||
db.instance_get(ctxt, '1', columns_to_join=[]
|
||||
db.instance_get(self.fake_context, '1', columns_to_join=[]
|
||||
).AndReturn(inst)
|
||||
self.mox.ReplayAll()
|
||||
inst = objects.Instance.get_by_id(ctxt, '1', expected_attrs=[])
|
||||
inst = objects.Instance.get_by_id(self.fake_context, '1',
|
||||
expected_attrs=[])
|
||||
self.stubs.Set(objects.Instance, 'obj_load_attr', _fake_obj_load_attr)
|
||||
|
||||
self.load_attr_called = False
|
||||
|
|
Loading…
Reference in New Issue