diff --git a/ec2api/api/route_table.py b/ec2api/api/route_table.py index f8694e71..61071a66 100644 --- a/ec2api/api/route_table.py +++ b/ec2api/api/route_table.py @@ -77,8 +77,7 @@ def delete_route(context, route_table_id, destination_cidr_block): cleaner.addCleanup(db_api.update_item, context, rollback_route_table_state) - _update_routes_in_associated_subnets(context, route_table, cleaner, - rollback_route_table_state) + _update_routes_in_associated_subnets(context, route_table, cleaner) return True @@ -98,15 +97,11 @@ def associate_route_table(context, route_table_id, subnet_id): msg = msg % {'rtb_id': route_table_id} raise exception.ResourceAlreadyAssociated(msg) - vpc = db_api.get_item_by_id(context, subnet['vpc_id']) - main_route_table = db_api.get_item_by_id(context, vpc['route_table_id']) with common.OnCrashCleaner() as cleaner: _associate_subnet_item(context, subnet, route_table['id']) cleaner.addCleanup(_disassociate_subnet_item, context, subnet) - _update_subnet_host_routes( - context, subnet, route_table, - cleaner=cleaner, rollback_route_table_object=main_route_table) + _update_subnet_host_routes(context, subnet, route_table, cleaner) return {'associationId': ec2utils.change_ec2_id_kind(subnet['id'], 'rtbassoc')} @@ -119,27 +114,23 @@ def replace_route_table_association(context, association_id, route_table_id): vpc = db_api.get_item_by_id( context, ec2utils.change_ec2_id_kind(association_id, 'vpc')) if vpc is None: - raise exception.InvalidAssociationIDNotFound( - id=association_id) + raise exception.InvalidAssociationIDNotFound(id=association_id) - rollabck_route_table_object = db_api.get_item_by_id( - context, vpc['route_table_id']) + rollback_route_table_id = vpc['route_table_id'] with common.OnCrashCleaner() as cleaner: _associate_vpc_item(context, vpc, route_table['id']) cleaner.addCleanup(_associate_vpc_item, context, vpc, - rollabck_route_table_object['id']) + rollback_route_table_id) # NOTE(ft): this can cause unnecessary update of subnets, which are # associated with the route table _update_routes_in_associated_subnets( - context, route_table, cleaner, - rollabck_route_table_object, is_main=True) + context, route_table, cleaner, is_main=True) else: subnet = db_api.get_item_by_id( context, ec2utils.change_ec2_id_kind(association_id, 'subnet')) if subnet is None or 'route_table_id' not in subnet: - raise exception.InvalidAssociationIDNotFound( - id=association_id) + raise exception.InvalidAssociationIDNotFound(id=association_id) if subnet['vpc_id'] != route_table['vpc_id']: msg = _('Route table association %(rtbassoc_id)s and route table ' '%(rtb_id)s belong to different networks') @@ -147,16 +138,13 @@ def replace_route_table_association(context, association_id, route_table_id): 'rtb_id': route_table_id} raise exception.InvalidParameterValue(msg) - rollabck_route_table_object = db_api.get_item_by_id( - context, subnet['route_table_id']) + rollback_route_table_id = subnet['route_table_id'] with common.OnCrashCleaner() as cleaner: _associate_subnet_item(context, subnet, route_table['id']) cleaner.addCleanup(_associate_subnet_item, context, subnet, - rollabck_route_table_object['id']) + rollback_route_table_id) - _update_subnet_host_routes( - context, subnet, route_table, cleaner=cleaner, - rollback_route_table_object=rollabck_route_table_object) + _update_subnet_host_routes(context, subnet, route_table, cleaner) return {'newAssociationId': association_id} @@ -168,27 +156,22 @@ def disassociate_route_table(context, association_id): vpc = db_api.get_item_by_id( context, ec2utils.change_ec2_id_kind(association_id, 'vpc')) if vpc is None: - raise exception.InvalidAssociationIDNotFound( - id=association_id) + raise exception.InvalidAssociationIDNotFound(id=association_id) msg = _('Cannot disassociate the main route table association ' '%(rtbassoc_id)s') % {'rtbassoc_id': association_id} raise exception.InvalidParameterValue(msg) if 'route_table_id' not in subnet: - raise exception.InvalidAssociationIDNotFound( - id=association_id) + raise exception.InvalidAssociationIDNotFound(id=association_id) - rollback_route_table_object = db_api.get_item_by_id( - context, subnet['route_table_id']) + rollback_route_table_id = subnet['route_table_id'] vpc = db_api.get_item_by_id(context, subnet['vpc_id']) main_route_table = db_api.get_item_by_id(context, vpc['route_table_id']) with common.OnCrashCleaner() as cleaner: _disassociate_subnet_item(context, subnet) cleaner.addCleanup(_associate_subnet_item, context, subnet, - rollback_route_table_object['id']) + rollback_route_table_id) - _update_subnet_host_routes( - context, subnet, main_route_table, cleaner=cleaner, - rollback_route_table_object=rollback_route_table_object) + _update_subnet_host_routes(context, subnet, main_route_table, cleaner) return True @@ -377,8 +360,7 @@ def _set_route(context, route_table_id, destination_cidr_block, db_api.update_item(context, route_table) cleaner.addCleanup(db_api.update_item, context, rollabck_route_table_state) - _update_routes_in_associated_subnets(context, route_table, cleaner, - rollabck_route_table_state) + _update_routes_in_associated_subnets(context, route_table, cleaner) return True @@ -458,7 +440,6 @@ def _format_route_table(context, route_table, is_main=False, def _update_routes_in_associated_subnets(context, route_table, cleaner, - rollabck_route_table_object, is_main=None): if is_main is None: vpc = db_api.get_item_by_id(context, route_table['vpc_id']) @@ -473,13 +454,11 @@ def _update_routes_in_associated_subnets(context, route_table, cleaner, if (subnet['vpc_id'] == route_table['vpc_id'] and subnet.get('route_table_id') in appropriate_rtb_ids): _update_subnet_host_routes( - context, subnet, route_table, cleaner=cleaner, - rollback_route_table_object=rollabck_route_table_object, + context, subnet, route_table, cleaner, router_objects=router_objects, neutron=neutron) -def _update_subnet_host_routes(context, subnet, route_table, cleaner=None, - rollback_route_table_object=None, +def _update_subnet_host_routes(context, subnet, route_table, cleaner, router_objects=None, neutron=None): neutron = neutron or clients.neutron(context) os_subnet = neutron.show_subnet(subnet['os_id'])['subnet'] @@ -489,38 +468,35 @@ def _update_subnet_host_routes(context, subnet, route_table, cleaner=None, router_objects) neutron.update_subnet(subnet['os_id'], {'subnet': {'host_routes': host_routes}}) - if cleaner and rollback_route_table_object: - cleaner.addCleanup(_update_subnet_host_routes, context, subnet, - rollback_route_table_object) + cleaner.addCleanup( + neutron.update_subnet, subnet['os_id'], + {'subnet': {'host_routes': os_subnet['host_routes']}}) def _get_router_objects(context, route_table): - return dict((route['gateway_id'], - db_api.get_item_by_id(context, route['gateway_id'])) - if route.get('gateway_id') else - (route['network_interface_id'], - db_api.get_item_by_id(context, route['network_interface_id'])) - for route in route_table['routes'] - if route.get('gateway_id') or 'network_interface_id' in route) + object_ids = [route[id_key] + for route in route_table['routes'] + for id_key in ('gateway_id', 'network_interface_id') + if id_key in route and route[id_key]] + return dict((item['id'], item) + for item in db_api.get_items_by_ids(context, object_ids)) def _get_subnet_host_routes(context, route_table, gateway_ip, router_objects=None): + if router_objects is None: + router_objects = _get_router_objects(context, route_table) + def get_nexthop(route): if 'gateway_id' in route: gateway_id = route['gateway_id'] if gateway_id: - gateway = (router_objects[route['gateway_id']] - if router_objects else - db_api.get_item_by_id(context, gateway_id)) + gateway = router_objects.get(route['gateway_id']) if (not gateway or - gateway.get('vpc_id') != route_table['vpc_id']): + gateway['vpc_id'] != route_table['vpc_id']): return '127.0.0.1' return gateway_ip - network_interface = ( - router_objects[route['network_interface_id']] - if router_objects else - db_api.get_item_by_id(context, route['network_interface_id'])) + network_interface = router_objects.get(route['network_interface_id']) if not network_interface: return '127.0.0.1' return network_interface['private_ip_address'] diff --git a/ec2api/tests/unit/test_route_table.py b/ec2api/tests/unit/test_route_table.py index 3a69da62..89322b86 100644 --- a/ec2api/tests/unit/test_route_table.py +++ b/ec2api/tests/unit/test_route_table.py @@ -67,8 +67,7 @@ class RouteTableTestCase(base.ApiTestCase): self.db_api.update_item.assert_called_once_with( mock.ANY, route_table) routes_updater.assert_called_once_with( - mock.ANY, route_table, mock.ANY, - rollback_route_table_state) + mock.ANY, route_table, mock.ANY) self.db_api.update_item.reset_mock() routes_updater.reset_mock() @@ -251,8 +250,7 @@ class RouteTableTestCase(base.ApiTestCase): 'network_interface_id': fakes.ID_EC2_NETWORK_INTERFACE_1, 'destination_cidr_block': '0.0.0.0/0'}) self.db_api.update_item.assert_called_once_with(mock.ANY, route_table) - routes_updater.assert_called_once_with(mock.ANY, route_table, mock.ANY, - rollback_route_table_state) + routes_updater.assert_called_once_with(mock.ANY, route_table, mock.ANY) def test_replace_route_invalid_parameters(self): self.set_mock_db_items(fakes.DB_ROUTE_TABLE_1, @@ -278,7 +276,7 @@ class RouteTableTestCase(base.ApiTestCase): if r['destination_cidr_block'] != fakes.CIDR_EXTERNAL_NETWORK] self.db_api.update_item.assert_called_once_with(mock.ANY, route_table) routes_updater.assert_called_once_with( - mock.ANY, route_table, mock.ANY, fakes.DB_ROUTE_TABLE_2) + mock.ANY, route_table, mock.ANY) def test_delete_route_invalid_parameters(self): self.set_mock_db_items() @@ -327,8 +325,7 @@ class RouteTableTestCase(base.ApiTestCase): self.db_api.update_item.assert_called_once_with( mock.ANY, subnet) routes_updater.assert_called_once_with( - mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, cleaner=mock.ANY, - rollback_route_table_object=fakes.DB_ROUTE_TABLE_1) + mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, mock.ANY) def test_associate_route_table_invalid_parameters(self): def do_check(params, error_code): @@ -392,8 +389,7 @@ class RouteTableTestCase(base.ApiTestCase): self.db_api.update_item.assert_called_once_with( mock.ANY, subnet) routes_updater.assert_called_once_with( - mock.ANY, subnet, fakes.DB_ROUTE_TABLE_2, cleaner=mock.ANY, - rollback_route_table_object=fakes.DB_ROUTE_TABLE_3) + mock.ANY, subnet, fakes.DB_ROUTE_TABLE_2, mock.ANY) @mock.patch('ec2api.api.route_table._update_routes_in_associated_subnets') def test_replace_route_table_association_main(self, routes_updater): @@ -411,8 +407,7 @@ class RouteTableTestCase(base.ApiTestCase): self.db_api.update_item.assert_called_once_with( mock.ANY, vpc) routes_updater.assert_called_once_with( - mock.ANY, fakes.DB_ROUTE_TABLE_2, mock.ANY, - fakes.DB_ROUTE_TABLE_1, is_main=True) + mock.ANY, fakes.DB_ROUTE_TABLE_2, mock.ANY, is_main=True) def test_replace_route_table_association_invalid_parameters(self): def do_check(params, error_code): @@ -499,9 +494,7 @@ class RouteTableTestCase(base.ApiTestCase): self.db_api.update_item.assert_called_once_with( mock.ANY, subnet) routes_updater.assert_called_once_with( - mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, - cleaner=mock.ANY, - rollback_route_table_object=fakes.DB_ROUTE_TABLE_3) + mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, mock.ANY) def test_disassociate_route_table_invalid_parameter(self): def do_check(params, error_code): @@ -713,7 +706,8 @@ class RouteTableTestCase(base.ApiTestCase): route_table._update_subnet_host_routes( self._create_context(), fakes.DB_SUBNET_1, - fakes.DB_ROUTE_TABLE_1, router_objects={'fake': 'objects'}) + fakes.DB_ROUTE_TABLE_1, common.OnCrashCleaner(), + router_objects={'fake': 'objects'}) self.neutron.show_subnet.assert_called_once_with(fakes.ID_OS_SUBNET_1) routes_getter.assert_called_once_with( @@ -724,32 +718,21 @@ class RouteTableTestCase(base.ApiTestCase): {'subnet': {'host_routes': 'fake_routes'}}) self.neutron.reset_mock() - routes_getter.reset_mock() - - routes_getter.side_effect = ['fake_routes', 'fake_previous_routes'] try: with common.OnCrashCleaner() as cleaner: route_table._update_subnet_host_routes( self._create_context(), fakes.DB_SUBNET_1, fakes.DB_ROUTE_TABLE_1, cleaner, - fakes.DB_ROUTE_TABLE_2, router_objects={'fake': 'objects'}) raise Exception('fake_exception') except Exception as ex: if ex.message != 'fake_exception': raise - self.neutron.show_subnet.assert_any_call(fakes.ID_OS_SUBNET_1) - routes_getter.assert_any_call( - mock.ANY, fakes.DB_ROUTE_TABLE_1, fakes.IP_GATEWAY_SUBNET_1, - {'fake': 'objects'}) - routes_getter.assert_any_call( - mock.ANY, fakes.DB_ROUTE_TABLE_2, fakes.IP_GATEWAY_SUBNET_1, - None) self.neutron.update_subnet.assert_any_call( fakes.ID_OS_SUBNET_1, - {'subnet': {'host_routes': 'fake_previous_routes'}}) + {'subnet': {'host_routes': fakes.OS_SUBNET_1['host_routes']}}) @mock.patch('ec2api.api.route_table._get_router_objects') @mock.patch('ec2api.api.route_table._update_subnet_host_routes') @@ -768,15 +751,12 @@ class RouteTableTestCase(base.ApiTestCase): get_router_objects.return_value = {'fake': 'objects'} route_table._update_routes_in_associated_subnets( - mock.MagicMock(), fakes.DB_ROUTE_TABLE_2, 'fake_cleaner', - {'fake': 'table'}) + mock.MagicMock(), fakes.DB_ROUTE_TABLE_2, 'fake_cleaner') self.db_api.get_item_by_id.assert_called_once_with( mock.ANY, fakes.ID_EC2_VPC_1) routes_updater.assert_called_once_with( - mock.ANY, subnet_rtb_2, fakes.DB_ROUTE_TABLE_2, - cleaner='fake_cleaner', - rollback_route_table_object={'fake': 'table'}, + mock.ANY, subnet_rtb_2, fakes.DB_ROUTE_TABLE_2, 'fake_cleaner', router_objects={'fake': 'objects'}, neutron=mock.ANY) get_router_objects.assert_called_once_with(mock.ANY, fakes.DB_ROUTE_TABLE_2) @@ -787,14 +767,13 @@ class RouteTableTestCase(base.ApiTestCase): route_table._update_routes_in_associated_subnets( mock.MagicMock(), fakes.DB_ROUTE_TABLE_1, 'fake_cleaner', - {'fake': 'table'}, is_main=True) + is_main=True) self.assertEqual(0, self.db_api.get_item_by_id.call_count) routes_updater.assert_called_once_with( mock.ANY, subnet_default_rtb, fakes.DB_ROUTE_TABLE_1, - cleaner='fake_cleaner', - rollback_route_table_object={'fake': 'table'}, - router_objects={'fake': 'objects'}, neutron=mock.ANY) + 'fake_cleaner', router_objects={'fake': 'objects'}, + neutron=mock.ANY) get_router_objects.assert_called_once_with(mock.ANY, fakes.DB_ROUTE_TABLE_1)