From f63f99d822670da792050c2660dd5bf87798ed2a Mon Sep 17 00:00:00 2001 From: Niklas Schwarz Date: Wed, 6 Mar 2024 09:57:55 +0100 Subject: [PATCH] Add type annotations Add type annotations to the different parameters and return values to modernize the python used. This will also introduce mypy as another tool for static code analysis which will currently not run in CI Change-Id: Ic09e47673f916328568c413d0e8485d36c283c24 --- .pylintrc | 20 +- neutron_vpnaas/agent/ovn/vpn/agent.py | 91 +++---- neutron_vpnaas/agent/ovn/vpn/ovsdb.py | 19 +- .../rpc/agentnotifiers/vpn_rpc_agent_api.py | 17 +- .../db/migration/alembic_migrations/env.py | 6 +- .../db/vpn/vpn_agentschedulers_db.py | 116 ++++++--- neutron_vpnaas/db/vpn/vpn_db.py | 173 ++++++++----- neutron_vpnaas/db/vpn/vpn_ext_gw_db.py | 52 ++-- neutron_vpnaas/db/vpn/vpn_validator.py | 75 +++--- .../extensions/vpn_agentschedulers.py | 101 ++++---- .../extensions/vpn_endpoint_groups.py | 23 +- neutron_vpnaas/extensions/vpnaas.py | 68 ++++-- neutron_vpnaas/opts.py | 7 +- .../scheduler/vpn_agent_scheduler.py | 99 ++++++-- neutron_vpnaas/services/vpn/agent.py | 23 +- .../services/vpn/common/netns_wrapper.py | 14 +- .../services/vpn/device_drivers/__init__.py | 6 +- .../services/vpn/device_drivers/ipsec.py | 230 ++++++++++-------- .../vpn/device_drivers/libreswan_ipsec.py | 41 ++-- .../services/vpn/device_drivers/ovn_ipsec.py | 80 +++--- .../vpn/device_drivers/strongswan_ipsec.py | 18 +- .../services/vpn/ovn/agent_monitor.py | 6 +- neutron_vpnaas/services/vpn/ovn_agent.py | 2 +- neutron_vpnaas/services/vpn/ovn_plugin.py | 26 +- neutron_vpnaas/services/vpn/plugin.py | 138 ++++++----- .../services/vpn/service_drivers/__init__.py | 49 ++-- .../vpn/service_drivers/base_ipsec.py | 86 ++++--- .../vpn/service_drivers/driver_validator.py | 15 +- .../services/vpn/service_drivers/ipsec.py | 8 +- .../vpn/service_drivers/ipsec_validator.py | 17 +- .../services/vpn/service_drivers/ovn_ipsec.py | 143 +++++++---- neutron_vpnaas/services/vpn/vpn_service.py | 13 +- .../tests/functional/common/test_scenario.py | 2 +- .../db/vpn/test_vpn_agentschedulers_db.py | 2 +- .../tests/unit/db/vpn/test_vpn_db.py | 4 +- .../vpn/service_drivers/test_ovn_ipsec.py | 2 +- setup.cfg | 4 + test-requirements.txt | 1 + tools/check_i18n.py | 2 +- tox.ini | 14 +- 40 files changed, 1105 insertions(+), 708 deletions(-) diff --git a/.pylintrc b/.pylintrc index 1c942d426..1f935c524 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,26 +15,22 @@ ignore=.git,tests,openstack disable= # "F" Fatal errors that prevent further processing import-error, + import-untyped, # "I" Informational noise locally-disabled, # "E" Error for important programming issues (likely bugs) access-member-before-definition, - bad-super-call, maybe-no-member, - no-member, - no-method-argument, - no-self-argument, + no-name-in-module, not-callable, - no-value-for-parameter, - super-on-old-class, too-few-format-args, + unsubscriptable-object, # "W" Warnings for stylistic problems or minor programming issues - abstract-method, anomalous-backslash-in-string, anomalous-unicode-escape-in-string, arguments-differ, attribute-defined-outside-init, - bad-builtin, + # bad-builtin, bad-indentation, broad-except, dangerous-default-value, @@ -57,7 +53,6 @@ disable= unnecessary-lambda, unnecessary-pass, unpacking-non-sequence, - unreachable, unused-argument, unused-import, unused-variable, @@ -67,17 +62,13 @@ disable= bad-continuation, invalid-name, missing-docstring, - old-style-class, superfluous-parens, # "R" Refactor recommendations - abstract-class-little-used, - abstract-class-not-used, - consider-using-set-comprehension, + cyclic-import, duplicate-code, inconsistent-return-statements, interface-not-implemented, no-else-raise, - no-else-return, no-self-use, too-few-public-methods, too-many-ancestors, @@ -89,7 +80,6 @@ disable= too-many-public-methods, too-many-return-statements, too-many-statements, - useless-object-inheritance [BASIC] # Variable names can be 1 to 31 characters long, with lowercase and underscores diff --git a/neutron_vpnaas/agent/ovn/vpn/agent.py b/neutron_vpnaas/agent/ovn/vpn/agent.py index 7374fb220..5d6466793 100644 --- a/neutron_vpnaas/agent/ovn/vpn/agent.py +++ b/neutron_vpnaas/agent/ovn/vpn/agent.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing as ty import uuid from neutron.agent.linux import external_process from neutron.common.ovn import utils as ovn_utils from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config +from oslo_config import cfg from oslo_log import log as logging from oslo_service import service from ovsdbapp.backend.ovs_idl import event as row_event @@ -32,44 +34,6 @@ LOG = logging.getLogger(__name__) OVN_VPNAGENT_UUID_NAMESPACE = uuid.UUID('e1ce3b12-b1e0-4c81-ba27-07c0fec9c12b') -class ChassisCreateEventBase(row_event.RowEvent): - """Row create event - Chassis name == our_chassis. - - On connection, we get a dump of all chassis so if we catch a creation - of our own chassis it has to be a reconnection. In this case, we need - to do a full sync to make sure that we capture all changes while the - connection to OVSDB was down. - """ - table = None - - def __init__(self, vpn_agent): - self.agent = vpn_agent - self.first_time = True - events = (self.ROW_CREATE,) - super().__init__( - events, self.table, (('name', '=', self.agent.chassis),)) - self.event_name = self.__class__.__name__ - - def run(self, event, row, old): - if self.first_time: - self.first_time = False - else: - # NOTE(lucasagomes): Re-register the ovn vpn agent - # with the local chassis in case its entry was re-created - # (happens when restarting the ovn-controller) - self.agent.register_vpn_agent() - LOG.info("Connection to OVSDB established, doing a full sync") - self.agent.sync() - - -class ChassisCreateEvent(ChassisCreateEventBase): - table = 'Chassis' - - -class ChassisPrivateCreateEvent(ChassisCreateEventBase): - table = 'Chassis_Private' - - class SbGlobalUpdateEvent(row_event.RowEvent): """Row update event on SB_Global table.""" @@ -90,7 +54,7 @@ class SbGlobalUpdateEvent(row_event.RowEvent): class OvnVpnAgent(service.Service): - def __init__(self, conf): + def __init__(self, conf: cfg.ConfigOpts): super().__init__() self.conf = conf vlog.use_python_logger(max_level=config.get_ovn_ovsdb_log_level()) @@ -102,13 +66,13 @@ class OvnVpnAgent(service.Service): self.device_drivers = self.service.load_device_drivers(self.conf.host) def _load_config(self): - self.chassis = self._get_own_chassis_name() + self.chassis: ty.Optional[str] = self._get_own_chassis_name() try: self.chassis_id = uuid.UUID(self.chassis) except ValueError: # OVS system-id could be a non UUID formatted string. self.chassis_id = uuid.uuid5(OVN_VPNAGENT_UUID_NAMESPACE, - self.chassis) + self.chassis if self.chassis else '') LOG.debug("Loaded chassis name %s (UUID: %s).", self.chassis, self.chassis_id) @@ -156,12 +120,51 @@ class OvnVpnAgent(service.Service): self.sb_idl.db_add(table, self.chassis, 'external_ids', ext_ids).execute(check_error=True) - def _get_own_chassis_name(self): + def _get_own_chassis_name(self) -> ty.Optional[str]: """Return the external_ids:system-id value of the Open_vSwitch table. As long as ovn-controller is running on this node, the key is guaranteed to exist and will include the chassis name. """ - ext_ids = self.ovs_idl.db_get( + ext_ids: ty.Optional[ty.Dict[str, str]] = self.ovs_idl.db_get( 'Open_vSwitch', '.', 'external_ids').execute() - return ext_ids['system-id'] + return ext_ids['system-id'] if ext_ids else None + + +class ChassisCreateEventBase(row_event.RowEvent): + """Row create event - Chassis name == our_chassis. + + On connection, we get a dump of all chassis so if we catch a creation + of our own chassis it has to be a reconnection. In this case, we need + to do a full sync to make sure that we capture all changes while the + connection to OVSDB was down. + """ + + table: ty.Optional[str] = None + + def __init__(self, vpn_agent: OvnVpnAgent): + self.agent = vpn_agent + self.first_time: bool = True + events: ty.Tuple[str] = (self.ROW_CREATE,) + super().__init__(events, self.table, + (('name', '=', self.agent.chassis),)) + self.event_name = self.__class__.__name__ + + def run(self, event, row, old): + if self.first_time: + self.first_time = False + else: + # NOTE(lucasagomes): Re-register the ovn vpn agent + # with the local chassis in case its entry was re-created + # (happens when restarting the ovn-controller) + self.agent.register_vpn_agent() + LOG.info("Connection to OVSDB established, doing a full sync") + self.agent.sync() + + +class ChassisCreateEvent(ChassisCreateEventBase): + table = "Chassis" + + +class ChassisPrivateCreateEvent(ChassisCreateEventBase): + table = "Chassis_Private" diff --git a/neutron_vpnaas/agent/ovn/vpn/ovsdb.py b/neutron_vpnaas/agent/ovn/vpn/ovsdb.py index 765865769..2c2da1d8e 100644 --- a/neutron_vpnaas/agent/ovn/vpn/ovsdb.py +++ b/neutron_vpnaas/agent/ovn/vpn/ovsdb.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing as ty from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import impl_idl_ovn @@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__) class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): - SCHEMA = 'OVN_Southbound' + SCHEMA: str = 'OVN_Southbound' def __init__(self, chassis=None, events=None, tables=None): connection_string = config.get_ovn_sb_connection() @@ -39,8 +40,8 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): for table in tables: helper.register_table(table) try: - super().__init__( - None, connection_string, helper, leader_only=False) + super().__init__(None, connection_string, + helper, leader_only=False) except TypeError: # TODO(bpetermann) We can remove this when we require ovs>=2.12.0 super().__init__(None, connection_string, helper) @@ -54,7 +55,7 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): @tenacity.retry( wait=tenacity.wait_exponential(max=180), reraise=True) - def _get_ovsdb_helper(self, connection_string): + def _get_ovsdb_helper(self, connection_string: str): return idlutils.get_schema_helper(connection_string, self.SCHEMA) def start(self): @@ -62,14 +63,18 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl): self, timeout=config.get_ovn_ovsdb_timeout()) return impl_idl_ovn.OvsdbSbOvnIdl(conn) + def post_connect(self): + pass -class VPNAgentOvsIdl(object): + +class VPNAgentOvsIdl: def start(self): - connection_string = config.cfg.CONF.ovs.ovsdb_connection + connection_string: str = config.cfg.CONF.ovs.ovsdb_connection helper = idlutils.get_schema_helper(connection_string, 'Open_vSwitch') - tables = ('Open_vSwitch', 'Bridge', 'Port', 'Interface') + tables: ty.Tuple[str, str, str, str] = ('Open_vSwitch', 'Bridge', + 'Port', 'Interface') for table in tables: helper.register_table(table) ovs_idl = idl.Idl( diff --git a/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py b/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py index 3e6fefd47..ac55bb794 100644 --- a/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py +++ b/neutron_vpnaas/api/rpc/agentnotifiers/vpn_rpc_agent_api.py @@ -12,7 +12,10 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty + from neutron.api.rpc.agentnotifiers import utils as ag_utils +from neutron_lib import context from neutron_lib import rpc as n_rpc import oslo_messaging @@ -23,25 +26,31 @@ from neutron_vpnaas.services.vpn.common import topics AGENT_NOTIFY_MAX_ATTEMPTS = 2 -class VPNAgentNotifyAPI(object): +class VPNAgentNotifyAPI: """API for plugin to notify VPN agent.""" def __init__(self, topic=topics.IPSEC_AGENT_TOPIC): target = oslo_messaging.Target(topic=topic, version='1.0') self.client = n_rpc.get_client(target) - def agent_updated(self, context, admin_state_up, host): + def agent_updated( + self, context: context.ContextBase, + admin_state_up: bool, host: str): cctxt = self.client.prepare(server=host) cctxt.cast(context, 'agent_updated', payload={'admin_state_up': admin_state_up}) - def vpnservice_removed_from_agent(self, context, router_id, host): + def vpnservice_removed_from_agent( + self, context: context.ContextBase, + router_id: str, host: str): """Notify agent about removed VPN service(s) of a router.""" cctxt = self.client.prepare(server=host) cctxt.cast(context, 'vpnservice_removed_from_agent', router_id=router_id) - def vpnservice_added_to_agent(self, context, router_ids, host): + def vpnservice_added_to_agent( + self, context: context.ContextBase, + router_ids: ty.List[str], host: str): """Notify agent about added VPN service(s) of router(s).""" # need to use call here as we want to be sure agent received # notification and router will not be "lost". However using call() diff --git a/neutron_vpnaas/db/migration/alembic_migrations/env.py b/neutron_vpnaas/db/migration/alembic_migrations/env.py index eb2d7e0be..2c641cc03 100644 --- a/neutron_vpnaas/db/migration/alembic_migrations/env.py +++ b/neutron_vpnaas/db/migration/alembic_migrations/env.py @@ -27,8 +27,8 @@ from neutron_vpnaas.db.migration import alembic_migrations MYSQL_ENGINE = None config = context.config -neutron_config = config.neutron_config -logging_config.fileConfig(config.config_file_name) +neutron_config = config.neutron_config # type: ignore +logging_config.fileConfig(config.config_file_name) # type: ignore target_metadata = model_base.BASEV2.metadata @@ -46,7 +46,7 @@ def set_mysql_engine(): def run_migrations_offline(): set_mysql_engine() - kwargs = dict() + kwargs = {} if neutron_config.database.connection: kwargs['url'] = neutron_config.database.connection else: diff --git a/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py b/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py index a0dca64cb..7e96cf1ba 100644 --- a/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py +++ b/neutron_vpnaas/db/vpn/vpn_agentschedulers_db.py @@ -13,9 +13,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty import random +from neutron.extensions import agent as nagent +from neutron.extensions import l3 from neutron.extensions import router_availability_zone as router_az from neutron import worker as neutron_worker from neutron_lib import context as ncontext @@ -31,8 +34,10 @@ import sqlalchemy as sa from sqlalchemy import func from neutron_vpnaas._i18n import _ +from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.extensions import vpn_agentschedulers +from neutron_vpnaas.scheduler.vpn_agent_scheduler import VPNScheduler from neutron_vpnaas.services.vpn.common.constants import AGENT_TYPE_VPN @@ -71,15 +76,15 @@ class VPNAgentSchedulerDbMixin( using the VPN agent. """ - vpn_scheduler = None - agent_notifiers = {} + vpn_scheduler: ty.Optional[VPNScheduler] = None + agent_notifiers: ty.Dict[str, nfy_api.VPNAgentNotifyAPI] = {} @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(plugin_const.L3) @property - def core_plugin(self): + def core_plugin(self) -> nagent.AgentPluginBase: return directory.get_plugin() def add_periodic_vpn_agent_status_check(self): @@ -96,7 +101,7 @@ class VPNAgentSchedulerDbMixin( check_worker = neutron_worker.PeriodicWorker( self.reschedule_vpnservices_from_down_agents, interval, initial_delay) - self.add_worker(check_worker) + self.add_worker(check_worker) # type: ignore def reschedule_vpnservices_from_down_agents(self): """Reschedule VPN services from down VPN agents. @@ -111,9 +116,10 @@ class VPNAgentSchedulerDbMixin( for binding in down_bindings: if binding.vpn_agent_id in agents_back_online: continue - agent = self.core_plugin.get_agent(context, + agent: ty.Optional[nagent.Agent] = self.core_plugin.get_agent( + context, binding.vpn_agent_id) - if agent['alive']: + if agent and agent['alive']: agents_back_online.add(binding.vpn_agent_id) continue @@ -137,7 +143,8 @@ class VPNAgentSchedulerDbMixin( "rescheduling.") @db_api.CONTEXT_READER - def get_down_router_bindings(self, context): + def get_down_router_bindings(self, + context: ncontext.Context) -> ty.List[RouterVPNAgentBinding]: vpn_agents = self.get_vpn_agents(context, active=False) if not vpn_agents: return [] @@ -148,7 +155,9 @@ class VPNAgentSchedulerDbMixin( RouterVPNAgentBinding.vpn_agent_id.in_(vpn_agent_ids)) return query.all() - def validate_agent_router_combination(self, context, agent, router): + def validate_agent_router_combination(self, context: ncontext.ContextBase, + agent: nagent.Agent, + router: ty.Dict[str, ty.Any]): """Validate if the router can be correctly assigned to the agent. :raises: InvalidVPNAgent if attempting to assign router to an @@ -158,7 +167,9 @@ class VPNAgentSchedulerDbMixin( raise vpn_agentschedulers.InvalidVPNAgent(id=agent['id']) @db_api.CONTEXT_READER - def check_agent_router_scheduling_needed(self, context, agent, router): + def check_agent_router_scheduling_needed(self, context: ncontext.Context, + agent: ty.Dict[str, ty.Any], + router: ty.Dict[str, ty.Any]): """Check if the scheduling of router's VPN services is needed. :raises: RouterHostedByVPNAgent if router is already assigned @@ -180,7 +191,8 @@ class VPNAgentSchedulerDbMixin( router_id=router_id, agent_id=bindings[0].vpn_agent_id) - def create_router_to_agent_binding(self, context, router_id, agent_id): + def create_router_to_agent_binding(self, context: ncontext.Context, + router_id: str, agent_id: str): """Create router to VPN agent binding.""" try: with db_api.CONTEXT_WRITER.using(context): @@ -203,11 +215,12 @@ class VPNAgentSchedulerDbMixin( {'router_id': router_id, 'agent_id': agent_id}) return True - def add_router_to_vpn_agent(self, context, agent_id, router_id): + def add_router_to_vpn_agent(self, context: ncontext.Context, + agent_id: str, router_id: str): """Add a VPN agent to host VPN services of a router.""" with db_api.CONTEXT_WRITER.using(context): router = self.l3_plugin.get_router(context, router_id) - agent = self.core_plugin.get_agent(context, agent_id) + agent: nagent.Agent = self.core_plugin.get_agent(context, agent_id) self.validate_agent_router_combination(context, agent, router) if not self.check_agent_router_scheduling_needed( context, agent, router): @@ -232,7 +245,8 @@ class VPNAgentSchedulerDbMixin( self.vpn_router_agent_binding_changed( context, router_id, agent['host']) - def remove_router_from_vpn_agent(self, context, agent_id, router_id): + def remove_router_from_vpn_agent(self, context: ncontext.Context, + agent_id: str, router_id: str): """Remove the router from VPN agent. After removal, the VPN service(s) of the router will be non-hosted @@ -248,7 +262,8 @@ class VPNAgentSchedulerDbMixin( vpn_notifier.vpnservice_removed_from_agent( context, router_id, agent['host']) - def _unbind_router(self, context, router_id, agent_id): + def _unbind_router(self, context: ncontext.Context, + router_id: str, agent_id: str): with db_api.CONTEXT_WRITER.using(context): query = context.session.query(RouterVPNAgentBinding) query = query.filter( @@ -256,7 +271,8 @@ class VPNAgentSchedulerDbMixin( RouterVPNAgentBinding.vpn_agent_id == agent_id) return query.delete() - def reschedule_router(self, context, router_id, cur_agent): + def reschedule_router(self, context: ncontext.Context, router_id: str, + cur_agent: nagent.Agent): """Reschedule router to a new VPN agent Remove the router from the agent currently hosting it and @@ -282,8 +298,10 @@ class VPNAgentSchedulerDbMixin( self.vpn_router_agent_binding_changed( context, router_id, new_agent['host']) - def _notify_agents_router_rescheduled(self, context, router_id, - old_agent, new_agent): + def _notify_agents_router_rescheduled(self, context: ncontext.Context, + router_id: str, + old_agent: ty.Dict[str, ty.Any], + new_agent: ty.Dict[str, ty.Any]): vpn_notifier = self.agent_notifiers.get(AGENT_TYPE_VPN) if not vpn_notifier: return @@ -303,7 +321,8 @@ class VPNAgentSchedulerDbMixin( router_id=router_id) @db_api.CONTEXT_READER - def list_routers_on_vpn_agent(self, context, agent_id): + def list_routers_on_vpn_agent(self, context: ncontext.Context, + agent_id: str): query = context.session.query(RouterVPNAgentBinding.router_id) query = query.filter(RouterVPNAgentBinding.vpn_agent_id == agent_id) @@ -312,13 +331,15 @@ class VPNAgentSchedulerDbMixin( return {'routers': self.l3_plugin.get_routers(context, filters={'id': router_ids})} - else: - # Exception will be thrown if the requested agent does not exist. - self.core_plugin.get_agent(context, agent_id) - return {'routers': []} + + # Exception will be thrown if the requested agent does not exist. + self.core_plugin.get_agent(context, agent_id) + return {'routers': []} @db_api.CONTEXT_READER - def get_vpn_agents_hosting_routers(self, context, router_ids, active=None): + def get_vpn_agents_hosting_routers(self, context: ncontext.Context, + router_ids: ty.Optional[ty.List[str]], + active: ty.Optional[bool] = None): if not router_ids: return [] query = context.session.query(RouterVPNAgentBinding) @@ -332,29 +353,39 @@ class VPNAgentSchedulerDbMixin( if agent['alive'] == active] return vpn_agents - def list_vpn_agents_hosting_router(self, context, router_id): + def list_vpn_agents_hosting_router(self, context: ncontext.Context, + router_id: str): vpn_agents = self.get_vpn_agents_hosting_routers(context, [router_id]) return {'agents': vpn_agents} - def get_vpn_agents(self, context, active=None, host=None): + def get_vpn_agents(self, context: ncontext.Context, + active: ty.Optional[bool] = None, + host: ty.Optional[str] = None) -> \ + ty.Optional[ty.List[nagent.Agent]]: filters = {'agent_type': [AGENT_TYPE_VPN]} if host is not None: filters['host'] = [host] - vpn_agents = self.core_plugin.get_agents(context, filters=filters) + vpn_agents: ty.Optional[ty.List[nagent.Agent]] = \ + self.core_plugin.get_agents(context, filters=filters) if active is None: return vpn_agents - else: - return [vpn_agent - for vpn_agent in vpn_agents - if vpn_agent['alive'] == active] - def get_vpn_agent_on_host(self, context, host, active=None): + if not vpn_agents: + return None + + return [vpn_agent + for vpn_agent in vpn_agents + if vpn_agent['alive'] == active] + + def get_vpn_agent_on_host(self, context: ncontext.Context, + host: str, active: ty.Optional[bool] = None): agents = self.get_vpn_agents(context, active=active, host=host) if agents: return agents[0] @db_api.CONTEXT_READER - def get_unscheduled_vpn_routers(self, context, router_ids=None): + def get_unscheduled_vpn_routers(self, context: ncontext.Context, + router_ids: ty.Optional[ty.List[str]] = None): """Get IDs of routers which have unscheduled VPN services.""" query = context.session.query(vpn_models.VPNService.router_id) query = query.outerjoin( @@ -366,12 +397,14 @@ class VPNAgentSchedulerDbMixin( vpn_models.VPNService.router_id.in_(router_ids)) return [router_id for router_id, in query.all()] - def auto_schedule_routers(self, context, vpn_agent): + def auto_schedule_routers(self, context: ncontext.Context, vpn_agent): if self.vpn_scheduler: return self.vpn_scheduler.auto_schedule_routers( self, context, vpn_agent) - def schedule_router(self, context, router, candidates=None): + def schedule_router(self, context: ncontext.Context, + router, candidates: ty.Optional[ty.List] = None) -> \ + ty.Optional[nagent.Agent]: """Schedule VPN services of a router to a VPN agent. Returns the chosen agent; None if another server scheduled the @@ -381,9 +414,11 @@ class VPNAgentSchedulerDbMixin( if self.vpn_scheduler: return self.vpn_scheduler.schedule( self, context, router, candidates=candidates) + return None @db_api.CONTEXT_READER - def get_vpn_agent_with_min_routers(self, context, agent_ids): + def get_vpn_agent_with_min_routers(self, context: ncontext.Context, + agent_ids: ty.Optional[ty.List[str]]): """Return VPN agent with the least number of routers.""" if not agent_ids: return None @@ -397,15 +432,18 @@ class VPNAgentSchedulerDbMixin( unused_agent_ids = set(agent_ids) - set(used_agent_ids) if unused_agent_ids: return unused_agent_ids.pop() - else: - return used_agent_ids[0] + return used_agent_ids[0] - def get_hosts_to_notify(self, context, router_id): + def get_hosts_to_notify(self, context: ncontext.Context, router_id): """Returns all hosts to send notification about router update""" agents = self.get_vpn_agents_hosting_routers(context, [router_id], active=True) return [a['host'] for a in agents] + def vpn_router_agent_binding_changed(self, context: ncontext.Context, + router_id: str, host: str): + pass + class AZVPNAgentSchedulerDbMixin(VPNAgentSchedulerDbMixin, router_az.RouterAvailabilityZonePluginBase): diff --git a/neutron_vpnaas/db/vpn/vpn_db.py b/neutron_vpnaas/db/vpn/vpn_db.py index f7f3257a5..d383c53d9 100644 --- a/neutron_vpnaas/db/vpn/vpn_db.py +++ b/neutron_vpnaas/db/vpn/vpn_db.py @@ -13,12 +13,14 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron.db import models_v2 from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.db import model_query from neutron_lib.db import utils as db_utils @@ -58,12 +60,13 @@ class VPNPluginDb(vpnaas.VPNPluginBase, """ return vpn_validator.VpnReferenceValidator() - def update_status(self, context, model, v_id, status): + def update_status(self, context: context.Context, model, v_id: str, + status: str): with db_api.CONTEXT_WRITER.using(context): v_db = self._get_resource(context, model, v_id) v_db.update({'status': status}) - def _get_resource(self, context, model, v_id): + def _get_resource(self, context: context.Context, model, v_id): try: r = model_query.get_by_id(context, model, v_id) except exc.NoResultFound: @@ -91,7 +94,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if utils.in_pending_status(status): raise vpn_exception.VPNStateInvalidToUpdate(id=_id, state=status) - def _make_ipsec_site_connection_dict(self, ipsec_site_conn, fields=None): + def _make_ipsec_site_connection_dict(self, + ipsec_site_conn: ty.Dict[str, ty.Any], + fields=None): res = {'id': ipsec_site_conn['id'], 'tenant_id': ipsec_site_conn['tenant_id'], @@ -123,15 +128,17 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def get_endpoint_info(self, context, ipsec_sitecon): + def get_endpoint_info(self, context: context.Context, ipsec_sitecon): """Obtain all endpoint info, and store in connection for validation.""" ipsec_sitecon['local_epg_subnets'] = self.get_endpoint_group( context, ipsec_sitecon['local_ep_group_id']) ipsec_sitecon['peer_epg_cidrs'] = self.get_endpoint_group( context, ipsec_sitecon['peer_ep_group_id']) - def validate_connection_info(self, context, validator, ipsec_sitecon, - vpnservice): + def validate_connection_info( + self, context: context.Context, + validator: vpn_validator.VpnReferenceValidator, ipsec_sitecon, + vpnservice): """Collect info and validate connection. If endpoint groups used (default), collect the group info and @@ -150,7 +157,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, validator.validate_ipsec_site_connection(context, ipsec_sitecon, ip_version, vpnservice) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection: ty.Dict[str, ty.Dict]): ipsec_sitecon = ipsec_site_connection['ipsec_site_connection'] validator = self._get_validator() validator.assign_sensible_ipsec_sitecon_defaults(ipsec_sitecon) @@ -203,7 +211,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return self._make_ipsec_site_connection_dict(ipsec_site_conn_db) def update_ipsec_site_connection( - self, context, + self, context: context.Context, ipsec_site_conn_id, ipsec_site_connection): ipsec_sitecon = ipsec_site_connection['ipsec_site_connection'] changed_peer_cidrs = False @@ -252,36 +260,41 @@ class VPNPluginDb(vpnaas.VPNPluginBase, result['peer_cidrs'] = new_peer_cidrs return result - def delete_ipsec_site_connection(self, context, ipsec_site_conn_id): + def delete_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str): with db_api.CONTEXT_WRITER.using(context): ipsec_site_conn_db = self._get_resource( context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id) context.session.delete(ipsec_site_conn_db) - def _get_ipsec_site_connection( - self, context, ipsec_site_conn_id): + def _get_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str) -> vpn_models.IPsecSiteConnection: return self._get_resource( context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id) - def get_ipsec_site_connection(self, context, - ipsec_site_conn_id, fields=None): + def get_ipsec_site_connection(self, context: context.Context, + ipsec_site_conn_id: str, fields=None): ipsec_site_conn_db = self._get_ipsec_site_connection( context, ipsec_site_conn_id) return self._make_ipsec_site_connection_dict( ipsec_site_conn_db, fields) - def get_ipsec_site_connections(self, context, filters=None, fields=None): + def get_ipsec_site_connections(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None) -> ty.List[vpn_models.IPsecSiteConnection]: return model_query.get_collection( context, vpn_models.IPsecSiteConnection, self._make_ipsec_site_connection_dict, filters=filters, fields=fields) - def update_ipsec_site_conn_status(self, context, conn_id, new_status): + def update_ipsec_site_conn_status(self, context: context.Context, + conn_id: str, new_status: str): with db_api.CONTEXT_WRITER.using(context): self._update_connection_status(context, conn_id, new_status, True) - def _update_connection_status(self, context, conn_id, new_status, - updated_pending): + def _update_connection_status(self, context: context.Context, + conn_id: str, new_status: str, + updated_pending: bool): """Update the connection status, if changed. If the connection is not in a pending state, unconditionally update @@ -295,7 +308,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if not utils.in_pending_status(conn_db.status) or updated_pending: conn_db.status = new_status - def _make_ikepolicy_dict(self, ikepolicy, fields=None): + def _make_ikepolicy_dict(self, ikepolicy: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': ikepolicy['id'], 'tenant_id': ikepolicy['tenant_id'], 'name': ikepolicy['name'], @@ -313,10 +327,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, + ikepolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ike = ikepolicy['ikepolicy'] validator = self._get_validator() - lifetime_info = ike['lifetime'] + lifetime_info: ty.Dict[str, ty.Any] = ike['lifetime'] lifetime_units = lifetime_info.get('units', 'seconds') lifetime_value = lifetime_info.get('value', 3600) @@ -339,7 +354,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(ike_db) return self._make_ikepolicy_dict(ike_db) - def update_ikepolicy(self, context, ikepolicy_id, ikepolicy): + def update_ikepolicy(self, context: context.Context, ikepolicy_id: str, + ikepolicy: ty.Dict[str, ty.Dict]): ike = ikepolicy['ikepolicy'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -359,7 +375,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, ike_db.update(ike) return self._make_ikepolicy_dict(ike_db) - def delete_ikepolicy(self, context, ikepolicy_id): + def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( ikepolicy_id=ikepolicy_id).first(): @@ -369,19 +385,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(ike_db) @db_api.CONTEXT_READER - def get_ikepolicy(self, context, ikepolicy_id, fields=None): + def get_ikepolicy(self, context: context.Context, ikepolicy_id: str, + fields=None): ike_db = self._get_resource( context, vpn_models.IKEPolicy, ikepolicy_id) return self._make_ikepolicy_dict(ike_db, fields) @db_api.CONTEXT_READER - def get_ikepolicies(self, context, filters=None, fields=None): + def get_ikepolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.IKEPolicy, self._make_ikepolicy_dict, filters=filters, fields=fields) - def _make_ipsecpolicy_dict(self, ipsecpolicy, fields=None): - + def _make_ipsecpolicy_dict(self, ipsecpolicy: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': ipsecpolicy['id'], 'tenant_id': ipsecpolicy['tenant_id'], 'name': ipsecpolicy['name'], @@ -399,10 +417,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return db_utils.resource_fields(res, fields) - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, + ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ipsecp = ipsecpolicy['ipsecpolicy'] validator = self._get_validator() - lifetime_info = ipsecp['lifetime'] + lifetime_info: ty.Dict[str, ty.Any] = ipsecp['lifetime'] lifetime_units = lifetime_info.get('units', 'seconds') lifetime_value = lifetime_info.get('value', 3600) @@ -423,7 +442,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(ipsecp_db) return self._make_ipsecpolicy_dict(ipsecp_db) - def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]): ipsecp = ipsecpolicy['ipsecpolicy'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -444,7 +464,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, ipsecp_db.update(ipsecp) return self._make_ipsecpolicy_dict(ipsecp_db) - def delete_ipsecpolicy(self, context, ipsecpolicy_id): + def delete_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( ipsecpolicy_id=ipsecpolicy_id).first(): @@ -455,18 +476,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(ipsec_db) @db_api.CONTEXT_READER - def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None): + def get_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str, fields=None): ipsec_db = self._get_resource( context, vpn_models.IPsecPolicy, ipsecpolicy_id) return self._make_ipsecpolicy_dict(ipsec_db, fields) @db_api.CONTEXT_READER - def get_ipsecpolicies(self, context, filters=None, fields=None): + def get_ipsecpolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.IPsecPolicy, self._make_ipsecpolicy_dict, filters=filters, fields=fields) - def _make_vpnservice_dict(self, vpnservice, fields=None): + def _make_vpnservice_dict(self, vpnservice: ty.Dict[str, ty.Any], + fields=None) -> ty.Dict[str, ty.Any]: res = {'id': vpnservice['id'], 'name': vpnservice['name'], 'description': vpnservice['description'], @@ -480,7 +504,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, 'status': vpnservice['status']} return db_utils.resource_fields(res, fields) - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.Context, + vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]): vpns = vpnservice['vpnservice'] flavor_id = vpns.get('flavor_id', None) validator = self._get_validator() @@ -499,8 +524,10 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(vpnservice_db) return self._make_vpnservice_dict(vpnservice_db) - def set_external_tunnel_ips(self, context, vpnservice_id, v4_ip=None, - v6_ip=None): + def set_external_tunnel_ips(self, context: context.Context, + vpnservice_id: str, + v4_ip: ty.Optional[str] = None, + v6_ip: ty.Optional[str] = None): """Update the external tunnel IP(s) for service.""" vpns = {'external_v4_ip': v4_ip, 'external_v6_ip': v6_ip} with db_api.CONTEXT_WRITER.using(context): @@ -509,20 +536,22 @@ class VPNPluginDb(vpnaas.VPNPluginBase, vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def set_vpnservice_status(self, context, vpnservice_id, status, - updated_pending_status=False): + def set_vpnservice_status(self, context: context.Context, + vpnservice_id: str, status: str, + updated_pending_status: bool = False): vpns = {'status': status} with db_api.CONTEXT_WRITER.using(context): vpns_db = self._get_resource(context, vpn_models.VPNService, vpnservice_id) if (utils.in_pending_status(vpns_db.status) and not updated_pending_status): - raise vpnaas.VPNStateInvalidToUpdate( + raise vpnaas.VPNStateInvalidToUpdate( # type: ignore id=vpnservice_id, state=vpns_db.status) vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: context.Context, vpnservice_id: str, + vpnservice: ty.Dict[str, ty.Optional[ty.Dict]]): vpns = vpnservice['vpnservice'] with db_api.CONTEXT_WRITER.using(context): vpns_db = self._get_resource(context, vpn_models.VPNService, @@ -532,7 +561,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase, vpns_db.update(vpns) return self._make_vpnservice_dict(vpns_db) - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: context.Context, vpnservice_id: str): with db_api.CONTEXT_WRITER.using(context): if context.session.query(vpn_models.IPsecSiteConnection).filter_by( vpnservice_id=vpnservice_id @@ -544,23 +573,26 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(vpns_db) @db_api.CONTEXT_READER - def _get_vpnservice(self, context, vpnservice_id): + def _get_vpnservice(self, context: context.Context, vpnservice_id: str): return self._get_resource(context, vpn_models.VPNService, vpnservice_id) @db_api.CONTEXT_READER - def get_vpnservice(self, context, vpnservice_id, fields=None): + def get_vpnservice(self, context: context.Context, vpnservice_id: str, + fields=None): vpns_db = self._get_resource(context, vpn_models.VPNService, vpnservice_id) return self._make_vpnservice_dict(vpns_db, fields) @db_api.CONTEXT_READER - def get_vpnservices(self, context, filters=None, fields=None): + def get_vpnservices(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None) -> ty.List: return model_query.get_collection(context, vpn_models.VPNService, self._make_vpnservice_dict, filters=filters, fields=fields) - def check_router_in_use(self, context, router_id): + def check_router_in_use(self, context: context.Context, router_id: str): vpnservices = self.get_vpnservices( context, filters={'router_id': [router_id]}) if vpnservices: @@ -572,7 +604,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, "(%(services)s)" % {'plural': plural, 'services': services}) - def check_subnet_in_use(self, context, subnet_id, router_id): + def check_subnet_in_use(self, context: context.Context, subnet_id: str, + router_id: str): with db_api.CONTEXT_READER.using(context): vpnservices = context.session.query( vpn_models.VPNService).filter_by( @@ -605,7 +638,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, subnet_id=subnet_id, ipsec_site_connection_id=connection['id']) - def check_subnet_in_use_by_endpoint_group(self, context, subnet_id): + def check_subnet_in_use_by_endpoint_group(self, context: context.Context, + subnet_id: str): with db_api.CONTEXT_READER.using(context): query = context.session.query(vpn_models.VPNEndpointGroup) query = query.filter(vpn_models.VPNEndpointGroup.endpoint_type == @@ -620,7 +654,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, raise vpn_exception.SubnetInUseByEndpointGroup( subnet_id=subnet_id, group_id=group['id']) - def _make_endpoint_group_dict(self, endpoint_group, fields=None): + def _make_endpoint_group_dict(self, endpoint_group: ty.Dict[str, ty.Any], + fields: ty.Optional[ty.Dict] = None) -> ty.Dict: res = {'id': endpoint_group['id'], 'tenant_id': endpoint_group['tenant_id'], 'name': endpoint_group['name'], @@ -630,7 +665,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, for ep in endpoint_group['endpoints']]} return db_utils.resource_fields(res, fields) - def create_endpoint_group(self, context, endpoint_group): + def create_endpoint_group(self, context: context.Context, + endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]): group = endpoint_group['endpoint_group'] validator = self._get_validator() with db_api.CONTEXT_WRITER.using(context): @@ -650,8 +686,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.add(endpoint_db) return self._make_endpoint_group_dict(endpoint_group_db) - def update_endpoint_group(self, context, endpoint_group_id, - endpoint_group): + def update_endpoint_group(self, context: context.Context, + endpoint_group_id: str, + endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]): group_changes = endpoint_group['endpoint_group'] # Note: Endpoints cannot be changed, so will not do validation with db_api.CONTEXT_WRITER.using(context): @@ -661,7 +698,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, endpoint_group_db.update(group_changes) return self._make_endpoint_group_dict(endpoint_group_db) - def delete_endpoint_group(self, context, endpoint_group_id): + def delete_endpoint_group(self, context: context.Context, + endpoint_group_id: str): with db_api.CONTEXT_WRITER.using(context): self.check_endpoint_group_not_in_use(context, endpoint_group_id) endpoint_group_db = self._get_resource( @@ -669,18 +707,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase, context.session.delete(endpoint_group_db) @db_api.CONTEXT_READER - def get_endpoint_group(self, context, endpoint_group_id, fields=None): + def get_endpoint_group(self, context: context.Context, + endpoint_group_id: str, fields=None): endpoint_group_db = self._get_resource( context, vpn_models.VPNEndpointGroup, endpoint_group_id) return self._make_endpoint_group_dict(endpoint_group_db, fields) @db_api.CONTEXT_READER - def get_endpoint_groups(self, context, filters=None, fields=None): + def get_endpoint_groups(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): return model_query.get_collection(context, vpn_models.VPNEndpointGroup, self._make_endpoint_group_dict, filters=filters, fields=fields) - def check_endpoint_group_not_in_use(self, context, group_id): + def check_endpoint_group_not_in_use(self, context: context.Context, + group_id: str): query = context.session.query(vpn_models.IPsecSiteConnection) query = query.filter( sa.or_( @@ -690,13 +731,15 @@ class VPNPluginDb(vpnaas.VPNPluginBase, if query.first(): raise vpn_exception.EndpointGroupInUse(group_id=group_id) - def get_vpnservice_router_id(self, context, vpnservice_id): + def get_vpnservice_router_id(self, context: context.Context, + vpnservice_id: str): with db_api.CONTEXT_READER.using(context): vpnservice = self._get_vpnservice(context, vpnservice_id) return vpnservice['router_id'] @db_api.CONTEXT_READER - def get_peer_cidrs_for_router(self, context, router_id): + def get_peer_cidrs_for_router(self, context: context.Context, + router_id: str): filters = {'router_id': [router_id]} vpnservices = model_query.get_collection_query( context, vpn_models.VPNService, filters=filters).all() @@ -712,8 +755,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase, return cidrs -class VPNPluginRpcDbMixin(object): - def _build_local_subnet_cidr_map(self, context): +class VPNPluginRpcDbMixin(VPNPluginDb): + def _build_local_subnet_cidr_map(self, context: context.Context): """Build a dict of all local endpoint subnets, with list of CIDRs.""" query = context.session.query(models_v2.Subnet.id, models_v2.Subnet.cidr) @@ -728,7 +771,8 @@ class VPNPluginRpcDbMixin(object): vpn_models.VPNEndpointGroup.id) return {sn.id: sn.cidr for sn in query.all()} - def update_status_by_agent(self, context, service_status_info_list): + def update_status_by_agent(self, context: context.Context, + service_status_info_list): """Updating vpnservice and vpnconnection status. :param context: context variable @@ -768,7 +812,8 @@ class VPNPluginRpcDbMixin(object): def vpn_router_gateway_callback(resource, event, trigger, payload=None): # the event payload objects - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: context = payload.context router_id = payload.resource_id @@ -782,7 +827,8 @@ def vpn_router_gateway_callback(resource, event, trigger, payload=None): def migration_callback(resource, event, trigger, payload): context = payload.context router = payload.latest_state - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: vpn_plugin.check_router_in_use(context, router['id']) return True @@ -792,7 +838,8 @@ def subnet_callback(resource, event, trigger, payload=None): """Respond to subnet based notifications - see if subnet in use.""" context = payload.context subnet_id = payload.resource_id - vpn_plugin = directory.get_plugin(p_constants.VPN) + vpn_plugin: ty.Optional[VPNPluginDb] = \ + directory.get_plugin(p_constants.VPN) if vpn_plugin: vpn_plugin.check_subnet_in_use_by_endpoint_group(context, subnet_id) diff --git a/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py b/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py index 3b3e26342..4a6715c47 100644 --- a/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py +++ b/neutron_vpnaas/db/vpn/vpn_ext_gw_db.py @@ -13,6 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron.db.models import l3 as l3_models from neutron.db import models_v2 @@ -20,6 +21,7 @@ from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.db import model_base from neutron_lib.db import model_query @@ -33,8 +35,15 @@ from sqlalchemy import orm from sqlalchemy.orm import exc from neutron_vpnaas._i18n import _ +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.common import constants as v_constants +#pylint: disable=ungrouped-imports +# Additional import for typechecking. Importing these without typechecking +# would resolve in a cyclic dependency +if ty.TYPE_CHECKING: + from neutron.db import db_base_plugin_v2 as db_plugin +#pylint: enable=ungrouped-imports LOG = logging.getLogger(__name__) @@ -80,11 +89,11 @@ class VPNExtGW(model_base.BASEV2, model_base.HasId, model_base.HasProject): @registry.has_registry_receivers -class VPNExtGWPlugin_db(object): +class VPNExtGWPlugin_db(vpn_db.VPNPluginDb): """DB class to support vpn external ports configuration.""" @property - def _core_plugin(self): + def _core_plugin(self) -> 'db_plugin.NeutronDbPluginV2': return directory.get_plugin() @property @@ -95,13 +104,13 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.PORT, [events.BEFORE_DELETE]) def _prevent_vpn_port_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_port_deletion(payload.context, payload.resource_id) @db_api.CONTEXT_READER - def _id_used(self, context, id_column, resource_id): + def _id_used(self, context: context.Context, id_column, resource_id): return context.session.query(VPNExtGW).filter( sa.and_( id_column == resource_id, @@ -109,7 +118,8 @@ class VPNExtGWPlugin_db(object): ) ).count() > 0 - def prevent_vpn_port_deletion(self, context, port_id): + def prevent_vpn_port_deletion(self, context: context.Context, + port_id: str): """Checks to make sure a port is allowed to be deleted. Raises an exception if this is not the case. This should be called by @@ -124,7 +134,7 @@ class VPNExtGWPlugin_db(object): # non-existent ports don't need to be protected from deletion return - port_id_column = { + port_id_column: ty.Optional[str] = { v_constants.DEVICE_OWNER_VPN_ROUTER_GW: VPNExtGW.gw_port_id, v_constants.DEVICE_OWNER_TRANSIT_NETWORK: VPNExtGW.transit_port_id, @@ -142,12 +152,13 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.SUBNET, [events.BEFORE_DELETE]) def _prevent_vpn_subnet_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_subnet_deletion(payload.context, payload.resource_id) - def prevent_vpn_subnet_deletion(self, context, subnet_id): + def prevent_vpn_subnet_deletion(self, context: context.Context, + subnet_id: str): if self._id_used(context, VPNExtGW.transit_subnet_id, subnet_id): reason = _('Subnet is used by VPN service') raise n_exc.SubnetInUse(subnet_id=subnet_id, reason=reason) @@ -156,16 +167,18 @@ class VPNExtGWPlugin_db(object): @registry.receives(resources.NETWORK, [events.BEFORE_DELETE]) def _prevent_vpn_network_delete_callback(resource, event, trigger, payload=None): - vpn_plugin = directory.get_plugin(plugin_const.VPN) + vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN) if vpn_plugin: vpn_plugin.prevent_vpn_network_deletion(payload.context, payload.resource_id) - def prevent_vpn_network_deletion(self, context, network_id): + def prevent_vpn_network_deletion(self, context: context.Context, + network_id: str): if self._id_used(context, VPNExtGW.transit_network_id, network_id): raise VPNNetworkInUse(network_id=network_id) - def _make_vpn_ext_gw_dict(self, gateway_db): + def _make_vpn_ext_gw_dict(self, gateway_db: ty.Optional[VPNExtGW]) -> \ + ty.Optional[ty.Dict[str, ty.Any]]: if not gateway_db: return None gateway = { @@ -187,7 +200,8 @@ class VPNExtGWPlugin_db(object): gateway[key] = value return gateway - def _get_vpn_gw_by_router_id(self, context, router_id): + def _get_vpn_gw_by_router_id(self, context: context.Context, + router_id: str) -> ty.Optional[VPNExtGW]: try: gateway_db = context.session.query(VPNExtGW).filter( VPNExtGW.router_id == router_id).one() @@ -196,17 +210,20 @@ class VPNExtGWPlugin_db(object): return gateway_db @db_api.CONTEXT_READER - def get_vpn_gw_by_router_id(self, context, router_id): + def get_vpn_gw_by_router_id( + self, context: context.Context, router_id: str): return self._get_vpn_gw_by_router_id(context, router_id) @db_api.CONTEXT_READER - def get_vpn_gw_dict_by_router_id(self, context, router_id, refresh=False): + def get_vpn_gw_dict_by_router_id(self, context: context.Context, + router_id: str, refresh: bool = False): gateway_db = self._get_vpn_gw_by_router_id(context, router_id) if gateway_db and refresh: context.session.refresh(gateway_db) return self._make_vpn_ext_gw_dict(gateway_db) - def create_gateway(self, context, gateway): + def create_gateway(self, context: context.Context, + gateway: ty.Dict[str, ty.Dict[str, ty.Any]]): info = gateway['gateway'] with db_api.CONTEXT_WRITER.using(context): @@ -223,14 +240,15 @@ class VPNExtGWPlugin_db(object): return self._make_vpn_ext_gw_dict(gateway_db) - def update_gateway(self, context, gateway_id, gateway): + def update_gateway(self, context: context.Context, gateway_id: str, + gateway: ty.Dict[str, ty.Dict[str, ty.Any]]): info = gateway['gateway'] with db_api.CONTEXT_WRITER.using(context): gateway_db = model_query.get_by_id(context, VPNExtGW, gateway_id) gateway_db.update(info) return self._make_vpn_ext_gw_dict(gateway_db) - def delete_gateway(self, context, gateway_id): + def delete_gateway(self, context: context.Context, gateway_id: str): with db_api.CONTEXT_WRITER.using(context): query = context.session.query(VPNExtGW) return query.filter(VPNExtGW.id == gateway_id).delete() diff --git a/neutron_vpnaas/db/vpn/vpn_validator.py b/neutron_vpnaas/db/vpn/vpn_validator.py index 3366be2fb..0c10f604d 100644 --- a/neutron_vpnaas/db/vpn/vpn_validator.py +++ b/neutron_vpnaas/db/vpn/vpn_validator.py @@ -12,12 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import socket import netaddr + from neutron.db import l3_db from neutron.db import models_v2 +from neutron import neutron_plugin_base_v2 +from neutron.services.l3_router import l3_router_plugin from neutron_lib.api import validators +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_lib.exceptions import vpn as vpn_exception from neutron_lib.plugins import constants as plugin_const @@ -27,7 +33,7 @@ from neutron_vpnaas._i18n import _ from neutron_vpnaas.services.vpn.common import constants -class VpnReferenceValidator(object): +class VpnReferenceValidator: """ Baseline validation routines for VPN resources. @@ -38,28 +44,30 @@ class VpnReferenceValidator(object): IP_MIN_MTU = {4: 68, 6: 1280} @property - def l3_plugin(self): + def l3_plugin(self) -> l3_router_plugin.L3RouterPlugin: try: return self._l3_plugin except AttributeError: - self._l3_plugin = directory.get_plugin(plugin_const.L3) + self._l3_plugin: l3_router_plugin.L3RouterPlugin = \ + directory.get_plugin(plugin_const.L3) return self._l3_plugin @property - def core_plugin(self): + def core_plugin(self) -> neutron_plugin_base_v2.NeutronPluginBaseV2: try: return self._core_plugin except AttributeError: - self._core_plugin = directory.get_plugin() + self._core_plugin: neutron_plugin_base_v2.NeutronPluginBaseV2 = \ + directory.get_plugin() return self._core_plugin - def _check_dpd(self, ipsec_sitecon): + def _check_dpd(self, ipsec_sitecon: ty.Dict[str, ty.Union[int, ty.Any]]): """Ensure that DPD timeout is greater than DPD interval.""" if ipsec_sitecon['dpd_timeout'] <= ipsec_sitecon['dpd_interval']: raise vpn_exception.IPsecSiteConnectionDpdIntervalValueError( attr='dpd_timeout') - def _check_mtu(self, context, mtu, ip_version): + def _check_mtu(self, context: context.Context, mtu, ip_version): if mtu < VpnReferenceValidator.IP_MIN_MTU[ip_version]: raise vpn_exception.IPsecSiteConnectionMtuError( mtu=mtu, version=ip_version) @@ -95,7 +103,7 @@ class VpnReferenceValidator(object): ip_version = netaddr.IPAddress(ipsec_sitecon['peer_address']).version self._validate_peer_address(ip_version, router) - def _get_local_subnets(self, context, endpoint_group): + def _get_local_subnets(self, context: context.Context, endpoint_group): if endpoint_group['type'] != constants.SUBNET_ENDPOINT: raise vpn_exception.WrongEndpointGroupType( group_type=endpoint_group['type'], which=endpoint_group['id'], @@ -118,7 +126,7 @@ class VpnReferenceValidator(object): """ if len(local_subnets) == 1: return local_subnets[0]['ip_version'] - ip_versions = set([subnet['ip_version'] for subnet in local_subnets]) + ip_versions = {subnet['ip_version'] for subnet in local_subnets} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForIPSecEndpoints( group=group_id) @@ -131,7 +139,7 @@ class VpnReferenceValidator(object): """ if len(peer_cidrs) == 1: return netaddr.IPNetwork(peer_cidrs[0]).version - ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs]) + ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForIPSecEndpoints( group=group_id) @@ -149,12 +157,13 @@ class VpnReferenceValidator(object): """Ensure all CIDRs have the same IP version.""" if len(peer_cidrs) == 1: return netaddr.IPNetwork(peer_cidrs[0]).version - ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs]) + ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs} if len(ip_versions) > 1: raise vpn_exception.MixedIPVersionsForPeerCidrs() return ip_versions.pop() - def _check_local_subnets_on_router(self, context, router, local_subnets): + def _check_local_subnets_on_router(self, context: context.Context, + router, local_subnets): for subnet in local_subnets: self._check_subnet_id(context, router, subnet['id']) @@ -163,7 +172,9 @@ class VpnReferenceValidator(object): if local_ip_version != peer_ip_version: raise vpn_exception.MixedIPVersionsForIPSecConnection() - def validate_ipsec_conn_optional_args(self, ipsec_sitecon, subnet): + def validate_ipsec_conn_optional_args(self, + ipsec_sitecon: ty.Dict[str, ty.Any], + subnet): """Ensure that proper combinations of optional args are used. When VPN service has a subnet, then we must have peer_cidrs, and @@ -176,10 +187,10 @@ class VpnReferenceValidator(object): local_epg_id = ipsec_sitecon.get('local_ep_group_id') peer_epg_id = ipsec_sitecon.get('peer_ep_group_id') peer_cidrs = ipsec_sitecon.get('peer_cidrs') + epgs: ty.List[str] = [] if subnet: if not peer_cidrs: raise vpn_exception.MissingPeerCidrs() - epgs = [] if local_epg_id: epgs.append('local') if peer_epg_id: @@ -192,7 +203,6 @@ class VpnReferenceValidator(object): else: if peer_cidrs: raise vpn_exception.PeerCidrsInvalid() - epgs = [] if not local_epg_id: epgs.append('local') if not peer_epg_id: @@ -203,8 +213,9 @@ class VpnReferenceValidator(object): raise vpn_exception.MissingRequiredEndpointGroup( which=which, suffix=suffix) - def assign_sensible_ipsec_sitecon_defaults(self, ipsec_sitecon, - prev_conn=None): + def assign_sensible_ipsec_sitecon_defaults(self, + ipsec_sitecon: ty.Dict[str, ty.Any], + prev_conn: ty.Optional[ty.Dict] = None): """Provide defaults for optional items, if missing. With endpoint groups capabilities, the peer_cidr (legacy mode) @@ -229,7 +240,7 @@ class VpnReferenceValidator(object): prev_conn = {'dpd_action': 'hold', 'dpd_interval': 30, 'dpd_timeout': 120} - dpd = ipsec_sitecon.get('dpd', {}) + dpd: ty.Dict[str, ty.Any] = ipsec_sitecon.get('dpd', {}) ipsec_sitecon['dpd_action'] = dpd.get('action', prev_conn['dpd_action']) ipsec_sitecon['dpd_interval'] = dpd.get('interval', @@ -237,8 +248,10 @@ class VpnReferenceValidator(object): ipsec_sitecon['dpd_timeout'] = dpd.get('timeout', prev_conn['dpd_timeout']) - def validate_ipsec_site_connection(self, context, ipsec_sitecon, - local_ip_version, vpnservice=None): + def validate_ipsec_site_connection( + self, context: context.Context, ipsec_sitecon: ty.Dict[str, ty.Any], + local_ip_version, + vpnservice: ty.Optional[ty.Dict[str, ty.Any]] = None): """Reference implementation of validation for IPSec connection. This makes sure that IP versions are the same. For endpoint groups, @@ -254,7 +267,9 @@ class VpnReferenceValidator(object): local_subnets = self._get_local_subnets( context, ipsec_sitecon['local_epg_subnets']) self._check_local_subnets_on_router( - context, vpnservice['router_id'], local_subnets) + context, + vpnservice['router_id'] if vpnservice else '', + local_subnets) local_ip_version = self._check_local_endpoint_ip_versions( ipsec_sitecon['local_ep_group_id'], local_subnets) peer_cidrs = self._get_peer_cidrs(ipsec_sitecon['peer_epg_cidrs']) @@ -272,12 +287,13 @@ class VpnReferenceValidator(object): if mtu: self._check_mtu(context, mtu, local_ip_version) - def _check_router(self, context, router_id): + def _check_router(self, context: context.Context, router_id: str): router = self.l3_plugin.get_router(context, router_id) if not router.get(l3_db.EXTERNAL_GW_INFO): raise vpn_exception.RouterIsNotExternal(router_id=router_id) - def _check_subnet_id(self, context, router_id, subnet_id): + def _check_subnet_id(self, context: context.Context, + router_id: str, subnet_id: str): ports = self.core_plugin.get_ports( context, filters={ @@ -288,13 +304,14 @@ class VpnReferenceValidator(object): subnet_id=subnet_id, router_id=router_id) - def validate_vpnservice(self, context, vpnservice): + def validate_vpnservice(self, context: context.Context, + vpnservice: ty.Dict[str, ty.Any]): self._check_router(context, vpnservice['router_id']) if vpnservice['subnet_id'] is not None: self._check_subnet_id(context, vpnservice['router_id'], vpnservice['subnet_id']) - def validate_ipsec_policy(self, context, ipsec_policy): + def validate_ipsec_policy(self, context: context.Context, ipsec_policy): """Reference implementation of validation for IPSec Policy. Service driver can override and implement specific logic @@ -311,7 +328,8 @@ class VpnReferenceValidator(object): group_type=constants.CIDR_ENDPOINT, endpoint=cidr, why=_("Invalid CIDR")) - def _validate_subnets(self, context, subnet_ids): + def _validate_subnets(self, context: context.Context, + subnet_ids: ty.List[str]): """Ensure UUIDs OK and subnets exist.""" for subnet_id in subnet_ids: msg = validators.validate_uuid(subnet_id) @@ -325,7 +343,8 @@ class VpnReferenceValidator(object): raise vpn_exception.NonExistingSubnetInEndpointGroup( subnet=subnet_id) - def validate_endpoint_group(self, context, endpoint_group): + def validate_endpoint_group(self, context: context.Context, + endpoint_group: ty.Dict[str, ty.Any]): """Reference validator for endpoint group. Ensures that there is at least one endpoint, all the endpoints in the @@ -342,7 +361,7 @@ class VpnReferenceValidator(object): elif group_type == constants.SUBNET_ENDPOINT: self._validate_subnets(context, endpoints) - def validate_ike_policy(self, context, ike_policy): + def validate_ike_policy(self, context: context.Context, ike_policy): """Reference implementation of validation for IKE Policy. Service driver can override and implement specific logic diff --git a/neutron_vpnaas/extensions/vpn_agentschedulers.py b/neutron_vpnaas/extensions/vpn_agentschedulers.py index 0a09bb30a..1d1c9e7a3 100644 --- a/neutron_vpnaas/extensions/vpn_agentschedulers.py +++ b/neutron_vpnaas/extensions/vpn_agentschedulers.py @@ -13,6 +13,7 @@ # under the License. import abc +import typing as ty from neutron.api import extensions from neutron.api.v2 import resource @@ -20,6 +21,7 @@ from neutron import policy from neutron import wsgi from neutron_lib.api import extensions as lib_extensions from neutron_lib.api import faults as base +from neutron_lib import context from neutron_lib import exceptions from neutron_lib.plugins import constants as plugin_const from neutron_lib.plugins import directory @@ -37,9 +39,36 @@ VPN_AGENT = 'vpn-agent' VPN_AGENTS = VPN_AGENT + 's' +class VPNAgentSchedulerPluginBase(metaclass=abc.ABCMeta): + """REST API to operate the VPN agent scheduler. + + All methods must be in an admin context. + """ + + @abc.abstractmethod + def add_router_to_vpn_agent(self, context: context.ContextBase, + id: str, router_id: str): + pass + + @abc.abstractmethod + def remove_router_from_vpn_agent(self, context: context.ContextBase, + id: str, router_id: str): + pass + + @abc.abstractmethod + def list_routers_on_vpn_agent(self, context: context.ContextBase, id: str): + pass + + @abc.abstractmethod + def list_vpn_agents_hosting_router(self, context: context.ContextBase, + router_id: str): + pass + + class VPNRouterSchedulerController(wsgi.Controller): - def get_plugin(self): - plugin = directory.get_plugin(plugin_const.VPN) + def get_plugin(self) -> VPNAgentSchedulerPluginBase: + plugin: VPNAgentSchedulerPluginBase = \ + directory.get_plugin(plugin_const.VPN) if not plugin: LOG.error('No plugin for VPN registered to handle VPN ' 'router scheduling') @@ -47,18 +76,18 @@ class VPNRouterSchedulerController(wsgi.Controller): raise webob.exc.HTTPNotFound(msg) return plugin - def index(self, request, **kwargs): + def index(self, request: wsgi.Request, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "get_%s" % VPN_ROUTERS, + f"get_{VPN_ROUTERS}", {}) return plugin.list_routers_on_vpn_agent( request.context, kwargs['agent_id']) - def create(self, request, body, **kwargs): + def create(self, request: wsgi.Request, body, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "create_%s" % VPN_ROUTER, + f"create_{VPN_ROUTER}", {}) agent_id = kwargs['agent_id'] router_id = body['router_id'] @@ -67,10 +96,10 @@ class VPNRouterSchedulerController(wsgi.Controller): notify(request.context, 'vpn_agent.router.add', router_id, agent_id) return result - def delete(self, request, id, **kwargs): + def delete(self, request: wsgi.Request, id, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "delete_%s" % VPN_ROUTER, + f"delete_{VPN_ROUTER}", {}) agent_id = kwargs['agent_id'] result = plugin.remove_router_from_vpn_agent(request.context, agent_id, @@ -80,18 +109,19 @@ class VPNRouterSchedulerController(wsgi.Controller): class VPNAgentsHostingRouterController(wsgi.Controller): - def get_plugin(self): - plugin = directory.get_plugin(plugin_const.VPN) + def get_plugin(self) -> VPNAgentSchedulerPluginBase: + plugin: VPNAgentSchedulerPluginBase = \ + directory.get_plugin(plugin_const.VPN) if not plugin: LOG.error('VPN plugin not registered to handle agent scheduling') msg = 'The resource could not be found.' raise webob.exc.HTTPNotFound(msg) return plugin - def index(self, request, **kwargs): + def index(self, request: wsgi.Request, **kwargs): plugin = self.get_plugin() policy.enforce(request.context, - "get_%s" % VPN_AGENTS, + f"get_{VPN_AGENTS}", {}) return plugin.list_vpn_agents_hosting_router( request.context, kwargs['router_id']) @@ -102,35 +132,35 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor): """ @classmethod - def get_name(cls): + def get_name(cls) -> str: return "VPN Agent Scheduler" @classmethod - def get_alias(cls): + def get_alias(cls) -> str: return "vpn-agent-scheduler" @classmethod - def get_description(cls): + def get_description(cls) -> str: return "Schedule VPN services of routers among VPN agents" @classmethod - def get_updated(cls): + def get_updated(cls) -> str: return "2016-08-15T10:00:00-00:00" @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[extensions.ResourceExtension]: """Returns Ext Resources.""" - exts = [] - parent = dict(member_name="agent", - collection_name="agents") + exts: ty.List[extensions.ResourceExtension] = [] + parent = {'member_name': "agent", + 'collection_name': "agents"} controller = resource.Resource(VPNRouterSchedulerController(), base.FAULT_MAP) exts.append(extensions.ResourceExtension( VPN_ROUTERS, controller, parent)) - parent = dict(member_name="router", - collection_name="routers") + parent = {'member_name': "router", + 'collection_name': "routers"} controller = resource.Resource(VPNAgentsHostingRouterController(), base.FAULT_MAP) @@ -138,7 +168,7 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor): VPN_AGENTS, controller, parent)) return exts - def get_extended_resources(self, version): + def get_extended_resources(self, version) -> ty.Dict: return {} @@ -161,30 +191,7 @@ class RouterReschedulingFailed(exceptions.Conflict): "No eligible VPN agent found.") -class VPNAgentSchedulerPluginBase(object, metaclass=abc.ABCMeta): - """REST API to operate the VPN agent scheduler. - - All methods must be in an admin context. - """ - - @abc.abstractmethod - def add_router_to_vpn_agent(self, context, id, router_id): - pass - - @abc.abstractmethod - def remove_router_from_vpn_agent(self, context, id, router_id): - pass - - @abc.abstractmethod - def list_routers_on_vpn_agent(self, context, id): - pass - - @abc.abstractmethod - def list_vpn_agents_hosting_router(self, context, router_id): - pass - - -def notify(context, action, router_id, agent_id): +def notify(context: context.ContextBase, action, router_id, agent_id): info = {'id': agent_id, 'router_id': router_id} notifier = n_rpc.get_notifier('router') notifier.info(context, action, {'agent': info}) diff --git a/neutron_vpnaas/extensions/vpn_endpoint_groups.py b/neutron_vpnaas/extensions/vpn_endpoint_groups.py index 97afcff00..cefda6c1d 100644 --- a/neutron_vpnaas/extensions/vpn_endpoint_groups.py +++ b/neutron_vpnaas/extensions/vpn_endpoint_groups.py @@ -14,10 +14,14 @@ import abc +import typing as ty + from neutron_lib.api.definitions import vpn_endpoint_groups from neutron_lib.api import extensions +from neutron_lib import context from neutron_lib.plugins import constants as nconstants +from neutron.api import extensions as nextensions from neutron.api.v2 import resource_helper @@ -25,7 +29,7 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor): api_definition = vpn_endpoint_groups @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[nextensions.ResourceExtension]: plural_mappings = resource_helper.build_plural_mappings( {}, vpn_endpoint_groups.RESOURCE_ATTRIBUTE_MAP) return resource_helper.build_resource_info( @@ -36,25 +40,28 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor): translate_name=True) -class VPNEndpointGroupsPluginBase(object, metaclass=abc.ABCMeta): +class VPNEndpointGroupsPluginBase(metaclass=abc.ABCMeta): @abc.abstractmethod - def create_endpoint_group(self, context, endpoint_group): + def create_endpoint_group(self, context: context.Context, endpoint_group): pass @abc.abstractmethod - def update_endpoint_group(self, context, endpoint_group_id, - endpoint_group): + def update_endpoint_group(self, context: context.Context, + endpoint_group_id: str, endpoint_group): pass @abc.abstractmethod - def delete_endpoint_group(self, context, endpoint_group_id): + def delete_endpoint_group(self, context: context.Context, + endpoint_group_id: str): pass @abc.abstractmethod - def get_endpoint_group(self, context, endpoint_group_id, fields=None): + def get_endpoint_group(self, context: context.Context, + endpoint_group_id: str, fields=None): pass @abc.abstractmethod - def get_endpoint_groups(self, context, filters=None, fields=None): + def get_endpoint_groups(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass diff --git a/neutron_vpnaas/extensions/vpnaas.py b/neutron_vpnaas/extensions/vpnaas.py index c2ffc6c84..6335b502e 100644 --- a/neutron_vpnaas/extensions/vpnaas.py +++ b/neutron_vpnaas/extensions/vpnaas.py @@ -14,13 +14,16 @@ # under the License. import abc +import typing as ty from neutron_lib.api.definitions import vpn from neutron_lib.api import extensions +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_lib.plugins import constants as nconstants from neutron_lib.services import base as service_base +from neutron.api import extensions as nextensions from neutron.api.v2 import resource_helper from neutron_vpnaas._i18n import _ @@ -51,7 +54,7 @@ class Vpnaas(extensions.APIExtensionDescriptor): api_definition = vpn @classmethod - def get_resources(cls): + def get_resources(cls) -> ty.List[nextensions.ResourceExtension]: special_mappings = {'ikepolicies': 'ikepolicy', 'ipsecpolicies': 'ipsecpolicy'} plural_mappings = resource_helper.build_plural_mappings( @@ -71,90 +74,105 @@ class Vpnaas(extensions.APIExtensionDescriptor): class VPNPluginBase(service_base.ServicePluginBase, metaclass=abc.ABCMeta): - def get_plugin_type(self): + def get_plugin_type(self) -> str: return nconstants.VPN - def get_plugin_description(self): + def get_plugin_description(self) -> str: return 'VPN service plugin' @abc.abstractmethod - def get_vpnservices(self, context, filters=None, fields=None): + def get_vpnservices(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass @abc.abstractmethod - def get_vpnservice(self, context, vpnservice_id, fields=None): + def get_vpnservice(self, context: context.Context, vpnservice_id: str, + fields=None): pass @abc.abstractmethod - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.Context, vpnservice): pass @abc.abstractmethod - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: context.Context, vpnservice_id: str, + vpnservice): pass @abc.abstractmethod - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: context.Context, vpnservice_id: str): pass @abc.abstractmethod - def get_ipsec_site_connections(self, context, filters=None, fields=None): + def get_ipsec_site_connections(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, + fields=None): pass @abc.abstractmethod - def get_ipsec_site_connection(self, context, - ipsecsite_conn_id, fields=None): + def get_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str, fields=None): pass @abc.abstractmethod - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): pass @abc.abstractmethod - def update_ipsec_site_connection(self, context, - ipsecsite_conn_id, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str, + ipsec_site_connection): pass @abc.abstractmethod - def delete_ipsec_site_connection(self, context, ipsecsite_conn_id): + def delete_ipsec_site_connection(self, context: context.Context, + ipsecsite_conn_id: str): pass @abc.abstractmethod - def get_ikepolicy(self, context, ikepolicy_id, fields=None): + def get_ikepolicy(self, context: context.Context, ikepolicy_id: str, + fields=None): pass @abc.abstractmethod - def get_ikepolicies(self, context, filters=None, fields=None): + def get_ikepolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict], fields=None): pass @abc.abstractmethod - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, ikepolicy): pass @abc.abstractmethod - def update_ikepolicy(self, context, ikepolicy_id, ikepolicy): + def update_ikepolicy(self, context: context.Context, ikepolicy_id: str, + ikepolicy): pass @abc.abstractmethod - def delete_ikepolicy(self, context, ikepolicy_id): + def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str): pass @abc.abstractmethod - def get_ipsecpolicies(self, context, filters=None, fields=None): + def get_ipsecpolicies(self, context: context.Context, + filters: ty.Optional[ty.Dict] = None, fields=None): pass @abc.abstractmethod - def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None): + def get_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + fields=None): pass @abc.abstractmethod - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass @abc.abstractmethod - def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str, + ipsecpolicy): pass @abc.abstractmethod - def delete_ipsecpolicy(self, context, ipsecpolicy_id): + def delete_ipsecpolicy(self, context: context.Context, + ipsecpolicy_id: str): pass diff --git a/neutron_vpnaas/opts.py b/neutron_vpnaas/opts.py index ef7d3bbd7..6ef11940a 100644 --- a/neutron_vpnaas/opts.py +++ b/neutron_vpnaas/opts.py @@ -9,6 +9,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty import neutron.conf.plugins.ml2.drivers.ovn.ovn_conf import neutron.services.provider_configuration @@ -19,7 +20,7 @@ import neutron_vpnaas.services.vpn.device_drivers.strongswan_ipsec import neutron_vpnaas.services.vpn.ovn_agent -def list_agent_opts(): +def list_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('vpnagent', neutron_vpnaas.services.vpn.agent.vpn_agent_opts), @@ -33,7 +34,7 @@ def list_agent_opts(): ] -def list_ovn_agent_opts(): +def list_ovn_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('vpnagent', neutron_vpnaas.services.vpn.ovn_agent.VPN_AGENT_OPTS), @@ -51,7 +52,7 @@ def list_ovn_agent_opts(): ] -def list_opts(): +def list_opts() -> ty.List[ty.Tuple[str, ty.List]]: return [ ('service_providers', neutron.services.provider_configuration.serviceprovider_opts) diff --git a/neutron_vpnaas/scheduler/vpn_agent_scheduler.py b/neutron_vpnaas/scheduler/vpn_agent_scheduler.py index bea9bfe84..2434466af 100644 --- a/neutron_vpnaas/scheduler/vpn_agent_scheduler.py +++ b/neutron_vpnaas/scheduler/vpn_agent_scheduler.py @@ -14,33 +14,42 @@ import abc import random +import typing as ty +from neutron.extensions import agent as nagent from neutron.extensions import availability_zone as az_ext +from neutron.extensions import l3 +from neutron_lib import context from neutron_lib.plugins import constants as plugin_constants from neutron_lib.plugins import directory from oslo_config import cfg from oslo_log import log as logging +if ty.TYPE_CHECKING: + from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as scheduler_db from neutron_vpnaas.extensions import vpn_agentschedulers LOG = logging.getLogger(__name__) -class VPNScheduler(object, metaclass=abc.ABCMeta): +class VPNScheduler(metaclass=abc.ABCMeta): @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(plugin_constants.L3) @abc.abstractmethod - def schedule(self, plugin, context, router_id, - candidates=None, hints=None): + def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None) -> \ + ty.Optional[nagent.Agent]: """Schedule the router to an active VPN agent. Schedule the router only if it is not already scheduled. """ pass - def _get_unscheduled_routers(self, context, plugin, router_ids=None): + def _get_unscheduled_routers(self, context: context.Context, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + router_ids=None) -> ty.List: """Get the list of routers with VPN services to be scheduled. If router IDs are omitted, look for all unscheduled routers. @@ -57,7 +66,10 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): context, filters={'id': unscheduled_router_ids}) return [] - def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent): + def _get_routers_can_schedule( + self, context: context.Context, + plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + routers, vpn_agent,): """Get the subset of routers whose VPN services can be scheduled on the VPN agent. """ @@ -65,7 +77,9 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): # all routers can be scheduled to it return routers - def auto_schedule_routers(self, plugin, context, vpn_agent): + def auto_schedule_routers( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, vpn_agent) -> ty.List[str]: """Schedule non-hosted routers to a VPN agent. :returns: True if routers have been successfully assigned to the agent @@ -83,20 +97,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): self._bind_routers(context, plugin, target_routers, vpn_agent) return [router['id'] for router in target_routers] - def _get_candidates(self, plugin, context, sync_router): + def _get_candidates( + self, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + sync_router) -> ty.Optional[ty.List[nagent.Agent]]: """Return VPN agents where a router could be scheduled.""" active_vpn_agents = plugin.get_vpn_agents(context, active=True) if not active_vpn_agents: LOG.warning('No active VPN agents') return active_vpn_agents - def _bind_routers(self, context, plugin, routers, vpn_agent): + def _bind_routers( + self, context: context.Context, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + routers, vpn_agent): for router in routers: plugin.create_router_to_agent_binding( context, router['id'], vpn_agent['id']) - def _schedule_router(self, plugin, context, router_id, - candidates=None): + def _schedule_router( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + router_id, + candidates: ty.Optional[ty.List[nagent.Agent]] = None) -> \ + ty.Optional[nagent.Agent]: current_vpn_agents = plugin.get_vpn_agents_hosting_routers( context, [router_id]) if current_vpn_agents: @@ -118,9 +143,13 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): if plugin.create_router_to_agent_binding(context, router_id, chosen_agent['id']): return chosen_agent + return None @abc.abstractmethod - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: """Choose an agent from candidates based on a specific policy.""" pass @@ -128,24 +157,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta): class ChanceScheduler(VPNScheduler): """Randomly allocate an VPN agent for a router.""" - def schedule(self, plugin, context, router_id, - candidates=None): + def schedule( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None): return self._schedule_router( plugin, context, router_id, candidates=candidates) - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: return random.choice(candidates) class LeastRoutersScheduler(VPNScheduler): """Allocate to an VPN agent with the least number of routers bound.""" - def schedule(self, plugin, context, router_id, - candidates=None): + def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, router_id, candidates=None, hints=None): return self._schedule_router( plugin, context, router_id, candidates=candidates) - def _choose_vpn_agent(self, plugin, context, candidates): + def _choose_vpn_agent( + self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + candidates: ty.List[nagent.Agent]) -> nagent.Agent: candidates_dict = {c['id']: c for c in candidates} chosen_agent_id = plugin.get_vpn_agent_with_min_routers( context, candidates_dict.keys()) @@ -156,9 +192,12 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler): """Availability zone aware scheduler.""" def _get_az_hints(self, router): return (router.get(az_ext.AZ_HINTS) or - cfg.CONF.default_availability_zones) + cfg.CONF.default_availability_zones) - def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent): + def _get_routers_can_schedule( + self, context: context.Context, + plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase, + routers, vpn_agent): """Overwrite VPNScheduler's method to filter by availability zone.""" target_routers = [] for r in routers: @@ -172,14 +211,20 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler): return super()._get_routers_can_schedule( context, plugin, target_routers, vpn_agent) - def _get_candidates(self, plugin, context, sync_router): + def _get_candidates( + self, + plugin: 'scheduler_db.VPNAgentSchedulerDbMixin', + context: context.Context, + sync_router) -> ty.Optional[ty.List[nagent.Agent]]: """Overwrite VPNScheduler's method to filter by availability zone.""" all_candidates = super()._get_candidates(plugin, context, sync_router) - candidates = [] - az_hints = self._get_az_hints(sync_router) - for agent in all_candidates: - if not az_hints or agent['availability_zone'] in az_hints: - candidates.append(agent) + if all_candidates: + candidates = [] + az_hints = self._get_az_hints(sync_router) + for agent in all_candidates: + if not az_hints or agent['availability_zone'] in az_hints: + candidates.append(agent) - return candidates + return candidates + return None diff --git a/neutron_vpnaas/services/vpn/agent.py b/neutron_vpnaas/services/vpn/agent.py index 86810aeaa..b5fe1356e 100644 --- a/neutron_vpnaas/services/vpn/agent.py +++ b/neutron_vpnaas/services/vpn/agent.py @@ -14,12 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty +from neutron.agent.l3 import l3_agent_extension_api from neutron_lib.agent import l3_extension +from neutron_lib import context from oslo_config import cfg from oslo_log import log as logging from neutron_vpnaas._i18n import _ +from neutron_vpnaas.services.vpn import device_drivers from neutron_vpnaas.services.vpn import vpn_service LOG = logging.getLogger(__name__) @@ -43,10 +47,11 @@ cfg.CONF.register_opts(vpn_agent_opts, 'vpnagent') class VPNAgent(l3_extension.L3AgentExtension): """VPNaaS Agent support to be used by Neutron L3 agent.""" - def initialize(self, connection, driver_type): + def initialize(self, connection, driver_type: device_drivers.DeviceDriver): LOG.debug("Loading VPNaaS") - def consume_api(self, agent_api): + def consume_api(self, + agent_api: l3_agent_extension_api.L3AgentExtensionAPI): LOG.debug("Loading consume_api for VPNaaS") self.agent_api = agent_api @@ -58,7 +63,7 @@ class VPNAgent(l3_extension.L3AgentExtension): self.service = vpn_service.VPNService(self) self.device_drivers = self.service.load_device_drivers(self.host) - def add_router(self, context, data): + def add_router(self, context: context.Context, data): """Handles router add event""" ri = self.agent_api.get_router_info(data['id']) if ri is not None: @@ -69,17 +74,18 @@ class VPNAgent(l3_extension.L3AgentExtension): LOG.debug("Router %s was concurrently deleted while " "creating VPN for it", data['id']) - def update_router(self, context, data): + def update_router(self, context: context.Context, data): """Handles router update event""" for device_driver in self.device_drivers: device_driver.sync(context, [data]) - def delete_router(self, context, data): + def delete_router(self, context: context.Context, data): """Handles router delete event""" for device_driver in self.device_drivers: device_driver.destroy_router(data['id']) - def ha_state_change(self, context, data): + def ha_state_change(self, context: context.Context, + data: ty.Dict[str, str]): """Enable the vpn process when router transitioned to master. And disable vpn process for backup router. @@ -98,7 +104,7 @@ class VPNAgent(l3_extension.L3AgentExtension): else: process.disable() - def update_network(self, context, data): + def update_network(self, context: context.Context, data): pass @@ -109,5 +115,4 @@ class L3WithVPNaaS(VPNAgent): self.conf = conf else: self.conf = cfg.CONF - super(L3WithVPNaaS, self).__init__( - host=self.conf.host, conf=self.conf) + super().__init__(host=self.conf.host, conf=self.conf) diff --git a/neutron_vpnaas/services/vpn/common/netns_wrapper.py b/neutron_vpnaas/services/vpn/common/netns_wrapper.py index 25a4f6877..36069e4af 100644 --- a/neutron_vpnaas/services/vpn/common/netns_wrapper.py +++ b/neutron_vpnaas/services/vpn/common/netns_wrapper.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import configparser as ConfigParser import errno import os @@ -31,7 +33,7 @@ from neutron_vpnaas._i18n import _ LOG = logging.getLogger(__name__) -def setup_conf(): +def setup_conf() -> cfg.ConfigOpts: cli_opts = [ cfg.DictOpt('mount_paths', required=True, @@ -50,9 +52,9 @@ def setup_conf(): return conf -def execute(cmd): +def execute(cmd) -> ty.Optional[int]: if not cmd: - return + return None cmd = list(map(str, cmd)) LOG.debug("Running command: %s", cmd) env = os.environ.copy() @@ -106,12 +108,12 @@ def filter_command(command, rootwrap_config): 'name': exc.match.name}) sys.exit(errno.EINVAL) except wrapper.NoFilterMatched: - LOG.error('Unauthorized command: %(cmd)s (no filter matched)', - {'cmd': command}) + LOG.error("Unauthorized command: %(cmd)s (no filter matched)", + {"cmd": command}) sys.exit(errno.EPERM) -def execute_with_mount(): +def execute_with_mount() -> ty.Optional[int]: config.register_common_config_options() conf = setup_conf() conf() diff --git a/neutron_vpnaas/services/vpn/device_drivers/__init__.py b/neutron_vpnaas/services/vpn/device_drivers/__init__.py index 60bf101eb..06638e56b 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/__init__.py +++ b/neutron_vpnaas/services/vpn/device_drivers/__init__.py @@ -14,14 +14,16 @@ # under the License. import abc +from neutron_lib import context -class DeviceDriver(object, metaclass=abc.ABCMeta): + +class DeviceDriver(metaclass=abc.ABCMeta): def __init__(self, agent, host): pass @abc.abstractmethod - def sync(self, context, processes): + def sync(self, context: context.ContextBase, processes): pass @abc.abstractmethod diff --git a/neutron_vpnaas/services/vpn/device_drivers/ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/ipsec.py index 7ffcaac71..e2c68c2e9 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/ipsec.py @@ -12,6 +12,9 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + import abc import base64 import copy @@ -126,7 +129,8 @@ IPSEC_CONNS = 'ipsec_site_connections' PSK_BASE64_PREFIX = '0s' -def _get_template(template_file): +def _get_template( + template_file: ty.Union[str, jinja2.Template]) -> jinja2.Template: global JINJA_ENV if not JINJA_ENV: templateLoader = jinja2.FileSystemLoader(searchpath="/") @@ -134,7 +138,7 @@ def _get_template(template_file): return JINJA_ENV.get_template(template_file) -class BaseSwanProcess(object, metaclass=abc.ABCMeta): +class BaseSwanProcess(metaclass=abc.ABCMeta): """Swan Family Process Manager This class manages start/restart/stop ipsec process. @@ -189,12 +193,12 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): STATUS_IPSEC_SA_ESTABLISHED_RE2 = ( r'\d{3} #\d+: "([a-f0-9\-\/x]+).*established.*newest IPSEC') - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self.conf = conf self.id = process_id - self.updated_pending_status = False + self.updated_pending_status: bool = False self.namespace = namespace - self.connection_status = {} + self.connection_status: ty.Dict[str, ty.Dict[str, ty.Any]] = {} self.config_dir = os.path.join( self.conf.ipsec.config_base_dir, self.id) self.etc_dir = os.path.join(self.config_dir, 'etc') @@ -212,32 +216,31 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): def translate_dialect(self): if not self.vpnservice: return - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']: - self._dialect(ipsec_site_conn, 'initiator') - self._dialect(ipsec_site_conn['ikepolicy'], 'ike_version') - for key in ['encryption_algorithm', - 'auth_algorithm', - 'pfs']: - self._dialect(ipsec_site_conn['ikepolicy'], key) - self._dialect(ipsec_site_conn['ipsecpolicy'], key) - if (('local_id' not in ipsec_site_conn.keys()) or - (not ipsec_site_conn['local_id'])): - ipsec_site_conn['local_id'] = ipsec_site_conn['external_ip'] + for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]: + self._dialect(ipsec_site_conn, "initiator") + self._dialect(ipsec_site_conn["ikepolicy"], "ike_version") + for key in ["encryption_algorithm", "auth_algorithm", "pfs"]: + self._dialect(ipsec_site_conn["ikepolicy"], key) + self._dialect(ipsec_site_conn["ipsecpolicy"], key) + if ("local_id" not in ipsec_site_conn.keys()) or ( + not ipsec_site_conn["local_id"] + ): + ipsec_site_conn["local_id"] = ipsec_site_conn["external_ip"] def base64_encode_psk(self): if not self.vpnservice: return - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']: - psk = ipsec_site_conn['psk'] + for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]: + psk = ipsec_site_conn["psk"] encoded_psk = base64.b64encode(encodeutils.safe_encode(psk)) # NOTE(huntxu): base64.b64encode returns an instance of 'bytes' # in Python 3, convert it to a str. For Python 2, after calling # safe_decode, psk is converted into a unicode not containing any # non-ASCII characters so it doesn't matter. - psk = encodeutils.safe_decode(encoded_psk, incoming='utf_8') - ipsec_site_conn['psk'] = PSK_BASE64_PREFIX + psk + psk = encodeutils.safe_decode(encoded_psk, incoming="utf_8") + ipsec_site_conn["psk"] = PSK_BASE64_PREFIX + psk - def get_ns_wrapper(self): + def get_ns_wrapper(self) -> str: """ Check if we're inside a virtualenv. If we are, then we should respect this and launch wrapper from venv as well. @@ -249,7 +252,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): ns_wrapper = self.NS_WRAPPER return ns_wrapper - def update_vpnservice(self, vpnservice): + def update_vpnservice(self, vpnservice: ty.Dict[str, ty.Any]): self.vpnservice = vpnservice self.translate_dialect() self.base64_encode_psk() @@ -275,7 +278,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): agent_utils.execute( cmd=["rm", "-rf", self.config_dir], run_as_root=True) - def _get_config_filename(self, kind): + def _get_config_filename(self, kind) -> str: config_dir = self.etc_dir return os.path.join(config_dir, kind) @@ -286,15 +289,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): dir_path = os.path.join(self.config_dir, subdir) fileutils.ensure_tree(dir_path, 0o755) - def _gen_config_content(self, template_file, vpnservice): + def _gen_config_content(self, template_file, vpnservice) -> str: template = _get_template(template_file) return template.render( {'vpnservice': vpnservice, 'state_path': self.conf.state_path}) - def _get_rootwrap_config(self): + def _get_rootwrap_config(self) -> ty.Optional[str]: if 'neutron-rootwrap' in cfg.CONF.AGENT.root_helper: - rh_tokens = cfg.CONF.AGENT.root_helper.split(' ') + rh_tokens: ty.List[str] = cfg.CONF.AGENT.root_helper.split(' ') if len(rh_tokens) == 3 and os.path.exists(rh_tokens[2]): return rh_tokens[2] return None @@ -304,13 +307,13 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): pass @property - def status(self): + def status(self) -> str: if self.active: return constants.ACTIVE return constants.DOWN @property - def active(self): + def active(self) -> bool: """Check if the process is active or not.""" if not self.namespace: return False @@ -329,8 +332,9 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): # Disable the process if a vpnservice is disabled or it has no # enabled IPSec site connections. vpnservice_has_active_ipsec_site_conns = any( - [ipsec_site_conn['admin_state_up'] - for ipsec_site_conn in self.vpnservice['ipsec_site_connections']]) + ipsec_site_conn['admin_state_up'] + for ipsec_site_conn in + self.vpnservice['ipsec_site_connections']) if (not self.vpnservice['admin_state_up'] or not vpnservice_has_active_ipsec_site_conns): self.disable() @@ -386,7 +390,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): def stop(self): """Stop process.""" - def _check_status_line(self, line): + def _check_status_line(self, + line: str) -> ty.Tuple[ty.Optional[str], ty.Optional[str]]: """Parse a line and search for status information. If a connection has an established Security Association, @@ -404,14 +409,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): if m: connection_id = m.group(1) return connection_id, constants.ACTIVE - else: - m = self.STATUS_PATTERN.search(line) - if m: - connection_id = m.group(1) - return connection_id, constants.DOWN + + m = self.STATUS_PATTERN.search(line) + if m: + connection_id = m.group(1) + return connection_id, constants.DOWN return None, None - def _extract_and_record_connection_status(self, status_output): + def _extract_and_record_connection_status(self, + status_output: ty.Optional[str]): if not status_output: self.connection_status = {} return @@ -423,8 +429,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta): if conn_id: self._record_connection_status(conn_id, conn_status) - def _record_connection_status(self, connection_id, status, - force_status_update=False): + def _record_connection_status(self, connection_id: str, status, + force_status_update: bool = False): conn_info = self.connection_status.get(connection_id) if not conn_info: self.connection_status[connection_id] = { @@ -445,18 +451,20 @@ class OpenSwanProcess(BaseSwanProcess): (2) ipsec addconn: Adds new ipsec addconn (3) ipsec whack: control interface for IPSEC keying daemon """ - def __init__(self, conf, process_id, vpnservice, namespace): - super(OpenSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) - self.secrets_file = os.path.join( - self.etc_dir, 'ipsec.secrets') - self.config_file = os.path.join( - self.etc_dir, 'ipsec.conf') - self.pid_path = os.path.join( - self.config_dir, 'var', 'run', 'pluto') - self.pid_file = '%s.pid' % self.pid_path - def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): + super().__init__(conf, process_id, vpnservice, namespace) + self.secrets_file: str = os.path.join( + self.etc_dir, 'ipsec.secrets') + self.config_file: str = os.path.join( + self.etc_dir, 'ipsec.conf') + self.pid_path: str = os.path.join( + self.config_dir, 'var', 'run', 'pluto') + self.pid_file: str = f'{self.pid_path}.pid' + + def _execute(self, cmd, check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None + ) -> ty.Optional[str]: """Execute command on namespace.""" ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace) return ip_wrapper.netns.execute(cmd, check_exit_code=check_exit_code, @@ -490,7 +498,7 @@ class OpenSwanProcess(BaseSwanProcess): shutil.copyfile(config_file_name, config_file_name + '.old') os.chmod(config_file_name + '.old', 0o600) - def _process_running(self): + def _process_running(self) -> bool: """Checks if process is still running.""" # If no PID file, we assume the process is not running. @@ -502,9 +510,10 @@ class OpenSwanProcess(BaseSwanProcess): # on throwing to tell us something. If the pid file exists, # delve into the process information and check if it matches # our expected command line. - with open(self.pid_file, 'r') as f: + with open(self.pid_file, 'r', encoding="C") as f: pid = f.readline().strip() - with open('/proc/%s/cmdline' % pid) as cmd_line_file: + with open(f'/proc/{pid}/cmdline', + encoding="C") as cmd_line_file: cmd_line = cmd_line_file.readline() if self.pid_path in cmd_line and 'pluto' in cmd_line: # Okay the process is probably a pluto process @@ -529,7 +538,7 @@ class OpenSwanProcess(BaseSwanProcess): def _cleanup_control_files(self): try: - ctl_file = '%s.ctl' % self.pid_path + ctl_file = f'{self.pid_path}.ctl' LOG.debug('Removing %(pidfile)s and %(ctlfile)s', {'pidfile': self.pid_file, 'ctlfile': ctl_file}) @@ -545,14 +554,14 @@ class OpenSwanProcess(BaseSwanProcess): 'files for router %(router)s. %(msg)s', {'router': self.id, 'msg': e}) - def get_status(self): + def get_status(self) -> ty.Optional[str]: return self._execute([self.binary, 'whack', '--ctlbase', self.pid_path, '--status'], extra_ok_codes=[1, 3]) - def _config_changed(self): + def _config_changed(self) -> bool: secrets_file = os.path.join( self.etc_dir, 'ipsec.secrets') config_file = os.path.join( @@ -590,19 +599,25 @@ class OpenSwanProcess(BaseSwanProcess): LOG.warning('Server appears to still be running, restart ' 'of router %s may fail', self.id) self.start() - return - def _resolve_fqdn(self, fqdn): + def _resolve_fqdn(self, fqdn) -> ty.Optional[str]: # The first addrinfo member from the list returned by # socket.getaddrinfo is used for the address resolution. # The code doesn't filter for ipv4 or ipv6 address. try: - addrinfo = socket.getaddrinfo(fqdn, None)[0] + addrinfo: ty.Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + ty.Union[ty.Tuple[str, int], ty.Tuple[str, int, int, int]], + ] = socket.getaddrinfo(fqdn, None)[0] return addrinfo[-1][0] except socket.gaierror: LOG.exception("Peer address %s cannot be resolved", fqdn) + return None - def _get_nexthop(self, address, connection_id): + def _get_nexthop(self, address: str, connection_id: str) -> str: # check if address is an ip address or fqdn invalid_ip_address = validators.validate_ip_address(address) if invalid_ip_address: @@ -615,28 +630,31 @@ class OpenSwanProcess(BaseSwanProcess): else: ip_addr = address routes = self._execute(['ip', 'route', 'get', ip_addr]) - if routes.find('via') >= 0: + if routes and routes.find('via') >= 0: return routes.split(' ')[2] return address - def _virtual_privates(self, vpnservice): + def _virtual_privates(self, + vpnservice: ty.Dict[str, ty.List[ty.Dict[str, ty.Any]]]) -> str: """Returns line of virtual_privates. virtual_private contains the networks that are allowed as subnet for the remote client. """ - virtual_privates = [] - nets = [] + virtual_privates: ty.List[str] = [] + nets: ty.List = [] for ipsec_site_conn in vpnservice['ipsec_site_connections']: nets += ipsec_site_conn['local_cidrs'] nets += ipsec_site_conn['peer_cidrs'] for net in nets: version = netaddr.IPNetwork(net).version - virtual_privates.append('%%v%s:%s' % (version, net)) + virtual_privates.append(f'%v{version}:{net}') virtual_privates.sort() return ','.join(virtual_privates) - def _gen_config_content(self, template_file, vpnservice): + def _gen_config_content(self, + template_file: ty.Union[str, jinja2.Template], + vpnservice) -> str: template = _get_template(template_file) virtual_privates = self._virtual_privates(vpnservice) return template.render( @@ -660,7 +678,7 @@ class OpenSwanProcess(BaseSwanProcess): def add_ipsec_connection(self, nexthop, conn_id): self._execute([self.binary, 'addconn', - '--ctlbase', '%s.ctl' % self.pid_path, + '--ctlbase', f'{self.pid_path}.ctl', '--defaultroutenexthop', nexthop, '--config', self.config_file, conn_id ]) @@ -745,9 +763,9 @@ class OpenSwanProcess(BaseSwanProcess): self.initiate_connection(ipsec_site_conn['id']) self._copy_configs() - def get_established_connections(self): - connections = [] - status_output = self.get_status() + def get_established_connections(self) -> ty.List[str]: + connections: ty.List[str] = [] + status_output: ty.Optional[str] = self.get_status() if not status_output: return connections @@ -781,7 +799,7 @@ class OpenSwanProcess(BaseSwanProcess): self.connection_status = {} -class IPsecVpnDriverApi(object): +class IPsecVpnDriverApi: """IPSecVpnDriver RPC api.""" @log_helpers.log_method_call @@ -790,7 +808,7 @@ class IPsecVpnDriverApi(object): self.client = n_rpc.get_client(target) @log_helpers.log_method_call - def get_vpn_services_on_host(self, context, host): + def get_vpn_services_on_host(self, context: context.ContextBase, host): """Get list of vpnservices. The vpnservices including related ipsec_site_connection, @@ -800,7 +818,7 @@ class IPsecVpnDriverApi(object): return cctxt.call(context, 'get_vpn_services_on_host', host=host) @log_helpers.log_method_call - def update_status(self, context, status): + def update_status(self, context: context.ContextBase, status): """Update local status. This method call updates status attribute of @@ -823,19 +841,20 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): # 1.0 Initial version target = oslo_messaging.Target(version='1.0') - def __init__(self, vpn_service, host): + def __init__(self, vpn_service, host: str): # TODO(pc_m) Replace vpn_service with config arg, once all driver # implementations no longer need vpn_service. self.conf = vpn_service.conf self.host = host self.conn = n_rpc.Connection() - self.context = context.get_admin_context_without_session() - self.topic = topics.IPSEC_AGENT_TOPIC - node_topic = '%s.%s' % (self.topic, self.host) + self.context: context.ContextBase = \ + context.get_admin_context_without_session() + self.topic: str = topics.IPSEC_AGENT_TOPIC + node_topic: str = f'{self.topic}.{self.host}' - self.processes = {} - self.routers = {} - self.process_status_cache = {} + self.processes: ty.Dict[str, BaseSwanProcess] = {} + self.routers: ty.Dict[str, ty.Any] = {} + self.process_status_cache: ty.Dict[str, ty.Dict[str, ty.Any]] = {} self.endpoints = [self] self.conn.create_consumer(node_topic, self.endpoints, fanout=False) @@ -846,7 +865,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): self.process_status_cache_check.start( interval=self.conf.ipsec.ipsec_status_check_interval) - def get_namespace(self, router_id): + def get_namespace(self, router_id: str) -> ty.Optional[str]: """Get namespace of router. :router_id: router_id @@ -856,7 +875,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): """ router = self.routers.get(router_id) if not router: - return + return None # For DVR, use SNAT namespace # TODO(pcm): Use router object method to tell if DVR, when available if router.router['distributed']: @@ -941,14 +960,14 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): for peer_cidr in ipsec_site_connection['peer_cidrs']: func(router_id, 'POSTROUTING', - '-s %s -d %s -m policy ' + f'-s {local_cidr} -d {peer_cidr} -m policy ' '--dir out --pol ipsec ' - '-j ACCEPT ' % (local_cidr, peer_cidr), + '-j ACCEPT ', top=True) self.iptables_apply(router_id) @log_helpers.log_method_call - def vpnservice_updated(self, context, **kwargs): + def vpnservice_updated(self, context: context.ContextBase, **kwargs): """Vpnservice updated rpc handler VPN Service Driver will call this method @@ -959,16 +978,18 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): self.sync(context, [router] if router else []) @abc.abstractmethod - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace) -> BaseSwanProcess: pass - def ensure_process(self, process_id, vpnservice=None): + def ensure_process(self, process_id: str, + vpnservice=None) -> BaseSwanProcess: """Ensuring process. If the process doesn't exist, it will create process and store it in self.process """ - process = self.processes.get(process_id) + process = self.processes.get(process_id, None) if not process or not process.namespace: namespace = self.get_namespace(process_id) process = self.create_process( @@ -1022,7 +1043,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): if process_id in self.routers: del self.routers[process_id] - def get_process_status_cache(self, process): + def get_process_status_cache(self, + process: BaseSwanProcess) -> ty.Dict[str, ty.Any]: if not self.process_status_cache.get(process.id): self.process_status_cache[process.id] = { 'status': None, @@ -1031,7 +1053,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'ipsec_site_connections': {}} return self.process_status_cache[process.id] - def is_status_updated(self, process, previous_status): + def is_status_updated(self, process: BaseSwanProcess, + previous_status: ty.Dict[str, ty.Any]) -> bool: if process.updated_pending_status: return True if process.status != previous_status['status']: @@ -1039,13 +1062,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): if (process.connection_status != previous_status['ipsec_site_connections']): return True + return False def unset_updated_pending_status(self, process): process.updated_pending_status = False for connection_status in process.connection_status.values(): connection_status['updated_pending_status'] = False - def copy_process_status(self, process): + def copy_process_status(self, + process: BaseSwanProcess) -> ty.Dict[str, ty.Any]: return { 'id': process.vpnservice['id'], 'status': process.status, @@ -1053,7 +1078,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'ipsec_site_connections': copy.deepcopy(process.connection_status) } - def update_downed_connections(self, process_id, new_status): + def update_downed_connections(self, process_id: str, new_status): """Update info to be reported, if connections just went down. If there is no longer any information for a connection, because it @@ -1069,13 +1094,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): 'updated_pending_status': True } - def should_be_reported(self, context, process): + def should_be_reported(self, context: context.ContextBase, + process: BaseSwanProcess) -> bool: if (context.is_admin or process.vpnservice["tenant_id"] == context.tenant_id): return True + return False @log_helpers.log_method_call - def report_status(self, context): + def report_status(self, context: context.ContextBase): status_changed_vpn_services = [] for process_id, process in list(self.processes.items()): # NOTE(mnaser): It's not necessary to check status for processes @@ -1104,7 +1131,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): @log_helpers.log_method_call @lockutils.synchronized('vpn-agent', 'neutron-') - def sync(self, context, routers): + def sync(self, context: context.ContextBase, routers): """Sync status with server side. :param context: context object for RPC call @@ -1122,8 +1149,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): """ vpnservices = self.agent_rpc.get_vpn_services_on_host( context, self.host) - router_ids = [vpnservice['router_id'] for vpnservice in vpnservices] - sync_router_ids = [router['id'] for router in routers] + router_ids = [vpnservice["router_id"] for vpnservice in vpnservices] + sync_router_ids = [router["id"] for router in routers] self._sync_vpn_processes(vpnservices, sync_router_ids) self._delete_vpn_processes(sync_router_ids, router_ids) @@ -1138,8 +1165,10 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): for vpnservice in vpnservices: if vpnservice['router_id'] not in self.processes or ( vpnservice['router_id'] in sync_router_ids): - process = self.ensure_process(vpnservice['router_id'], - vpnservice=vpnservice) + process: ty.Optional[BaseSwanProcess] = self.ensure_process( + vpnservice['router_id'], vpnservice=vpnservice) + if not process: + return self._update_nat(vpnservice, self.add_nat_rule) router = self.routers.get(vpnservice['router_id']) if not router: @@ -1168,7 +1197,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta): class OpenSwanDriver(IPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> BaseSwanProcess: return OpenSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py index 90731f7a4..a535e2d5f 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/libreswan_ipsec.py @@ -12,8 +12,9 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty + import os -import os.path from neutron.agent.linux import ip_lib @@ -26,39 +27,38 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): Libreswan needs nssdb initialised before running pluto daemon. """ # pylint: disable=useless-super-delegation - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self._rootwrap_cfg = self._get_rootwrap_config() - super(LibreSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) + super().__init__(conf, process_id, vpnservice, namespace) - def _ipsec_execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def _ipsec_execute(self, cmd: ty.List[str], check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None): """Execute ipsec command on namespace. This execute is wrapped by namespace wrapper. The namespace wrapper will bind /etc and /var/run """ ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace) - mount_paths = {'/etc': '%s/etc' % self.config_dir, - '/var/run': '%s/var/run' % self.config_dir} + mount_paths = {'/etc': f'{self.config_dir}/etc', + '/var/run': f'{self.config_dir}/var/run'} mount_paths_str = ','.join( - "%s:%s" % (source, target) - for source, target in mount_paths.items()) + f"{source}:{target}" for source, target in mount_paths.items()) ns_wrapper = self.get_ns_wrapper() return ip_wrapper.netns.execute( [ns_wrapper, - '--mount_paths=%s' % mount_paths_str, - ('--rootwrap_config=%s' % self._rootwrap_cfg - if self._rootwrap_cfg else ''), - '--cmd=%s,%s' % (self.binary, ','.join(cmd))], + f'--mount_paths={mount_paths_str}', + '--rootwrap_config={0}'.format( + self._rootwrap_cfg if self._rootwrap_cfg else ""), + f'--cmd={self.binary},{",".join(cmd)}'], check_exit_code=check_exit_code, extra_ok_codes=extra_ok_codes) def _ensure_needed_files(self): # addconn reads from /etc/hosts and /etc/resolv.conf. As /etc would be # bind-mounted, create these two empty files in the target directory. - with open('%s/etc/hosts' % self.config_dir, 'a'): + with open(f'{self.config_dir}/etc/hosts', 'a', encoding="utf8"): pass - with open('%s/etc/resolv.conf' % self.config_dir, 'a'): + with open(f'{self.config_dir}/etc/resolv.conf', 'a', encoding="utf8"): pass def ensure_configs(self): @@ -75,17 +75,17 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): if os.path.exists(secrets_file): self._execute(['rm', '-f', secrets_file]) - super(LibreSwanProcess, self).ensure_configs() + super().ensure_configs() # LibreSwan uses the capabilities library to restrict access to # ipsec.secrets to users that have explicit access. Since pluto is # running as root and the file has 0600 perms, we must set the # owner of the file to root. - self._execute(['chown', '--from=%s' % os.getuid(), 'root:root', + self._execute(['chown', f'--from={os.getuid()}', 'root:root', secrets_file]) # Libreswan needs to write logs to this directory. - self._execute(['chown', '--from=%s' % os.getuid(), 'root:root', + self._execute(['chown', f'--from={os.getuid()}', 'root:root', self.log_dir]) self._ensure_needed_files() @@ -131,11 +131,12 @@ class LibreSwanProcess(ipsec.OpenSwanProcess): ['whack', '--name', conn_name, '--asynchronous', '--initiate']) def terminate_connection(self, conn_name): - self._ipsec_execute(['whack', '--name', conn_name, '--terminate']) + self._ipsec_execute(["whack", "--name", conn_name, "--terminate"]) class LibreSwanDriver(ipsec.IPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return LibreSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py index aaaaac0ec..0785775e3 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/ovn_ipsec.py @@ -14,9 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + +import abc + import netaddr from neutron.agent.common import utils as agent_common_utils +from neutron.agent.linux import interface from neutron.agent.linux import ip_lib from neutron_lib import constants as lib_constants from neutron_lib import context as nctx @@ -30,7 +36,7 @@ from neutron_vpnaas.services.vpn.device_drivers import strongswan_ipsec PORT_PREFIX_INTERNAL = 'vr' PORT_PREFIX_EXTERNAL = 'vg' -PORT_PREFIXES = { +PORT_PREFIXES: ty.Dict[str, str] = { 'internal': PORT_PREFIX_INTERNAL, 'external': PORT_PREFIX_EXTERNAL, } @@ -38,7 +44,7 @@ PORT_PREFIXES = { LOG = logging.getLogger(__name__) -class DeviceManager(object): +class DeviceManager: """Device Manager for ports in qvpn-xx namespace. It is a veth pair, one side in qvpn and the other side is attached to ovs. @@ -51,16 +57,17 @@ class DeviceManager(object): self.host = host self.plugin = plugin self.context = context - self.driver = agent_common_utils.load_interface_driver(conf) + self.driver: interface.LinuxInterfaceDriver = \ + agent_common_utils.load_interface_driver(conf) - def get_interface_name(self, port, ptype): + def get_interface_name(self, port: ty.Dict[str, str], ptype: str) -> str: suffix = port['id'] return (PORT_PREFIXES[ptype] + suffix)[:self.driver.DEV_NAME_LEN] - def get_namespace_name(self, process_id): + def get_namespace_name(self, process_id: str): return self.OVN_NS_PREFIX + process_id - def get_existing_process_ids(self): + def get_existing_process_ids(self) -> ty.List: """Return the process IDs derived from the existing VPN namespaces.""" return [ns[len(self.OVN_NS_PREFIX):] for ns in ip_lib.list_network_namespaces() @@ -87,7 +94,8 @@ class DeviceManager(object): device.route.delete_route(cidr, via=via, metric=100, proto='static') - def list_routes(self, namespace, via=None): + def list_routes(self, namespace, + via=None) -> ty.List[ty.Dict[str, ty.Any]]: device = ip_lib.IPDevice(None, namespace=namespace) return device.route.list_routes( lib_constants.IP_VERSION_4, proto='static', via=via) @@ -100,24 +108,26 @@ class DeviceManager(object): for r in routes: device.route.delete_route(r['cidr'], via=r['via']) - def _del_port(self, process_id, ptype): + def _del_port(self, process_id: str, ptype: str): namespace = self.get_namespace_name(process_id) prefix = PORT_PREFIXES[ptype] device = ip_lib.IPDevice(None, namespace=namespace) - ports = device.addr.list() + ports: ty.List[ty.Dict[str, ty.Union[str, ty.Any]]] = \ + device.addr.list() for p in ports: if not p['name'].startswith(prefix): continue interface_name = p['name'] self.driver.unplug(interface_name, namespace=namespace) - def del_internal_port(self, process_id): + def del_internal_port(self, process_id: str): self._del_port(process_id, 'internal') - def del_external_port(self, process_id): + def del_external_port(self, process_id: str): self._del_port(process_id, 'external') - def setup_external(self, process_id, network_details): + def setup_external(self, process_id: str, + network_details) -> ty.Optional[str]: network = network_details["external_network"] vpn_port = network_details['gw_port'] ns_name = self.get_namespace_name(process_id) @@ -143,7 +153,7 @@ class DeviceManager(object): subnet_id = fixed_ip['subnet_id'] subnet = self.plugin.get_subnet_info(subnet_id) net = netaddr.IPNetwork(subnet['cidr']) - ip_cidr = '%s/%s' % (fixed_ip['ip_address'], net.prefixlen) + ip_cidr = f'{fixed_ip["ip_address"]}/{net.prefixlen}' ip_cidrs.append(ip_cidr) subnets.append(subnet) self.driver.init_l3(interface_name, ip_cidrs, @@ -152,7 +162,7 @@ class DeviceManager(object): self.set_default_route(ns_name, subnet, interface_name) return interface_name - def setup_internal(self, process_id, network_details): + def setup_internal(self, process_id, network_details) -> ty.Optional[str]: vpn_port = network_details["transit_port"] ns_name = self.get_namespace_name(process_id) interface_name = self.get_interface_name(vpn_port, 'internal') @@ -172,19 +182,19 @@ class DeviceManager(object): ip_cidrs = [] for fixed_ip in vpn_port['fixed_ips']: - ip_cidr = '%s/%s' % (fixed_ip['ip_address'], 28) + ip_cidr = f'{fixed_ip["ip_address"]}/28' ip_cidrs.append(ip_cidr) self.driver.init_l3(interface_name, ip_cidrs, namespace=ns_name) return interface_name -class NamespaceManager(object): +class NamespaceManager: def __init__(self, use_ipv6=False): self.ip_wrapper_root = ip_lib.IPWrapper() self.use_ipv6 = use_ipv6 - def exists(self, name): + def exists(self, name) -> bool: return ip_lib.network_namespace_exists(name) def create(self, name): @@ -231,9 +241,9 @@ class IPsecOvnDriverApi(ipsec.IPsecVpnDriverApi): subnet_id=subnet_id) -class OvnIPsecDriver(ipsec.IPsecDriver): +class OvnIPsecDriver(ipsec.IPsecDriver, metaclass=abc.ABCMeta): - def __init__(self, vpn_service, host): + def __init__(self, vpn_service, host: str): self.nsmgr = NamespaceManager() super().__init__(vpn_service, host) self.agent_rpc = IPsecOvnDriverApi(topics.IPSEC_DRIVER_TOPIC) @@ -242,7 +252,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): get_router_based_iptables_manager = None - def get_namespace(self, router_id): + def get_namespace(self, router_id) -> str: """Get namespace for VPN services of router. :router_id: router_id @@ -250,7 +260,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): """ return self.devmgr.get_namespace_name(router_id) - def _cleanup_namespace(self, router_id): + def _cleanup_namespace(self, router_id: str): ns_name = self.devmgr.get_namespace_name(router_id) if not self.nsmgr.exists(ns_name): return @@ -259,7 +269,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): self.devmgr.del_external_port(router_id) self.nsmgr.delete(ns_name) - def _ensure_namespace(self, router_id, network_details): + def _ensure_namespace(self, router_id: str, network_details) -> str: ns_name = self.get_namespace(router_id) if not self.nsmgr.exists(ns_name): self.nsmgr.create(ns_name) @@ -272,7 +282,12 @@ class OvnIPsecDriver(ipsec.IPsecDriver): return ns_name - def destroy_process(self, process_id): + @abc.abstractmethod + def create_process(self, process_id: str, + vpnservice, namespace) -> ipsec.BaseSwanProcess: + pass + + def destroy_process(self, process_id: str): LOG.info('process %s is destroyed', process_id) namespace = self.devmgr.get_namespace_name(process_id) @@ -316,7 +331,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): new_local_cidrs - old_local_cidrs, gateway_ip) - def _sync_vpn_processes(self, vpnservices, sync_router_ids): + def _sync_vpn_processes(self, vpnservices, sync_router_ids: ty.List[str]): # Ensure the ipsec process is enabled only for # - the vpn services which are not yet in self.processes # - vpn services whose router id is in 'sync_router_ids' @@ -331,7 +346,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver): process = self.ensure_process(router_id, vpnservice=vpnservice) process.update() - def _cleanup_stale_vpn_processes(self, vpn_router_ids): + def _cleanup_stale_vpn_processes(self, vpn_router_ids: ty.List[str]): super()._cleanup_stale_vpn_processes(vpn_router_ids) # Look for additional namespaces on this node that we don't know # and that should be deleted @@ -340,17 +355,20 @@ class OvnIPsecDriver(ipsec.IPsecDriver): self.destroy_process(router_id) @lockutils.synchronized('vpn-agent', 'neutron-') - def vpnservice_removed_from_agent(self, context, router_id): + def vpnservice_removed_from_agent(self, context: nctx.Context, + router_id: str): # must run under the same lock as sync() self.destroy_process(router_id) - def vpnservice_added_to_agent(self, context, router_ids): + def vpnservice_added_to_agent(self, context: nctx.Context, + router_ids: ty.List[str]): routers = [{'id': router_id} for router_id in router_ids] self.sync(context, routers) class OvnStrongSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnStrongSwanProcess( self.conf, process_id, @@ -359,7 +377,8 @@ class OvnStrongSwanDriver(OvnIPsecDriver): class OvnOpenSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnOpenSwanProcess( self.conf, process_id, @@ -368,7 +387,8 @@ class OvnOpenSwanDriver(OvnIPsecDriver): class OvnLibreSwanDriver(OvnIPsecDriver): - def create_process(self, process_id, vpnservice, namespace): + def create_process(self, process_id: str, vpnservice, + namespace: str) -> ipsec.BaseSwanProcess: return OvnLibreSwanProcess( self.conf, process_id, diff --git a/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py b/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py index 708952a1f..17c3dffc6 100644 --- a/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py +++ b/neutron_vpnaas/services/vpn/device_drivers/strongswan_ipsec.py @@ -12,8 +12,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - import os +import typing as ty from oslo_config import cfg from oslo_log import log as logging @@ -75,15 +75,14 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): STATUS_RE = r'([a-f0-9\-]+).* (ROUTED|CONNECTING|INSTALLED)' STATUS_NOT_RUNNING_RE = 'Command:.*ipsec.*status.*Exit code: [1|3] ' - def __init__(self, conf, process_id, vpnservice, namespace): + def __init__(self, conf, process_id: str, vpnservice, namespace: str): self.DIALECT_MAP['v1'] = 'ikev1' self.DIALECT_MAP['v2'] = 'ikev2' self.DIALECT_MAP['sha256'] = 'sha256' self._strongswan_piddir = self._get_strongswan_piddir() self._rootwrap_cfg = self._get_rootwrap_config() LOG.debug("strongswan piddir is '%s'", (self._strongswan_piddir)) - super(StrongSwanProcess, self).__init__(conf, process_id, - vpnservice, namespace) + super().__init__(conf, process_id, vpnservice, namespace) def _get_strongswan_piddir(self): return utils.execute( @@ -103,7 +102,8 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): return connection_id, status return None, None - def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None): + def _execute(self, cmd: ty.List[str], check_exit_code: bool = True, + extra_ok_codes: ty.Optional[ty.List[int]] = None): """Execute command on namespace. This execute is wrapped by namespace wrapper. @@ -113,11 +113,11 @@ class StrongSwanProcess(ipsec.BaseSwanProcess): ns_wrapper = self.get_ns_wrapper() return ip_wrapper.netns.execute( [ns_wrapper, - '--mount_paths=/etc:%s/etc,%s:%s/var/run' % ( + '--mount_paths=/etc:{0}/etc,{1}:{2}/var/run'.format( self.config_dir, self._strongswan_piddir, self.config_dir), - ('--rootwrap_config=%s' % self._rootwrap_cfg - if self._rootwrap_cfg else ''), - '--cmd=%s' % ','.join(cmd)], + '--rootwrap_config={0}'.format( + self._rootwrap_cfg if self._rootwrap_cfg else ""), + f'--cmd={",".join(cmd)}'], check_exit_code=check_exit_code, extra_ok_codes=extra_ok_codes) diff --git a/neutron_vpnaas/services/vpn/ovn/agent_monitor.py b/neutron_vpnaas/services/vpn/ovn/agent_monitor.py index 6597f29b0..5e5dc925b 100644 --- a/neutron_vpnaas/services/vpn/ovn/agent_monitor.py +++ b/neutron_vpnaas/services/vpn/ovn/agent_monitor.py @@ -14,6 +14,7 @@ from neutron.plugins.ml2.drivers.ovn.agent import neutron_agent from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import ovsdb_monitor +from neutron.services.ovn_l3 import plugin from neutron_lib.plugins import constants as plugin_constants from neutron_lib.plugins import directory @@ -73,9 +74,10 @@ class ChassisVPNAgentWriteEvent(ovsdb_monitor.ChassisAgentEvent): clear_down=True) -class OVNVPNAgentMonitor(object): +class OVNVPNAgentMonitor: def watch_agent_events(self): - l3_plugin = directory.get_plugin(plugin_constants.L3) + l3_plugin: plugin.OVNL3RouterPlugin = \ + directory.get_plugin(plugin_constants.L3) sb_ovn = l3_plugin._sb_ovn if sb_ovn: idl = sb_ovn.ovsdb_connection.idl diff --git a/neutron_vpnaas/services/vpn/ovn_agent.py b/neutron_vpnaas/services/vpn/ovn_agent.py index 1ec331a39..d7de90506 100644 --- a/neutron_vpnaas/services/vpn/ovn_agent.py +++ b/neutron_vpnaas/services/vpn/ovn_agent.py @@ -51,7 +51,7 @@ OVS_OPTS = [ ] -def register_opts(conf): +def register_opts(conf: cfg.ConfigOpts): common_config.register_common_config_options() agent_config.register_interface_driver_opts_helper(conf) agent_config.register_interface_opts(conf) diff --git a/neutron_vpnaas/services/vpn/ovn_plugin.py b/neutron_vpnaas/services/vpn/ovn_plugin.py index 5b2a6ea23..d3f34bafd 100644 --- a/neutron_vpnaas/services/vpn/ovn_plugin.py +++ b/neutron_vpnaas/services/vpn/ovn_plugin.py @@ -13,24 +13,25 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import typing as ty from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources +from neutron_lib import context from oslo_config import cfg from oslo_utils import importutils from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db -from neutron_vpnaas.db.vpn.vpn_db import VPNPluginDb from neutron_vpnaas.db.vpn import vpn_ext_gw_db from neutron_vpnaas.services.vpn.common import constants from neutron_vpnaas.services.vpn.ovn import agent_monitor -from neutron_vpnaas.services.vpn.plugin import VPNDriverPlugin +from neutron_vpnaas.services.vpn import plugin as vpn_plugin +from neutron_vpnaas.services.vpn.service_drivers import ovn_ipsec -class VPNOVNPlugin(VPNPluginDb, - vpn_ext_gw_db.VPNExtGWPlugin_db, +class VPNOVNPlugin(vpn_ext_gw_db.VPNExtGWPlugin_db, agent_db.AZVPNAgentSchedulerDbMixin, agent_monitor.OVNVPNAgentMonitor): """Implementation of the VPN Service Plugin. @@ -50,13 +51,14 @@ class VPNOVNPlugin(VPNPluginDb, resources.PROCESS, events.AFTER_INIT) - def check_router_in_use(self, context, router_id): + def check_router_in_use(self, context: context.Context, router_id): pass def post_fork_initialize(self, resource, event, trigger, payload=None): self.watch_agent_events() - def vpn_router_agent_binding_changed(self, context, router_id, host): + def vpn_router_agent_binding_changed(self, context: context.Context, + router_id: str, host: str): pass supported_extension_aliases = ["vpnaas", @@ -66,11 +68,15 @@ class VPNOVNPlugin(VPNPluginDb, path_prefix = "/vpn" -class VPNOVNDriverPlugin(VPNOVNPlugin, VPNDriverPlugin): - def vpn_router_agent_binding_changed(self, context, router_id, host): +class VPNOVNDriverPlugin(VPNOVNPlugin, vpn_plugin.VPNDriverPlugin): + def vpn_router_agent_binding_changed(self, context: context.Context, + router_id: str, host: str): super().vpn_router_agent_binding_changed(context, router_id, host) filters = {'router_id': [router_id]} vpnservices = self.get_vpnservices(context, filters=filters) for vpnservice in vpnservices: - driver = self._get_driver_for_vpnservice(context, vpnservice) - driver.update_port_bindings(context, router_id, host) + driver: ty.Optional[ovn_ipsec.BaseOvnIPsecVPNDriver] = \ + self._get_driver_for_vpnservice( # type: ignore + context, vpnservice) + if driver: + driver.update_port_bindings(context, router_id, host) diff --git a/neutron_vpnaas/services/vpn/plugin.py b/neutron_vpnaas/services/vpn/plugin.py index 6fc9acc84..415b6554e 100644 --- a/neutron_vpnaas/services/vpn/plugin.py +++ b/neutron_vpnaas/services/vpn/plugin.py @@ -1,4 +1,3 @@ - # (c) Copyright 2013 Hewlett-Packard Development Company, L.P. # All Rights Reserved. # @@ -14,6 +13,9 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +from neutron.db import flavors_db from neutron.db import servicetype_db as st_db from neutron.services import provider_configuration as pconf from neutron.services import service_base @@ -26,6 +28,7 @@ from oslo_log import log as logging from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.extensions import vpn_flavors +from neutron_vpnaas.services.vpn import service_drivers LOG = logging.getLogger(__name__) @@ -54,8 +57,11 @@ class VPNPlugin(vpn_db.VPNPluginDb): class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): """VpnPlugin which supports VPN Service Drivers.""" #TODO(nati) handle ikepolicy and ipsecpolicy update usecase + drivers: ty.Dict[str, service_drivers.VpnDriver] + default_provider: ty.Optional[str] + def __init__(self): - super(VPNDriverPlugin, self).__init__() + super().__init__() self.service_type_manager = st_db.ServiceTypeManager.get_instance() add_provider_configuration(self.service_type_manager, constants.VPN) # Load the service driver from neutron.conf. @@ -72,14 +78,16 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): vpn_db.subscribe() @property - def _flavors_plugin(self): + def _flavors_plugin(self) -> flavors_db.FlavorsDbMixin: return directory.get_plugin(constants.FLAVORS) def start_rpc_listeners(self): servers = [] for driver_name, driver in self.drivers.items(): - if hasattr(driver, 'start_rpc_listeners'): - servers.extend(driver.start_rpc_listeners()) + start_rpc_listeners: ty.Optional[ty.Callable[..., ty.List]] = \ + getattr(driver, 'start_rpc_listeners', None) + if start_rpc_listeners and callable(start_rpc_listeners): + servers.extend(start_rpc_listeners()) return servers def _check_orphan_vpnservice_associations(self): @@ -124,7 +132,9 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): context, constants.VPN, self.default_provider, vpnservice_id) - def _get_provider_for_flavor(self, context, flavor_id): + def _get_provider_for_flavor( + self, context: ncontext.Context, + flavor_id: ty.Optional[str]) -> ty.Optional[str]: if flavor_id: if self._flavors_plugin is None: raise vpn_flavors.FlavorsPluginNotLoaded() @@ -137,7 +147,7 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): raise flav_exc.FlavorDisabled() providers = self._flavors_plugin.get_flavor_next_provider( context, fl_db['id']) - provider = providers[0].get('provider') + provider: ty.Optional[str] = providers[0].get('provider', None) if provider not in self.drivers: raise vpn_flavors.NoProviderFoundForFlavor(flavor_id=flavor_id) else: @@ -147,84 +157,98 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin): LOG.debug("Selected provider %s", provider) return provider - def _get_driver_for_vpnservice(self, context, vpnservice): + def _get_driver_for_vpnservice(self, context: ncontext.Context, + vpnservice) -> \ + ty.Optional[service_drivers.VpnDriver]: stm = self.service_type_manager - provider_names = stm.get_provider_names_by_resource_ids( - context, [vpnservice['id']]) + provider_names: ty.Dict[str, str] = \ + stm.get_provider_names_by_resource_ids(context, [vpnservice['id']]) provider = provider_names.get(vpnservice['id']) - return self.drivers[provider] + return self.drivers.get(provider) if provider else None - def _get_driver_for_ipsec_site_connection(self, context, + def _get_driver_for_ipsec_site_connection(self, context: ncontext.Context, ipsec_site_connection): # Only vpnservice_id is required as the vpnservice should be already # associated with a provider after its creation. vpnservice = {'id': ipsec_site_connection['vpnservice_id']} return self._get_driver_for_vpnservice(context, vpnservice) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, + context: ncontext.Context, + ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]: driver = self._get_driver_for_ipsec_site_connection( context, ipsec_site_connection['ipsec_site_connection']) - driver.validator.validate_ipsec_site_connection( - context, - ipsec_site_connection['ipsec_site_connection']) - ipsec_site_connection = super( - VPNDriverPlugin, self).create_ipsec_site_connection( - context, ipsec_site_connection) - driver.create_ipsec_site_connection(context, ipsec_site_connection) - return ipsec_site_connection + if driver: + driver.validator.validate_ipsec_site_connection( + context, + ipsec_site_connection['ipsec_site_connection']) + ipsec_site_connection = super().create_ipsec_site_connection( + context, ipsec_site_connection) + driver.create_ipsec_site_connection(context, ipsec_site_connection) + return ipsec_site_connection + return None - def delete_ipsec_site_connection(self, context, ipsec_conn_id): + def delete_ipsec_site_connection(self, context: ncontext.Context, + ipsec_site_conn_id: str): ipsec_site_connection = self.get_ipsec_site_connection( - context, ipsec_conn_id) - super(VPNDriverPlugin, self).delete_ipsec_site_connection( - context, ipsec_conn_id) + context, ipsec_site_conn_id) + super().delete_ipsec_site_connection( + context, ipsec_site_conn_id) driver = self._get_driver_for_ipsec_site_connection( context, ipsec_site_connection) - driver.delete_ipsec_site_connection(context, ipsec_site_connection) + if driver: + driver.delete_ipsec_site_connection(context, ipsec_site_connection) def update_ipsec_site_connection( - self, context, - ipsec_conn_id, ipsec_site_connection): + self, context: ncontext.Context, + ipsec_site_conn_id: str, + ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]: old_ipsec_site_connection = self.get_ipsec_site_connection( - context, ipsec_conn_id) + context, ipsec_site_conn_id) driver = self._get_driver_for_ipsec_site_connection( context, old_ipsec_site_connection) - driver.validator.validate_ipsec_site_connection( - context, - ipsec_site_connection['ipsec_site_connection']) - ipsec_site_connection = super( - VPNDriverPlugin, self).update_ipsec_site_connection( + if driver: + driver.validator.validate_ipsec_site_connection( context, - ipsec_conn_id, - ipsec_site_connection) - driver.update_ipsec_site_connection( - context, old_ipsec_site_connection, ipsec_site_connection) - return ipsec_site_connection + ipsec_site_connection['ipsec_site_connection']) + ipsec_site_connection = super().update_ipsec_site_connection( + context, + ipsec_site_conn_id, + ipsec_site_connection) + driver.update_ipsec_site_connection( + context, old_ipsec_site_connection, ipsec_site_connection) + return ipsec_site_connection + return None - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: ncontext.Context, + vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]) -> \ + ty.Optional[ty.Dict[str, ty.Any]]: provider = self._get_provider_for_flavor( context, vpnservice['vpnservice'].get('flavor_id')) - vpnservice = super( - VPNDriverPlugin, self).create_vpnservice(context, vpnservice) - self.service_type_manager.add_resource_association( - context, constants.VPN, provider, vpnservice['id']) - driver = self.drivers[provider] - driver.create_vpnservice(context, vpnservice) - return vpnservice + if provider: + vpnservice = super().create_vpnservice(context, vpnservice) + self.service_type_manager.add_resource_association( + context, constants.VPN, provider, vpnservice['id']) + driver = self.drivers[provider] + driver.create_vpnservice(context, vpnservice) + return vpnservice + return None - def update_vpnservice(self, context, vpnservice_id, vpnservice): + def update_vpnservice(self, context: ncontext.Context, vpnservice_id: str, + vpnservice) -> ty.Dict[str, ty.Any]: old_vpn_service = self.get_vpnservice(context, vpnservice_id) - new_vpn_service = super( - VPNDriverPlugin, self).update_vpnservice(context, vpnservice_id, - vpnservice) + new_vpn_service = super().update_vpnservice(context, vpnservice_id, + vpnservice) driver = self._get_driver_for_vpnservice(context, old_vpn_service) - driver.update_vpnservice(context, old_vpn_service, new_vpn_service) + if driver: + driver.update_vpnservice(context, old_vpn_service, new_vpn_service) return new_vpn_service - def delete_vpnservice(self, context, vpnservice_id): + def delete_vpnservice(self, context: ncontext.Context, vpnservice_id: str): vpnservice = self._get_vpnservice(context, vpnservice_id) - super(VPNDriverPlugin, self).delete_vpnservice(context, vpnservice_id) + super().delete_vpnservice(context, vpnservice_id) driver = self._get_driver_for_vpnservice(context, vpnservice) - self.service_type_manager.del_resource_associations( - context, [vpnservice_id]) - driver.delete_vpnservice(context, vpnservice) + if driver: + self.service_type_manager.del_resource_associations( + context, [vpnservice_id]) + driver.delete_vpnservice(context, vpnservice) diff --git a/neutron_vpnaas/services/vpn/service_drivers/__init__.py b/neutron_vpnaas/services/vpn/service_drivers/__init__.py index f9c90012e..ec4b991b8 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/__init__.py +++ b/neutron_vpnaas/services/vpn/service_drivers/__init__.py @@ -14,21 +14,25 @@ # under the License. import abc +import typing as ty +from neutron.extensions import l3 +from neutron_lib import context from neutron_lib.plugins import constants from neutron_lib.plugins import directory from neutron_lib import rpc as n_rpc from oslo_log import log as logging import oslo_messaging +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.service_drivers import driver_validator LOG = logging.getLogger(__name__) -class VpnDriver(object, metaclass=abc.ABCMeta): +class VpnDriver(metaclass=abc.ABCMeta): - def __init__(self, service_plugin, validator=None): + def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None): self.service_plugin = service_plugin if validator is None: validator = driver_validator.VpnDriverValidator(self) @@ -36,7 +40,7 @@ class VpnDriver(object, metaclass=abc.ABCMeta): self.name = '' @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return directory.get_plugin(constants.L3) @property @@ -44,43 +48,49 @@ class VpnDriver(object, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def create_vpnservice(self, context, vpnservice): + def create_vpnservice(self, context: context.ContextBase, vpnservice): pass @abc.abstractmethod def update_vpnservice( - self, context, old_vpnservice, vpnservice): + self, context: context.ContextBase, old_vpnservice, vpnservice): pass @abc.abstractmethod - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: context.ContextBase, vpnservice): pass @abc.abstractmethod - def create_ipsec_site_connection(self, context, ipsec_site_connection): - pass - - @abc.abstractmethod - def update_ipsec_site_connection(self, context, old_ipsec_site_connection, + def create_ipsec_site_connection(self, context: context.ContextBase, ipsec_site_connection): pass @abc.abstractmethod - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.ContextBase, + old_ipsec_site_connection, + ipsec_site_connection): + pass + + @abc.abstractmethod + def delete_ipsec_site_connection(self, context: context.ContextBase, + ipsec_site_connection): pass -class BaseIPsecVpnAgentApi(object): +class BaseIPsecVpnAgentApi: """Base class for IPSec API to agent.""" - def __init__(self, topic, default_version, driver): + def __init__(self, topic: str, default_version: str, + driver: VpnDriver): self.topic = topic self.driver = driver - target = oslo_messaging.Target(topic=topic, version=default_version) - self.client = n_rpc.get_client(target) + self.target = oslo_messaging.Target(topic=topic, + version=default_version) + self.client = n_rpc.get_client(self.target) - def _agent_notification(self, context, method, router_id, - version=None, **kwargs): + def _agent_notification(self, context: context.ContextBase, method, + router_id: str, + version: ty.Optional[str] = None, **kwargs): """Notify update for the agent. This method will find where is the router, and @@ -103,7 +113,8 @@ class BaseIPsecVpnAgentApi(object): cctxt = self.client.prepare(server=l3_agent.host, version=version) cctxt.cast(context, method, **kwargs) - def vpnservice_updated(self, context, router_id, **kwargs): + def vpnservice_updated(self, context: context.ContextBase, router_id: str, + **kwargs): """Send update event of vpnservices.""" kwargs['router'] = {'id': router_id} self._agent_notification(context, 'vpnservice_updated', router_id, diff --git a/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py index 13ef9910d..a96e3880c 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/base_ipsec.py @@ -12,17 +12,23 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +import typing as ty + import abc import netaddr -import oslo_messaging - +from neutron.db import agents_db from neutron.db.models import l3agent from neutron.db.models import servicetype +from neutron.objects import agent as nagent from neutron_lib import constants as lib_constants +from neutron_lib import context from neutron_lib.db import api as db_api from neutron_lib.plugins import directory +import oslo_messaging +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.services.vpn import service_drivers @@ -31,7 +37,7 @@ IPSEC = 'ipsec' BASE_IPSEC_VERSION = '1.0' -class IPsecVpnDriverCallBack(object): +class IPsecVpnDriverCallBack: """Callback for IPSecVpnDriver rpc.""" # history @@ -40,12 +46,13 @@ class IPsecVpnDriverCallBack(object): target = oslo_messaging.Target(version=BASE_IPSEC_VERSION) def __init__(self, driver): - super(IPsecVpnDriverCallBack, self).__init__() + super().__init__() self.driver = driver - def _get_agent_hosting_vpn_services(self, context, host): - plugin = directory.get_plugin() - agent = plugin._get_agent_by_type_and_host( + def _get_agent_hosting_vpn_services(self, context: context.Context, + host: ty.Optional[str]): + plugin: agents_db.AgentDbMixin = directory.get_plugin() + agent: ty.Optional[nagent.Agent] = plugin._get_agent_by_type_and_host( context, lib_constants.AGENT_TYPE_L3, host) agent_conf = plugin.get_configuration_dict(agent) # Retrieve the agent_mode to check if this is the @@ -53,7 +60,9 @@ class IPsecVpnDriverCallBack(object): # case of distributed the vpn service should reside # only on a dvr_snat node. agent_mode = agent_conf.get('agent_mode', 'legacy') - if not agent.admin_state_up or agent_mode == 'dvr': + if (not agent and + not agent.admin_state_up or # type: ignore + agent_mode == 'dvr'): return [] query = context.session.query(vpn_models.VPNService) query = query.join(vpn_models.IPsecSiteConnection) @@ -65,25 +74,27 @@ class IPsecVpnDriverCallBack(object): servicetype.ProviderResourceAssociation.resource_id == vpn_models.VPNService.id) query = query.filter( - l3agent.RouterL3AgentBinding.l3_agent_id == agent.id) + l3agent.RouterL3AgentBinding.l3_agent_id == + agent.id) # type: ignore query = query.filter( servicetype.ProviderResourceAssociation.provider_name == self.driver.name) return query @db_api.CONTEXT_READER - def get_vpn_services_on_host(self, context, host=None): + def get_vpn_services_on_host(self, context: context.Context, + host: ty.Optional[str] = None): """Returns the vpnservices on the host.""" vpnservices = self._get_agent_hosting_vpn_services( context, host) - plugin = self.driver.service_plugin + plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin local_cidr_map = plugin._build_local_subnet_cidr_map(context) return [self.driver.make_vpnservice_dict(vpnservice, local_cidr_map) - for vpnservice in vpnservices] + for vpnservice in vpnservices] - def update_status(self, context, status): + def update_status(self, context: context.Context, status): """Update status of vpnservices.""" - plugin = self.driver.service_plugin + plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin plugin.update_status_by_agent(context, status) @@ -94,57 +105,63 @@ class IPsecVpnAgentApi(service_drivers.BaseIPsecVpnAgentApi): # pylint: disable=useless-super-delegation def __init__(self, topic, default_version, driver): - super(IPsecVpnAgentApi, self).__init__( - topic, default_version, driver) + super().__init__(topic, default_version, driver) class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): """Base VPN Service Driver class.""" - def __init__(self, service_plugin, validator=None): - super(BaseIPsecVPNDriver, self).__init__(service_plugin, validator) + agent_rpc: service_drivers.BaseIPsecVpnAgentApi + + def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None): + super().__init__(service_plugin, validator) self.create_rpc_conn() @property - def service_type(self): + def service_type(self) -> str: return IPSEC @abc.abstractmethod - def create_rpc_conn(self): + def create_rpc_conn(self) -> None: pass - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def update_ipsec_site_connection( - self, context, old_ipsec_site_connection, ipsec_site_connection): + def update_ipsec_site_connection(self, context: context.Context, + old_ipsec_site_connection, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def delete_ipsec_site_connection(self, context: context.Context, + ipsec_site_connection): router_id = self.service_plugin.get_vpnservice_router_id( context, ipsec_site_connection['vpnservice_id']) self.agent_rpc.vpnservice_updated(context, router_id) - def create_ikepolicy(self, context, ikepolicy): + def create_ikepolicy(self, context: context.Context, ikepolicy): pass - def delete_ikepolicy(self, context, ikepolicy): + def delete_ikepolicy(self, context: context.Context, ikepolicy): pass - def update_ikepolicy(self, context, old_ikepolicy, ikepolicy): + def update_ikepolicy(self, context: context.Context, + old_ikepolicy, ikepolicy): pass - def create_ipsecpolicy(self, context, ipsecpolicy): + def create_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass - def delete_ipsecpolicy(self, context, ipsecpolicy): + def delete_ipsecpolicy(self, context: context.Context, ipsecpolicy): pass - def update_ipsecpolicy(self, context, old_ipsec_policy, ipsecpolicy): + def update_ipsecpolicy(self, context: context.Context, + old_ipsec_policy, ipsecpolicy): pass def _get_gateway_ips(self, router): @@ -164,7 +181,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): return v4_ip, v6_ip @db_api.CONTEXT_WRITER - def create_vpnservice(self, context, vpnservice_dict): + def create_vpnservice(self, context: context.Context, vpnservice_dict): """Get the gateway IP(s) and save for later use. For the reference implementation, this side's tunnel IP (external_ip) @@ -181,10 +198,11 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): vpnservice_dict['id'], v4_ip=v4_ip, v6_ip=v6_ip) - def update_vpnservice(self, context, old_vpnservice, vpnservice): + def update_vpnservice(self, context: context.Context, + old_vpnservice, vpnservice): self.agent_rpc.vpnservice_updated(context, vpnservice['router_id']) - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: context.Context, vpnservice): self.agent_rpc.vpnservice_updated(context, vpnservice['router_id']) def get_external_ip_based_on_peer(self, vpnservice, ipsec_site_con): @@ -204,7 +222,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta): also converting parameter name for vpn agent driver """ - vpnservice_dict = dict(vpnservice) + vpnservice_dict: ty.Dict[str, ty.Any] = dict(vpnservice) # Populate tenant_id for RPC compat vpnservice_dict['tenant_id'] = vpnservice_dict['project_id'] vpnservice_dict['ipsec_site_connections'] = [] diff --git a/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py b/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py index 0486d197f..f87636669 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py +++ b/neutron_vpnaas/services/vpn/service_drivers/driver_validator.py @@ -12,18 +12,25 @@ # License for the specific language governing permissions and limitations # under the License. # +import typing as ty + +from neutron.extensions import l3 +from neutron_lib import context +if ty.TYPE_CHECKING: + from neutron_vpnaas.services.vpn import service_drivers -class VpnDriverValidator(object): +class VpnDriverValidator: """Driver-specific validation routines for VPN resources.""" - def __init__(self, driver): + def __init__(self, driver: 'service_drivers.VpnDriver'): self.driver = driver @property - def l3_plugin(self): + def l3_plugin(self) -> l3.RouterPluginBase: return self.driver.l3_plugin - def validate_ipsec_site_connection(self, context, ipsec_sitecon): + def validate_ipsec_site_connection(self, context: context.ContextBase, + ipsec_sitecon): """Driver can override this for its additional validations.""" pass diff --git a/neutron_vpnaas/services/vpn/service_drivers/ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/ipsec.py index 4e98cb52c..d26a0e8b8 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ipsec.py @@ -15,6 +15,7 @@ from neutron_lib import rpc as n_rpc +from neutron_vpnaas.db.vpn import vpn_db from neutron_vpnaas.services.vpn.common import topics from neutron_vpnaas.services.vpn.service_drivers import base_ipsec from neutron_vpnaas.services.vpn.service_drivers import ipsec_validator @@ -27,10 +28,9 @@ BASE_IPSEC_VERSION = '1.0' class IPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): """VPN Service Driver class for IPsec.""" - def __init__(self, service_plugin): - super(IPsecVPNDriver, self).__init__( - service_plugin, - ipsec_validator.IpsecVpnValidator(self)) + def __init__(self, service_plugin: vpn_db.VPNPluginDb): + super().__init__(service_plugin, + ipsec_validator.IpsecVpnValidator(self)) def create_rpc_conn(self): self.endpoints = [base_ipsec.IPsecVpnDriverCallBack(self)] diff --git a/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py b/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py index 49fc45d5b..5c8bb9a37 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ipsec_validator.py @@ -12,6 +12,9 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +from neutron_lib import context from neutron_lib import exceptions as nexception from neutron_vpnaas._i18n import _ @@ -29,7 +32,8 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator): and Libreswan. """ - def _check_transform_protocol(self, context, transform_protocol): + def _check_transform_protocol(self, context: context.ContextBase, + transform_protocol: ty.Optional[str]): """Restrict selecting ah-esp as IPSec Policy transform protocol. For those *Swan implementations, the 'ah-esp' transform protocol @@ -41,12 +45,15 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator): key='transform_protocol', value=transform_protocol) - def validate_ipsec_policy(self, context, ipsec_policy): - transform_protocol = ipsec_policy.get('transform_protocol') + def validate_ipsec_policy(self, context: context.ContextBase, + ipsec_policy: ty.Dict[str, ty.Union[ty.Any, str]]): + transform_protocol: ty.Optional[str] = \ + ipsec_policy.get('transform_protocol', None) self._check_transform_protocol(context, transform_protocol) - def validate_ipsec_site_connection(self, context, ipsec_sitecon): - if 'ipsecpolicy_id' in ipsec_sitecon: + def validate_ipsec_site_connection(self, context: context.ContextBase, + ipsec_sitecon): + if "ipsecpolicy_id" in ipsec_sitecon: ipsec_policy = self.driver.service_plugin.get_ipsecpolicy( context, ipsec_sitecon['ipsecpolicy_id']) self.validate_ipsec_policy(context, ipsec_policy) diff --git a/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py b/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py index 47760f064..abce7ce1e 100644 --- a/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py +++ b/neutron_vpnaas/services/vpn/service_drivers/ovn_ipsec.py @@ -14,8 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + +import abc + import netaddr +from neutron.db import extraroute_db +from neutron.plugins.ml2 import plugin from neutron_lib.api.definitions import portbindings from neutron_lib.callbacks import events from neutron_lib.callbacks import registry @@ -33,14 +39,22 @@ from oslo_config import cfg from oslo_db import exception as o_exc from oslo_log import log as logging + from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db -from neutron_vpnaas.db.vpn.vpn_ext_gw_db import RouterIsNotVPNExternal +from neutron_vpnaas.db.vpn import vpn_ext_gw_db as ext_gw from neutron_vpnaas.db.vpn import vpn_models from neutron_vpnaas.extensions import vpnaas from neutron_vpnaas.services.vpn.common import constants as v_constants from neutron_vpnaas.services.vpn.common import topics +from neutron_vpnaas.services.vpn import ovn_plugin from neutron_vpnaas.services.vpn.service_drivers import base_ipsec +#pylint: disable=ungrouped-imports +# Additional import for typechecking. Importing these without typechecking +# would resolve in a cyclic dependency +if ty.TYPE_CHECKING: + from neutron.db import db_base_plugin_v2 as db_plugin +#pylint: enable=ungrouped-imports LOG = logging.getLogger(__name__) @@ -63,17 +77,18 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): self.admin_ctx = nctx.get_admin_context() @property - def core_plugin(self): + def core_plugin(self) -> 'db_plugin.NeutronDbPluginV2': return self.driver.core_plugin @property - def service_plugin(self): + def service_plugin(self) -> ext_gw.VPNExtGWPlugin_db: return self.driver.service_plugin - def _get_vpn_gateway(self, context, router_id): + def _get_vpn_gateway(self, context: nctx.ContextBase, router_id: str): return self.service_plugin.get_vpn_gw_by_router_id(context, router_id) - def get_vpn_transit_network_details(self, context, router_id): + def get_vpn_transit_network_details(self, context: nctx.ContextBase, + router_id: str): vpn_gw = self._get_vpn_gateway(context, router_id) network_id = vpn_gw.gw_port['network_id'] external_network = self.core_plugin.get_network(context, network_id) @@ -86,13 +101,14 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): } return details - def get_subnet_info(self, context, subnet_id=None): + def get_subnet_info(self, context: nctx.ContextBase, + subnet_id: ty.Optional[str] = None): try: return self.core_plugin.get_subnet(context, subnet_id) except n_exc.SubnetNotFound: return None - def _get_agent_hosting_vpn_services(self, context, host): + def _get_agent_hosting_vpn_services(self, context: nctx.Context, host): agent = self.service_plugin.get_vpn_agent_on_host(context, host) if not agent: return [] @@ -114,20 +130,23 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack): @registry.has_registry_receivers -class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): - def __init__(self, service_plugin): - self._l3_plugin = None - self._core_plugin = None +class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver, + metaclass=abc.ABCMeta): + def __init__(self, service_plugin: ovn_plugin.VPNOVNPlugin): + self._l3_plugin: \ + ty.Optional[extraroute_db.ExtraRoute_dbonly_mixin] = None + self._core_plugin: ty.Optional[plugin.Ml2Plugin] = None + self.service_plugin = service_plugin super().__init__(service_plugin) @property - def l3_plugin(self): + def l3_plugin(self) -> extraroute_db.ExtraRoute_dbonly_mixin: if self._l3_plugin is None: self._l3_plugin = directory.get_plugin(plugin_constants.L3) return self._l3_plugin @property - def core_plugin(self): + def core_plugin(self) -> plugin.Ml2Plugin: if self._core_plugin is None: self._core_plugin = directory.get_plugin() return self._core_plugin @@ -155,20 +174,25 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): raise vpnaas.RouteInUseByVPN( destinations=", ".join(conflict_cidrs)) - def get_vpn_gw_port_name(self, router_id): + @abc.abstractmethod + def create_rpc_conn(self) -> None: + pass + + def get_vpn_gw_port_name(self, router_id: str) -> str: return VPN_GW_PORT_PREFIX + router_id - def get_vpn_namespace_port_name(self, router_id): + def get_vpn_namespace_port_name(self, router_id: str) -> str: return TRANSIT_PORT_PREFIX + router_id - def get_transit_network_name(self, router_id): + def get_transit_network_name(self, router_id: str) -> str: return TRANSIT_NETWORK_PREFIX + router_id - def get_transit_subnet_name(self, router_id): + def get_transit_subnet_name(self, router_id: str) -> str: return TRANSIT_SUBNET_PREFIX + router_id - def make_transit_network(self, router_id, tenant_id, agent_host, - gateway_update): + def make_transit_network(self, router_id: str, tenant_id: str, + agent_host: str, + gateway_update: ty.Dict[str, ty.Any]): context = nctx.get_admin_context() network_data = { 'tenant_id': HIDDEN_PROJECT_ID, @@ -213,13 +237,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): {"port": port_data}) gateway_update['transit_port_id'] = port['id'] - def _del_port(self, context, port_id): + def _del_port(self, context: nctx.ContextBase, port_id: str): try: self.core_plugin.delete_port(context, port_id, l3_port_check=False) except n_exc.PortNotFound: pass - def _remove_router_interface(self, context, router_id, subnet_id): + def _remove_router_interface(self, context: nctx.ContextBase, + router_id: str, subnet_id: str): try: self.l3_plugin.remove_router_interface( context, router_id, {'subnet_id': subnet_id}) @@ -227,27 +252,27 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): n_exc.SubnetNotFound): pass - def _del_subnet(self, context, subnet_id): + def _del_subnet(self, context: nctx.ContextBase, subnet_id: str): try: self.core_plugin.delete_subnet(context, subnet_id) except n_exc.SubnetNotFound: pass - def _del_network(self, context, network_id): + def _del_network(self, context: nctx.ContextBase, network_id: str): try: self.core_plugin.delete_network(context, network_id) except n_exc.NetworkNotFound: pass - def del_transit_network(self, gw): + def del_transit_network(self, gw: ext_gw.VPNExtGW): context = nctx.get_admin_context() - router_id = gw['router_id'] + router_id: str = gw['router_id'] - port_id = gw.get('transit_port_id') + port_id: ty.Optional[str] = gw.get('transit_port_id') if port_id: self._del_port(context, port_id) - subnet_id = gw.get('transit_subnet_id') + subnet_id: ty.Optional[str] = gw.get('transit_subnet_id') if subnet_id: self._remove_router_interface(context, router_id, subnet_id) self._del_subnet(context, subnet_id) @@ -256,7 +281,8 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): if network_id: self._del_network(context, network_id) - def make_gw_port(self, router_id, network_id, agent_host, gateway_update): + def make_gw_port(self, router_id: str, network_id: str, + agent_host: str, gateway_update: ty.Dict[str, ty.Any]): context = nctx.get_admin_context() port_data = {'tenant_id': HIDDEN_PROJECT_ID, 'network_id': network_id, @@ -273,7 +299,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): LOG.debug('No IPs available for external network %s', network_id) gateway_update['gw_port_id'] = gw_port['id'] - def del_gw_port(self, gateway): + def del_gw_port(self, gateway: ext_gw.VPNExtGW): context = nctx.get_admin_context() port_id = gateway.get('gw_port_id') if port_id: @@ -290,24 +316,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): cidrs.append(ep.endpoint) return cidrs - def _routes_update(self, cidrs, nexthop): + def _routes_update(self, cidrs: ty.Set, nexthop): routes = [{'destination': cidr, 'nexthop': nexthop} for cidr in cidrs] return {'router': {'routes': routes}} - def _update_static_routes(self, context, ipsec_site_connection): + def _update_static_routes(self, context: nctx.ContextBase, + ipsec_site_connection): vpnservice = self.service_plugin.get_vpnservice( context, ipsec_site_connection['vpnservice_id']) router_id = vpnservice['router_id'] - gw = self.service_plugin.get_vpn_gw_by_router_id(context, router_id) + gw: ext_gw.VPNExtGW = self.service_plugin.get_vpn_gw_by_router_id( + context, router_id) nexthop = gw.transit_port['fixed_ips'][0]['ip_address'] router = self.l3_plugin.get_router(context, router_id) old_routes = router.get('routes', []) - old_cidrs = set([r['destination'] for r in old_routes - if r['nexthop'] == nexthop]) + old_cidrs = {r['destination'] for r in old_routes + if r['nexthop'] == nexthop} new_cidrs = set( self.service_plugin.get_peer_cidrs_for_router(context, router_id)) @@ -330,7 +358,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): nctx.get_admin_context(), router['id']) if gateway is None or gateway['external_fixed_ips'] is None: - raise RouterIsNotVPNExternal(router_id=router['id']) + raise ext_gw.RouterIsNotVPNExternal(router_id=router['id']) v4_ip = v6_ip = None for fixed_ip in gateway['external_fixed_ips']: @@ -343,12 +371,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): v6_ip = addr return v4_ip, v6_ip - def _update_gateway(self, context, gateway_id, **kwargs): + def _update_gateway(self, context: nctx.Context, + gateway_id: str, **kwargs): gateway = {'gateway': kwargs} return self.service_plugin.update_gateway(context, gateway_id, gateway) @db_api.retry_if_session_inactive() - def _ensure_gateway(self, context, vpnservice): + def _ensure_gateway(self, context: nctx.Context, vpnservice) -> \ + ty.Dict[str, ty.Any]: gw = self.service_plugin.get_vpn_gw_dict_by_router_id( context, vpnservice['router_id'], refresh=True) if not gw: @@ -371,23 +401,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): return gw @db_api.CONTEXT_WRITER - def _setup(self, context, vpnservice_dict): + def _setup(self, context: nctx.Context, + vpnservice_dict: ty.Dict[str, ty.Any]): router_id = vpnservice_dict['router_id'] agent = self.service_plugin.schedule_router(context, router_id) if not agent: raise vpnaas.NoVPNAgentAvailable agent_host = agent['host'] - gateway = self._ensure_gateway(context, vpnservice_dict) + gateway: ty.Optional[ty.Dict[str, ty.Any]] = self._ensure_gateway( + context, vpnservice_dict) # If the gateway status is ACTIVE the ports have been created already - if gateway['status'] == lib_constants.ACTIVE: + if gateway and gateway['status'] == lib_constants.ACTIVE: return - vpnservice = self.service_plugin._get_vpnservice(context, - vpnservice_dict['id']) + vpnservice = self.service_plugin._get_vpnservice( + context, vpnservice_dict["id"]) network_id = vpnservice.router.gw_port.network_id - gateway_update = {} # keeps track of already-created IDs + # keeps track of already-created IDs + gateway_update: ty.Dict[str, ty.Any] = {} try: self.make_gw_port(router_id, network_id, agent_host, gateway_update) @@ -396,16 +429,16 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): agent_host, gateway_update) except Exception: - self._update_gateway(context, gateway['id'], + self._update_gateway(context, gateway['id'], # type: ignore status=lib_constants.ERROR, **gateway_update) raise - self._update_gateway(context, gateway['id'], + self._update_gateway(context, gateway['id'], # type: ignore status=lib_constants.ACTIVE, **gateway_update) - def _cleanup(self, context, router_id): + def _cleanup(self, context: nctx.Context, router_id: str): gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context, router_id) if not gw: @@ -423,7 +456,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): status=lib_constants.ERROR) raise - def create_vpnservice(self, context, vpnservice_dict): + def create_vpnservice(self, context: nctx.Context, vpnservice_dict): try: self._setup(context, vpnservice_dict) except Exception: @@ -435,7 +468,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): raise super().create_vpnservice(context, vpnservice_dict) - def delete_vpnservice(self, context, vpnservice): + def delete_vpnservice(self, context: nctx.Context, vpnservice): router_id = vpnservice['router_id'] super().delete_vpnservice(context, vpnservice) services = self.service_plugin.get_vpnservices(context) @@ -443,11 +476,13 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): if router_id not in router_ids: self._cleanup(context, router_id) - def create_ipsec_site_connection(self, context, ipsec_site_connection): + def create_ipsec_site_connection(self, context: nctx.Context, + ipsec_site_connection): self._update_static_routes(context, ipsec_site_connection) super().create_ipsec_site_connection(context, ipsec_site_connection) - def delete_ipsec_site_connection(self, context, ipsec_site_connection): + def delete_ipsec_site_connection(self, context: nctx.Context, + ipsec_site_connection): self._update_static_routes(context, ipsec_site_connection) super().delete_ipsec_site_connection(context, ipsec_site_connection) @@ -457,11 +492,11 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): super().update_ipsec_site_connection( context, old_ipsec_site_connection, ipsec_site_connection) - def _update_port_binding(self, context, port_id, host): + def _update_port_binding(self, context: nctx.Context, port_id, host): port_data = {'binding:host_id': host} self.core_plugin.update_port(context, port_id, {'port': port_data}) - def update_port_bindings(self, context, router_id, host): + def update_port_bindings(self, context: nctx.Context, router_id, host): gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context, router_id) if not gw: @@ -475,7 +510,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver): class IPsecOvnVpnAgentApi(base_ipsec.IPsecVpnAgentApi): - def _agent_notification(self, context, method, router_id, + def _agent_notification(self, context: nctx.Context, method, router_id, version=None, **kwargs): """Notify update for the agent. @@ -508,7 +543,7 @@ class IPsecOvnVPNDriver(BaseOvnIPsecVPNDriver): self.agent_rpc = IPsecOvnVpnAgentApi( topics.IPSEC_AGENT_TOPIC, BASE_IPSEC_VERSION, self) - def start_rpc_listeners(self): + def start_rpc_listeners(self) -> ty.List: self.endpoints = [IPsecVpnOvnDriverCallBack(self)] self.conn = n_rpc.Connection() self.conn.create_consumer( diff --git a/neutron_vpnaas/services/vpn/vpn_service.py b/neutron_vpnaas/services/vpn/vpn_service.py index e2b463d66..9d3fb4e1c 100644 --- a/neutron_vpnaas/services/vpn/vpn_service.py +++ b/neutron_vpnaas/services/vpn/vpn_service.py @@ -13,25 +13,30 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + from neutron.services import provider_configuration as provconfig from neutron_lib.exceptions import vpn as vpn_exception +from oslo_config import cfg from oslo_log import log as logging from oslo_utils import importutils +from neutron_vpnaas.services.vpn.device_drivers import ipsec + LOG = logging.getLogger(__name__) DEVICE_DRIVERS = 'device_drivers' -class VPNService(object): +class VPNService: """VPN Service observer.""" def __init__(self, l3_agent): - self.conf = l3_agent.conf + self.conf: cfg.ConfigOpts = l3_agent.conf - def load_device_drivers(self, host): + def load_device_drivers(self, host) -> ty.List[ipsec.IPsecDriver]: """Loads one or more device drivers for VPNaaS.""" - drivers = [] + drivers: ty.List[ipsec.IPsecDriver] = [] for device_driver in self.conf.vpnagent.vpn_device_driver: device_driver = provconfig.get_provider_driver_class( device_driver, DEVICE_DRIVERS) diff --git a/neutron_vpnaas/tests/functional/common/test_scenario.py b/neutron_vpnaas/tests/functional/common/test_scenario.py index dd2f2f14b..46e42e0ed 100644 --- a/neutron_vpnaas/tests/functional/common/test_scenario.py +++ b/neutron_vpnaas/tests/functional/common/test_scenario.py @@ -178,7 +178,7 @@ def get_ovs_bridge(br_name): Vm = collections.namedtuple('Vm', ['namespace', 'port_ip']) -class SiteInfo(object): +class SiteInfo: """Holds info on the router, ports, service, and connection.""" diff --git a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py index cac246d19..7bf885cdc 100644 --- a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py +++ b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_agentschedulers_db.py @@ -44,7 +44,7 @@ VPN_HOSTA = "host-1" VPN_HOSTB = "host-2" -class VPNAgentSchedulerTestMixIn(object): +class VPNAgentSchedulerTestMixIn: def _request_list(self, path, admin_context=True, expected_code=exc.HTTPOk.code): req = self._path_req(path, admin_context=admin_context) diff --git a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py index 54860ade1..e401012b5 100644 --- a/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py +++ b/neutron_vpnaas/tests/unit/db/vpn/test_vpn_db.py @@ -69,7 +69,7 @@ class TestVpnCorePlugin(test_l3_plugin.TestL3NatIntPlugin, self.router_scheduler = l3_agent_scheduler.ChanceScheduler() -class VPNTestMixin(object): +class VPNTestMixin: resource_prefix_map = dict( (k.replace('_', '-'), "/vpn") @@ -1718,7 +1718,7 @@ class TestVpnaas(VPNPluginDbTestCase): # tests. # TODO(pcm): Put helpers in another module for sharing -class NeutronResourcesMixin(object): +class NeutronResourcesMixin: def create_network(self, overrides=None): """Create database entry for network.""" diff --git a/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py b/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py index c698a2b3a..e4f62bc0d 100644 --- a/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py +++ b/neutron_vpnaas/tests/unit/services/vpn/service_drivers/test_ovn_ipsec.py @@ -57,7 +57,7 @@ class FakeSqlQueryObject(dict): super(FakeSqlQueryObject, self).__init__(**entries) -class FakeGatewayDB(object): +class FakeGatewayDB: def __init__(self): self.gateways_by_router = {} self.gateways_by_id = {} diff --git a/setup.cfg b/setup.cfg index 44d70768d..b452ceb32 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,3 +51,7 @@ oslo.policy.policies = neutron-vpnaas = neutron_vpnaas.policies:list_rules neutron.policies = neutron-vpnaas = neutron_vpnaas.policies:list_rules + +[mypy] +files = neutron_vpnaas/*.py,neutron_vpnaas/agent/**/*.py,neutron_vpnaas/api/**/*.py,neutron_vpnaas/cmd/**/*.py,neutron_vpnaas/db/**/*.py,neutron_vpnaas/extensions/**/*.py,neutron_vpnaas/policies/**/*.py,neutron_vpnaas/scheduler/**/*.py,neutron_vpnaas/services/**/*.py +ignore_missing_imports = true \ No newline at end of file diff --git a/test-requirements.txt b/test-requirements.txt index bddbc8133..71a011fc3 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -11,3 +11,4 @@ stestr>=1.0.0 # Apache-2.0 # see https://review.opendev.org/c/openstack/neutron/+/848706 WebTest>=2.0.27 # MIT +mypy>=1.7.0 # MIT diff --git a/tools/check_i18n.py b/tools/check_i18n.py index 9c49f941e..be81bc964 100644 --- a/tools/check_i18n.py +++ b/tools/check_i18n.py @@ -35,7 +35,7 @@ class ASTWalker(compiler.visitor.ASTVisitor): compiler.visitor.ASTVisitor.default(self, node, *args) -class Visitor(object): +class Visitor: def __init__(self, filename, i18n_msg_predicates, msg_format_checkers, debug): diff --git a/tox.ini b/tox.ini index 3e939b52d..e6eb57422 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,16 @@ [tox] +minversion = 4.0.0 +ignore_basepython_conflict=true +requires = virtualenv>=20.17.1 envlist = py39,py38,pep8,docs -minversion = 3.18.0 [testenv] +usedevelop = True setenv = VIRTUAL_ENV={envdir} OS_LOG_CAPTURE={env:OS_LOG_CAPTURE:true} OS_STDOUT_CAPTURE={env:OS_STDOUT_CAPTURE:true} OS_STDERR_CAPTURE={env:OS_STDERR_CAPTURE:true} PYTHONWARNINGS=default::DeprecationWarning -usedevelop = True deps = -c{env:TOX_CONSTRAINTS_FILE:https://opendev.org/openstack/requirements/raw/branch/master/upper-constraints.txt} -r{toxinidir}/requirements.txt -r{toxinidir}/test-requirements.txt @@ -83,6 +85,7 @@ commands = neutron-db-manage --subproject neutron-vpnaas --database-connection sqlite:// check_migration {[testenv:genconfig]commands} {[testenv:genpolicy]commands} + {[testenv:mypy]commands} allowlist_externals = bash @@ -153,3 +156,10 @@ commands = bash {toxinidir}/tools/generate_config_file_samples.sh [testenv:genpolicy] commands = oslopolicy-sample-generator --config-file=etc/oslo-policy-generator/policy.conf + +[testenv:mypy] +description = + Run type checks. +deps = {[testenv]deps} +commands = + mypy --install-types --non-interactive --check-untyped-defs