Pass correct context to get_by_compute_node()

A recent change to the PciDevTracker class allowed for the passing of
the compute node ID to the __init__() method, where
PciDevsList.get_by_compute_node() was called with a 'context' parameter,
which wasn't defined, resulting in the imported context module being
passed instead.

This change requires that the context be passed in to the __init__() for
the PciDevTracker class, and that that be used to create the
PciDevsList. The existing import of the context module is no longer
needed in the pci/manager.py file, so the conflict is no longer a
problem. The only place in the code that currently instantiates a
PciDevTracker object is in the ResourceTracker, so that has been updated
to pass in the context. A unit test to check for context has also been
added.

Change-Id: Id136eabacb00e4381c03f12d8484fc90a5eb48b1
Closes-Bug: #1408480
This commit is contained in:
EdLeafe 2015-01-12 15:56:21 +00:00
parent 5b3bfdbc14
commit 50ee9dd76e
5 changed files with 52 additions and 24 deletions

View File

@ -345,9 +345,6 @@ class ResourceTracker(object):
@utils.synchronized(COMPUTE_RESOURCE_SEMAPHORE)
def _update_available_resource(self, context, resources):
if 'pci_passthrough_devices' in resources:
if not self.pci_tracker:
self.pci_tracker = pci_manager.PciDevTracker()
devs = []
for dev in jsonutils.loads(resources.pop(
'pci_passthrough_devices')):
@ -357,6 +354,10 @@ class ResourceTracker(object):
if self.pci_filter.device_assignable(dev):
devs.append(dev)
if not self.pci_tracker:
n_id = self.compute_node['id'] if self.compute_node else None
self.pci_tracker = pci_manager.PciDevTracker(context,
node_id=n_id)
self.pci_tracker.set_hvdevs(devs)
# Grab all instances assigned to this node:

View File

@ -20,7 +20,6 @@ from oslo_log import log as logging
from nova.compute import task_states
from nova.compute import vm_states
from nova import context
from nova import exception
from nova.i18n import _LW
from nova import objects
@ -43,7 +42,7 @@ class PciDevTracker(object):
information is updated to DB when devices information is changed.
"""
def __init__(self, node_id=None):
def __init__(self, context, node_id=None):
"""Create a pci device tracker.
If a node_id is passed in, it will fetch pci devices information

View File

