diff --git a/.pylintrc b/.pylintrc index 1c942d426..1f935c524 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,26 +15,22 @@ ignore=.git,tests,openstack disable= # "F" Fatal errors that prevent further processing import-error, + import-untyped, # "I" Informational noise locally-disabled, # "E" Error for important programming issues (likely bugs) access-member-before-definition, - bad-super-call, maybe-no-member, - no-member, - no-method-argument, - no-self-argument, + no-name-in-module, not-callable, - no-value-for-parameter, - super-on-old-class, too-few-format-args, + unsubscriptable-object, # "W" Warnings for stylistic problems or minor programming issues - abstract-method, anomalous-backslash-in-string, anomalous-unicode-escape-in-string, arguments-differ, attribute-defined-outside-init, - bad-builtin, + # bad-builtin, bad-indentation, broad-except, dangerous-default-value, @@ -57,7 +53,6 @@ disable= unnecessary-lambda, unnecessary-pass, unpacking-non-sequence, - unreachable, unused-argument, unused-import, unused-variable, @@ -67,17 +62,13 @@ disable= bad-continuation, invalid-name, missing-docstring, - old-style-class, superfluous-parens, # "R" Refactor recommendations - abstract-class-little-used, - abstract-class-not-used, - consider-using-set-comprehension, + cyclic-import, duplicate-code, inconsistent-return-statements, interface-not-implemented, no-else-raise, - no-else-return, no-self-use, too-few-public-methods, too-many-ancestors, @@ -89,7 +80,6 @@ disable= too-many-public-methods, too-many-return-statements, too-many-statements, - useless-object-inheritance [BASIC] # Variable names can be 1 to 31 characters long, with lowercase and underscores diff --git a/neutron_vpnaas/agent/ovn/vpn/agent.py b/neutron_vpnaas/agent/ovn/vpn/agent.py index 7374fb220..5d6466793 100644 --- a/neutron_vpnaas/agent/ovn/vpn/agent.py +++ b/neutron_vpnaas/agent/ovn/vpn/agent.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing as ty import uuid from neutron.agent.linux import external_process from neutron.common.ovn import utils as ovn_utils from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config +from oslo_config import cfg from oslo_log import log as logging from oslo_service import service from ovsdbapp.backend.ovs_idl import event as row_event @@ -32,44 +34,6 @@ LOG = logging.getLogger(__name__) OVN_VPNAGENT_UUID_NAMESPACE = uuid.UUID('e1ce3b12-b1e0-4c81-ba27-07c0fec9c12b') -class ChassisCreateEventBase(row_event.RowEvent): - """Row create event - Chassis name == our_chassis. - - On connection, we get a dump of all chassis so if we catch a creation - of our own chassis it has to be a reconnection. In this case, we need - to do a full sync to make sure that we capture all changes while the - connection to OVSDB was down. - """ - table = None - - def __init__(self, vpn_agent): - self.agent = vpn_agent - self.first_time = True - events = (self.ROW_CREATE,) - super().__init__( - events, self.table, (('name', '=', self.agent.chassis),)) - self.event_name = self.__class__.__name__ - - def run(self, event, row, old): - if self.first_time: - self.first_time = False - else: - # NOTE(lucasagomes): Re-register the ovn vpn agent - # with the local chassis in case its entry was re-created - # (happens when restarting the ovn-controller) - self.agent.register_vpn_agent() - LOG.info("Connection to OVSDB established, doing a full sync") - self.agent.sync() - - -class ChassisCreateEvent(ChassisCreateEventBase): - table = 'Chassis' - - -class ChassisPrivateCreateEvent(ChassisCreateEventBase): - table = 'Chassis_Private' - - class SbGlobalUpdateEvent(row_event.RowEvent): """Row update event on SB_Global table.""" @@ -90,7 +54,7 @@ class SbGlobalUpdateEvent(row_event.RowEvent): class OvnVpnAgent(service.Service): - def __init__(self, conf): + def __init__(self, conf: cfg.ConfigOpts): super().__init__() self.conf = conf vlog.use_python_logger(max_level=config.get_ovn_ovsdb_log_level()) @@ -102,13 +66,13 @@ class OvnVpnAgent(service.Service): self.device_drivers = self.service.load_device_drivers(self.conf.host) def _load_config(self): - self.chassis = self._get_own_chassis_name() + self.chassis: ty.Optional[str] = self._get_own_chassis_name() try: self.chassis_id = uuid.UUID(self.chassis) except ValueError: # OVS system-id could be a non UUID formatted string. self.chassis_id = uuid.uuid5(OVN_VPNAGENT_UUID_NAMESPACE, - self.chassis) + self.chassis if self.chassis else '') LOG.debug("Loaded chassis name %s (UUID: %s).", self.chassis, self.chassis_id) @@ -156,12 +120,51 @@ class OvnVpnAgent(service.Service): self.sb_idl.db_add(table, self.chassis, 'external_ids', ext_ids).execute(check_error=True) - def _get_own_chassis_name(self): + def _get_own_chassis_name(self) -> ty.Optional[str]: """Return the external_ids:system-id value of the Open_vSwitch table. As long as ovn-controller is running on this node, the key is guaranteed to exist and will include the chassis name. """ - ext_ids = self.ovs_idl.db_get( + ext_ids: ty.Optional[ty.Dict[str, str]] = self.ovs_idl.db_get( 'Open_vSwitch', '.', 'external_ids').execute() - return ext_ids['system-id'] + return ext_ids['system-id'] if ext_ids else None + + +class ChassisCreateEventBase(row_event.RowEvent): + """Row create event - Chassis name == our_chassis. + + On connection, we get a dump of all chassis so if we catch a creation + of our own chassis it has to be a reconnection. In this case, we need + to do a full sync to make sure that we capture all changes while the + connection to OVSDB was down. + """ + + table: ty.Optional[str] = None + + def __init__(self, vpn_agent: OvnVpnAgent): + self.agent = vpn_agent + self.first_time: bool = True + events: ty.Tuple[str] = (self.ROW_CREATE,) + super().__init__(events, self.table, + (('name', '=', self.agent.chassis),)) + self.event_name = self.__class__.__name__ + + def run(self, event, row, old): + if self.first_time: + self.first_time = False + else: + # NOTE(lucasagomes): Re-register the ovn vpn agent + # with the local chassis in case its entry was re-created + # (happens when restarting the ovn-controller) + self.agent.register_vpn_agent() + LOG.info("Connection to OVSDB established, doing a full sync") + self.agent.sync() + + +class ChassisCreateEvent(ChassisCreateEventBase): + table = "Chassis" + + +class ChassisPrivateCreateEvent(ChassisCreateEventBase): + table = "Chassis_Private" diff --git a/neutron_vpnaas/agent/ovn/vpn/ovsdb.py b/neutron_vpnaas/agent/ovn/vpn/ovsdb.py index 765865769..2c2da1d8e 100644 --- a/neutron_vpnaas/agent/ovn/vpn/ovsdb.py +++ b/neutron_vpnaas/agent/ovn/vpn/ovsdb.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing as ty from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import impl_idl_ovn @@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__) class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): - SCHEMA = 'OVN_Southbound' + SCHEMA: str = 'OVN_Southbound' def __init__(self, chassis=None, events=None, tables=None): connection_string = config.get_ovn_sb_connection() @@ -39,8 +40,8 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): for table in tables: helper.register_table(table) try: - super().__init__( - None, connection_string, helper, leader_only=False) + super().__init__(None, connection_string, + helper, leader_only=False) except TypeError: # TODO(bpetermann) We can remove this when we require ovs>=2.12.0 super().__init__(None, connection_string, helper) @@ -54,7 +55,7 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): @tenacity.retry( wait=tenacity.wait_exponential(max=180), reraise=True) - def _get_ovsdb_helper(self, connection_string): + def _get_ovsdb_helper(self, connection_string: str): return idlutils.get_schema_helper(connection_string, self.SCHEMA) def start(self): @@ -62,14 +63,18 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): self, timeout=config.get_ovn_ovsdb_timeout()) return impl_idl_ovn.OvsdbSbOvnIdl(conn) + def post_connect(self): + pass -class VPNAgentOvsIdl(object): + +class VPNAgentOvsIdl: def start(self): - connection_string = config.cfg.CONF.ovs.ovsdb_connection + connection_string: str = config.cfg.CONF.ovs.ovsdb_connection helper = idlutils.get_schema_helper(connection_string, 'Open_vSwitch') - tables = ('Open_vSwitch', 'Bridge', 'Port', 'Interface') + tables: ty.Tuple[str, str, str, str] = ('Open_vSwitch', 'Bridge', + 'Port', 'Interface') for table in tables: helper.register_table(table) ovs_idl = idl.Idl( diff --git a/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py b/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py index 3e6fefd47..ac55bb794 100644 --- a/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py +++ b/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py @@ -12,7 +12,10 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty + from neutron.api.rpc.agentnotifiers import utils as ag_utils +from neutron_lib import context from neutron_lib import rpc as n_rpc import oslo_messaging @@ -23,25 +26,31 @@ from neutron_vpnaas.services.vpn.common import topics AGENT_NOTIFY_MAX_ATTEMPTS = 2 -class VPNAgentNotifyAPI(object): +class VPNAgentNotifyAPI: """API for plugin to notify VPN agent.""" def __init__(self, topic=topics.IPSEC_AGENT_TOPIC): target = oslo_messaging.Target(topic=topic, version='1.0') self.client = n_rpc.get_client(target) - def agent_updated(self, context, admin_state_up, host): + def agent_updated( + self, context: context.ContextBase, + admin_state_up: bool, host: str): cctxt = self.client.prepare(server=host) cctxt.cast(context, 'agent_updated', payload={'admin_state_up': admin_state_up}) - def vpnservice_removed_from_agent(self, context, router_id, host): + def vpnservice_removed_from_agent( + self, context: context.ContextBase, + router_id: str, host: str): """Notify agent about removed VPN service(s) of a router.""" cctxt = self.client.prepare(server=host) cctxt.cast(context, 'vpnservice_removed_from_agent', router_id=router_id) - def vpnservice_added_to_agent(self, context, router_ids, host): + def vpnservice_added_to_agent( + self, context: context.ContextBase, + router_ids: ty.List[str], host: str): """Notify agent about added VPN service(s) of router(s).""" # need to use call here as we want to be sure agent received # notification and router will not be "lost". However using call() diff --git a/neutron_vpnaas/db/migration/alembic_migrations/env.py b/neutron_vpnaas/db/migration/alembic_migrations/env.py index eb2d7e0be..2c641cc03 100644 --- a/neutron_vpnaas/db/migration/alembic_migrations/env.py +++ b/neutron_vpnaas/db/migration/alembic_migrations/env.py @@ -27,8 +27,8 @@ from neutron_vpnaas.db.migration import alembic_migrations MYSQL_ENGINE = None config = context.config -neutron_config = config.neutron_config -logging_config.fileConfig(config.config_file_name) +neutron_config = config.neutron_config # type: ignore +logging_config.fileConfig(config.config_file_name) # type: ignore target_metadata = model_base.BASEV2.metadata @@ -46,7 +46,7 @@ def set_mysql_engine(): def run_migrations_offline(): set_mysql_engine() - kwargs = dict() + kwargs = {} if neutron_config.database.connection: kwargs['url'] = neutron_config.database.connection else: diff --git a/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py b/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py index a0dca64cb..7e96cf1ba 100644 --- a/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py +++ b/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py @@ -13,9 +13,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty import random +from neutron.extensions import agent as nagent +from neutron.extensions import l3 from neutron.extensions import router_availability_zone as router_az from neutron import worker as neutron_worker from neutron_lib import context as ncontext @@ -31,8 +34,10 @@ import sqlalchemy as sa from sqlalchemy import func from neutron_vpnaas._i18n import _ +from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.extensions import vpn_agentschedulers +from neutron_vpnaas.scheduler.vpn_agent_scheduler import VPNScheduler from neutron_vpnaas.services.vpn.common.constants import AGENT_TYPE_VPN @@ -71,15 +76,15 @@ class VPNAgentSchedulerDbMixin( using the VPN agent. """ - vpn_scheduler = None - agent_notifiers = {} + vpn_scheduler: ty.Optional[VPNScheduler] = None + agent_notifiers: ty.Dict[str, nfy_api.VPNAgentNotifyAPI] = {} @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(plugin_const.L3) @property - def core_plugin(self): + def core_plugin(self) -> nagent.AgentPluginBase: return directory.get_plugin() def add_periodic_vpn_agent_status_check(self): @@ -96,7 +101,7 @@ class VPNAgentSchedulerDbMixin( check_worker = neutron_worker.PeriodicWorker( self.reschedule_vpnservices_from_down_agents, interval, initial_delay) - self.add_worker(check_worker) + self.add_worker(check_worker) # type: ignore def reschedule_vpnservices_from_down_agents(self): """Reschedule VPN services from down VPN agents. @@ -111,9 +116,10 @@ class VPNAgentSchedulerDbMixin( for binding in down_bindings: if binding.vpn_agent_id in agents_back_online: continue - agent = self.core_plugin.get_agent(context, + agent: ty.Optional[nagent.Agent] = self.core_plugin.get_agent( + context, binding.vpn_agent_id) - if agent['alive']: + if agent and agent['alive']: agents_back_online.add(binding.vpn_agent_id) continue @@ -137,7 +143,8 @@ class VPNAgentSchedulerDbMixin( "rescheduling.") @db_api.CONTEXT_READER - def get_down_router_bindings(self, context): + def get_down_router_bindings(self, + context: ncontext.Context) -> ty.List[RouterVPNAgentBinding]: vpn_agents = self.get_vpn_agents(context, active=False) if not vpn_agents: return [] @@ -148,7 +155,9 @@ class VPNAgentSchedulerDbMixin( RouterVPNAgentBinding.vpn_agent_id.in_(vpn_agent_ids)) return query.all() - def validate_agent_router_combination(self, context, agent, router): + def validate_agent_router_combination(self, context: ncontext.ContextBase, + agent: nagent.Agent, + router: ty.Dict[str, ty.Any]): """Validate if the router can be correctly assigned to the agent. :raises: InvalidVPNAgent if attempting to assign router to an @@ -158,7 +167,9 @@ class VPNAgentSchedulerDbMixin( raise vpn_agentschedulers.InvalidVPNAgent(id=agent['id']) @db_api.CONTEXT_READER - def check_agent_router_scheduling_needed(self, context, agent, router): + def check_agent_router_scheduling_needed(self, context: ncontext.Context, + agent: ty.Dict[str, ty.Any], + router: ty.Dict[str, ty.Any]): """Check if the scheduling of router's VPN services is needed. :raises: RouterHostedByVPNAgent if router is already assigned @@ -180,7 +191,8 @@ class VPNAgentSchedulerDbMixin( router_id=router_id, agent_id=bindings[0].vpn_agent_id) - def create_router_to_agent_binding(self, context, router_id, agent_id): + def create_router_to_agent_binding(self, context: ncontext.Context, + router_id: str, agent_id: str): """Create router to VPN agent binding.""" try: with db_api.CONTEXT_WRITER.using(context): @@ -203,11 +215,12 @@ class VPNAgentSchedulerDbMixin( {'router_id': router_id, 'agent_id': agent_id}) return True - def add_router_to_vpn_agent(self, context, agent_id, router_id): + def add_router_to_vpn_agent(self, context: ncontext.Context, + agent_id: str, router_id: str): """Add a VPN agent to host VPN services of a router.""" with db_api.CONTEXT_WRITER.using(context): router = self.l3_plugin.get_router(context, router_id) - agent = self.core_plugin.get_agent(context, agent_id) + agent: nagent.Agent = self.core_plugin.get_agent(context, agent_id) self.validate_agent_router_combination(context, agent, router) if not self.check_agent_router_scheduling_needed( context, agent, router): @@ -232,7 +245,8 @@ class VPNAgentSchedulerDbMixin( self.vpn_router_agent_binding_changed( context, router_id, agent['host']) - def remove_router_from_vpn_agent(self, context, agent_id, router_id): + def remove_router_from_vpn_agent(self, context: ncontext.Context, + agent_id: str, router_id: str): """Remove the router from VPN agent. After removal, the VPN service(s) of the router will be non-hosted @@ -248,7 +262,8 @@ class VPNAgentSchedulerDbMixin( vpn_notifier.vpnservice_removed_from_agent( context, router_id, agent['host']) - def _unbind_router(self, context, router_id, agent_id): + def _unbind_router(self, context: ncontext.Context, + router_id: str, agent_id: str): with db_api.CONTEXT_WRITER.using(context): query = context.session.query(RouterVPNAgentBinding) query = query.filter( @@ -256,7 +271,8 @@ class VPNAgentSchedulerDbMixin( RouterVPNAgentBinding.vpn_agent_id == agent_id) return query.delete() - def reschedule_router(self, context, router_id, cur_agent): + def reschedule_router(self, context: ncontext.Context, router_id: str, + cur_agent: nagent.Agent): """Reschedule router to a new VPN agent Remove the router from the agent currently hosting it and @@ -282,8 +298,10 @@ class VPNAgentSchedulerDbMixin( self.vpn_router_agent_binding_changed( context, router_id, new_agent['host']) - def _notify_agents_router_rescheduled(self, context, router_id, - old_agent, new_agent): + def _notify_agents_router_rescheduled(self, context: ncontext.Context, + router_id: str, + old_agent: ty.Dict[str, ty.Any], + new_agent: ty.Dict[str, ty.Any]): vpn_notifier = self.agent_notifiers.get(AGENT_TYPE_VPN) if not vpn_notifier: return @@ -303,7 +321,8 @@ class VPNAgentSchedulerDbMixin( router_id=router_id) @db_api.CONTEXT_READER - def list_routers_on_vpn_agent(self, context, agent_id): + def list_routers_on_vpn_agent(self, context: ncontext.Context, + agent_id: str): query = context.session.query(RouterVPNAgentBinding.router_id) query = query.filter(RouterVPNAgentBinding.vpn_agent_id == agent_id) @@ -312,13 +331,15 @@ class VPNAgentSchedulerDbMixin( return {'routers': self.l3_plugin.get_routers(context, filters={'id': router_ids})} - else: - # Exception will be thrown if the requested agent does not exist. - self.core_plugin.get_agent(context, agent_id) - return {'routers': []} + + # Exception will be thrown if the requested agent does not exist. + self.core_plugin.get_agent(context, agent_id) + return {'routers': []} @db_api.CONTEXT_READER - def get_vpn_agents_hosting_routers(self, context, router_ids, active=None): + def get_vpn_agents_hosting_routers(self, context: ncontext.Context, + router_ids: ty.Optional[ty.List[str]], + active: ty.Optional[bool] = None): if not router_ids: return [] query = context.session.query(RouterVPNAgentBinding) @@ -332,29 +353,39 @@ class VPNAgentSchedulerDbMixin( if agent['alive'] == active] return vpn_agents - def list_vpn_agents_hosting_router(self, context, router_id): + def list_vpn_agents_hosting_router(self, context: ncontext.Context, + router_id: str): vpn_agents = self.get_vpn_agents_hosting_routers(context, [router_id]) return {'agents': vpn_agents} - def get_vpn_agents(self, context, active=None, host=None): + def get_vpn_agents(self, context: ncontext.Context, + active: ty.Optional[bool] = None, + host: ty.Optional[str] = None) -> \ + ty.Optional[ty.List[nagent.Agent]]: filters = {'agent_type': [AGENT_TYPE_VPN]} if host is not None: filters['host'] = [host] - vpn_agents = self.core_plugin.get_agents(context, filters=filters) + vpn_agents: ty.Optional[ty.List[nagent.Agent]] = \ + self.core_plugin.get_agents(context, filters=filters) if active is None: return vpn_agents - else: - return [vpn_agent - for vpn_agent in vpn_agents - if vpn_agent['alive'] == active] - def get_vpn_agent_on_host(self, context, host, active=None): + if not vpn_agents: + return None + + return [vpn_agent + for vpn_agent in vpn_agents + if vpn_agent['alive'] == active] + + def get_vpn_agent_on_host(self, context: ncontext.Context, + host: str, active: ty.Optional[bool] = None): agents = self.get_vpn_agents(context, active=active, host=host) if agents: return agents[0] @db_api.CONTEXT_READER - def get_unscheduled_vpn_routers(self, context, router_ids=None): + def get_unscheduled_vpn_routers(self, context: ncontext.Context, + router_ids: ty.Optional[ty.List[str]] = None): """Get IDs of routers which have unscheduled VPN services.""" query = context.session.query(vpn_models.VPNService.router_id) query = query.outerjoin( @@ -366,12 +397,14 @@ class VPNAgentSchedulerDbMixin( vpn_models.VPNService.router_id.in_(router_ids)) return [router_id for router_id, in query.all()] - def auto_schedule_routers(self, context, vpn_agent): + def auto_schedule_routers(self, context: ncontext.Context, vpn_agent): if self.vpn_scheduler: return self.vpn_scheduler.auto_schedule_routers( self, context, vpn_agent) - def schedule_router(self, context, router, candidates=None): + def schedule_router(self, context: ncontext.Context, + router, candidates: ty.Optional[ty.List] = None) -> \ + ty.Optional[nagent.Agent]: """Schedule VPN services of a router to a VPN agent. Returns the chosen agent; None if another server scheduled the @@ -381,9 +414,11 @@ class VPNAgentSchedulerDbMixin( if self.vpn_scheduler: return self.vpn_scheduler.schedule( self, context, router, candidates=candidates) + return None @db_api.CONTEXT_READER - def get_vpn_agent_with_min_routers(self, context, agent_ids): + def get_vpn_agent_with_min_routers(self, context: ncontext.Context, + agent_ids: ty.Optional[ty.List[str]]): """Return VPN agent with the least number of routers.""" if not agent_ids: return None @@ -397,15 +432,18 @@ class VPNAgentSchedulerDbMixin( unused_agent_ids = set(agent_ids) - set(used_agent_ids) if unused_agent_ids: return unused_agent_ids.pop() - else: - return used_agent_ids[0] + return used_agent_ids[0] - def get_hosts_to_notify(self, context, router_id): + def get_hosts_to_notify(self, context: ncontext.Context, router_id): """Returns all hosts to send notification about router update""" agents = self.get_vpn_agents_hosting_routers(context, [router_id], active=True) return [a['host'] for a in agents] + def vpn_router_agent_binding_changed(self, context: ncontext.Context, + router_id: str, host: str): + pass + class AZVPNAgentSchedulerDbMixin(VPNAgentSchedulerDbMixin, router_az.RouterAvailabilityZonePluginBase): diff --git a/neutron_vpnaas/db/vpn/vpn_db.py b/neutron_vpnaas/db/vpn/vpn_db.py index f7f3257a5..d383c53d9 100644 --- a/neutron_vpnaas/db/vpn/vpn_db.py +++ b/neutron_vpnaas/db/vpn/vpn_db.py @@ -13,12 +13,14 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron.db import models_v2 from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.db import model_query from neutron_lib.db import utils as db_utils @@ -58,12 +60,13 @@ class VPNPluginDb(vpnaas.VPNPluginBase, """ return vpn_validator.VpnReferenceValidator() - def update_status(self, context, model, v_id, status): + def update_status(self, context: context.Context, model, v_id: str, + status: str): with db_api.CONTEXT_WRITER.using(context): v_db = self._get_resource(context, model, v_id) v_db.update({'status': status}) - def _get_resource(self, context, model, v_id): + def _get_resource(self, context: context.Context, model, v_id): try: r = model_query.get_by_id(context, model, v_id) except exc.NoResultFound: @@ -91,7 +94,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if utils.in_pending_status(status): raise vpn_exception.VPNStateInvalidToUpdate(id=_id, state=status) - def _make_ipsec_site_connection_dict(self, ipsec_site_conn, fields=None): + def _make_ipsec_site_connection_dict(self, + ipsec_site_conn: ty.Dict[str, ty.Any], + fields=None): res = {'id': ipsec_site_conn['id'], 'tenant_id': ipsec_site_conn['tenant_id'], @@ -123,15 +128,17 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def get_endpoint_info(self, context, ipsec_sitecon): + def get_endpoint_info(self, context: context.Context, ipsec_sitecon): """Obtain all endpoint info, and store in connection for validation.""" ipsec_sitecon['local_epg_subnets'] = self.get_endpoint_group( context, ipsec_sitecon['local_ep_group_id']) ipsec_sitecon['peer_epg_cidrs'] = self.get_endpoint_group( context, ipsec_sitecon['peer_ep_group_id']) - def validate_connection_info(self, context, validator, ipsec_sitecon, - vpnservice): + def validate_connection_info( + self, context: context.Context, + validator: vpn_validator.VpnReferenceValidator, ipsec_sitecon, + vpnservice): """Collect info and validate connection. If endpoint groups used (default), collect the group info and @@ -150,7 +157,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, validator.validate_ipsec_site_connection(context, ipsec_sitecon, ip_version, vpnservice) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection: ty.Dict[str, ty.Dict]): ipsec_sitecon = ipsec_site_connection['ipsec_site_connection'] validator = self._get_validator() validator.assign_sensible_ipsec_sitecon_defaults(ipsec_sitecon) @@ -203,7 +211,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return self._make_ipsec_site_connection_dict(ipsec_site_conn_db) def update_ipsec_site_connection( - self, context, + self, context: context.Context, ipsec_site_conn_id, ipsec_site_connection): ipsec_sitecon = ipsec_site_connection['ipsec_site_connection'] changed_peer_cidrs = False @@ -252,36 +260,41 @@ class VPNPluginDb(vpnaas.VPNPluginBase, result['peer_cidrs'] = new_peer_cidrs return result - def delete_ipsec_site_connection(self, context, ipsec_site_conn_id): + def delete_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str): with db_api.CONTEXT_WRITER.using(context): ipsec_site_conn_db = self._get_resource( context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id) context.session.delete(ipsec_site_conn_db) - def _get_ipsec_site_connection( - self, context, ipsec_site_conn_id): + def _get_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str) -> vpn_models.IPsecSiteConnection: return self._get_resource( context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id) - def get_ipsec_site_connection(self, context, - ipsec_site_conn_id, fields=None): + def get_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str, fields=None): ipsec_site_conn_db = self._get_ipsec_site_connection( context, ipsec_site_conn_id) return self._make_ipsec_site_connection_dict( ipsec_site_conn_db, fields) - def get_ipsec_site_connections(self, context, filters=None, fields=None): + def get_ipsec_site_connections(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None) -> ty.List[vpn_models.IPsecSiteConnection]: return model_query.get_collection( context, vpn_models.IPsecSiteConnection, self._make_ipsec_site_connection_dict, filters=filters, fields=fields) - def update_ipsec_site_conn_status(self, context, conn_id, new_status): + def update_ipsec_site_conn_status(self, context: context.Context, + conn_id: str, new_status: str): with db_api.CONTEXT_WRITER.using(context): self._update_connection_status(context, conn_id, new_status, True) - def _update_connection_status(self, context, conn_id, new_status, - updated_pending): + def _update_connection_status(self, context: context.Context, + conn_id: str, new_status: str, + updated_pending: bool): """Update the connection status, if changed. If the connection is not in a pending state, unconditionally update @@ -295,7 +308,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if not utils.in_pending_status(conn_db.status) or updated_pending: conn_db.status = new_status - def _make_ikepolicy_dict(self, ikepolicy, fields=None): + def _make_ikepolicy_dict(self, ikepolicy: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': ikepolicy['id'], 'tenant_id': ikepolicy['tenant_id'], 'name': ikepolicy['name'], @@ -313,10 +327,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, + ikepolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ike = ikepolicy['ikepolicy'] validator = self._get_validator() - lifetime_info = ike['lifetime'] + lifetime_info: ty.Dict[str, ty.Any] = ike['lifetime'] lifetime_units = lifetime_info.get('units', 'seconds') lifetime_value = lifetime_info.get('value', 3600) @@ -339,7 +354,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(ike_db) return self._make_ikepolicy_dict(ike_db) - def update_ikepolicy(self, context, ikepolicy_id, ikepolicy): + def update_ikepolicy(self, context: context.Context, ikepolicy_id: str, + ikepolicy: ty.Dict[str, ty.Dict]): ike = ikepolicy['ikepolicy'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -359,7 +375,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, ike_db.update(ike) return self._make_ikepolicy_dict(ike_db) - def delete_ikepolicy(self, context, ikepolicy_id): + def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( ikepolicy_id=ikepolicy_id).first(): @@ -369,19 +385,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(ike_db) @db_api.CONTEXT_READER - def get_ikepolicy(self, context, ikepolicy_id, fields=None): + def get_ikepolicy(self, context: context.Context, ikepolicy_id: str, + fields=None): ike_db = self._get_resource( context, vpn_models.IKEPolicy, ikepolicy_id) return self._make_ikepolicy_dict(ike_db, fields) @db_api.CONTEXT_READER - def get_ikepolicies(self, context, filters=None, fields=None): + def get_ikepolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.IKEPolicy, self._make_ikepolicy_dict, filters=filters, fields=fields) - def _make_ipsecpolicy_dict(self, ipsecpolicy, fields=None): - + def _make_ipsecpolicy_dict(self, ipsecpolicy: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': ipsecpolicy['id'], 'tenant_id': ipsecpolicy['tenant_id'], 'name': ipsecpolicy['name'], @@ -399,10 +417,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, + ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ipsecp = ipsecpolicy['ipsecpolicy'] validator = self._get_validator() - lifetime_info = ipsecp['lifetime'] + lifetime_info: ty.Dict[str, ty.Any] = ipsecp['lifetime'] lifetime_units = lifetime_info.get('units', 'seconds') lifetime_value = lifetime_info.get('value', 3600) @@ -423,7 +442,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(ipsecp_db) return self._make_ipsecpolicy_dict(ipsecp_db) - def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ipsecp = ipsecpolicy['ipsecpolicy'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -444,7 +464,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, ipsecp_db.update(ipsecp) return self._make_ipsecpolicy_dict(ipsecp_db) - def delete_ipsecpolicy(self, context, ipsecpolicy_id): + def delete_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( ipsecpolicy_id=ipsecpolicy_id).first(): @@ -455,18 +476,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(ipsec_db) @db_api.CONTEXT_READER - def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None): + def get_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str, fields=None): ipsec_db = self._get_resource( context, vpn_models.IPsecPolicy, ipsecpolicy_id) return self._make_ipsecpolicy_dict(ipsec_db, fields) @db_api.CONTEXT_READER - def get_ipsecpolicies(self, context, filters=None, fields=None): + def get_ipsecpolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.IPsecPolicy, self._make_ipsecpolicy_dict, filters=filters, fields=fields) - def _make_vpnservice_dict(self, vpnservice, fields=None): + def _make_vpnservice_dict(self, vpnservice: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': vpnservice['id'], 'name': vpnservice['name'], 'description': vpnservice['description'], @@ -480,7 +504,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, 'status': vpnservice['status']} return db_utils.resource_fields(res, fields) - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.Context, + vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]): vpns = vpnservice['vpnservice'] flavor_id = vpns.get('flavor_id', None) validator = self._get_validator() @@ -499,8 +524,10 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(vpnservice_db) return self._make_vpnservice_dict(vpnservice_db) - def set_external_tunnel_ips(self, context, vpnservice_id, v4_ip=None, - v6_ip=None): + def set_external_tunnel_ips(self, context: context.Context, + vpnservice_id: str, + v4_ip: ty.Optional[str] = None, + v6_ip: ty.Optional[str] = None): """Update the external tunnel IP(s) for service.""" vpns = {'external_v4_ip': v4_ip, 'external_v6_ip': v6_ip} with db_api.CONTEXT_WRITER.using(context): @@ -509,20 +536,22 @@ class VPNPluginDb(vpnaas.VPNPluginBase, vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def set_vpnservice_status(self, context, vpnservice_id, status, - updated_pending_status=False): + def set_vpnservice_status(self, context: context.Context, + vpnservice_id: str, status: str, + updated_pending_status: bool = False): vpns = {'status': status} with db_api.CONTEXT_WRITER.using(context): vpns_db = self._get_resource(context, vpn_models.VPNService, vpnservice_id) if (utils.in_pending_status(vpns_db.status) and not updated_pending_status): - raise vpnaas.VPNStateInvalidToUpdate( + raise vpnaas.VPNStateInvalidToUpdate( # type: ignore id=vpnservice_id, state=vpns_db.status) vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: context.Context, vpnservice_id: str, + vpnservice: ty.Dict[str, ty.Optional[ty.Dict]]): vpns = vpnservice['vpnservice'] with db_api.CONTEXT_WRITER.using(context): vpns_db = self._get_resource(context, vpn_models.VPNService, @@ -532,7 +561,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: context.Context, vpnservice_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( vpnservice_id=vpnservice_id @@ -544,23 +573,26 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(vpns_db) @db_api.CONTEXT_READER - def _get_vpnservice(self, context, vpnservice_id): + def _get_vpnservice(self, context: context.Context, vpnservice_id: str): return self._get_resource(context, vpn_models.VPNService, vpnservice_id) @db_api.CONTEXT_READER - def get_vpnservice(self, context, vpnservice_id, fields=None): + def get_vpnservice(self, context: context.Context, vpnservice_id: str, + fields=None): vpns_db = self._get_resource(context, vpn_models.VPNService, vpnservice_id) return self._make_vpnservice_dict(vpns_db, fields) @db_api.CONTEXT_READER - def get_vpnservices(self, context, filters=None, fields=None): + def get_vpnservices(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None) -> ty.List: return model_query.get_collection(context, vpn_models.VPNService, self._make_vpnservice_dict, filters=filters, fields=fields) - def check_router_in_use(self, context, router_id): + def check_router_in_use(self, context: context.Context, router_id: str): vpnservices = self.get_vpnservices( context, filters={'router_id': [router_id]}) if vpnservices: @@ -572,7 +604,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, "(%(services)s)" % {'plural': plural, 'services': services}) - def check_subnet_in_use(self, context, subnet_id, router_id): + def check_subnet_in_use(self, context: context.Context, subnet_id: str, + router_id: str): with db_api.CONTEXT_READER.using(context): vpnservices = context.session.query( vpn_models.VPNService).filter_by( @@ -605,7 +638,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, subnet_id=subnet_id, ipsec_site_connection_id=connection['id']) - def check_subnet_in_use_by_endpoint_group(self, context, subnet_id): + def check_subnet_in_use_by_endpoint_group(self, context: context.Context, + subnet_id: str): with db_api.CONTEXT_READER.using(context): query = context.session.query(vpn_models.VPNEndpointGroup) query = query.filter(vpn_models.VPNEndpointGroup.endpoint_type == @@ -620,7 +654,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, raise vpn_exception.SubnetInUseByEndpointGroup( subnet_id=subnet_id, group_id=group['id']) - def _make_endpoint_group_dict(self, endpoint_group, fields=None): + def _make_endpoint_group_dict(self, endpoint_group: ty.Dict[str, ty.Any], + fields: ty.Optional[ty.Dict] = None) -> ty.Dict: res = {'id': endpoint_group['id'], 'tenant_id': endpoint_group['tenant_id'], 'name': endpoint_group['name'], @@ -630,7 +665,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, for ep in endpoint_group['endpoints']]} return db_utils.resource_fields(res, fields) - def create_endpoint_group(self, context, endpoint_group): + def create_endpoint_group(self, context: context.Context, + endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]): group = endpoint_group['endpoint_group'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -650,8 +686,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(endpoint_db) return self._make_endpoint_group_dict(endpoint_group_db) - def update_endpoint_group(self, context, endpoint_group_id, - endpoint_group): + def update_endpoint_group(self, context: context.Context, + endpoint_group_id: str, + endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]): group_changes = endpoint_group['endpoint_group'] # Note: Endpoints cannot be changed, so will not do validation with db_api.CONTEXT_WRITER.using(context): @@ -661,7 +698,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, endpoint_group_db.update(group_changes) return self._make_endpoint_group_dict(endpoint_group_db) - def delete_endpoint_group(self, context, endpoint_group_id): + def delete_endpoint_group(self, context: context.Context, + endpoint_group_id: str): with db_api.CONTEXT_WRITER.using(context): self.check_endpoint_group_not_in_use(context, endpoint_group_id) endpoint_group_db = self._get_resource( @@ -669,18 +707,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(endpoint_group_db) @db_api.CONTEXT_READER - def get_endpoint_group(self, context, endpoint_group_id, fields=None): + def get_endpoint_group(self, context: context.Context, + endpoint_group_id: str, fields=None): endpoint_group_db = self._get_resource( context, vpn_models.VPNEndpointGroup, endpoint_group_id) return self._make_endpoint_group_dict(endpoint_group_db, fields) @db_api.CONTEXT_READER - def get_endpoint_groups(self, context, filters=None, fields=None): + def get_endpoint_groups(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.VPNEndpointGroup, self._make_endpoint_group_dict, filters=filters, fields=fields) - def check_endpoint_group_not_in_use(self, context, group_id): + def check_endpoint_group_not_in_use(self, context: context.Context, + group_id: str): query = context.session.query(vpn_models.IPsecSiteConnection) query = query.filter( sa.or_( @@ -690,13 +731,15 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if query.first(): raise vpn_exception.EndpointGroupInUse(group_id=group_id) - def get_vpnservice_router_id(self, context, vpnservice_id): + def get_vpnservice_router_id(self, context: context.Context, + vpnservice_id: str): with db_api.CONTEXT_READER.using(context): vpnservice = self._get_vpnservice(context, vpnservice_id) return vpnservice['router_id'] @db_api.CONTEXT_READER - def get_peer_cidrs_for_router(self, context, router_id): + def get_peer_cidrs_for_router(self, context: context.Context, + router_id: str): filters = {'router_id': [router_id]} vpnservices = model_query.get_collection_query( context, vpn_models.VPNService, filters=filters).all() @@ -712,8 +755,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return cidrs -class VPNPluginRpcDbMixin(object): - def _build_local_subnet_cidr_map(self, context): +class VPNPluginRpcDbMixin(VPNPluginDb): + def _build_local_subnet_cidr_map(self, context: context.Context): """Build a dict of all local endpoint subnets, with list of CIDRs.""" query = context.session.query(models_v2.Subnet.id, models_v2.Subnet.cidr) @@ -728,7 +771,8 @@ class VPNPluginRpcDbMixin(object): vpn_models.VPNEndpointGroup.id) return {sn.id: sn.cidr for sn in query.all()} - def update_status_by_agent(self, context, service_status_info_list): + def update_status_by_agent(self, context: context.Context, + service_status_info_list): """Updating vpnservice and vpnconnection status. :param context: context variable @@ -768,7 +812,8 @@ class VPNPluginRpcDbMixin(object): def vpn_router_gateway_callback(resource, event, trigger, payload=None): # the event payload objects - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: context = payload.context router_id = payload.resource_id @@ -782,7 +827,8 @@ def vpn_router_gateway_callback(resource, event, trigger, payload=None): def migration_callback(resource, event, trigger, payload): context = payload.context router = payload.latest_state - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: vpn_plugin.check_router_in_use(context, router['id']) return True @@ -792,7 +838,8 @@ def subnet_callback(resource, event, trigger, payload=None): """Respond to subnet based notifications - see if subnet in use.""" context = payload.context subnet_id = payload.resource_id - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: vpn_plugin.check_subnet_in_use_by_endpoint_group(context, subnet_id) diff --git a/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py b/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py index 3b3e26342..4a6715c47 100644 --- a/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py +++ b/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py @@ -13,6 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron.db.models import l3 as l3_models from neutron.db import models_v2 @@ -20,6 +21,7 @@ from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.db import model_base from neutron_lib.db import model_query @@ -33,8 +35,15 @@ from sqlalchemy import orm from sqlalchemy.orm import exc from neutron_vpnaas._i18n import _ +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.common import constants as v_constants +#pylint: disable=ungrouped-imports +# Additional import for typechecking. Importing these without typechecking +# would resolve in a cyclic dependency +if ty.TYPE_CHECKING: + from neutron.db import db_base_plugin_v2 as db_plugin +#pylint: enable=ungrouped-imports LOG = logging.getLogger(__name__) @@ -80,11 +89,11 @@ class VPNExtGW(model_base.BASEV2, model_base.HasId, model_base.HasProject): @registry.has_registry_receivers -class VPNExtGWPlugin_db(object): +class VPNExtGWPlugin_db(vpn_db.VPNPluginDb): """DB class to support vpn external ports configuration.""" @property - def _core_plugin(self): + def _core_plugin(self) -> 'db_plugin.NeutronDbPluginV2': return directory.get_plugin() @property @@ -95,13 +104,13 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.PORT, [events.BEFORE_DELETE]) def _prevent_vpn_port_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_port_deletion(payload.context, payload.resource_id) @db_api.CONTEXT_READER - def _id_used(self, context, id_column, resource_id): + def _id_used(self, context: context.Context, id_column, resource_id): return context.session.query(VPNExtGW).filter( sa.and_( id_column == resource_id, @@ -109,7 +118,8 @@ class VPNExtGWPlugin_db(object): ) ).count() > 0 - def prevent_vpn_port_deletion(self, context, port_id): + def prevent_vpn_port_deletion(self, context: context.Context, + port_id: str): """Checks to make sure a port is allowed to be deleted. Raises an exception if this is not the case. This should be called by @@ -124,7 +134,7 @@ class VPNExtGWPlugin_db(object): # non-existent ports don't need to be protected from deletion return - port_id_column = { + port_id_column: ty.Optional[str] = { v_constants.DEVICE_OWNER_VPN_ROUTER_GW: VPNExtGW.gw_port_id, v_constants.DEVICE_OWNER_TRANSIT_NETWORK: VPNExtGW.transit_port_id, @@ -142,12 +152,13 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.SUBNET, [events.BEFORE_DELETE]) def _prevent_vpn_subnet_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_subnet_deletion(payload.context, payload.resource_id) - def prevent_vpn_subnet_deletion(self, context, subnet_id): + def prevent_vpn_subnet_deletion(self, context: context.Context, + subnet_id: str): if self._id_used(context, VPNExtGW.transit_subnet_id, subnet_id): reason = _('Subnet is used by VPN service') raise n_exc.SubnetInUse(subnet_id=subnet_id, reason=reason) @@ -156,16 +167,18 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.NETWORK, [events.BEFORE_DELETE]) def _prevent_vpn_network_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_network_deletion(payload.context, payload.resource_id) - def prevent_vpn_network_deletion(self, context, network_id): + def prevent_vpn_network_deletion(self, context: context.Context, + network_id: str): if self._id_used(context, VPNExtGW.transit_network_id, network_id): raise VPNNetworkInUse(network_id=network_id) - def _make_vpn_ext_gw_dict(self, gateway_db): + def _make_vpn_ext_gw_dict(self, gateway_db: ty.Optional[VPNExtGW]) -> \ + ty.Optional[ty.Dict[str, ty.Any]]: if not gateway_db: return None gateway = { @@ -187,7 +200,8 @@ class VPNExtGWPlugin_db(object): gateway[key] = value return gateway - def _get_vpn_gw_by_router_id(self, context, router_id): + def _get_vpn_gw_by_router_id(self, context: context.Context, + router_id: str) -> ty.Optional[VPNExtGW]: try: gateway_db = context.session.query(VPNExtGW).filter( VPNExtGW.router_id == router_id).one() @@ -196,17 +210,20 @@ class VPNExtGWPlugin_db(object): return gateway_db @db_api.CONTEXT_READER - def get_vpn_gw_by_router_id(self, context, router_id): + def get_vpn_gw_by_router_id( + self, context: context.Context, router_id: str): return self._get_vpn_gw_by_router_id(context, router_id) @db_api.CONTEXT_READER - def get_vpn_gw_dict_by_router_id(self, context, router_id, refresh=False): + def get_vpn_gw_dict_by_router_id(self, context: context.Context, + router_id: str, refresh: bool = False): gateway_db = self._get_vpn_gw_by_router_id(context, router_id) if gateway_db and refresh: context.session.refresh(gateway_db) return self._make_vpn_ext_gw_dict(gateway_db) - def create_gateway(self, context, gateway): + def create_gateway(self, context: context.Context, + gateway: ty.Dict[str, ty.Dict[str, ty.Any]]): info = gateway['gateway'] with db_api.CONTEXT_WRITER.using(context): @@ -223,14 +240,15 @@ class VPNExtGWPlugin_db(object): return self._make_vpn_ext_gw_dict(gateway_db) - def update_gateway(self, context, gateway_id, gateway): + def update_gateway(self, context: context.Context, gateway_id: str, + gateway: ty.Dict[str, ty.Dict[str, ty.Any]]): info = gateway['gateway'] with db_api.CONTEXT_WRITER.using(context): gateway_db = model_query.get_by_id(context, VPNExtGW, gateway_id) gateway_db.update(info) return self._make_vpn_ext_gw_dict(gateway_db) - def delete_gateway(self, context, gateway_id): + def delete_gateway(self, context: context.Context, gateway_id: str): with db_api.CONTEXT_WRITER.using(context): query = context.session.query(VPNExtGW) return query.filter(VPNExtGW.id == gateway_id).delete() diff --git a/neutron_vpnaas/db/vpn/vpn_validator.py b/neutron_vpnaas/db/vpn/vpn_validator.py index 3366be2fb..0c10f604d 100644 --- a/neutron_vpnaas/db/vpn/vpn_validator.py +++ b/neutron_vpnaas/db/vpn/vpn_validator.py @@ -12,12 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import socket import netaddr + from neutron.db import l3_db from neutron.db import models_v2 +from neutron import neutron_plugin_base_v2 +from neutron.services.l3_router import l3_router_plugin from neutron_lib.api import validators +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_lib.exceptions import vpn as vpn_exception from neutron_lib.plugins import constants as plugin_const @@ -27,7 +33,7 @@ from neutron_vpnaas._i18n import _ from neutron_vpnaas.services.vpn.common import constants -class VpnReferenceValidator(object): +class VpnReferenceValidator: """ Baseline validation routines for VPN resources. @@ -38,28 +44,30 @@ class VpnReferenceValidator(object): IP_MIN_MTU = {4: 68, 6: 1280} @property - def l3_plugin(self): + def l3_plugin(self) -> l3_router_plugin.L3RouterPlugin: try: return self._l3_plugin except AttributeError: - self._l3_plugin = directory.get_plugin(plugin_const.L3) + self._l3_plugin: l3_router_plugin.L3RouterPlugin = \ + directory.get_plugin(plugin_const.L3) return self._l3_plugin @property - def core_plugin(self): + def core_plugin(self) -> neutron_plugin_base_v2.NeutronPluginBaseV2: try: return self._core_plugin except AttributeError: - self._core_plugin = directory.get_plugin() + self._core_plugin: neutron_plugin_base_v2.NeutronPluginBaseV2 = \ + directory.get_plugin() return self._core_plugin - def _check_dpd(self, ipsec_sitecon): + def _check_dpd(self, ipsec_sitecon: ty.Dict[str, ty.Union[int, ty.Any]]): """Ensure that DPD timeout is greater than DPD interval.""" if ipsec_sitecon['dpd_timeout'] <= ipsec_sitecon['dpd_interval']: raise vpn_exception.IPsecSiteConnectionDpdIntervalValueError( attr='dpd_timeout') - def _check_mtu(self, context, mtu, ip_version): + def _check_mtu(self, context: context.Context, mtu, ip_version): if mtu < VpnReferenceValidator.IP_MIN_MTU[ip_version]: raise vpn_exception.IPsecSiteConnectionMtuError( mtu=mtu, version=ip_version) @@ -95,7 +103,7 @@ class VpnReferenceValidator(object): ip_version = netaddr.IPAddress(ipsec_sitecon['peer_address']).version self._validate_peer_address(ip_version, router) - def _get_local_subnets(self, context, endpoint_group): + def _get_local_subnets(self, context: context.Context, endpoint_group): if endpoint_group['type'] != constants.SUBNET_ENDPOINT: raise vpn_exception.WrongEndpointGroupType( group_type=endpoint_group['type'], which=endpoint_group['id'], @@ -118,7 +126,7 @@ class VpnReferenceValidator(object): """ if len(local_subnets) == 1: return local_subnets[0]['ip_version'] - ip_versions = set([subnet['ip_version'] for subnet in local_subnets]) + ip_versions = {subnet['ip_version'] for subnet in local_subnets} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForIPSecEndpoints( group=group_id) @@ -131,7 +139,7 @@ class VpnReferenceValidator(object): """ if len(peer_cidrs) == 1: return netaddr.IPNetwork(peer_cidrs[0]).version - ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs]) + ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForIPSecEndpoints( group=group_id) @@ -149,12 +157,13 @@ class VpnReferenceValidator(object): """Ensure all CIDRs have the same IP version.""" if len(peer_cidrs) == 1: return netaddr.IPNetwork(peer_cidrs[0]).version - ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs]) + ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForPeerCidrs() return ip_versions.pop() - def _check_local_subnets_on_router(self, context, router, local_subnets): + def _check_local_subnets_on_router(self, context: context.Context, + router, local_subnets): for subnet in local_subnets: self._check_subnet_id(context, router, subnet['id']) @@ -163,7 +172,9 @@ class VpnReferenceValidator(object): if local_ip_version != peer_ip_version: raise vpn_exception.MixedIPVersionsForIPSecConnection() - def validate_ipsec_conn_optional_args(self, ipsec_sitecon, subnet): + def validate_ipsec_conn_optional_args(self, + ipsec_sitecon: ty.Dict[str, ty.Any], + subnet): """Ensure that proper combinations of optional args are used. When VPN service has a subnet, then we must have peer_cidrs, and @@ -176,10 +187,10 @@ class VpnReferenceValidator(object): local_epg_id = ipsec_sitecon.get('local_ep_group_id') peer_epg_id = ipsec_sitecon.get('peer_ep_group_id') peer_cidrs = ipsec_sitecon.get('peer_cidrs') + epgs: ty.List[str] = [] if subnet: if not peer_cidrs: raise vpn_exception.MissingPeerCidrs() - epgs = [] if local_epg_id: epgs.append('local') if peer_epg_id: @@ -192,7 +203,6 @@ class VpnReferenceValidator(object): else: if peer_cidrs: raise vpn_exception.PeerCidrsInvalid() - epgs = [] if not local_epg_id: epgs.append('local') if not peer_epg_id: @@ -203,8 +213,9 @@ class VpnReferenceValidator(object): raise vpn_exception.MissingRequiredEndpointGroup( which=which, suffix=suffix) - def assign_sensible_ipsec_sitecon_defaults(self, ipsec_sitecon, - prev_conn=None): + def assign_sensible_ipsec_sitecon_defaults(self, + ipsec_sitecon: ty.Dict[str, ty.Any], + prev_conn: ty.Optional[ty.Dict] = None): """Provide defaults for optional items, if missing. With endpoint groups capabilities, the peer_cidr (legacy mode) @@ -229,7 +240,7 @@ class VpnReferenceValidator(object): prev_conn = {'dpd_action': 'hold', 'dpd_interval': 30, 'dpd_timeout': 120} - dpd = ipsec_sitecon.get('dpd', {}) + dpd: ty.Dict[str, ty.Any] = ipsec_sitecon.get('dpd', {}) ipsec_sitecon['dpd_action'] = dpd.get('action', prev_conn['dpd_action']) ipsec_sitecon['dpd_interval'] = dpd.get('interval', @@ -237,8 +248,10 @@ class VpnReferenceValidator(object): ipsec_sitecon['dpd_timeout'] = dpd.get('timeout', prev_conn['dpd_timeout']) - def validate_ipsec_site_connection(self, context, ipsec_sitecon, - local_ip_version, vpnservice=None): + def validate_ipsec_site_connection( + self, context: context.Context, ipsec_sitecon: ty.Dict[str, ty.Any], + local_ip_version, + vpnservice: ty.Optional[ty.Dict[str, ty.Any]] = None): """Reference implementation of validation for IPSec connection. This makes sure that IP versions are the same. For endpoint groups, @@ -254,7 +267,9 @@ class VpnReferenceValidator(object): local_subnets = self._get_local_subnets( context, ipsec_sitecon['local_epg_subnets']) self._check_local_subnets_on_router( - context, vpnservice['router_id'], local_subnets) + context, + vpnservice['router_id'] if vpnservice else '', + local_subnets) local_ip_version = self._check_local_endpoint_ip_versions( ipsec_sitecon['local_ep_group_id'], local_subnets) peer_cidrs = self._get_peer_cidrs(ipsec_sitecon['peer_epg_cidrs']) @@ -272,12 +287,13 @@ class VpnReferenceValidator(object): if mtu: self._check_mtu(context, mtu, local_ip_version) - def _check_router(self, context, router_id): + def _check_router(self, context: context.Context, router_id: str): router = self.l3_plugin.get_router(context, router_id) if not router.get(l3_db.EXTERNAL_GW_INFO): raise vpn_exception.RouterIsNotExternal(router_id=router_id) - def _check_subnet_id(self, context, router_id, subnet_id): + def _check_subnet_id(self, context: context.Context, + router_id: str, subnet_id: str): ports = self.core_plugin.get_ports( context, filters={ @@ -288,13 +304,14 @@ class VpnReferenceValidator(object): subnet_id=subnet_id, router_id=router_id) - def validate_vpnservice(self, context, vpnservice): + def validate_vpnservice(self, context: context.Context, + vpnservice: ty.Dict[str, ty.Any]): self._check_router(context, vpnservice['router_id']) if vpnservice['subnet_id'] is not None: self._check_subnet_id(context, vpnservice['router_id'], vpnservice['subnet_id']) - def validate_ipsec_policy(self, context, ipsec_policy): + def validate_ipsec_policy(self, context: context.Context, ipsec_policy): """Reference implementation of validation for IPSec Policy. Service driver can override and implement specific logic @@ -311,7 +328,8 @@ class VpnReferenceValidator(object): group_type=constants.CIDR_ENDPOINT, endpoint=cidr, why=_("Invalid CIDR")) - def _validate_subnets(self, context, subnet_ids): + def _validate_subnets(self, context: context.Context, + subnet_ids: ty.List[str]): """Ensure UUIDs OK and subnets exist.""" for subnet_id in subnet_ids: msg = validators.validate_uuid(subnet_id) @@ -325,7 +343,8 @@ class VpnReferenceValidator(object): raise vpn_exception.NonExistingSubnetInEndpointGroup( subnet=subnet_id) - def validate_endpoint_group(self, context, endpoint_group): + def validate_endpoint_group(self, context: context.Context, + endpoint_group: ty.Dict[str, ty.Any]): """Reference validator for endpoint group. Ensures that there is at least one endpoint, all the endpoints in the @@ -342,7 +361,7 @@ class VpnReferenceValidator(object): elif group_type == constants.SUBNET_ENDPOINT: self._validate_subnets(context, endpoints) - def validate_ike_policy(self, context, ike_policy): + def validate_ike_policy(self, context: context.Context, ike_policy): """Reference implementation of validation for IKE Policy. Service driver can override and implement specific logic diff --git a/neutron_vpnaas/extensions/vpn_agentschedulers.py b/neutron_vpnaas/extensions/vpn_agentschedulers.py index 0a09bb30a..1d1c9e7a3 100644 --- a/neutron_vpnaas/extensions/vpn_agentschedulers.py +++ b/neutron_vpnaas/extensions/vpn_agentschedulers.py @@ -13,6 +13,7 @@ # under the License. import abc +import typing as ty from neutron.api import extensions from neutron.api.v2 import resource @@ -20,6 +21,7 @@ from neutron import policy from neutron import wsgi from neutron_lib.api import extensions as lib_extensions from neutron_lib.api import faults as base +from neutron_lib import context from neutron_lib import exceptions from neutron_lib.plugins import constants as plugin_const from neutron_lib.plugins import directory @@ -37,9 +39,36 @@ VPN_AGENT = 'vpn-agent' VPN_AGENTS = VPN_AGENT + 's' +class VPNAgentSchedulerPluginBase(metaclass=abc.ABCMeta): + """REST API to operate the VPN agent scheduler. + + All methods must be in an admin context. + """ + + @abc.abstractmethod + def add_router_to_vpn_agent(self, context: context.ContextBase, + id: str, router_id: str): + pass + + @abc.abstractmethod + def remove_router_from_vpn_agent(self, context: context.ContextBase, + id: str, router_id: str): + pass + + @abc.abstractmethod + def list_routers_on_vpn_agent(self, context: context.ContextBase, id: str): + pass + + @abc.abstractmethod + def list_vpn_agents_hosting_router(self, context: context.ContextBase, + router_id: str): + pass + + class VPNRouterSchedulerController(wsgi.Controller): - def get_plugin(self): - plugin = directory.get_plugin(plugin_const.VPN) + def get_plugin(self) -> VPNAgentSchedulerPluginBase: + plugin: VPNAgentSchedulerPluginBase = \ + directory.get_plugin(plugin_const.VPN) if not plugin: LOG.error('No plugin for VPN registered to handle VPN ' 'router scheduling') @@ -47,18 +76,18 @@ class VPNRouterSchedulerController(wsgi.Controller): raise webob.exc.HTTPNotFound(msg) return plugin - def index(self, request, **kwargs): + def index(self, request: wsgi.Request, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "get_%s" % VPN_ROUTERS, + f"get_{VPN_ROUTERS}", {}) return plugin.list_routers_on_vpn_agent( request.context, kwargs['agent_id']) - def create(self, request, body, **kwargs): + def create(self, request: wsgi.Request, body, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "create_%s" % VPN_ROUTER, + f"create_{VPN_ROUTER}", {}) agent_id = kwargs['agent_id'] router_id = body['router_id'] @@ -67,10 +96,10 @@ class VPNRouterSchedulerController(wsgi.Controller): notify(request.context, 'vpn_agent.router.add', router_id, agent_id) return result - def delete(self, request, id, **kwargs): + def delete(self, request: wsgi.Request, id, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "delete_%s" % VPN_ROUTER, + f"delete_{VPN_ROUTER}", {}) agent_id = kwargs['agent_id'] result = plugin.remove_router_from_vpn_agent(request.context, agent_id, @@ -80,18 +109,19 @@ class VPNRouterSchedulerController(wsgi.Controller): class VPNAgentsHostingRouterController(wsgi.Controller): - def get_plugin(self): - plugin = directory.get_plugin(plugin_const.VPN) + def get_plugin(self) -> VPNAgentSchedulerPluginBase: + plugin: VPNAgentSchedulerPluginBase = \ + directory.get_plugin(plugin_const.VPN) if not plugin: LOG.error('VPN plugin not registered to handle agent scheduling') msg = 'The resource could not be found.' raise webob.exc.HTTPNotFound(msg) return plugin - def index(self, request, **kwargs): + def index(self, request: wsgi.Request, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "get_%s" % VPN_AGENTS, + f"get_{VPN_AGENTS}", {}) return plugin.list_vpn_agents_hosting_router( request.context, kwargs['router_id']) @@ -102,35 +132,35 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor): """ @classmethod - def get_name(cls): + def get_name(cls) -> str: return "VPN Agent Scheduler" @classmethod - def get_alias(cls): + def get_alias(cls) -> str: return "vpn-agent-scheduler" @classmethod - def get_description(cls): + def get_description(cls) -> str: return "Schedule VPN services of routers among VPN agents" @classmethod - def get_updated(cls): + def get_updated(cls) -> str: return "2016-08-15T10:00:00-00:00" @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[extensions.ResourceExtension]: """Returns Ext Resources.""" - exts = [] - parent = dict(member_name="agent", - collection_name="agents") + exts: ty.List[extensions.ResourceExtension] = [] + parent = {'member_name': "agent", + 'collection_name': "agents"} controller = resource.Resource(VPNRouterSchedulerController(), base.FAULT_MAP) exts.append(extensions.ResourceExtension( VPN_ROUTERS, controller, parent)) - parent = dict(member_name="router", - collection_name="routers") + parent = {'member_name': "router", + 'collection_name': "routers"} controller = resource.Resource(VPNAgentsHostingRouterController(), base.FAULT_MAP) @@ -138,7 +168,7 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor): VPN_AGENTS, controller, parent)) return exts - def get_extended_resources(self, version): + def get_extended_resources(self, version) -> ty.Dict: return {} @@ -161,30 +191,7 @@ class RouterReschedulingFailed(exceptions.Conflict): "No eligible VPN agent found.") -class VPNAgentSchedulerPluginBase(object, metaclass=abc.ABCMeta): - """REST API to operate the VPN agent scheduler. - - All methods must be in an admin context. - """ - - @abc.abstractmethod - def add_router_to_vpn_agent(self, context, id, router_id): - pass - - @abc.abstractmethod - def remove_router_from_vpn_agent(self, context, id, router_id): - pass - - @abc.abstractmethod - def list_routers_on_vpn_agent(self, context, id): - pass - - @abc.abstractmethod - def list_vpn_agents_hosting_router(self, context, router_id): - pass - - -def notify(context, action, router_id, agent_id): +def notify(context: context.ContextBase, action, router_id, agent_id): info = {'id': agent_id, 'router_id': router_id} notifier = n_rpc.get_notifier('router') notifier.info(context, action, {'agent': info}) diff --git a/neutron_vpnaas/extensions/vpn_endpoint_groups.py b/neutron_vpnaas/extensions/vpn_endpoint_groups.py index 97afcff00..cefda6c1d 100644 --- a/neutron_vpnaas/extensions/vpn_endpoint_groups.py +++ b/neutron_vpnaas/extensions/vpn_endpoint_groups.py @@ -14,10 +14,14 @@ import abc +import typing as ty + from neutron_lib.api.definitions import vpn_endpoint_groups from neutron_lib.api import extensions +from neutron_lib import context from neutron_lib.plugins import constants as nconstants +from neutron.api import extensions as nextensions from neutron.api.v2 import resource_helper @@ -25,7 +29,7 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor): api_definition = vpn_endpoint_groups @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[nextensions.ResourceExtension]: plural_mappings = resource_helper.build_plural_mappings( {}, vpn_endpoint_groups.RESOURCE_ATTRIBUTE_MAP) return resource_helper.build_resource_info( @@ -36,25 +40,28 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor): translate_name=True) -class VPNEndpointGroupsPluginBase(object, metaclass=abc.ABCMeta): +class VPNEndpointGroupsPluginBase(metaclass=abc.ABCMeta): @abc.abstractmethod - def create_endpoint_group(self, context, endpoint_group): + def create_endpoint_group(self, context: context.Context, endpoint_group): pass @abc.abstractmethod - def update_endpoint_group(self, context, endpoint_group_id, - endpoint_group): + def update_endpoint_group(self, context: context.Context, + endpoint_group_id: str, endpoint_group): pass @abc.abstractmethod - def delete_endpoint_group(self, context, endpoint_group_id): + def delete_endpoint_group(self, context: context.Context, + endpoint_group_id: str): pass @abc.abstractmethod - def get_endpoint_group(self, context, endpoint_group_id, fields=None): + def get_endpoint_group(self, context: context.Context, + endpoint_group_id: str, fields=None): pass @abc.abstractmethod - def get_endpoint_groups(self, context, filters=None, fields=None): + def get_endpoint_groups(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass diff --git a/neutron_vpnaas/extensions/vpnaas.py b/neutron_vpnaas/extensions/vpnaas.py index c2ffc6c84..6335b502e 100644 --- a/neutron_vpnaas/extensions/vpnaas.py +++ b/neutron_vpnaas/extensions/vpnaas.py @@ -14,13 +14,16 @@ # under the License. import abc +import typing as ty from neutron_lib.api.definitions import vpn from neutron_lib.api import extensions +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_lib.plugins import constants as nconstants from neutron_lib.services import base as service_base +from neutron.api import extensions as nextensions from neutron.api.v2 import resource_helper from neutron_vpnaas._i18n import _ @@ -51,7 +54,7 @@ class Vpnaas(extensions.APIExtensionDescriptor): api_definition = vpn @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[nextensions.ResourceExtension]: special_mappings = {'ikepolicies': 'ikepolicy', 'ipsecpolicies': 'ipsecpolicy'} plural_mappings = resource_helper.build_plural_mappings( @@ -71,90 +74,105 @@ class Vpnaas(extensions.APIExtensionDescriptor): class VPNPluginBase(service_base.ServicePluginBase, metaclass=abc.ABCMeta): - def get_plugin_type(self): + def get_plugin_type(self) -> str: return nconstants.VPN - def get_plugin_description(self): + def get_plugin_description(self) -> str: return 'VPN service plugin' @abc.abstractmethod - def get_vpnservices(self, context, filters=None, fields=None): + def get_vpnservices(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass @abc.abstractmethod - def get_vpnservice(self, context, vpnservice_id, fields=None): + def get_vpnservice(self, context: context.Context, vpnservice_id: str, + fields=None): pass @abc.abstractmethod - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.Context, vpnservice): pass @abc.abstractmethod - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: context.Context, vpnservice_id: str, + vpnservice): pass @abc.abstractmethod - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: context.Context, vpnservice_id: str): pass @abc.abstractmethod - def get_ipsec_site_connections(self, context, filters=None, fields=None): + def get_ipsec_site_connections(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None): pass @abc.abstractmethod - def get_ipsec_site_connection(self, context, - ipsecsite_conn_id, fields=None): + def get_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str, fields=None): pass @abc.abstractmethod - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): pass @abc.abstractmethod - def update_ipsec_site_connection(self, context, - ipsecsite_conn_id, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str, + ipsec_site_connection): pass @abc.abstractmethod - def delete_ipsec_site_connection(self, context, ipsecsite_conn_id): + def delete_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str): pass @abc.abstractmethod - def get_ikepolicy(self, context, ikepolicy_id, fields=None): + def get_ikepolicy(self, context: context.Context, ikepolicy_id: str, + fields=None): pass @abc.abstractmethod - def get_ikepolicies(self, context, filters=None, fields=None): + def get_ikepolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict], fields=None): pass @abc.abstractmethod - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, ikepolicy): pass @abc.abstractmethod - def update_ikepolicy(self, context, ikepolicy_id, ikepolicy): + def update_ikepolicy(self, context: context.Context, ikepolicy_id: str, + ikepolicy): pass @abc.abstractmethod - def delete_ikepolicy(self, context, ikepolicy_id): + def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str): pass @abc.abstractmethod - def get_ipsecpolicies(self, context, filters=None, fields=None): + def get_ipsecpolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass @abc.abstractmethod - def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None): + def get_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + fields=None): pass @abc.abstractmethod - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass @abc.abstractmethod - def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + ipsecpolicy): pass @abc.abstractmethod - def delete_ipsecpolicy(self, context, ipsecpolicy_id): + def delete_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str): pass diff --git a/neutron_vpnaas/opts.py b/neutron_vpnaas/opts.py index ef7d3bbd7..6ef11940a 100644 --- a/neutron_vpnaas/opts.py +++ b/neutron_vpnaas/opts.py @@ -9,6 +9,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty import neutron.conf.plugins.ml2.drivers.ovn.ovn_conf import neutron.services.provider_configuration @@ -19,7 +20,7 @@ import neutron_vpnaas.services.vpn.device_drivers.strongswan_ipsec import neutron_vpnaas.services.vpn.ovn_agent -def list_agent_opts(): +def list_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('vpnagent', neutron_vpnaas.services.vpn.agent.vpn_agent_opts), @@ -33,7 +34,7 @@ def list_agent_opts(): ] -def list_ovn_agent_opts(): +def list_ovn_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('vpnagent', neutron_vpnaas.services.vpn.ovn_agent.VPN_AGENT_OPTS), @@ -51,7 +52,7 @@ def list_ovn_agent_opts(): ] -def list_opts(): +def list_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('service_providers', neutron.services.provider_configuration.serviceprovider_opts) diff --git a/neutron_vpnaas/scheduler/vpn_agent_scheduler.py b/neutron_vpnaas/scheduler/vpn_agent_scheduler.py index bea9bfe84..2434466af 100644 --- a/neutron_vpnaas/scheduler/vpn_agent_scheduler.py +++ b/neutron_vpnaas/scheduler/vpn_agent_scheduler.py @@ -14,33 +14,42 @@ import abc import random +import typing as ty +from neutron.extensions import agent as nagent from neutron.extensions import availability_zone as az_ext +from neutron.extensions import l3 +from neutron_lib import context from neutron_lib.plugins import constants as plugin_constants from neutron_lib.plugins import directory from oslo_config import cfg from oslo_log import log as logging +if ty.TYPE_CHECKING: + from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as scheduler_db from neutron_vpnaas.extensions import vpn_agentschedulers LOG = logging.getLogger(__name__) -class VPNScheduler(object, metaclass=abc.ABCMeta): +class VPNScheduler(metaclass=abc.ABCMeta): @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(plugin_constants.L3) @abc.abstractmethod - def schedule(self, plugin, context, router_id, - candidates=None, hints=None): + def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None) -> \ + ty.Optional[nagent.Agent]: """Schedule the router to an active VPN agent. Schedule the router only if it is not already scheduled. """ pass - def _get_unscheduled_routers(self, context, plugin, router_ids=None): + def _get_unscheduled_routers(self, context: context.Context, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + router_ids=None) -> ty.List: """Get the list of routers with VPN services to be scheduled. If router IDs are omitted, look for all unscheduled routers. @@ -57,7 +66,10 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): context, filters={'id': unscheduled_router_ids}) return [] - def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent): + def _get_routers_can_schedule( + self, context: context.Context, + plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + routers, vpn_agent,): """Get the subset of routers whose VPN services can be scheduled on the VPN agent. """ @@ -65,7 +77,9 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): # all routers can be scheduled to it return routers - def auto_schedule_routers(self, plugin, context, vpn_agent): + def auto_schedule_routers( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, vpn_agent) -> ty.List[str]: """Schedule non-hosted routers to a VPN agent. :returns: True if routers have been successfully assigned to the agent @@ -83,20 +97,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): self._bind_routers(context, plugin, target_routers, vpn_agent) return [router['id'] for router in target_routers] - def _get_candidates(self, plugin, context, sync_router): + def _get_candidates( + self, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + sync_router) -> ty.Optional[ty.List[nagent.Agent]]: """Return VPN agents where a router could be scheduled.""" active_vpn_agents = plugin.get_vpn_agents(context, active=True) if not active_vpn_agents: LOG.warning('No active VPN agents') return active_vpn_agents - def _bind_routers(self, context, plugin, routers, vpn_agent): + def _bind_routers( + self, context: context.Context, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + routers, vpn_agent): for router in routers: plugin.create_router_to_agent_binding( context, router['id'], vpn_agent['id']) - def _schedule_router(self, plugin, context, router_id, - candidates=None): + def _schedule_router( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + router_id, + candidates: ty.Optional[ty.List[nagent.Agent]] = None) -> \ + ty.Optional[nagent.Agent]: current_vpn_agents = plugin.get_vpn_agents_hosting_routers( context, [router_id]) if current_vpn_agents: @@ -118,9 +143,13 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): if plugin.create_router_to_agent_binding(context, router_id, chosen_agent['id']): return chosen_agent + return None @abc.abstractmethod - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: """Choose an agent from candidates based on a specific policy.""" pass @@ -128,24 +157,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): class ChanceScheduler(VPNScheduler): """Randomly allocate an VPN agent for a router.""" - def schedule(self, plugin, context, router_id, - candidates=None): + def schedule( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None): return self._schedule_router( plugin, context, router_id, candidates=candidates) - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: return random.choice(candidates) class LeastRoutersScheduler(VPNScheduler): """Allocate to an VPN agent with the least number of routers bound.""" - def schedule(self, plugin, context, router_id, - candidates=None): + def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None): return self._schedule_router( plugin, context, router_id, candidates=candidates) - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: candidates_dict = {c['id']: c for c in candidates} chosen_agent_id = plugin.get_vpn_agent_with_min_routers( context, candidates_dict.keys()) @@ -156,9 +192,12 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler): """Availability zone aware scheduler.""" def _get_az_hints(self, router): return (router.get(az_ext.AZ_HINTS) or - cfg.CONF.default_availability_zones) + cfg.CONF.default_availability_zones) - def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent): + def _get_routers_can_schedule( + self, context: context.Context, + plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + routers, vpn_agent): """Overwrite VPNScheduler's method to filter by availability zone.""" target_routers = [] for r in routers: @@ -172,14 +211,20 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler): return super()._get_routers_can_schedule( context, plugin, target_routers, vpn_agent) - def _get_candidates(self, plugin, context, sync_router): + def _get_candidates( + self, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + sync_router) -> ty.Optional[ty.List[nagent.Agent]]: """Overwrite VPNScheduler's method to filter by availability zone.""" all_candidates = super()._get_candidates(plugin, context, sync_router) - candidates = [] - az_hints = self._get_az_hints(sync_router) - for agent in all_candidates: - if not az_hints or agent['availability_zone'] in az_hints: - candidates.append(agent) + if all_candidates: + candidates = [] + az_hints = self._get_az_hints(sync_router) + for agent in all_candidates: + if not az_hints or agent['availability_zone'] in az_hints: + candidates.append(agent) - return candidates + return candidates + return None diff --git a/neutron_vpnaas/services/vpn/agent.py b/neutron_vpnaas/services/vpn/agent.py index 86810aeaa..8b55dec33 100644 --- a/neutron_vpnaas/services/vpn/agent.py +++ b/neutron_vpnaas/services/vpn/agent.py @@ -14,12 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty +from neutron.agent.l3 import l3_agent_extension_api from neutron_lib.agent import l3_extension +from neutron_lib import context from oslo_config import cfg from oslo_log import log as logging from neutron_vpnaas._i18n import _ +from neutron_vpnaas.services.vpn import device_drivers from neutron_vpnaas.services.vpn import vpn_service LOG = logging.getLogger(__name__) @@ -43,10 +47,11 @@ cfg.CONF.register_opts(vpn_agent_opts, 'vpnagent') class VPNAgent(l3_extension.L3AgentExtension): """VPNaaS Agent support to be used by Neutron L3 agent.""" - def initialize(self, connection, driver_type): + def initialize(self, connection, driver_type: device_drivers.DeviceDriver): LOG.debug("Loading VPNaaS") - def consume_api(self, agent_api): + def consume_api(self, + agent_api: l3_agent_extension_api.L3AgentExtensionAPI): LOG.debug("Loading consume_api for VPNaaS") self.agent_api = agent_api @@ -58,28 +63,34 @@ class VPNAgent(l3_extension.L3AgentExtension): self.service = vpn_service.VPNService(self) self.device_drivers = self.service.load_device_drivers(self.host) - def add_router(self, context, data): + def add_router(self, context: context.Context, data): """Handles router add event""" ri = self.agent_api.get_router_info(data['id']) if ri is not None: for device_driver in self.device_drivers: device_driver.create_router(ri) - device_driver.sync(context, [ri.router]) + device_driver.sync(context, [ri]) else: LOG.debug("Router %s was concurrently deleted while " "creating VPN for it", data['id']) - def update_router(self, context, data): + def update_router(self, context: context.Context, data): """Handles router update event""" - for device_driver in self.device_drivers: - device_driver.sync(context, [data]) + ri = self.agent_api.get_router_info(data['id']) + if ri is not None: + for device_driver in self.device_drivers: + device_driver.sync(context, [ri]) + else: + LOG.debug("Router %s was concurrently deleted while " + "updating VPN for it", data['id']) - def delete_router(self, context, data): + def delete_router(self, context: context.Context, data): """Handles router delete event""" for device_driver in self.device_drivers: device_driver.destroy_router(data['id']) - def ha_state_change(self, context, data): + def ha_state_change(self, context: context.Context, + data: ty.Dict[str, str]): """Enable the vpn process when router transitioned to master. And disable vpn process for backup router. @@ -98,7 +109,7 @@ class VPNAgent(l3_extension.L3AgentExtension): else: process.disable() - def update_network(self, context, data): + def update_network(self, context: context.Context, data): pass @@ -109,5 +120,4 @@ class L3WithVPNaaS(VPNAgent): self.conf = conf else: self.conf = cfg.CONF - super(L3WithVPNaaS, self).__init__( - host=self.conf.host, conf=self.conf) + super().__init__(host=self.conf.host, conf=self.conf) diff --git a/neutron_vpnaas/services/vpn/common/netns_wrapper.py b/neutron_vpnaas/services/vpn/common/netns_wrapper.py index 49a4b7184..93832788d 100644 --- a/neutron_vpnaas/services/vpn/common/netns_wrapper.py +++ b/neutron_vpnaas/services/vpn/common/netns_wrapper.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import configparser as ConfigParser import errno import os @@ -31,7 +33,7 @@ from neutron_vpnaas._i18n import _ LOG = logging.getLogger(__name__) -def setup_conf(): +def setup_conf() -> cfg.ConfigOpts: cli_opts = [ cfg.DictOpt('mount_paths', required=True, @@ -50,9 +52,9 @@ def setup_conf(): return conf -def execute(cmd): +def execute(cmd) -> ty.Optional[int]: if not cmd: - return + return None cmd = list(map(str, cmd)) LOG.debug("Running command: %s", cmd) env = os.environ.copy() @@ -106,12 +108,12 @@ def filter_command(command, rootwrap_config): 'name': exc.match.name}) sys.exit(errno.EINVAL) except wrapper.NoFilterMatched: - LOG.error('Unauthorized command: %(cmd)s (no filter matched)', - {'cmd': command}) + LOG.error("Unauthorized command: %(cmd)s (no filter matched)", + {"cmd": command}) sys.exit(errno.EPERM) -def execute_with_mount(): +def execute_with_mount() -> ty.Optional[int]: config.register_common_config_options() conf = setup_conf() conf() diff --git a/neutron_vpnaas/services/vpn/device_drivers/__init__.py b/neutron_vpnaas/services/vpn/device_drivers/__init__.py index 60bf101eb..06638e56b 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/__init__.py +++ b/neutron_vpnaas/services/vpn/device_drivers/__init__.py @@ -14,14 +14,16 @@ # under the License. import abc +from neutron_lib import context -class DeviceDriver(object, metaclass=abc.ABCMeta): + +class DeviceDriver(metaclass=abc.ABCMeta): def __init__(self, agent, host): pass @abc.abstractmethod - def sync(self, context, processes): + def sync(self, context: context.ContextBase, processes): pass @abc.abstractmethod diff --git a/neutron_vpnaas/services/vpn/device_drivers/ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/ipsec.py index 7ffcaac71..e51840a30 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/ipsec.py @@ -12,6 +12,9 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + import abc import base64 import copy @@ -21,10 +24,12 @@ import re import shutil import socket import sys +from typing import List import eventlet import jinja2 import netaddr +from neutron.agent.l3.router_info import RouterInfo from neutron.agent.linux import ip_lib from neutron.agent.linux import utils as agent_utils from neutron_lib.api import validators @@ -126,7 +131,8 @@ IPSEC_CONNS = 'ipsec_site_connections' PSK_BASE64_PREFIX = '0s' -def _get_template(template_file): +def _get_template( + template_file: ty.Union[str, jinja2.Template]) -> jinja2.Template: global JINJA_ENV if not JINJA_ENV: templateLoader = jinja2.FileSystemLoader(searchpath="/") @@ -134,7 +140,7 @@ def _get_template(template_file): return JINJA_ENV.get_template(template_file) -class BaseSwanProcess(object, metaclass=abc.ABCMeta): +class BaseSwanProcess(metaclass=abc.ABCMeta): """Swan Family Process Manager This class manages start/restart/stop ipsec process. @@ -189,12 +195,12 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): STATUS_IPSEC_SA_ESTABLISHED_RE2 = ( r'\d{3} #\d+: "([a-f0-9\-\/x]+).*established.*newest IPSEC') - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self.conf = conf self.id = process_id - self.updated_pending_status = False + self.updated_pending_status: bool = False self.namespace = namespace - self.connection_status = {} + self.connection_status: ty.Dict[str, ty.Dict[str, ty.Any]] = {} self.config_dir = os.path.join( self.conf.ipsec.config_base_dir, self.id) self.etc_dir = os.path.join(self.config_dir, 'etc') @@ -212,32 +218,31 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): def translate_dialect(self): if not self.vpnservice: return - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']: - self._dialect(ipsec_site_conn, 'initiator') - self._dialect(ipsec_site_conn['ikepolicy'], 'ike_version') - for key in ['encryption_algorithm', - 'auth_algorithm', - 'pfs']: - self._dialect(ipsec_site_conn['ikepolicy'], key) - self._dialect(ipsec_site_conn['ipsecpolicy'], key) - if (('local_id' not in ipsec_site_conn.keys()) or - (not ipsec_site_conn['local_id'])): - ipsec_site_conn['local_id'] = ipsec_site_conn['external_ip'] + for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]: + self._dialect(ipsec_site_conn, "initiator") + self._dialect(ipsec_site_conn["ikepolicy"], "ike_version") + for key in ["encryption_algorithm", "auth_algorithm", "pfs"]: + self._dialect(ipsec_site_conn["ikepolicy"], key) + self._dialect(ipsec_site_conn["ipsecpolicy"], key) + if ("local_id" not in ipsec_site_conn.keys()) or ( + not ipsec_site_conn["local_id"] + ): + ipsec_site_conn["local_id"] = ipsec_site_conn["external_ip"] def base64_encode_psk(self): if not self.vpnservice: return - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']: - psk = ipsec_site_conn['psk'] + for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]: + psk = ipsec_site_conn["psk"] encoded_psk = base64.b64encode(encodeutils.safe_encode(psk)) # NOTE(huntxu): base64.b64encode returns an instance of 'bytes' # in Python 3, convert it to a str. For Python 2, after calling # safe_decode, psk is converted into a unicode not containing any # non-ASCII characters so it doesn't matter. - psk = encodeutils.safe_decode(encoded_psk, incoming='utf_8') - ipsec_site_conn['psk'] = PSK_BASE64_PREFIX + psk + psk = encodeutils.safe_decode(encoded_psk, incoming="utf_8") + ipsec_site_conn["psk"] = PSK_BASE64_PREFIX + psk - def get_ns_wrapper(self): + def get_ns_wrapper(self) -> str: """ Check if we're inside a virtualenv. If we are, then we should respect this and launch wrapper from venv as well. @@ -249,7 +254,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): ns_wrapper = self.NS_WRAPPER return ns_wrapper - def update_vpnservice(self, vpnservice): + def update_vpnservice(self, vpnservice: ty.Dict[str, ty.Any]): self.vpnservice = vpnservice self.translate_dialect() self.base64_encode_psk() @@ -275,7 +280,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): agent_utils.execute( cmd=["rm", "-rf", self.config_dir], run_as_root=True) - def _get_config_filename(self, kind): + def _get_config_filename(self, kind) -> str: config_dir = self.etc_dir return os.path.join(config_dir, kind) @@ -286,15 +291,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): dir_path = os.path.join(self.config_dir, subdir) fileutils.ensure_tree(dir_path, 0o755) - def _gen_config_content(self, template_file, vpnservice): + def _gen_config_content(self, template_file, vpnservice) -> str: template = _get_template(template_file) return template.render( {'vpnservice': vpnservice, 'state_path': self.conf.state_path}) - def _get_rootwrap_config(self): + def _get_rootwrap_config(self) -> ty.Optional[str]: if 'neutron-rootwrap' in cfg.CONF.AGENT.root_helper: - rh_tokens = cfg.CONF.AGENT.root_helper.split(' ') + rh_tokens: ty.List[str] = cfg.CONF.AGENT.root_helper.split(' ') if len(rh_tokens) == 3 and os.path.exists(rh_tokens[2]): return rh_tokens[2] return None @@ -304,13 +309,13 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): pass @property - def status(self): + def status(self) -> str: if self.active: return constants.ACTIVE return constants.DOWN @property - def active(self): + def active(self) -> bool: """Check if the process is active or not.""" if not self.namespace: return False @@ -329,8 +334,9 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): # Disable the process if a vpnservice is disabled or it has no # enabled IPSec site connections. vpnservice_has_active_ipsec_site_conns = any( - [ipsec_site_conn['admin_state_up'] - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']]) + ipsec_site_conn['admin_state_up'] + for ipsec_site_conn in + self.vpnservice['ipsec_site_connections']) if (not self.vpnservice['admin_state_up'] or not vpnservice_has_active_ipsec_site_conns): self.disable() @@ -386,7 +392,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): def stop(self): """Stop process.""" - def _check_status_line(self, line): + def _check_status_line(self, + line: str) -> ty.Tuple[ty.Optional[str], ty.Optional[str]]: """Parse a line and search for status information. If a connection has an established Security Association, @@ -404,14 +411,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): if m: connection_id = m.group(1) return connection_id, constants.ACTIVE - else: - m = self.STATUS_PATTERN.search(line) - if m: - connection_id = m.group(1) - return connection_id, constants.DOWN + + m = self.STATUS_PATTERN.search(line) + if m: + connection_id = m.group(1) + return connection_id, constants.DOWN return None, None - def _extract_and_record_connection_status(self, status_output): + def _extract_and_record_connection_status(self, + status_output: ty.Optional[str]): if not status_output: self.connection_status = {} return @@ -423,8 +431,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): if conn_id: self._record_connection_status(conn_id, conn_status) - def _record_connection_status(self, connection_id, status, - force_status_update=False): + def _record_connection_status(self, connection_id: str, status, + force_status_update: bool = False): conn_info = self.connection_status.get(connection_id) if not conn_info: self.connection_status[connection_id] = { @@ -445,18 +453,20 @@ class OpenSwanProcess(BaseSwanProcess): (2) ipsec addconn: Adds new ipsec addconn (3) ipsec whack: control interface for IPSEC keying daemon """ - def __init__(self, conf, process_id, vpnservice, namespace): - super(OpenSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) - self.secrets_file = os.path.join( - self.etc_dir, 'ipsec.secrets') - self.config_file = os.path.join( - self.etc_dir, 'ipsec.conf') - self.pid_path = os.path.join( - self.config_dir, 'var', 'run', 'pluto') - self.pid_file = '%s.pid' % self.pid_path - def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): + super().__init__(conf, process_id, vpnservice, namespace) + self.secrets_file: str = os.path.join( + self.etc_dir, 'ipsec.secrets') + self.config_file: str = os.path.join( + self.etc_dir, 'ipsec.conf') + self.pid_path: str = os.path.join( + self.config_dir, 'var', 'run', 'pluto') + self.pid_file: str = f'{self.pid_path}.pid' + + def _execute(self, cmd, check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None + ) -> ty.Optional[str]: """Execute command on namespace.""" ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace) return ip_wrapper.netns.execute(cmd, check_exit_code=check_exit_code, @@ -490,7 +500,7 @@ class OpenSwanProcess(BaseSwanProcess): shutil.copyfile(config_file_name, config_file_name + '.old') os.chmod(config_file_name + '.old', 0o600) - def _process_running(self): + def _process_running(self) -> bool: """Checks if process is still running.""" # If no PID file, we assume the process is not running. @@ -502,9 +512,10 @@ class OpenSwanProcess(BaseSwanProcess): # on throwing to tell us something. If the pid file exists, # delve into the process information and check if it matches # our expected command line. - with open(self.pid_file, 'r') as f: + with open(self.pid_file, 'r', encoding="C") as f: pid = f.readline().strip() - with open('/proc/%s/cmdline' % pid) as cmd_line_file: + with open(f'/proc/{pid}/cmdline', + encoding="C") as cmd_line_file: cmd_line = cmd_line_file.readline() if self.pid_path in cmd_line and 'pluto' in cmd_line: # Okay the process is probably a pluto process @@ -529,7 +540,7 @@ class OpenSwanProcess(BaseSwanProcess): def _cleanup_control_files(self): try: - ctl_file = '%s.ctl' % self.pid_path + ctl_file = f'{self.pid_path}.ctl' LOG.debug('Removing %(pidfile)s and %(ctlfile)s', {'pidfile': self.pid_file, 'ctlfile': ctl_file}) @@ -545,14 +556,14 @@ class OpenSwanProcess(BaseSwanProcess): 'files for router %(router)s. %(msg)s', {'router': self.id, 'msg': e}) - def get_status(self): + def get_status(self) -> ty.Optional[str]: return self._execute([self.binary, 'whack', '--ctlbase', self.pid_path, '--status'], extra_ok_codes=[1, 3]) - def _config_changed(self): + def _config_changed(self) -> bool: secrets_file = os.path.join( self.etc_dir, 'ipsec.secrets') config_file = os.path.join( @@ -590,19 +601,25 @@ class OpenSwanProcess(BaseSwanProcess): LOG.warning('Server appears to still be running, restart ' 'of router %s may fail', self.id) self.start() - return - def _resolve_fqdn(self, fqdn): + def _resolve_fqdn(self, fqdn) -> ty.Optional[str]: # The first addrinfo member from the list returned by # socket.getaddrinfo is used for the address resolution. # The code doesn't filter for ipv4 or ipv6 address. try: - addrinfo = socket.getaddrinfo(fqdn, None)[0] + addrinfo: ty.Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + ty.Union[ty.Tuple[str, int], ty.Tuple[str, int, int, int]], + ] = socket.getaddrinfo(fqdn, None)[0] return addrinfo[-1][0] except socket.gaierror: LOG.exception("Peer address %s cannot be resolved", fqdn) + return None - def _get_nexthop(self, address, connection_id): + def _get_nexthop(self, address: str, connection_id: str) -> str: # check if address is an ip address or fqdn invalid_ip_address = validators.validate_ip_address(address) if invalid_ip_address: @@ -615,28 +632,31 @@ class OpenSwanProcess(BaseSwanProcess): else: ip_addr = address routes = self._execute(['ip', 'route', 'get', ip_addr]) - if routes.find('via') >= 0: + if routes and routes.find('via') >= 0: return routes.split(' ')[2] return address - def _virtual_privates(self, vpnservice): + def _virtual_privates(self, + vpnservice: ty.Dict[str, ty.List[ty.Dict[str, ty.Any]]]) -> str: """Returns line of virtual_privates. virtual_private contains the networks that are allowed as subnet for the remote client. """ - virtual_privates = [] - nets = [] + virtual_privates: ty.List[str] = [] + nets: ty.List = [] for ipsec_site_conn in vpnservice['ipsec_site_connections']: nets += ipsec_site_conn['local_cidrs'] nets += ipsec_site_conn['peer_cidrs'] for net in nets: version = netaddr.IPNetwork(net).version - virtual_privates.append('%%v%s:%s' % (version, net)) + virtual_privates.append(f'%v{version}:{net}') virtual_privates.sort() return ','.join(virtual_privates) - def _gen_config_content(self, template_file, vpnservice): + def _gen_config_content(self, + template_file: ty.Union[str, jinja2.Template], + vpnservice) -> str: template = _get_template(template_file) virtual_privates = self._virtual_privates(vpnservice) return template.render( @@ -660,7 +680,7 @@ class OpenSwanProcess(BaseSwanProcess): def add_ipsec_connection(self, nexthop, conn_id): self._execute([self.binary, 'addconn', - '--ctlbase', '%s.ctl' % self.pid_path, + '--ctlbase', f'{self.pid_path}.ctl', '--defaultroutenexthop', nexthop, '--config', self.config_file, conn_id ]) @@ -745,9 +765,9 @@ class OpenSwanProcess(BaseSwanProcess): self.initiate_connection(ipsec_site_conn['id']) self._copy_configs() - def get_established_connections(self): - connections = [] - status_output = self.get_status() + def get_established_connections(self) -> ty.List[str]: + connections: ty.List[str] = [] + status_output: ty.Optional[str] = self.get_status() if not status_output: return connections @@ -781,7 +801,7 @@ class OpenSwanProcess(BaseSwanProcess): self.connection_status = {} -class IPsecVpnDriverApi(object): +class IPsecVpnDriverApi: """IPSecVpnDriver RPC api.""" @log_helpers.log_method_call @@ -790,7 +810,7 @@ class IPsecVpnDriverApi(object): self.client = n_rpc.get_client(target) @log_helpers.log_method_call - def get_vpn_services_on_host(self, context, host): + def get_vpn_services_on_host(self, context: context.ContextBase, host): """Get list of vpnservices. The vpnservices including related ipsec_site_connection, @@ -800,7 +820,7 @@ class IPsecVpnDriverApi(object): return cctxt.call(context, 'get_vpn_services_on_host', host=host) @log_helpers.log_method_call - def update_status(self, context, status): + def update_status(self, context: context.ContextBase, status): """Update local status. This method call updates status attribute of @@ -823,19 +843,20 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): # 1.0 Initial version target = oslo_messaging.Target(version='1.0') - def __init__(self, vpn_service, host): + def __init__(self, vpn_service, host: str): # TODO(pc_m) Replace vpn_service with config arg, once all driver # implementations no longer need vpn_service. self.conf = vpn_service.conf self.host = host self.conn = n_rpc.Connection() - self.context = context.get_admin_context_without_session() - self.topic = topics.IPSEC_AGENT_TOPIC - node_topic = '%s.%s' % (self.topic, self.host) + self.context: context.ContextBase = \ + context.get_admin_context_without_session() + self.topic: str = topics.IPSEC_AGENT_TOPIC + node_topic: str = f'{self.topic}.{self.host}' - self.processes = {} - self.routers = {} - self.process_status_cache = {} + self.processes: ty.Dict[str, BaseSwanProcess] = {} + self.routers: ty.Dict[str, ty.Any] = {} + self.process_status_cache: ty.Dict[str, ty.Dict[str, ty.Any]] = {} self.endpoints = [self] self.conn.create_consumer(node_topic, self.endpoints, fanout=False) @@ -846,7 +867,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): self.process_status_cache_check.start( interval=self.conf.ipsec.ipsec_status_check_interval) - def get_namespace(self, router_id): + def get_namespace(self, router_id: str) -> ty.Optional[str]: """Get namespace of router. :router_id: router_id @@ -856,14 +877,14 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): """ router = self.routers.get(router_id) if not router: - return + return None # For DVR, use SNAT namespace # TODO(pcm): Use router object method to tell if DVR, when available if router.router['distributed']: return router.snat_namespace.name return router.ns_name - def get_router_based_iptables_manager(self, router): + def get_router_based_iptables_manager(self, ri): """Returns router based iptables manager In DVR routers the IPsec VPN service should run inside @@ -877,61 +898,28 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): return the legacy iptables_manager. """ # TODO(pcm): Use router object method to tell if DVR, when available - if router.router['distributed']: - return router.snat_iptables_manager - return router.iptables_manager + if ri.router['distributed']: + return ri.snat_iptables_manager + return ri.iptables_manager - def add_nat_rule(self, router_id, chain, rule, top=False): - """Add nat rule in namespace. + def ensure_nat_rules(self, vpnservice): + """Ensure the required nat rules for ipsec exist in iptables. - :param router_id: router_id - :param chain: a string of chain name - :param rule: a string of rule - :param top: if top is true, the rule - will be placed on the top of chain - Note if there is no router, this method does nothing - """ - router = self.routers.get(router_id) - if not router: - return - iptables_manager = self.get_router_based_iptables_manager(router) - iptables_manager.ipv4['nat'].add_rule(chain, rule, top=top) - - def remove_nat_rule(self, router_id, chain, rule, top=False): - """Remove nat rule in namespace. - - :param router_id: router_id - :param chain: a string of chain name - :param rule: a string of rule - :param top: unused - needed to have same argument with add_nat_rule - """ - router = self.routers.get(router_id) - if not router: - return - iptables_manager = self.get_router_based_iptables_manager(router) - iptables_manager.ipv4['nat'].remove_rule(chain, rule, top=top) - - def iptables_apply(self, router_id): - """Apply IPtables. - - :param router_id: router_id - This method do nothing if there is no router - """ - router = self.routers.get(router_id) - if not router: - return - iptables_manager = self.get_router_based_iptables_manager(router) - iptables_manager.apply() - - def _update_nat(self, vpnservice, func): - """Setting up nat rule in iptables. - - We need to setup nat rule for ipsec packet. :param vpnservice: vpnservices - :param func: self.add_nat_rule or self.remove_nat_rule """ - router_id = vpnservice['router_id'] + LOG.debug("ensure_nat_rules called for router %s", + vpnservice['router_id']) + ri = self.routers.get(vpnservice['router_id']) + if not ri: + LOG.debug("No router info for router %s", vpnservice['router_id']) + return + + iptables_manager = self.get_router_based_iptables_manager(ri) + # clear all existing rules first + LOG.debug("Clearing vpnaas tagged NAT rules for router %s", + ri.router_id) + iptables_manager.ipv4['nat'].clear_rules_by_tag('vpnaas') + for ipsec_site_connection in vpnservice['ipsec_site_connections']: for local_cidr in ipsec_site_connection['local_cidrs']: # This ipsec rule is not needed for ipv6. @@ -939,16 +927,36 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): continue for peer_cidr in ipsec_site_connection['peer_cidrs']: - func(router_id, - 'POSTROUTING', - '-s %s -d %s -m policy ' - '--dir out --pol ipsec ' - '-j ACCEPT ' % (local_cidr, peer_cidr), - top=True) - self.iptables_apply(router_id) + LOG.debug("Adding an ipsec policy NAT rule" + "%s <-> %s to router id %s", + peer_cidr, local_cidr, vpnservice['router_id']) + + iptables_manager.ipv4['nat'].add_rule( + 'POSTROUTING', + '-s %s -d %s -m policy ' + '--dir out --pol ipsec ' + '-j ACCEPT ' % (local_cidr, peer_cidr), + top=True, tag='vpnaas') + + LOG.debug("Applying iptables for router id %s", + vpnservice['router_id']) + iptables_manager.apply() @log_helpers.log_method_call - def vpnservice_updated(self, context, **kwargs): + def remove_nat_rules(self, router_id): + """Remove all our iptables rules in namespace. + + :param router_id: router_id + """ + router = self.routers.get(router_id) + if not router: + return + iptables_manager = self.get_router_based_iptables_manager(router) + iptables_manager.ipv4['nat'].clear_rules_by_tag('vpnaas') + iptables_manager.apply() + + @log_helpers.log_method_call + def vpnservice_updated(self, context: context.ContextBase, **kwargs): """Vpnservice updated rpc handler VPN Service Driver will call this method @@ -959,16 +967,18 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): self.sync(context, [router] if router else []) @abc.abstractmethod - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace) -> BaseSwanProcess: pass - def ensure_process(self, process_id, vpnservice=None): + def ensure_process(self, process_id: str, + vpnservice=None) -> BaseSwanProcess: """Ensuring process. If the process doesn't exist, it will create process and store it in self.process """ - process = self.processes.get(process_id) + process = self.processes.get(process_id, None) if not process or not process.namespace: namespace = self.get_namespace(process_id) process = self.create_process( @@ -980,28 +990,28 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): process.update_vpnservice(vpnservice) return process - def create_router(self, router): + def create_router(self, router_info: RouterInfo): """Handling create router event. Agent calls this method, when the process namespace is ready. Note: process_id == router_id == vpnservice_id """ - process_id = router.router_id - self.routers[process_id] = router + process_id = router_info.router_id + self.routers[process_id] = router_info if process_id in self.processes: # In case of vpnservice is created # before router's namespace process = self.processes[process_id] - self._update_nat(process.vpnservice, self.add_nat_rule) + self.ensure_nat_rules(process.vpnservice) # Don't run ipsec process for backup HA router - if router.router['ha'] and router.ha_state == 'backup': + if router_info.router['ha'] and router_info.ha_state == 'backup': return process.enable() def destroy_process(self, process_id): """Destroy process. - Disable the process, remove the nat rule, and remove the process + Disable the process, remove the iptables rules, and remove the process manager for the processes that no longer are running vpn service. """ if process_id in self.processes: @@ -1009,7 +1019,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): process.disable() vpnservice = process.vpnservice if vpnservice: - self._update_nat(vpnservice, self.remove_nat_rule) + self.remove_nat_rules(process_id) del self.processes[process_id] def destroy_router(self, process_id): @@ -1022,7 +1032,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): if process_id in self.routers: del self.routers[process_id] - def get_process_status_cache(self, process): + def get_process_status_cache(self, + process: BaseSwanProcess) -> ty.Dict[str, ty.Any]: if not self.process_status_cache.get(process.id): self.process_status_cache[process.id] = { 'status': None, @@ -1031,7 +1042,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'ipsec_site_connections': {}} return self.process_status_cache[process.id] - def is_status_updated(self, process, previous_status): + def is_status_updated(self, process: BaseSwanProcess, + previous_status: ty.Dict[str, ty.Any]) -> bool: if process.updated_pending_status: return True if process.status != previous_status['status']: @@ -1039,13 +1051,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): if (process.connection_status != previous_status['ipsec_site_connections']): return True + return False def unset_updated_pending_status(self, process): process.updated_pending_status = False for connection_status in process.connection_status.values(): connection_status['updated_pending_status'] = False - def copy_process_status(self, process): + def copy_process_status(self, + process: BaseSwanProcess) -> ty.Dict[str, ty.Any]: return { 'id': process.vpnservice['id'], 'status': process.status, @@ -1053,7 +1067,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'ipsec_site_connections': copy.deepcopy(process.connection_status) } - def update_downed_connections(self, process_id, new_status): + def update_downed_connections(self, process_id: str, new_status): """Update info to be reported, if connections just went down. If there is no longer any information for a connection, because it @@ -1069,13 +1083,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'updated_pending_status': True } - def should_be_reported(self, context, process): + def should_be_reported(self, context: context.ContextBase, + process: BaseSwanProcess) -> bool: if (context.is_admin or process.vpnservice["tenant_id"] == context.tenant_id): return True + return False @log_helpers.log_method_call - def report_status(self, context): + def report_status(self, context: context.ContextBase): status_changed_vpn_services = [] for process_id, process in list(self.processes.items()): # NOTE(mnaser): It's not necessary to check status for processes @@ -1104,11 +1120,11 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): @log_helpers.log_method_call @lockutils.synchronized('vpn-agent', 'neutron-') - def sync(self, context, routers): + def sync(self, context: context.ContextBase, router_information: ty.List[RouterInfo]): """Sync status with server side. :param context: context object for RPC call - :param routers: Router objects which is created in this sync event + :param router_information: RouterInfo objects with updated state There could be many failure cases should be considered including the followings. @@ -1123,7 +1139,12 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): vpnservices = self.agent_rpc.get_vpn_services_on_host( context, self.host) router_ids = [vpnservice['router_id'] for vpnservice in vpnservices] - sync_router_ids = [router['id'] for router in routers] + sync_router_ids = [] + + for ri in router_information: + # Update our router_info with the updated one + self.routers[ri.router_id] = ri + sync_router_ids.append(ri.router_id) self._sync_vpn_processes(vpnservices, sync_router_ids) self._delete_vpn_processes(sync_router_ids, router_ids) @@ -1138,15 +1159,16 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): for vpnservice in vpnservices: if vpnservice['router_id'] not in self.processes or ( vpnservice['router_id'] in sync_router_ids): - process = self.ensure_process(vpnservice['router_id'], - vpnservice=vpnservice) - self._update_nat(vpnservice, self.add_nat_rule) - router = self.routers.get(vpnservice['router_id']) - if not router: + process: ty.Optional[BaseSwanProcess] = self.ensure_process( + vpnservice['router_id'], + vpnservice=vpnservice) + self.ensure_nat_rules(vpnservice) + ri = self.routers.get(vpnservice['router_id']) + if not ri: continue # For HA router, spawn vpn process on master router # and terminate vpn process on backup router - if router.router['ha'] and router.ha_state == 'backup': + if ri.router['ha'] and ri.ha_state == 'backup': process.disable() else: process.update() @@ -1168,7 +1190,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): class OpenSwanDriver(IPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> BaseSwanProcess: return OpenSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py index 90731f7a4..e55700850 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py @@ -12,11 +12,11 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import os -import os.path from neutron.agent.linux import ip_lib - from neutron_vpnaas.services.vpn.device_drivers import ipsec @@ -26,39 +26,38 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): Libreswan needs nssdb initialised before running pluto daemon. """ # pylint: disable=useless-super-delegation - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self._rootwrap_cfg = self._get_rootwrap_config() - super(LibreSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) + super().__init__(conf, process_id, vpnservice, namespace) - def _ipsec_execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def _ipsec_execute(self, cmd: ty.List[str], check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None): """Execute ipsec command on namespace. This execute is wrapped by namespace wrapper. The namespace wrapper will bind /etc and /var/run """ ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace) - mount_paths = {'/etc': '%s/etc' % self.config_dir, - '/var/run': '%s/var/run' % self.config_dir} + mount_paths = {'/etc': f'{self.config_dir}/etc', + '/var/run': f'{self.config_dir}/var/run'} mount_paths_str = ','.join( - "%s:%s" % (source, target) - for source, target in mount_paths.items()) + f"{source}:{target}" for source, target in mount_paths.items()) ns_wrapper = self.get_ns_wrapper() return ip_wrapper.netns.execute( [ns_wrapper, - '--mount_paths=%s' % mount_paths_str, - ('--rootwrap_config=%s' % self._rootwrap_cfg - if self._rootwrap_cfg else ''), - '--cmd=%s,%s' % (self.binary, ','.join(cmd))], + f'--mount_paths={mount_paths_str}', + '--rootwrap_config={0}'.format( + self._rootwrap_cfg if self._rootwrap_cfg else ""), + f'--cmd={self.binary},{",".join(cmd)}'], check_exit_code=check_exit_code, extra_ok_codes=extra_ok_codes) def _ensure_needed_files(self): # addconn reads from /etc/hosts and /etc/resolv.conf. As /etc would be # bind-mounted, create these two empty files in the target directory. - with open('%s/etc/hosts' % self.config_dir, 'a'): + with open(f'{self.config_dir}/etc/hosts', 'a', encoding="utf8"): pass - with open('%s/etc/resolv.conf' % self.config_dir, 'a'): + with open(f'{self.config_dir}/etc/resolv.conf', 'a', encoding="utf8"): pass def ensure_configs(self): @@ -75,17 +74,17 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): if os.path.exists(secrets_file): self._execute(['rm', '-f', secrets_file]) - super(LibreSwanProcess, self).ensure_configs() + super().ensure_configs() # LibreSwan uses the capabilities library to restrict access to # ipsec.secrets to users that have explicit access. Since pluto is # running as root and the file has 0600 perms, we must set the # owner of the file to root. - self._execute(['chown', '--from=%s' % os.getuid(), 'root:root', + self._execute(['chown', f'--from={os.getuid()}', 'root:root', secrets_file]) # Libreswan needs to write logs to this directory. - self._execute(['chown', '--from=%s' % os.getuid(), 'root:root', + self._execute(['chown', f'--from={os.getuid()}', 'root:root', self.log_dir]) self._ensure_needed_files() @@ -131,11 +130,12 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): ['whack', '--name', conn_name, '--asynchronous', '--initiate']) def terminate_connection(self, conn_name): - self._ipsec_execute(['whack', '--name', conn_name, '--terminate']) + self._ipsec_execute(["whack", "--name", conn_name, "--terminate"]) class LibreSwanDriver(ipsec.IPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return LibreSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py index aaaaac0ec..0785775e3 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py @@ -14,9 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + +import abc + import netaddr from neutron.agent.common import utils as agent_common_utils +from neutron.agent.linux import interface from neutron.agent.linux import ip_lib from neutron_lib import constants as lib_constants from neutron_lib import context as nctx @@ -30,7 +36,7 @@ from neutron_vpnaas.services.vpn.device_drivers import strongswan_ipsec PORT_PREFIX_INTERNAL = 'vr' PORT_PREFIX_EXTERNAL = 'vg' -PORT_PREFIXES = { +PORT_PREFIXES: ty.Dict[str, str] = { 'internal': PORT_PREFIX_INTERNAL, 'external': PORT_PREFIX_EXTERNAL, } @@ -38,7 +44,7 @@ PORT_PREFIXES = { LOG = logging.getLogger(__name__) -class DeviceManager(object): +class DeviceManager: """Device Manager for ports in qvpn-xx namespace. It is a veth pair, one side in qvpn and the other side is attached to ovs. @@ -51,16 +57,17 @@ class DeviceManager(object): self.host = host self.plugin = plugin self.context = context - self.driver = agent_common_utils.load_interface_driver(conf) + self.driver: interface.LinuxInterfaceDriver = \ + agent_common_utils.load_interface_driver(conf) - def get_interface_name(self, port, ptype): + def get_interface_name(self, port: ty.Dict[str, str], ptype: str) -> str: suffix = port['id'] return (PORT_PREFIXES[ptype] + suffix)[:self.driver.DEV_NAME_LEN] - def get_namespace_name(self, process_id): + def get_namespace_name(self, process_id: str): return self.OVN_NS_PREFIX + process_id - def get_existing_process_ids(self): + def get_existing_process_ids(self) -> ty.List: """Return the process IDs derived from the existing VPN namespaces.""" return [ns[len(self.OVN_NS_PREFIX):] for ns in ip_lib.list_network_namespaces() @@ -87,7 +94,8 @@ class DeviceManager(object): device.route.delete_route(cidr, via=via, metric=100, proto='static') - def list_routes(self, namespace, via=None): + def list_routes(self, namespace, + via=None) -> ty.List[ty.Dict[str, ty.Any]]: device = ip_lib.IPDevice(None, namespace=namespace) return device.route.list_routes( lib_constants.IP_VERSION_4, proto='static', via=via) @@ -100,24 +108,26 @@ class DeviceManager(object): for r in routes: device.route.delete_route(r['cidr'], via=r['via']) - def _del_port(self, process_id, ptype): + def _del_port(self, process_id: str, ptype: str): namespace = self.get_namespace_name(process_id) prefix = PORT_PREFIXES[ptype] device = ip_lib.IPDevice(None, namespace=namespace) - ports = device.addr.list() + ports: ty.List[ty.Dict[str, ty.Union[str, ty.Any]]] = \ + device.addr.list() for p in ports: if not p['name'].startswith(prefix): continue interface_name = p['name'] self.driver.unplug(interface_name, namespace=namespace) - def del_internal_port(self, process_id): + def del_internal_port(self, process_id: str): self._del_port(process_id, 'internal') - def del_external_port(self, process_id): + def del_external_port(self, process_id: str): self._del_port(process_id, 'external') - def setup_external(self, process_id, network_details): + def setup_external(self, process_id: str, + network_details) -> ty.Optional[str]: network = network_details["external_network"] vpn_port = network_details['gw_port'] ns_name = self.get_namespace_name(process_id) @@ -143,7 +153,7 @@ class DeviceManager(object): subnet_id = fixed_ip['subnet_id'] subnet = self.plugin.get_subnet_info(subnet_id) net = netaddr.IPNetwork(subnet['cidr']) - ip_cidr = '%s/%s' % (fixed_ip['ip_address'], net.prefixlen) + ip_cidr = f'{fixed_ip["ip_address"]}/{net.prefixlen}' ip_cidrs.append(ip_cidr) subnets.append(subnet) self.driver.init_l3(interface_name, ip_cidrs, @@ -152,7 +162,7 @@ class DeviceManager(object): self.set_default_route(ns_name, subnet, interface_name) return interface_name - def setup_internal(self, process_id, network_details): + def setup_internal(self, process_id, network_details) -> ty.Optional[str]: vpn_port = network_details["transit_port"] ns_name = self.get_namespace_name(process_id) interface_name = self.get_interface_name(vpn_port, 'internal') @@ -172,19 +182,19 @@ class DeviceManager(object): ip_cidrs = [] for fixed_ip in vpn_port['fixed_ips']: - ip_cidr = '%s/%s' % (fixed_ip['ip_address'], 28) + ip_cidr = f'{fixed_ip["ip_address"]}/28' ip_cidrs.append(ip_cidr) self.driver.init_l3(interface_name, ip_cidrs, namespace=ns_name) return interface_name -class NamespaceManager(object): +class NamespaceManager: def __init__(self, use_ipv6=False): self.ip_wrapper_root = ip_lib.IPWrapper() self.use_ipv6 = use_ipv6 - def exists(self, name): + def exists(self, name) -> bool: return ip_lib.network_namespace_exists(name) def create(self, name): @@ -231,9 +241,9 @@ class IPsecOvnDriverApi(ipsec.IPsecVpnDriverApi): subnet_id=subnet_id) -class OvnIPsecDriver(ipsec.IPsecDriver): +class OvnIPsecDriver(ipsec.IPsecDriver, metaclass=abc.ABCMeta): - def __init__(self, vpn_service, host): + def __init__(self, vpn_service, host: str): self.nsmgr = NamespaceManager() super().__init__(vpn_service, host) self.agent_rpc = IPsecOvnDriverApi(topics.IPSEC_DRIVER_TOPIC) @@ -242,7 +252,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): get_router_based_iptables_manager = None - def get_namespace(self, router_id): + def get_namespace(self, router_id) -> str: """Get namespace for VPN services of router. :router_id: router_id @@ -250,7 +260,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): """ return self.devmgr.get_namespace_name(router_id) - def _cleanup_namespace(self, router_id): + def _cleanup_namespace(self, router_id: str): ns_name = self.devmgr.get_namespace_name(router_id) if not self.nsmgr.exists(ns_name): return @@ -259,7 +269,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): self.devmgr.del_external_port(router_id) self.nsmgr.delete(ns_name) - def _ensure_namespace(self, router_id, network_details): + def _ensure_namespace(self, router_id: str, network_details) -> str: ns_name = self.get_namespace(router_id) if not self.nsmgr.exists(ns_name): self.nsmgr.create(ns_name) @@ -272,7 +282,12 @@ class OvnIPsecDriver(ipsec.IPsecDriver): return ns_name - def destroy_process(self, process_id): + @abc.abstractmethod + def create_process(self, process_id: str, + vpnservice, namespace) -> ipsec.BaseSwanProcess: + pass + + def destroy_process(self, process_id: str): LOG.info('process %s is destroyed', process_id) namespace = self.devmgr.get_namespace_name(process_id) @@ -316,7 +331,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): new_local_cidrs - old_local_cidrs, gateway_ip) - def _sync_vpn_processes(self, vpnservices, sync_router_ids): + def _sync_vpn_processes(self, vpnservices, sync_router_ids: ty.List[str]): # Ensure the ipsec process is enabled only for # - the vpn services which are not yet in self.processes # - vpn services whose router id is in 'sync_router_ids' @@ -331,7 +346,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): process = self.ensure_process(router_id, vpnservice=vpnservice) process.update() - def _cleanup_stale_vpn_processes(self, vpn_router_ids): + def _cleanup_stale_vpn_processes(self, vpn_router_ids: ty.List[str]): super()._cleanup_stale_vpn_processes(vpn_router_ids) # Look for additional namespaces on this node that we don't know # and that should be deleted @@ -340,17 +355,20 @@ class OvnIPsecDriver(ipsec.IPsecDriver): self.destroy_process(router_id) @lockutils.synchronized('vpn-agent', 'neutron-') - def vpnservice_removed_from_agent(self, context, router_id): + def vpnservice_removed_from_agent(self, context: nctx.Context, + router_id: str): # must run under the same lock as sync() self.destroy_process(router_id) - def vpnservice_added_to_agent(self, context, router_ids): + def vpnservice_added_to_agent(self, context: nctx.Context, + router_ids: ty.List[str]): routers = [{'id': router_id} for router_id in router_ids] self.sync(context, routers) class OvnStrongSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnStrongSwanProcess( self.conf, process_id, @@ -359,7 +377,8 @@ class OvnStrongSwanDriver(OvnIPsecDriver): class OvnOpenSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnOpenSwanProcess( self.conf, process_id, @@ -368,7 +387,8 @@ class OvnOpenSwanDriver(OvnIPsecDriver): class OvnLibreSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnLibreSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py index 708952a1f..17c3dffc6 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py @@ -12,8 +12,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - import os +import typing as ty from oslo_config import cfg from oslo_log import log as logging @@ -75,15 +75,14 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): STATUS_RE = r'([a-f0-9\-]+).* (ROUTED|CONNECTING|INSTALLED)' STATUS_NOT_RUNNING_RE = 'Command:.*ipsec.*status.*Exit code: [1|3] ' - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self.DIALECT_MAP['v1'] = 'ikev1' self.DIALECT_MAP['v2'] = 'ikev2' self.DIALECT_MAP['sha256'] = 'sha256' self._strongswan_piddir = self._get_strongswan_piddir() self._rootwrap_cfg = self._get_rootwrap_config() LOG.debug("strongswan piddir is '%s'", (self._strongswan_piddir)) - super(StrongSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) + super().__init__(conf, process_id, vpnservice, namespace) def _get_strongswan_piddir(self): return utils.execute( @@ -103,7 +102,8 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): return connection_id, status return None, None - def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def _execute(self, cmd: ty.List[str], check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None): """Execute command on namespace. This execute is wrapped by namespace wrapper. @@ -113,11 +113,11 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): ns_wrapper = self.get_ns_wrapper() return ip_wrapper.netns.execute( [ns_wrapper, - '--mount_paths=/etc:%s/etc,%s:%s/var/run' % ( + '--mount_paths=/etc:{0}/etc,{1}:{2}/var/run'.format( self.config_dir, self._strongswan_piddir, self.config_dir), - ('--rootwrap_config=%s' % self._rootwrap_cfg - if self._rootwrap_cfg else ''), - '--cmd=%s' % ','.join(cmd)], + '--rootwrap_config={0}'.format( + self._rootwrap_cfg if self._rootwrap_cfg else ""), + f'--cmd={",".join(cmd)}'], check_exit_code=check_exit_code, extra_ok_codes=extra_ok_codes) diff --git a/neutron_vpnaas/services/vpn/ovn/agent_monitor.py b/neutron_vpnaas/services/vpn/ovn/agent_monitor.py index 6597f29b0..5e5dc925b 100644 --- a/neutron_vpnaas/services/vpn/ovn/agent_monitor.py +++ b/neutron_vpnaas/services/vpn/ovn/agent_monitor.py @@ -14,6 +14,7 @@ from neutron.plugins.ml2.drivers.ovn.agent import neutron_agent from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import ovsdb_monitor +from neutron.services.ovn_l3 import plugin from neutron_lib.plugins import constants as plugin_constants from neutron_lib.plugins import directory @@ -73,9 +74,10 @@ class ChassisVPNAgentWriteEvent(ovsdb_monitor.ChassisAgentEvent): clear_down=True) -class OVNVPNAgentMonitor(object): +class OVNVPNAgentMonitor: def watch_agent_events(self): - l3_plugin = directory.get_plugin(plugin_constants.L3) + l3_plugin: plugin.OVNL3RouterPlugin = \ + directory.get_plugin(plugin_constants.L3) sb_ovn = l3_plugin._sb_ovn if sb_ovn: idl = sb_ovn.ovsdb_connection.idl diff --git a/neutron_vpnaas/services/vpn/ovn_agent.py b/neutron_vpnaas/services/vpn/ovn_agent.py index 1ec331a39..d7de90506 100644 --- a/neutron_vpnaas/services/vpn/ovn_agent.py +++ b/neutron_vpnaas/services/vpn/ovn_agent.py @@ -51,7 +51,7 @@ OVS_OPTS = [ ] -def register_opts(conf): +def register_opts(conf: cfg.ConfigOpts): common_config.register_common_config_options() agent_config.register_interface_driver_opts_helper(conf) agent_config.register_interface_opts(conf) diff --git a/neutron_vpnaas/services/vpn/ovn_plugin.py b/neutron_vpnaas/services/vpn/ovn_plugin.py index 5b2a6ea23..d3f34bafd 100644 --- a/neutron_vpnaas/services/vpn/ovn_plugin.py +++ b/neutron_vpnaas/services/vpn/ovn_plugin.py @@ -13,24 +13,25 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources +from neutron_lib import context from oslo_config import cfg from oslo_utils import importutils from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db -from neutron_vpnaas.db.vpn.vpn_db import VPNPluginDb from neutron_vpnaas.db.vpn import vpn_ext_gw_db from neutron_vpnaas.services.vpn.common import constants from neutron_vpnaas.services.vpn.ovn import agent_monitor -from neutron_vpnaas.services.vpn.plugin import VPNDriverPlugin +from neutron_vpnaas.services.vpn import plugin as vpn_plugin +from neutron_vpnaas.services.vpn.service_drivers import ovn_ipsec -class VPNOVNPlugin(VPNPluginDb, - vpn_ext_gw_db.VPNExtGWPlugin_db, +class VPNOVNPlugin(vpn_ext_gw_db.VPNExtGWPlugin_db, agent_db.AZVPNAgentSchedulerDbMixin, agent_monitor.OVNVPNAgentMonitor): """Implementation of the VPN Service Plugin. @@ -50,13 +51,14 @@ class VPNOVNPlugin(VPNPluginDb, resources.PROCESS, events.AFTER_INIT) - def check_router_in_use(self, context, router_id): + def check_router_in_use(self, context: context.Context, router_id): pass def post_fork_initialize(self, resource, event, trigger, payload=None): self.watch_agent_events() - def vpn_router_agent_binding_changed(self, context, router_id, host): + def vpn_router_agent_binding_changed(self, context: context.Context, + router_id: str, host: str): pass supported_extension_aliases = ["vpnaas", @@ -66,11 +68,15 @@ class VPNOVNPlugin(VPNPluginDb, path_prefix = "/vpn" -class VPNOVNDriverPlugin(VPNOVNPlugin, VPNDriverPlugin): - def vpn_router_agent_binding_changed(self, context, router_id, host): +class VPNOVNDriverPlugin(VPNOVNPlugin, vpn_plugin.VPNDriverPlugin): + def vpn_router_agent_binding_changed(self, context: context.Context, + router_id: str, host: str): super().vpn_router_agent_binding_changed(context, router_id, host) filters = {'router_id': [router_id]} vpnservices = self.get_vpnservices(context, filters=filters) for vpnservice in vpnservices: - driver = self._get_driver_for_vpnservice(context, vpnservice) - driver.update_port_bindings(context, router_id, host) + driver: ty.Optional[ovn_ipsec.BaseOvnIPsecVPNDriver] = \ + self._get_driver_for_vpnservice( # type: ignore + context, vpnservice) + if driver: + driver.update_port_bindings(context, router_id, host) diff --git a/neutron_vpnaas/services/vpn/plugin.py b/neutron_vpnaas/services/vpn/plugin.py index 6fc9acc84..415b6554e 100644 --- a/neutron_vpnaas/services/vpn/plugin.py +++ b/neutron_vpnaas/services/vpn/plugin.py @@ -1,4 +1,3 @@ - # (c) Copyright 2013 Hewlett-Packard Development Company, L.P. # All Rights Reserved. # @@ -14,6 +13,9 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +from neutron.db import flavors_db from neutron.db import servicetype_db as st_db from neutron.services import provider_configuration as pconf from neutron.services import service_base @@ -26,6 +28,7 @@ from oslo_log import log as logging from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.extensions import vpn_flavors +from neutron_vpnaas.services.vpn import service_drivers LOG = logging.getLogger(__name__) @@ -54,8 +57,11 @@ class VPNPlugin(vpn_db.VPNPluginDb): class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): """VpnPlugin which supports VPN Service Drivers.""" #TODO(nati) handle ikepolicy and ipsecpolicy update usecase + drivers: ty.Dict[str, service_drivers.VpnDriver] + default_provider: ty.Optional[str] + def __init__(self): - super(VPNDriverPlugin, self).__init__() + super().__init__() self.service_type_manager = st_db.ServiceTypeManager.get_instance() add_provider_configuration(self.service_type_manager, constants.VPN) # Load the service driver from neutron.conf. @@ -72,14 +78,16 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): vpn_db.subscribe() @property - def _flavors_plugin(self): + def _flavors_plugin(self) -> flavors_db.FlavorsDbMixin: return directory.get_plugin(constants.FLAVORS) def start_rpc_listeners(self): servers = [] for driver_name, driver in self.drivers.items(): - if hasattr(driver, 'start_rpc_listeners'): - servers.extend(driver.start_rpc_listeners()) + start_rpc_listeners: ty.Optional[ty.Callable[..., ty.List]] = \ + getattr(driver, 'start_rpc_listeners', None) + if start_rpc_listeners and callable(start_rpc_listeners): + servers.extend(start_rpc_listeners()) return servers def _check_orphan_vpnservice_associations(self): @@ -124,7 +132,9 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): context, constants.VPN, self.default_provider, vpnservice_id) - def _get_provider_for_flavor(self, context, flavor_id): + def _get_provider_for_flavor( + self, context: ncontext.Context, + flavor_id: ty.Optional[str]) -> ty.Optional[str]: if flavor_id: if self._flavors_plugin is None: raise vpn_flavors.FlavorsPluginNotLoaded() @@ -137,7 +147,7 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): raise flav_exc.FlavorDisabled() providers = self._flavors_plugin.get_flavor_next_provider( context, fl_db['id']) - provider = providers[0].get('provider') + provider: ty.Optional[str] = providers[0].get('provider', None) if provider not in self.drivers: raise vpn_flavors.NoProviderFoundForFlavor(flavor_id=flavor_id) else: @@ -147,84 +157,98 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): LOG.debug("Selected provider %s", provider) return provider - def _get_driver_for_vpnservice(self, context, vpnservice): + def _get_driver_for_vpnservice(self, context: ncontext.Context, + vpnservice) -> \ + ty.Optional[service_drivers.VpnDriver]: stm = self.service_type_manager - provider_names = stm.get_provider_names_by_resource_ids( - context, [vpnservice['id']]) + provider_names: ty.Dict[str, str] = \ + stm.get_provider_names_by_resource_ids(context, [vpnservice['id']]) provider = provider_names.get(vpnservice['id']) - return self.drivers[provider] + return self.drivers.get(provider) if provider else None - def _get_driver_for_ipsec_site_connection(self, context, + def _get_driver_for_ipsec_site_connection(self, context: ncontext.Context, ipsec_site_connection): # Only vpnservice_id is required as the vpnservice should be already # associated with a provider after its creation. vpnservice = {'id': ipsec_site_connection['vpnservice_id']} return self._get_driver_for_vpnservice(context, vpnservice) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, + context: ncontext.Context, + ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]: driver = self._get_driver_for_ipsec_site_connection( context, ipsec_site_connection['ipsec_site_connection']) - driver.validator.validate_ipsec_site_connection( - context, - ipsec_site_connection['ipsec_site_connection']) - ipsec_site_connection = super( - VPNDriverPlugin, self).create_ipsec_site_connection( - context, ipsec_site_connection) - driver.create_ipsec_site_connection(context, ipsec_site_connection) - return ipsec_site_connection + if driver: + driver.validator.validate_ipsec_site_connection( + context, + ipsec_site_connection['ipsec_site_connection']) + ipsec_site_connection = super().create_ipsec_site_connection( + context, ipsec_site_connection) + driver.create_ipsec_site_connection(context, ipsec_site_connection) + return ipsec_site_connection + return None - def delete_ipsec_site_connection(self, context, ipsec_conn_id): + def delete_ipsec_site_connection(self, context: ncontext.Context, + ipsec_site_conn_id: str): ipsec_site_connection = self.get_ipsec_site_connection( - context, ipsec_conn_id) - super(VPNDriverPlugin, self).delete_ipsec_site_connection( - context, ipsec_conn_id) + context, ipsec_site_conn_id) + super().delete_ipsec_site_connection( + context, ipsec_site_conn_id) driver = self._get_driver_for_ipsec_site_connection( context, ipsec_site_connection) - driver.delete_ipsec_site_connection(context, ipsec_site_connection) + if driver: + driver.delete_ipsec_site_connection(context, ipsec_site_connection) def update_ipsec_site_connection( - self, context, - ipsec_conn_id, ipsec_site_connection): + self, context: ncontext.Context, + ipsec_site_conn_id: str, + ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]: old_ipsec_site_connection = self.get_ipsec_site_connection( - context, ipsec_conn_id) + context, ipsec_site_conn_id) driver = self._get_driver_for_ipsec_site_connection( context, old_ipsec_site_connection) - driver.validator.validate_ipsec_site_connection( - context, - ipsec_site_connection['ipsec_site_connection']) - ipsec_site_connection = super( - VPNDriverPlugin, self).update_ipsec_site_connection( + if driver: + driver.validator.validate_ipsec_site_connection( context, - ipsec_conn_id, - ipsec_site_connection) - driver.update_ipsec_site_connection( - context, old_ipsec_site_connection, ipsec_site_connection) - return ipsec_site_connection + ipsec_site_connection['ipsec_site_connection']) + ipsec_site_connection = super().update_ipsec_site_connection( + context, + ipsec_site_conn_id, + ipsec_site_connection) + driver.update_ipsec_site_connection( + context, old_ipsec_site_connection, ipsec_site_connection) + return ipsec_site_connection + return None - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: ncontext.Context, + vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]) -> \ + ty.Optional[ty.Dict[str, ty.Any]]: provider = self._get_provider_for_flavor( context, vpnservice['vpnservice'].get('flavor_id')) - vpnservice = super( - VPNDriverPlugin, self).create_vpnservice(context, vpnservice) - self.service_type_manager.add_resource_association( - context, constants.VPN, provider, vpnservice['id']) - driver = self.drivers[provider] - driver.create_vpnservice(context, vpnservice) - return vpnservice + if provider: + vpnservice = super().create_vpnservice(context, vpnservice) + self.service_type_manager.add_resource_association( + context, constants.VPN, provider, vpnservice['id']) + driver = self.drivers[provider] + driver.create_vpnservice(context, vpnservice) + return vpnservice + return None - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: ncontext.Context, vpnservice_id: str, + vpnservice) -> ty.Dict[str, ty.Any]: old_vpn_service = self.get_vpnservice(context, vpnservice_id) - new_vpn_service = super( - VPNDriverPlugin, self).update_vpnservice(context, vpnservice_id, - vpnservice) + new_vpn_service = super().update_vpnservice(context, vpnservice_id, + vpnservice) driver = self._get_driver_for_vpnservice(context, old_vpn_service) - driver.update_vpnservice(context, old_vpn_service, new_vpn_service) + if driver: + driver.update_vpnservice(context, old_vpn_service, new_vpn_service) return new_vpn_service - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: ncontext.Context, vpnservice_id: str): vpnservice = self._get_vpnservice(context, vpnservice_id) - super(VPNDriverPlugin, self).delete_vpnservice(context, vpnservice_id) + super().delete_vpnservice(context, vpnservice_id) driver = self._get_driver_for_vpnservice(context, vpnservice) - self.service_type_manager.del_resource_associations( - context, [vpnservice_id]) - driver.delete_vpnservice(context, vpnservice) + if driver: + self.service_type_manager.del_resource_associations( + context, [vpnservice_id]) + driver.delete_vpnservice(context, vpnservice) diff --git a/neutron_vpnaas/services/vpn/service_drivers/__init__.py b/neutron_vpnaas/services/vpn/service_drivers/__init__.py index f9c90012e..ec4b991b8 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/__init__.py +++ b/neutron_vpnaas/services/vpn/service_drivers/__init__.py @@ -14,21 +14,25 @@ # under the License. import abc +import typing as ty +from neutron.extensions import l3 +from neutron_lib import context from neutron_lib.plugins import constants from neutron_lib.plugins import directory from neutron_lib import rpc as n_rpc from oslo_log import log as logging import oslo_messaging +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.service_drivers import driver_validator LOG = logging.getLogger(__name__) -class VpnDriver(object, metaclass=abc.ABCMeta): +class VpnDriver(metaclass=abc.ABCMeta): - def __init__(self, service_plugin, validator=None): + def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None): self.service_plugin = service_plugin if validator is None: validator = driver_validator.VpnDriverValidator(self) @@ -36,7 +40,7 @@ class VpnDriver(object, metaclass=abc.ABCMeta): self.name = '' @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(constants.L3) @property @@ -44,43 +48,49 @@ class VpnDriver(object, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.ContextBase, vpnservice): pass @abc.abstractmethod def update_vpnservice( - self, context, old_vpnservice, vpnservice): + self, context: context.ContextBase, old_vpnservice, vpnservice): pass @abc.abstractmethod - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: context.ContextBase, vpnservice): pass @abc.abstractmethod - def create_ipsec_site_connection(self, context, ipsec_site_connection): - pass - - @abc.abstractmethod - def update_ipsec_site_connection(self, context, old_ipsec_site_connection, + def create_ipsec_site_connection(self, context: context.ContextBase, ipsec_site_connection): pass @abc.abstractmethod - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.ContextBase, + old_ipsec_site_connection, + ipsec_site_connection): + pass + + @abc.abstractmethod + def delete_ipsec_site_connection(self, context: context.ContextBase, + ipsec_site_connection): pass -class BaseIPsecVpnAgentApi(object): +class BaseIPsecVpnAgentApi: """Base class for IPSec API to agent.""" - def __init__(self, topic, default_version, driver): + def __init__(self, topic: str, default_version: str, + driver: VpnDriver): self.topic = topic self.driver = driver - target = oslo_messaging.Target(topic=topic, version=default_version) - self.client = n_rpc.get_client(target) + self.target = oslo_messaging.Target(topic=topic, + version=default_version) + self.client = n_rpc.get_client(self.target) - def _agent_notification(self, context, method, router_id, - version=None, **kwargs): + def _agent_notification(self, context: context.ContextBase, method, + router_id: str, + version: ty.Optional[str] = None, **kwargs): """Notify update for the agent. This method will find where is the router, and @@ -103,7 +113,8 @@ class BaseIPsecVpnAgentApi(object): cctxt = self.client.prepare(server=l3_agent.host, version=version) cctxt.cast(context, method, **kwargs) - def vpnservice_updated(self, context, router_id, **kwargs): + def vpnservice_updated(self, context: context.ContextBase, router_id: str, + **kwargs): """Send update event of vpnservices.""" kwargs['router'] = {'id': router_id} self._agent_notification(context, 'vpnservice_updated', router_id, diff --git a/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py index 13ef9910d..a96e3880c 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py @@ -12,17 +12,23 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + import abc import netaddr -import oslo_messaging - +from neutron.db import agents_db from neutron.db.models import l3agent from neutron.db.models import servicetype +from neutron.objects import agent as nagent from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.plugins import directory +import oslo_messaging +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.services.vpn import service_drivers @@ -31,7 +37,7 @@ IPSEC = 'ipsec' BASE_IPSEC_VERSION = '1.0' -class IPsecVpnDriverCallBack(object): +class IPsecVpnDriverCallBack: """Callback for IPSecVpnDriver rpc.""" # history @@ -40,12 +46,13 @@ class IPsecVpnDriverCallBack(object): target = oslo_messaging.Target(version=BASE_IPSEC_VERSION) def __init__(self, driver): - super(IPsecVpnDriverCallBack, self).__init__() + super().__init__() self.driver = driver - def _get_agent_hosting_vpn_services(self, context, host): - plugin = directory.get_plugin() - agent = plugin._get_agent_by_type_and_host( + def _get_agent_hosting_vpn_services(self, context: context.Context, + host: ty.Optional[str]): + plugin: agents_db.AgentDbMixin = directory.get_plugin() + agent: ty.Optional[nagent.Agent] = plugin._get_agent_by_type_and_host( context, lib_constants.AGENT_TYPE_L3, host) agent_conf = plugin.get_configuration_dict(agent) # Retrieve the agent_mode to check if this is the @@ -53,7 +60,9 @@ class IPsecVpnDriverCallBack(object): # case of distributed the vpn service should reside # only on a dvr_snat node. agent_mode = agent_conf.get('agent_mode', 'legacy') - if not agent.admin_state_up or agent_mode == 'dvr': + if (not agent and + not agent.admin_state_up or # type: ignore + agent_mode == 'dvr'): return [] query = context.session.query(vpn_models.VPNService) query = query.join(vpn_models.IPsecSiteConnection) @@ -65,25 +74,27 @@ class IPsecVpnDriverCallBack(object): servicetype.ProviderResourceAssociation.resource_id == vpn_models.VPNService.id) query = query.filter( - l3agent.RouterL3AgentBinding.l3_agent_id == agent.id) + l3agent.RouterL3AgentBinding.l3_agent_id == + agent.id) # type: ignore query = query.filter( servicetype.ProviderResourceAssociation.provider_name == self.driver.name) return query @db_api.CONTEXT_READER - def get_vpn_services_on_host(self, context, host=None): + def get_vpn_services_on_host(self, context: context.Context, + host: ty.Optional[str] = None): """Returns the vpnservices on the host.""" vpnservices = self._get_agent_hosting_vpn_services( context, host) - plugin = self.driver.service_plugin + plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin local_cidr_map = plugin._build_local_subnet_cidr_map(context) return [self.driver.make_vpnservice_dict(vpnservice, local_cidr_map) - for vpnservice in vpnservices] + for vpnservice in vpnservices] - def update_status(self, context, status): + def update_status(self, context: context.Context, status): """Update status of vpnservices.""" - plugin = self.driver.service_plugin + plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin plugin.update_status_by_agent(context, status) @@ -94,57 +105,63 @@ class IPsecVpnAgentApi(service_drivers.BaseIPsecVpnAgentApi): # pylint: disable=useless-super-delegation def __init__(self, topic, default_version, driver): - super(IPsecVpnAgentApi, self).__init__( - topic, default_version, driver) + super().__init__(topic, default_version, driver) class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): """Base VPN Service Driver class.""" - def __init__(self, service_plugin, validator=None): - super(BaseIPsecVPNDriver, self).__init__(service_plugin, validator) + agent_rpc: service_drivers.BaseIPsecVpnAgentApi + + def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None): + super().__init__(service_plugin, validator) self.create_rpc_conn() @property - def service_type(self): + def service_type(self) -> str: return IPSEC @abc.abstractmethod - def create_rpc_conn(self): + def create_rpc_conn(self) -> None: pass - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def update_ipsec_site_connection( - self, context, old_ipsec_site_connection, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.Context, + old_ipsec_site_connection, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def delete_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, ikepolicy): pass - def delete_ikepolicy(self, context, ikepolicy): + def delete_ikepolicy(self, context: context.Context, ikepolicy): pass - def update_ikepolicy(self, context, old_ikepolicy, ikepolicy): + def update_ikepolicy(self, context: context.Context, + old_ikepolicy, ikepolicy): pass - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass - def delete_ipsecpolicy(self, context, ipsecpolicy): + def delete_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass - def update_ipsecpolicy(self, context, old_ipsec_policy, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, + old_ipsec_policy, ipsecpolicy): pass def _get_gateway_ips(self, router): @@ -164,7 +181,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): return v4_ip, v6_ip @db_api.CONTEXT_WRITER - def create_vpnservice(self, context, vpnservice_dict): + def create_vpnservice(self, context: context.Context, vpnservice_dict): """Get the gateway IP(s) and save for later use. For the reference implementation, this side's tunnel IP (external_ip) @@ -181,10 +198,11 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): vpnservice_dict['id'], v4_ip=v4_ip, v6_ip=v6_ip) - def update_vpnservice(self, context, old_vpnservice, vpnservice): + def update_vpnservice(self, context: context.Context, + old_vpnservice, vpnservice): self.agent_rpc.vpnservice_updated(context, vpnservice['router_id']) - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: context.Context, vpnservice): self.agent_rpc.vpnservice_updated(context, vpnservice['router_id']) def get_external_ip_based_on_peer(self, vpnservice, ipsec_site_con): @@ -204,7 +222,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): also converting parameter name for vpn agent driver """ - vpnservice_dict = dict(vpnservice) + vpnservice_dict: ty.Dict[str, ty.Any] = dict(vpnservice) # Populate tenant_id for RPC compat vpnservice_dict['tenant_id'] = vpnservice_dict['project_id'] vpnservice_dict['ipsec_site_connections'] = [] diff --git a/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py b/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py index 0486d197f..f87636669 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py +++ b/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py @@ -12,18 +12,25 @@ # License for the specific language governing permissions and limitations # under the License. # +import typing as ty + +from neutron.extensions import l3 +from neutron_lib import context +if ty.TYPE_CHECKING: + from neutron_vpnaas.services.vpn import service_drivers -class VpnDriverValidator(object): +class VpnDriverValidator: """Driver-specific validation routines for VPN resources.""" - def __init__(self, driver): + def __init__(self, driver: 'service_drivers.VpnDriver'): self.driver = driver @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return self.driver.l3_plugin - def validate_ipsec_site_connection(self, context, ipsec_sitecon): + def validate_ipsec_site_connection(self, context: context.ContextBase, + ipsec_sitecon): """Driver can override this for its additional validations.""" pass diff --git a/neutron_vpnaas/services/vpn/service_drivers/ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/ipsec.py index 4e98cb52c..d26a0e8b8 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ipsec.py @@ -15,6 +15,7 @@ from neutron_lib import rpc as n_rpc +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.common import topics from neutron_vpnaas.services.vpn.service_drivers import base_ipsec from neutron_vpnaas.services.vpn.service_drivers import ipsec_validator @@ -27,10 +28,9 @@ BASE_IPSEC_VERSION = '1.0' class IPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): """VPN Service Driver class for IPsec.""" - def __init__(self, service_plugin): - super(IPsecVPNDriver, self).__init__( - service_plugin, - ipsec_validator.IpsecVpnValidator(self)) + def __init__(self, service_plugin: vpn_db.VPNPluginDb): + super().__init__(service_plugin, + ipsec_validator.IpsecVpnValidator(self)) def create_rpc_conn(self): self.endpoints = [base_ipsec.IPsecVpnDriverCallBack(self)] diff --git a/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py b/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py index 49fc45d5b..5c8bb9a37 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py @@ -12,6 +12,9 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_vpnaas._i18n import _ @@ -29,7 +32,8 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator): and Libreswan. """ - def _check_transform_protocol(self, context, transform_protocol): + def _check_transform_protocol(self, context: context.ContextBase, + transform_protocol: ty.Optional[str]): """Restrict selecting ah-esp as IPSec Policy transform protocol. For those *Swan implementations, the 'ah-esp' transform protocol @@ -41,12 +45,15 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator): key='transform_protocol', value=transform_protocol) - def validate_ipsec_policy(self, context, ipsec_policy): - transform_protocol = ipsec_policy.get('transform_protocol') + def validate_ipsec_policy(self, context: context.ContextBase, + ipsec_policy: ty.Dict[str, ty.Union[ty.Any, str]]): + transform_protocol: ty.Optional[str] = \ + ipsec_policy.get('transform_protocol', None) self._check_transform_protocol(context, transform_protocol) - def validate_ipsec_site_connection(self, context, ipsec_sitecon): - if 'ipsecpolicy_id' in ipsec_sitecon: + def validate_ipsec_site_connection(self, context: context.ContextBase, + ipsec_sitecon): + if "ipsecpolicy_id" in ipsec_sitecon: ipsec_policy = self.driver.service_plugin.get_ipsecpolicy( context, ipsec_sitecon['ipsecpolicy_id']) self.validate_ipsec_policy(context, ipsec_policy) diff --git a/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py index 47760f064..abce7ce1e 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py @@ -14,8 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +import abc + import netaddr +from neutron.db import extraroute_db +from neutron.plugins.ml2 import plugin from neutron_lib.api.definitions import portbindings from neutron_lib.callbacks import events from neutron_lib.callbacks import registry @@ -33,14 +39,22 @@ from oslo_config import cfg from oslo_db import exception as o_exc from oslo_log import log as logging + from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db -from neutron_vpnaas.db.vpn.vpn_ext_gw_db import RouterIsNotVPNExternal +from neutron_vpnaas.db.vpn import vpn_ext_gw_db as ext_gw from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.extensions import vpnaas from neutron_vpnaas.services.vpn.common import constants as v_constants from neutron_vpnaas.services.vpn.common import topics +from neutron_vpnaas.services.vpn import ovn_plugin from neutron_vpnaas.services.vpn.service_drivers import base_ipsec +#pylint: disable=ungrouped-imports +# Additional import for typechecking. Importing these without typechecking +# would resolve in a cyclic dependency +if ty.TYPE_CHECKING: + from neutron.db import db_base_plugin_v2 as db_plugin +#pylint: enable=ungrouped-imports LOG = logging.getLogger(__name__) @@ -63,17 +77,18 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): self.admin_ctx = nctx.get_admin_context() @property - def core_plugin(self): + def core_plugin(self) -> 'db_plugin.NeutronDbPluginV2': return self.driver.core_plugin @property - def service_plugin(self): + def service_plugin(self) -> ext_gw.VPNExtGWPlugin_db: return self.driver.service_plugin - def _get_vpn_gateway(self, context, router_id): + def _get_vpn_gateway(self, context: nctx.ContextBase, router_id: str): return self.service_plugin.get_vpn_gw_by_router_id(context, router_id) - def get_vpn_transit_network_details(self, context, router_id): + def get_vpn_transit_network_details(self, context: nctx.ContextBase, + router_id: str): vpn_gw = self._get_vpn_gateway(context, router_id) network_id = vpn_gw.gw_port['network_id'] external_network = self.core_plugin.get_network(context, network_id) @@ -86,13 +101,14 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): } return details - def get_subnet_info(self, context, subnet_id=None): + def get_subnet_info(self, context: nctx.ContextBase, + subnet_id: ty.Optional[str] = None): try: return self.core_plugin.get_subnet(context, subnet_id) except n_exc.SubnetNotFound: return None - def _get_agent_hosting_vpn_services(self, context, host): + def _get_agent_hosting_vpn_services(self, context: nctx.Context, host): agent = self.service_plugin.get_vpn_agent_on_host(context, host) if not agent: return [] @@ -114,20 +130,23 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): @registry.has_registry_receivers -class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): - def __init__(self, service_plugin): - self._l3_plugin = None - self._core_plugin = None +class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver, + metaclass=abc.ABCMeta): + def __init__(self, service_plugin: ovn_plugin.VPNOVNPlugin): + self._l3_plugin: \ + ty.Optional[extraroute_db.ExtraRoute_dbonly_mixin] = None + self._core_plugin: ty.Optional[plugin.Ml2Plugin] = None + self.service_plugin = service_plugin super().__init__(service_plugin) @property - def l3_plugin(self): + def l3_plugin(self) -> extraroute_db.ExtraRoute_dbonly_mixin: if self._l3_plugin is None: self._l3_plugin = directory.get_plugin(plugin_constants.L3) return self._l3_plugin @property - def core_plugin(self): + def core_plugin(self) -> plugin.Ml2Plugin: if self._core_plugin is None: self._core_plugin = directory.get_plugin() return self._core_plugin @@ -155,20 +174,25 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): raise vpnaas.RouteInUseByVPN( destinations=", ".join(conflict_cidrs)) - def get_vpn_gw_port_name(self, router_id): + @abc.abstractmethod + def create_rpc_conn(self) -> None: + pass + + def get_vpn_gw_port_name(self, router_id: str) -> str: return VPN_GW_PORT_PREFIX + router_id - def get_vpn_namespace_port_name(self, router_id): + def get_vpn_namespace_port_name(self, router_id: str) -> str: return TRANSIT_PORT_PREFIX + router_id - def get_transit_network_name(self, router_id): + def get_transit_network_name(self, router_id: str) -> str: return TRANSIT_NETWORK_PREFIX + router_id - def get_transit_subnet_name(self, router_id): + def get_transit_subnet_name(self, router_id: str) -> str: return TRANSIT_SUBNET_PREFIX + router_id - def make_transit_network(self, router_id, tenant_id, agent_host, - gateway_update): + def make_transit_network(self, router_id: str, tenant_id: str, + agent_host: str, + gateway_update: ty.Dict[str, ty.Any]): context = nctx.get_admin_context() network_data = { 'tenant_id': HIDDEN_PROJECT_ID, @@ -213,13 +237,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): {"port": port_data}) gateway_update['transit_port_id'] = port['id'] - def _del_port(self, context, port_id): + def _del_port(self, context: nctx.ContextBase, port_id: str): try: self.core_plugin.delete_port(context, port_id, l3_port_check=False) except n_exc.PortNotFound: pass - def _remove_router_interface(self, context, router_id, subnet_id): + def _remove_router_interface(self, context: nctx.ContextBase, + router_id: str, subnet_id: str): try: self.l3_plugin.remove_router_interface( context, router_id, {'subnet_id': subnet_id}) @@ -227,27 +252,27 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): n_exc.SubnetNotFound): pass - def _del_subnet(self, context, subnet_id): + def _del_subnet(self, context: nctx.ContextBase, subnet_id: str): try: self.core_plugin.delete_subnet(context, subnet_id) except n_exc.SubnetNotFound: pass - def _del_network(self, context, network_id): + def _del_network(self, context: nctx.ContextBase, network_id: str): try: self.core_plugin.delete_network(context, network_id) except n_exc.NetworkNotFound: pass - def del_transit_network(self, gw): + def del_transit_network(self, gw: ext_gw.VPNExtGW): context = nctx.get_admin_context() - router_id = gw['router_id'] + router_id: str = gw['router_id'] - port_id = gw.get('transit_port_id') + port_id: ty.Optional[str] = gw.get('transit_port_id') if port_id: self._del_port(context, port_id) - subnet_id = gw.get('transit_subnet_id') + subnet_id: ty.Optional[str] = gw.get('transit_subnet_id') if subnet_id: self._remove_router_interface(context, router_id, subnet_id) self._del_subnet(context, subnet_id) @@ -256,7 +281,8 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): if network_id: self._del_network(context, network_id) - def make_gw_port(self, router_id, network_id, agent_host, gateway_update): + def make_gw_port(self, router_id: str, network_id: str, + agent_host: str, gateway_update: ty.Dict[str, ty.Any]): context = nctx.get_admin_context() port_data = {'tenant_id': HIDDEN_PROJECT_ID, 'network_id': network_id, @@ -273,7 +299,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): LOG.debug('No IPs available for external network %s', network_id) gateway_update['gw_port_id'] = gw_port['id'] - def del_gw_port(self, gateway): + def del_gw_port(self, gateway: ext_gw.VPNExtGW): context = nctx.get_admin_context() port_id = gateway.get('gw_port_id') if port_id: @@ -290,24 +316,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): cidrs.append(ep.endpoint) return cidrs - def _routes_update(self, cidrs, nexthop): + def _routes_update(self, cidrs: ty.Set, nexthop): routes = [{'destination': cidr, 'nexthop': nexthop} for cidr in cidrs] return {'router': {'routes': routes}} - def _update_static_routes(self, context, ipsec_site_connection): + def _update_static_routes(self, context: nctx.ContextBase, + ipsec_site_connection): vpnservice = self.service_plugin.get_vpnservice( context, ipsec_site_connection['vpnservice_id']) router_id = vpnservice['router_id'] - gw = self.service_plugin.get_vpn_gw_by_router_id(context, router_id) + gw: ext_gw.VPNExtGW = self.service_plugin.get_vpn_gw_by_router_id( + context, router_id) nexthop = gw.transit_port['fixed_ips'][0]['ip_address'] router = self.l3_plugin.get_router(context, router_id) old_routes = router.get('routes', []) - old_cidrs = set([r['destination'] for r in old_routes - if r['nexthop'] == nexthop]) + old_cidrs = {r['destination'] for r in old_routes + if r['nexthop'] == nexthop} new_cidrs = set( self.service_plugin.get_peer_cidrs_for_router(context, router_id)) @@ -330,7 +358,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): nctx.get_admin_context(), router['id']) if gateway is None or gateway['external_fixed_ips'] is None: - raise RouterIsNotVPNExternal(router_id=router['id']) + raise ext_gw.RouterIsNotVPNExternal(router_id=router['id']) v4_ip = v6_ip = None for fixed_ip in gateway['external_fixed_ips']: @@ -343,12 +371,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): v6_ip = addr return v4_ip, v6_ip - def _update_gateway(self, context, gateway_id, **kwargs): + def _update_gateway(self, context: nctx.Context, + gateway_id: str, **kwargs): gateway = {'gateway': kwargs} return self.service_plugin.update_gateway(context, gateway_id, gateway) @db_api.retry_if_session_inactive() - def _ensure_gateway(self, context, vpnservice): + def _ensure_gateway(self, context: nctx.Context, vpnservice) -> \ + ty.Dict[str, ty.Any]: gw = self.service_plugin.get_vpn_gw_dict_by_router_id( context, vpnservice['router_id'], refresh=True) if not gw: @@ -371,23 +401,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): return gw @db_api.CONTEXT_WRITER - def _setup(self, context, vpnservice_dict): + def _setup(self, context: nctx.Context, + vpnservice_dict: ty.Dict[str, ty.Any]): router_id = vpnservice_dict['router_id'] agent = self.service_plugin.schedule_router(context, router_id) if not agent: raise vpnaas.NoVPNAgentAvailable agent_host = agent['host'] - gateway = self._ensure_gateway(context, vpnservice_dict) + gateway: ty.Optional[ty.Dict[str, ty.Any]] = self._ensure_gateway( + context, vpnservice_dict) # If the gateway status is ACTIVE the ports have been created already - if gateway['status'] == lib_constants.ACTIVE: + if gateway and gateway['status'] == lib_constants.ACTIVE: return - vpnservice = self.service_plugin._get_vpnservice(context, - vpnservice_dict['id']) + vpnservice = self.service_plugin._get_vpnservice( + context, vpnservice_dict["id"]) network_id = vpnservice.router.gw_port.network_id - gateway_update = {} # keeps track of already-created IDs + # keeps track of already-created IDs + gateway_update: ty.Dict[str, ty.Any] = {} try: self.make_gw_port(router_id, network_id, agent_host, gateway_update) @@ -396,16 +429,16 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): agent_host, gateway_update) except Exception: - self._update_gateway(context, gateway['id'], + self._update_gateway(context, gateway['id'], # type: ignore status=lib_constants.ERROR, **gateway_update) raise - self._update_gateway(context, gateway['id'], + self._update_gateway(context, gateway['id'], # type: ignore status=lib_constants.ACTIVE, **gateway_update) - def _cleanup(self, context, router_id): + def _cleanup(self, context: nctx.Context, router_id: str): gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context, router_id) if not gw: @@ -423,7 +456,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): status=lib_constants.ERROR) raise - def create_vpnservice(self, context, vpnservice_dict): + def create_vpnservice(self, context: nctx.Context, vpnservice_dict): try: self._setup(context, vpnservice_dict) except Exception: @@ -435,7 +468,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): raise super().create_vpnservice(context, vpnservice_dict) - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: nctx.Context, vpnservice): router_id = vpnservice['router_id'] super().delete_vpnservice(context, vpnservice) services = self.service_plugin.get_vpnservices(context) @@ -443,11 +476,13 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): if router_id not in router_ids: self._cleanup(context, router_id) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: nctx.Context, + ipsec_site_connection): self._update_static_routes(context, ipsec_site_connection) super().create_ipsec_site_connection(context, ipsec_site_connection) - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def delete_ipsec_site_connection(self, context: nctx.Context, + ipsec_site_connection): self._update_static_routes(context, ipsec_site_connection) super().delete_ipsec_site_connection(context, ipsec_site_connection) @@ -457,11 +492,11 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): super().update_ipsec_site_connection( context, old_ipsec_site_connection, ipsec_site_connection) - def _update_port_binding(self, context, port_id, host): + def _update_port_binding(self, context: nctx.Context, port_id, host): port_data = {'binding:host_id': host} self.core_plugin.update_port(context, port_id, {'port': port_data}) - def update_port_bindings(self, context, router_id, host): + def update_port_bindings(self, context: nctx.Context, router_id, host): gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context, router_id) if not gw: @@ -475,7 +510,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): class IPsecOvnVpnAgentApi(base_ipsec.IPsecVpnAgentApi): - def _agent_notification(self, context, method, router_id, + def _agent_notification(self, context: nctx.Context, method, router_id, version=None, **kwargs): """Notify update for the agent. @@ -508,7 +543,7 @@ class IPsecOvnVPNDriver(BaseOvnIPsecVPNDriver): self.agent_rpc = IPsecOvnVpnAgentApi( topics.IPSEC_AGENT_TOPIC, BASE_IPSEC_VERSION, self) - def start_rpc_listeners(self): + def start_rpc_listeners(self) -> ty.List: self.endpoints = [IPsecVpnOvnDriverCallBack(self)] self.conn = n_rpc.Connection() self.conn.create_consumer( diff --git a/neutron_vpnaas/services/vpn/vpn_service.py b/neutron_vpnaas/services/vpn/vpn_service.py index e2b463d66..9d3fb4e1c 100644 --- a/neutron_vpnaas/services/vpn/vpn_service.py +++ b/neutron_vpnaas/services/vpn/vpn_service.py @@ -13,25 +13,30 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + from neutron.services import provider_configuration as provconfig from neutron_lib.exceptions import vpn as vpn_exception +from oslo_config import cfg from oslo_log import log as logging from oslo_utils import importutils +from neutron_vpnaas.services.vpn.device_drivers import ipsec + LOG = logging.getLogger(__name__) DEVICE_DRIVERS = 'device_drivers' -class VPNService(object): +class VPNService: """VPN Service observer.""" def __init__(self, l3_agent): - self.conf = l3_agent.conf + self.conf: cfg.ConfigOpts = l3_agent.conf - def load_device_drivers(self, host): + def load_device_drivers(self, host) -> ty.List[ipsec.IPsecDriver]: """Loads one or more device drivers for VPNaaS.""" - drivers = [] + drivers: ty.List[ipsec.IPsecDriver] = [] for device_driver in self.conf.vpnagent.vpn_device_driver: device_driver = provconfig.get_provider_driver_class( device_driver, DEVICE_DRIVERS) diff --git a/neutron_vpnaas/tests/functional/common/test_scenario.py b/neutron_vpnaas/tests/functional/common/test_scenario.py index dd2f2f14b..94074a750 100644 --- a/neutron_vpnaas/tests/functional/common/test_scenario.py +++ b/neutron_vpnaas/tests/functional/common/test_scenario.py @@ -178,7 +178,7 @@ def get_ovs_bridge(br_name): Vm = collections.namedtuple('Vm', ['namespace', 'port_ip']) -class SiteInfo(object): +class SiteInfo: """Holds info on the router, ports, service, and connection.""" @@ -523,8 +523,8 @@ class TestIPSecBase(framework.L3AgentTestFramework): local_router_id = site1.router.router_id peer_router_id = site2.router.router_id - self.driver.sync(mock.Mock(), [{'id': local_router_id}, - {'id': peer_router_id}]) + self.driver.sync(mock.Mock(), [site1.router, + site2.router]) self.agent._process_updated_router(site1.router.router) self.agent._process_updated_router(site2.router.router) self.addCleanup(self.driver._delete_vpn_processes, @@ -534,8 +534,7 @@ class TestIPSecBase(framework.L3AgentTestFramework): """Perform a sync on failover agent associated w/backup router.""" self.failover_driver.agent_rpc.get_vpn_services_on_host = mock.Mock( return_value=[site.vpn_service]) - self.failover_driver.sync(mock.Mock(), - [{'id': site.backup_router.router_id}]) + self.failover_driver.sync(mock.Mock(), [site.router]) def check_ping(self, from_site, to_site, instance=0, success=True): if success: diff --git a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py index cac246d19..7bf885cdc 100644 --- a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py +++ b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py @@ -44,7 +44,7 @@ VPN_HOSTA = "host-1" VPN_HOSTB = "host-2" -class VPNAgentSchedulerTestMixIn(object): +class VPNAgentSchedulerTestMixIn: def _request_list(self, path, admin_context=True, expected_code=exc.HTTPOk.code): req = self._path_req(path, admin_context=admin_context) diff --git a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py index 54860ade1..e401012b5 100644 --- a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py +++ b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py @@ -69,7 +69,7 @@ class TestVpnCorePlugin(test_l3_plugin.TestL3NatIntPlugin, self.router_scheduler = l3_agent_scheduler.ChanceScheduler() -class VPNTestMixin(object): +class VPNTestMixin: resource_prefix_map = dict( (k.replace('_', '-'), "/vpn") @@ -1718,7 +1718,7 @@ class TestVpnaas(VPNPluginDbTestCase): # tests. # TODO(pcm): Put helpers in another module for sharing -class NeutronResourcesMixin(object): +class NeutronResourcesMixin: def create_network(self, overrides=None): """Create database entry for network.""" diff --git a/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ipsec.py b/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ipsec.py index eab6ff2b8..1c7106b2b 100644 --- a/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ipsec.py +++ b/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ipsec.py @@ -35,12 +35,16 @@ from neutron_vpnaas.services.vpn.device_drivers import libreswan_ipsec from neutron_vpnaas.services.vpn.device_drivers import strongswan_ipsec from neutron_vpnaas.tests import base +# Note: process_id == router_id == vpnservice_id + _uuid = uuidutils.generate_uuid +FAKE_UUID = _uuid() FAKE_HOST = 'fake_host' -FAKE_ROUTER_ID = _uuid() +FAKE_ROUTER_ID = FAKE_UUID +FAKE_VPNSERVICE_ID = FAKE_UUID +FAKE_PROCESS_ID = FAKE_UUID FAKE_IPSEC_SITE_CONNECTION1_ID = _uuid() FAKE_IPSEC_SITE_CONNECTION2_ID = _uuid() -FAKE_VPNSERVICE_ID = _uuid() FAKE_IKE_POLICY = { 'ike_version': 'v1', 'encryption_algorithm': 'aes-128', @@ -431,13 +435,13 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): self._make_router_info_for_test() def _make_router_info_for_test(self): - self.router = legacy_router.LegacyRouter(router_id=FAKE_ROUTER_ID, + self.router_info = legacy_router.LegacyRouter(router_id=FAKE_ROUTER_ID, agent=self.agent, **self.ri_kwargs) - self.router.router['distributed'] = False - self.router.iptables_manager.ipv4['nat'] = self.iptables - self.router.iptables_manager.apply = self.apply_mock - self.driver.routers[FAKE_ROUTER_ID] = self.router + self.router_info.router['distributed'] = False + self.router_info.iptables_manager.ipv4['nat'] = self.iptables + self.router_info.iptables_manager.apply = self.apply_mock + self.driver.routers[FAKE_ROUTER_ID] = self.router_info def _test_vpnservice_updated(self, expected_param, **kwargs): with mock.patch.object(self.driver, 'sync') as sync: @@ -449,17 +453,16 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): self._test_vpnservice_updated([]) def test_vpnservice_updated_with_router_info(self): - router_info = {'id': FAKE_ROUTER_ID, 'ha': False} - kwargs = {'router': router_info} - self._test_vpnservice_updated([router_info], **kwargs) + kwargs = {'router': self.router_info} + self._test_vpnservice_updated([self.router_info], **kwargs) def test_create_router(self): process = mock.Mock(openswan_ipsec.OpenSwanProcess) process.vpnservice = self.vpnservice self.driver.processes = { FAKE_ROUTER_ID: process} - self.driver.create_router(self.router) - self._test_add_nat_rule() + self.driver.create_router(self.router_info) + self._test_ensure_nat_rules() process.enable.assert_called_once_with() def test_destroy_router(self): @@ -472,75 +475,89 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): process.disable.assert_called_once_with() self.assertNotIn(process_id, self.driver.processes) - def _test_add_nat_rule(self): - self.router.iptables_manager.ipv4['nat'].assert_has_calls([ + def _test_ensure_nat_rules(self): + self.router_info.iptables_manager.ipv4['nat'].assert_has_calls([ + mock.call.clear_rules_by_tag('vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 10.0.0.0/24 -d 20.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 10.0.0.0/24 -d 30.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 11.0.0.0/24 -d 40.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 11.0.0.0/24 -d 50.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True) + top=True, + tag='vpnaas') ]) - self.router.iptables_manager.apply.assert_called_once_with() + self.router_info.iptables_manager.apply.assert_called_once_with() - def _test_add_nat_rule_with_multiple_locals(self): - self.router.iptables_manager.ipv4['nat'].assert_has_calls([ + def _test_ensure_nat_rules_with_multiple_locals(self): + self.router_info.iptables_manager.ipv4['nat'].assert_has_calls([ + mock.call.clear_rules_by_tag('vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 10.0.0.0/24 -d 20.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 10.0.0.0/24 -d 30.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 11.0.0.0/24 -d 20.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 11.0.0.0/24 -d 30.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 12.0.0.0/24 -d 40.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 12.0.0.0/24 -d 50.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 13.0.0.0/24 -d 40.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True), + top=True, + tag='vpnaas'), mock.call.add_rule( 'POSTROUTING', '-s 13.0.0.0/24 -d 50.0.0.0/24 -m policy ' '--dir out --pol ipsec -j ACCEPT ', - top=True) + top=True, + tag='vpnaas') ]) - self.router.iptables_manager.apply.assert_called_once_with() + self.router_info.iptables_manager.apply.assert_called_once_with() def test_sync(self): fake_vpn_service = FAKE_VPN_SERVICE @@ -550,9 +567,8 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): self.driver._sync_vpn_processes = mock.Mock() self.driver._delete_vpn_processes = mock.Mock() self.driver._cleanup_stale_vpn_processes = mock.Mock() - sync_routers = [{'id': fake_vpn_service['router_id']}] sync_router_ids = [fake_vpn_service['router_id']] - self.driver.sync(context, sync_routers) + self.driver.sync(context, [self.router_info]) self.driver._sync_vpn_processes.assert_called_once_with( [fake_vpn_service], sync_router_ids) self.driver._delete_vpn_processes.assert_called_once_with( @@ -567,16 +583,16 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): with mock.patch.object(self.driver, 'ensure_process') as ensure_p: ensure_p.side_effect = self.fake_ensure_process self.driver._sync_vpn_processes([new_vpnservice], router_id) - self._test_add_nat_rule() + self._test_ensure_nat_rules() self.driver.processes[router_id].update.assert_called_once_with() - def test_add_nat_rules_with_multiple_local_subnets(self): + def test_ensure_nat_rules_with_multiple_local_subnets(self): """Ensure that add nat rule combinations are correct.""" overrides = {'local_cidrs': [['10.0.0.0/24', '11.0.0.0/24'], ['12.0.0.0/24', '13.0.0.0/24']]} self.modify_config_for_test(overrides) - self.driver._update_nat(self.vpnservice, self.driver.add_nat_rule) - self._test_add_nat_rule_with_multiple_locals() + self.driver.ensure_nat_rules(self.vpnservice) + self._test_ensure_nat_rules_with_multiple_locals() def test__sync_vpn_processes_router_with_no_vpn(self): """Test _sync_vpn_processes with a router not hosting vpnservice. @@ -613,14 +629,18 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): is updated, _sync_vpn_processes restart/update the existing vpnservices which are not yet stored in driver.processes. """ - router_id = FAKE_ROUTER_ID self.driver.process_status_cache = {} self.driver.processes = {} with mock.patch.object(self.driver, 'ensure_process') as ensure_p: ensure_p.side_effect = self.fake_ensure_process - self.driver._sync_vpn_processes([self.vpnservice], [router_id]) - self._test_add_nat_rule() - self.driver.processes[router_id].update.assert_called_once_with() + self.driver._sync_vpn_processes( + [self.vpnservice], + [FAKE_ROUTER_ID] + ) + self._test_ensure_nat_rules() + self.driver.processes[ + FAKE_ROUTER_ID + ].update.assert_called_once_with() def test_delete_vpn_processes(self): router_id_no_vpn = _uuid() @@ -671,6 +691,8 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): if process: del self.driver.processes[process_id] + # TODO(crohmann): Add test cases for HARouter and different ha_states + # @ddt [(False, None),(True, 'primary'), (True, 'standby')] def test_sync_update_vpnservice(self): with mock.patch.object(self.driver, 'ensure_process') as ensure_process: @@ -683,12 +705,12 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): self.driver.process_status_cache = {} self.driver.agent_rpc.get_vpn_services_on_host.return_value = [ new_vpn_service] - self.driver.sync(context, [{'id': FAKE_ROUTER_ID}]) + self.driver.sync(context, [self.router_info]) process = self.driver.processes[FAKE_ROUTER_ID] self.assertEqual(new_vpn_service, process.vpnservice) self.driver.agent_rpc.get_vpn_services_on_host.return_value = [ updated_vpn_service] - self.driver.sync(context, [{'id': FAKE_ROUTER_ID}]) + self.driver.sync(context, [self.router_info]) process = self.driver.processes[FAKE_ROUTER_ID] process.update_vpnservice.assert_called_once_with( updated_vpn_service) @@ -710,7 +732,10 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): self.driver.agent_rpc.get_vpn_services_on_host.return_value = [] context = mock.Mock() process_id = _uuid() - self.driver.sync(context, [{'id': process_id}]) + ri = self.router_info + ri.router_id = process_id + ri.router['id'] = process_id + self.driver.sync(context, [self.router_info]) self.assertNotIn(process_id, self.driver.processes) def test_status_updated_on_connection_admin_down(self): @@ -781,7 +806,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_status_handling_for_downed_connection(self, down_status): """Test status handling for downed connection.""" - router_id = self.router.router_id + router_id = self.router_info.router_id connection_id = FAKE_IPSEC_SITE_CONNECTION2_ID self.driver.ensure_process(router_id, self.vpnservice) self._execute.return_value = down_status @@ -794,7 +819,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_status_handling_for_active_connection(self, active_status): """Test status handling for active connection.""" - router_id = self.router.router_id + router_id = self.router_info.router_id connection_id = FAKE_IPSEC_SITE_CONNECTION2_ID self.driver.ensure_process(router_id, self.vpnservice) self._execute.return_value = active_status @@ -809,7 +834,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_status_handling_for_ike_v2_active_connection(self, active_status): """Test status handling for active connection.""" - router_id = self.router.router_id + router_id = self.router_info.router_id connection_id = FAKE_IPSEC_SITE_CONNECTION2_ID ike_policy = {'ike_version': 'v2', 'encryption_algorithm': 'aes-128', @@ -832,7 +857,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_connection_names_handling_for_multiple_subnets(self, active_status): """Test connection names handling for multiple subnets.""" - router_id = self.router.router_id + router_id = self.router_info.router_id process = self.driver.ensure_process(router_id, self.vpnservice) self._execute.return_value = active_status names = process.get_established_connections() @@ -841,7 +866,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_status_handling_for_deleted_connection(self, not_running_status): """Test status handling for deleted connection.""" - router_id = self.router.router_id + router_id = self.router_info.router_id self.driver.ensure_process(router_id, self.vpnservice) self._execute.return_value = not_running_status self.driver.report_status(mock.Mock()) @@ -853,7 +878,7 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def _test_parse_connection_status(self, not_running_status, active_status, down_status): """Test the status of ipsec-site-connection is parsed correctly.""" - router_id = self.router.router_id + router_id = self.router_info.router_id process = self.driver.ensure_process(router_id, self.vpnservice) self._execute.return_value = not_running_status self.assertFalse(process.active) @@ -873,41 +898,6 @@ class IPSecDeviceLegacy(BaseIPsecDeviceDriver): def test_fail_getting_namespace_for_unknown_router(self): self.assertFalse(self.driver.get_namespace('bogus_id')) - def test_add_nat_rule(self): - self.driver.add_nat_rule(FAKE_ROUTER_ID, 'fake_chain', - 'fake_rule', True) - self.iptables.add_rule.assert_called_once_with( - 'fake_chain', 'fake_rule', top=True) - - def test_add_nat_rule_with_no_router(self): - self.driver.add_nat_rule( - 'bogus_router_id', - 'fake_chain', - 'fake_rule', - True) - self.assertFalse(self.iptables.add_rule.called) - - def test_remove_rule(self): - self.driver.remove_nat_rule(FAKE_ROUTER_ID, 'fake_chain', - 'fake_rule', True) - self.iptables.remove_rule.assert_called_once_with( - 'fake_chain', 'fake_rule', top=True) - - def test_remove_rule_with_no_router(self): - self.driver.remove_nat_rule( - 'bogus_router_id', - 'fake_chain', - 'fake_rule') - self.assertFalse(self.iptables.remove_rule.called) - - def test_iptables_apply(self): - self.driver.iptables_apply(FAKE_ROUTER_ID) - self.apply_mock.assert_called_once_with() - - def test_iptables_apply_with_no_router(self): - self.driver.iptables_apply('bogus_router_id') - self.assertFalse(self.apply_mock.called) - class IPSecDeviceDVR(BaseIPsecDeviceDriver): @@ -918,21 +908,23 @@ class IPSecDeviceDVR(BaseIPsecDeviceDriver): self._make_dvr_edge_router_info_for_test() def _make_dvr_edge_router_info_for_test(self): - router = dvr_edge_router.DvrEdgeRouter(mock.sentinel.agent, + router_info = dvr_edge_router.DvrEdgeRouter(mock.sentinel.agent, mock.sentinel.myhost, FAKE_ROUTER_ID, **self.ri_kwargs) - router.router['distributed'] = True - router.snat_namespace = dvr_snat_ns.SnatNamespace(router.router['id'], - mock.sentinel.agent, - self.driver, - mock.ANY) - router.snat_namespace.create() - router.snat_iptables_manager = iptables_manager.IptablesManager( + router_info.router['distributed'] = True + router_info.snat_namespace = dvr_snat_ns.SnatNamespace( + router_info.router['id'], + mock.sentinel.agent, + self.driver, + mock.ANY + ) + router_info.snat_namespace.create() + router_info.snat_iptables_manager = iptables_manager.IptablesManager( namespace='snat-' + FAKE_ROUTER_ID, use_ipv6=mock.ANY) - router.snat_iptables_manager.ipv4['nat'] = self.iptables - router.snat_iptables_manager.apply = self.apply_mock - self.driver.routers[FAKE_ROUTER_ID] = router + router_info.snat_iptables_manager.ipv4['nat'] = self.iptables + router_info.snat_iptables_manager.apply = self.apply_mock + self.driver.routers[FAKE_ROUTER_ID] = router_info def test_sync_dvr(self): fake_vpn_service = FAKE_VPN_SERVICE @@ -942,11 +934,10 @@ class IPSecDeviceDVR(BaseIPsecDeviceDriver): self.driver._sync_vpn_processes = mock.Mock() self.driver._delete_vpn_processes = mock.Mock() self.driver._cleanup_stale_vpn_processes = mock.Mock() - sync_routers = [{'id': fake_vpn_service['router_id']}] sync_router_ids = [fake_vpn_service['router_id']] with mock.patch.object(self.driver, 'get_process_status_cache') as process_status: - self.driver.sync(context, sync_routers) + self.driver.sync(context, [self.driver.routers[FAKE_ROUTER_ID]]) self.driver._sync_vpn_processes.assert_called_once_with( [fake_vpn_service], sync_router_ids) self.driver._delete_vpn_processes.assert_called_once_with( @@ -959,22 +950,10 @@ class IPSecDeviceDVR(BaseIPsecDeviceDriver): namespace = self.driver.get_namespace(FAKE_ROUTER_ID) self.assertEqual('snat-' + FAKE_ROUTER_ID, namespace) - def test_add_nat_rule_with_dvr_edge_router(self): - self.driver.add_nat_rule(FAKE_ROUTER_ID, 'fake_chain', - 'fake_rule', True) - self.iptables.add_rule.assert_called_once_with( - 'fake_chain', 'fake_rule', top=True) - - def test_iptables_apply_with_dvr_edge_router(self): - self.driver.iptables_apply(FAKE_ROUTER_ID) + def test_ensure_nat_rules_with_dvr_edge_router(self): + self.driver.ensure_nat_rules(FAKE_VPN_SERVICE) self.apply_mock.assert_called_once_with() - def test_remove_rule_with_dvr_edge_router(self): - self.driver.remove_nat_rule(FAKE_ROUTER_ID, 'fake_chain', - 'fake_rule', True) - self.iptables.remove_rule.assert_called_once_with( - 'fake_chain', 'fake_rule', top=True) - class TestOpenSwanConfigGeneration(BaseIPsecDeviceDriver): @@ -1293,7 +1272,7 @@ class TestOpenSwanProcess(IPSecDeviceLegacy): 'updated_pending_status': True}}) self.assertRaises(vpn_exception.VPNPeerAddressNotResolved, - self.process._get_nexthop, 'foo.peer.addr', + self.process._get_nexthop, 'foo.peer.addr.', 'fake-conn-id') self.assertEqual(expected_connection_status_dict, self.process.connection_status) @@ -1303,7 +1282,7 @@ class TestOpenSwanProcess(IPSecDeviceLegacy): 'updated_pending_status': False}}) self.assertRaises(vpn_exception.VPNPeerAddressNotResolved, - self.process._get_nexthop, 'foo.peer.addr', + self.process._get_nexthop, 'foo.peer.addr.', 'fake-conn-id') self.assertEqual(expected_connection_status_dict, self.process.connection_status) diff --git a/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ovn_ipsec.py b/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ovn_ipsec.py index 43fa0982d..3f1a1a2d5 100644 --- a/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ovn_ipsec.py +++ b/neutron_vpnaas/tests/unit/services/vpn/device_drivers/test_ovn_ipsec.py @@ -197,7 +197,7 @@ class TestOvnStrongSwanDriver(test_ipsec.IPSecDeviceLegacy): 'transit_gateway_ip': '192.168.1.1', } - def test_iptables_apply(self): + def test_ensure_nat_rules(self): """Not applicable for OvnIPsecDriver""" pass @@ -218,19 +218,11 @@ class TestOvnStrongSwanDriver(test_ipsec.IPSecDeviceLegacy): """Not applicable for OvnIPsecDriver""" pass - def test_remove_rule(self): + def test_ensure_nat_rules_with_multiple_local_subnets(self): """Not applicable for OvnIPsecDriver""" pass - def test_add_nat_rules_with_multiple_local_subnets(self): - """Not applicable for OvnIPsecDriver""" - pass - - def _test_add_nat_rule(self): - """Not applicable for OvnIPsecDriver""" - pass - - def test_add_nat_rule(self): + def _test_ensure_nat_rules(self): """Not applicable for OvnIPsecDriver""" pass diff --git a/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py b/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py index c698a2b3a..e4f62bc0d 100644 --- a/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py +++ b/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py @@ -57,7 +57,7 @@ class FakeSqlQueryObject(dict): super(FakeSqlQueryObject, self).__init__(**entries) -class FakeGatewayDB(object): +class FakeGatewayDB: def __init__(self): self.gateways_by_router = {} self.gateways_by_id = {} diff --git a/releasenotes/notes/bug1943449-899ba4711ff3586e.yaml b/releasenotes/notes/bug1943449-899ba4711ff3586e.yaml new file mode 100644 index 000000000..6f6714ffd --- /dev/null +++ b/releasenotes/notes/bug1943449-899ba4711ff3586e.yaml @@ -0,0 +1,11 @@ +--- +prelude: > + Due to an change in the IPtables NAT rule format, with the tag "vpnaas" + upgrading to this release requires either a machine reboot or a move of + all routers from this agent to ensure there is rules of the old format left. +fixes: + - | + Reconciling via the sync method has been improved to ensure no + `ha_state_change` event was missed. + Also all IPtables NAT rules are now tagged "vpnaas" and refreshed on sync + to ensure they are current and there are no duplicates. \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 44d70768d..b452ceb32 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,3 +51,7 @@ oslo.policy.policies = neutron-vpnaas = neutron_vpnaas.policies:list_rules neutron.policies = neutron-vpnaas = neutron_vpnaas.policies:list_rules + +[mypy] +files = neutron_vpnaas/*.py,neutron_vpnaas/agent/**/*.py,neutron_vpnaas/api/**/*.py,neutron_vpnaas/cmd/**/*.py,neutron_vpnaas/db/**/*.py,neutron_vpnaas/extensions/**/*.py,neutron_vpnaas/policies/**/*.py,neutron_vpnaas/scheduler/**/*.py,neutron_vpnaas/services/**/*.py +ignore_missing_imports = true \ No newline at end of file diff --git a/test-requirements.txt b/test-requirements.txt index bddbc8133..71a011fc3 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -11,3 +11,4 @@ stestr>=1.0.0 # Apache-2.0 # see https://review.opendev.org/c/openstack/neutron/+/848706 WebTest>=2.0.27 # MIT +mypy>=1.7.0 # MIT diff --git a/tools/check_i18n.py b/tools/check_i18n.py index 9c49f941e..be81bc964 100644 --- a/tools/check_i18n.py +++ b/tools/check_i18n.py @@ -35,7 +35,7 @@ class ASTWalker(compiler.visitor.ASTVisitor): compiler.visitor.ASTVisitor.default(self, node, *args) -class Visitor(object): +class Visitor: def __init__(self, filename, i18n_msg_predicates, msg_format_checkers, debug): diff --git a/tox.ini b/tox.ini index 3e939b52d..e6eb57422 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,16 @@ [tox] +minversion = 4.0.0 +ignore_basepython_conflict=true +requires = virtualenv>=20.17.1 envlist = py39,py38,pep8,docs -minversion = 3.18.0 [testenv] +usedevelop = True setenv = VIRTUAL_ENV={envdir} OS_LOG_CAPTURE={env:OS_LOG_CAPTURE:true} OS_STDOUT_CAPTURE={env:OS_STDOUT_CAPTURE:true} OS_STDERR_CAPTURE={env:OS_STDERR_CAPTURE:true} PYTHONWARNINGS=default::DeprecationWarning -usedevelop = True deps = -c{env:TOX_CONSTRAINTS_FILE:https://opendev.org/openstack/requirements/raw/branch/master/upper-constraints.txt} -r{toxinidir}/requirements.txt -r{toxinidir}/test-requirements.txt @@ -83,6 +85,7 @@ commands = neutron-db-manage --subproject neutron-vpnaas --database-connection sqlite:// check_migration {[testenv:genconfig]commands} {[testenv:genpolicy]commands} + {[testenv:mypy]commands} allowlist_externals = bash @@ -153,3 +156,10 @@ commands = bash {toxinidir}/tools/generate_config_file_samples.sh [testenv:genpolicy] commands = oslopolicy-sample-generator --config-file=etc/oslo-policy-generator/policy.conf + +[testenv:mypy] +description = + Run type checks. +deps = {[testenv]deps} +commands = + mypy --install-types --non-interactive --check-untyped-defs