@ -21,6 +21,7 @@ import mock
from oslo_serialization import jsonutils
from nova.compute import claims
from nova import context
from nova import db
from nova import exception
from nova import objects
@ -43,9 +44,11 @@ class FakeResourceHandler(object):
class DummyTracker(object):
icalled = False
rcalled = False
pci_tracker = pci_manager.PciDevTracker()
ext_resources_handler = FakeResourceHandler()
def __init__(self):
self.new_pci_tracker()
def abort_instance_claim(self, *args, **kwargs):
self.icalled = True
@ -53,7 +56,8 @@ class DummyTracker(object):
self.rcalled = True
def new_pci_tracker(self):
self.pci_tracker = pci_manager.PciDevTracker()
ctxt = context.RequestContext('testuser', 'testproject')
self.pci_tracker = pci_manager.PciDevTracker(ctxt)
@mock.patch('nova.objects.InstancePCIRequests.get_by_instance_uuid',

View File

@ -233,6 +233,10 @@ class BaseTestCase(test.TestCase):
'flavor_get', self._fake_flavor_get)
self.host = 'fakehost'
self.compute = self._create_compute_node()
self.updated = False
self.deleted = False
self.update_call_count = 0
def _create_compute_node(self, values=None):
compute = {
@ -427,6 +431,13 @@ class BaseTestCase(test.TestCase):
# only used in the subsequent notification:
return (instance, instance)
def _fake_compute_node_update(self, ctx, compute_node_id, values,
prune_stats=False):
self.update_call_count += 1
self.updated = True
self.compute.update(values)
return self.compute
def _driver(self):
return FakeVirtDriver()
@ -440,6 +451,7 @@ class BaseTestCase(test.TestCase):
driver = self._driver()
tracker = resource_tracker.ResourceTracker(host, driver, node)
tracker.compute_node = self._create_compute_node()
tracker.ext_resources_handler = \
resources.ResourceHandler(RESOURCE_NAMES, True)
return tracker
@ -512,6 +524,8 @@ class MissingServiceTestCase(BaseTestCase):
self.tracker = self._tracker()
def test_missing_service(self):
self.tracker.compute_node = None
self.tracker._get_service = mock.Mock(return_value=None)
self.tracker.update_available_resource(self.context)
self.assertTrue(self.tracker.disabled)
@ -543,6 +557,7 @@ class MissingComputeNodeTestCase(BaseTestCase):
raise exception.ComputeHostNotFound(host=host)
def test_create_compute_node(self):
self.tracker.compute_node = None
self.tracker.update_available_resource(self.context)
self.assertTrue(self.created)
@ -558,10 +573,6 @@ class BaseTrackerTestCase(BaseTestCase):
# database models and a compatible compute driver:
super(BaseTrackerTestCase, self).setUp()
self.updated = False
self.deleted = False
self.update_call_count = 0
self.tracker = self._tracker()
self._migrations = {}
@ -582,11 +593,12 @@ class BaseTrackerTestCase(BaseTestCase):
patcher = pci_fakes.fake_pci_whitelist()
self.addCleanup(patcher.stop)
self.stubs.Set(self.tracker.scheduler_client, 'update_resource_stats',
self._fake_compute_node_update)
self._init_tracker()
self.limits = self._limits()
def _fake_service_get_by_compute_host(self, ctx, host):
self.compute = self._create_compute_node()
self.service = self._create_service(host, compute=self.compute)
return self.service
@ -721,7 +733,8 @@ class SchedulerClientTrackerTestCase(BaseTrackerTestCase):
def setUp(self):
super(SchedulerClientTrackerTestCase, self).setUp()
self.tracker.scheduler_client.update_resource_stats = mock.Mock()
self.tracker.scheduler_client.update_resource_stats = mock.Mock(
side_effect=self._fake_compute_node_update)
def test_create_resource(self):
self.tracker._write_ext_resources = mock.Mock()

View File

@ -17,6 +17,7 @@ import copy
import mock
import nova
from nova.compute import task_states
from nova.compute import vm_states
from nova import context
@ -109,6 +110,7 @@ class PciDevTrackerTestCase(test.TestCase):
def setUp(self):
super(PciDevTrackerTestCase, self).setUp()
self.fake_context = context.get_admin_context()
self.stubs.Set(db, 'pci_device_get_all_by_node',
self._fake_get_pci_devices)
# The fake_pci_whitelist must be called before creating the fake
@ -116,7 +118,7 @@ class PciDevTrackerTestCase(test.TestCase):
patcher = pci_fakes.fake_pci_whitelist()
self.addCleanup(patcher.stop)
self._create_fake_instance()
self.tracker = manager.PciDevTracker(1)
self.tracker = manager.PciDevTracker(self.fake_context, 1)
def test_pcidev_tracker_create(self):
self.assertEqual(len(self.tracker.pci_devs), 3)
@ -126,9 +128,16 @@ class PciDevTrackerTestCase(test.TestCase):
self.assertEqual(len(self.tracker.stats.pools), 3)
self.assertEqual(self.tracker.node_id, 1)
def test_pcidev_tracker_create_no_nodeid(self):
self.tracker = manager.PciDevTracker()
@mock.patch.object(nova.objects.PciDeviceList, 'get_by_compute_node')
def test_pcidev_tracker_create_no_nodeid(self, mock_get_cn):
self.tracker = manager.PciDevTracker(self.fake_context)
self.assertEqual(len(self.tracker.pci_devs), 0)
self.assertFalse(mock_get_cn.called)
@mock.patch.object(nova.objects.PciDeviceList, 'get_by_compute_node')
def test_pcidev_tracker_create_with_nodeid(self, mock_get_cn):
self.tracker = manager.PciDevTracker(self.fake_context, node_id=1)
mock_get_cn.assert_called_once_with(self.fake_context, 1)
def test_set_hvdev_new_dev(self):
fake_pci_3 = dict(fake_pci, address='0000:00:00.4', vendor_id='v2')
@ -278,30 +287,28 @@ class PciDevTrackerTestCase(test.TestCase):
def test_save(self):
self.stubs.Set(db, "pci_device_update", self._fake_pci_device_update)
ctxt = context.get_admin_context()
fake_pci_v3 = dict(fake_pci, address='0000:00:00.2', vendor_id='v3')
fake_pci_devs = [copy.deepcopy(fake_pci), copy.deepcopy(fake_pci_2),
copy.deepcopy(fake_pci_v3)]
self.tracker.set_hvdevs(fake_pci_devs)
self.update_called = 0
self.tracker.save(ctxt)
self.tracker.save(self.fake_context)
self.assertEqual(self.update_called, 3)
def test_save_removed(self):
self.stubs.Set(db, "pci_device_update", self._fake_pci_device_update)
self.stubs.Set(db, "pci_device_destroy", self._fake_pci_device_destroy)
self.destroy_called = 0
ctxt = context.get_admin_context()
self.assertEqual(len(self.tracker.pci_devs), 3)
dev = self.tracker.pci_devs[0]
self.update_called = 0
device.remove(dev)
self.tracker.save(ctxt)
self.tracker.save(self.fake_context)
self.assertEqual(len(self.tracker.pci_devs), 2)
self.assertEqual(self.destroy_called, 1)
def test_set_compute_node_id(self):
self.tracker = manager.PciDevTracker()
self.tracker = manager.PciDevTracker(self.fake_context)
fake_pci_devs = [copy.deepcopy(fake_pci), copy.deepcopy(fake_pci_1),
copy.deepcopy(fake_pci_2)]
self.tracker.set_hvdevs(fake_pci_devs)
@ -379,6 +386,10 @@ class PciDevTrackerTestCase(test.TestCase):
class PciGetInstanceDevs(test.TestCase):
def setUp(self):
super(PciGetInstanceDevs, self).setUp()
self.fake_context = context.get_admin_context()
def test_get_devs_object(self):
def _fake_obj_load_attr(foo, attrname):
if attrname == 'pci_devices':
@ -386,12 +397,12 @@ class PciGetInstanceDevs(test.TestCase):
foo.pci_devices = objects.PciDeviceList()
inst = fakes.stub_instance(id='1')
ctxt = context.get_admin_context()
self.mox.StubOutWithMock(db, 'instance_get')
db.instance_get(ctxt, '1', columns_to_join=[]
db.instance_get(self.fake_context, '1', columns_to_join=[]
).AndReturn(inst)
self.mox.ReplayAll()
inst = objects.Instance.get_by_id(ctxt, '1', expected_attrs=[])
inst = objects.Instance.get_by_id(self.fake_context, '1',
expected_attrs=[])
self.stubs.Set(objects.Instance, 'obj_load_attr', _fake_obj_load_attr)
self.load_attr_called = False