From fc37d8303dced9dcbebe0e65c491b4cd416324f4 Mon Sep 17 00:00:00 2001 From: Michael Johnson Date: Sat, 24 Feb 2024 23:43:59 +0000 Subject: [PATCH] Enable nftables rules for SR-IOV VIPs This patch enables setting the nftables rules in Amphora using SR-IOV VIPs. Change-Id: I554aac422371abafb4bb04e2d0df3fce3fa169d4 --- .../backends/agent/api_server/rules_schema.py | 52 +++++++++ .../backends/agent/api_server/server.py | 26 +++++ octavia/amphorae/backends/utils/interface.py | 10 +- .../amphorae/backends/utils/nftable_utils.py | 78 ++++++++++++-- octavia/amphorae/drivers/driver_base.py | 11 ++ .../drivers/haproxy/rest_api_driver.py | 22 ++++ .../amphorae/drivers/noop_driver/driver.py | 3 + octavia/common/constants.py | 11 +- octavia/common/utils.py | 6 ++ .../controller/worker/v2/controller_worker.py | 56 ++++++++-- .../worker/v2/flows/amphora_flows.py | 58 +++++++++- .../controller/worker/v2/flows/flow_utils.py | 17 +-- .../worker/v2/flows/listener_flows.py | 100 ++++++++++++++++- .../worker/v2/flows/load_balancer_flows.py | 44 ++++++-- .../worker/v2/tasks/amphora_driver_tasks.py | 23 ++++ .../worker/v2/tasks/database_tasks.py | 46 ++++++++ .../controller/worker/v2/tasks/shim_tasks.py | 28 +++++ octavia/db/repositories.py | 18 ++++ .../backend/agent/api_server/test_server.py | 37 +++++++ .../tests/functional/db/test_repositories.py | 8 ++ .../amphorae/backends/utils/test_interface.py | 58 +++++++++- .../backends/utils/test_nftable_utils.py | 98 ++++++++++++++++- .../drivers/haproxy/test_rest_api_driver.py | 26 ++++- .../haproxy/test_rest_api_driver_1_0.py | 10 ++ octavia/tests/unit/common/test_utils.py | 39 +++++++ .../worker/v2/flows/test_listener_flows.py | 65 +++++++++-- .../v2/flows/test_load_balancer_flows.py | 11 +- .../v2/tasks/test_amphora_driver_tasks.py | 28 +++++ .../worker/v2/tasks/test_database_tasks.py | 100 +++++++++++++++++ .../worker/v2/tasks/test_shim_tasks.py | 33 ++++++ .../worker/v2/test_controller_worker.py | 101 ++++++++++++------ ...port-for-SR-IOV-VIPs-862858ec61e9955b.yaml | 9 ++ 32 files changed, 1121 insertions(+), 111 deletions(-) create mode 100644 octavia/amphorae/backends/agent/api_server/rules_schema.py create mode 100644 octavia/controller/worker/v2/tasks/shim_tasks.py create mode 100644 octavia/tests/unit/controller/worker/v2/tasks/test_shim_tasks.py create mode 100644 releasenotes/notes/Add-support-for-SR-IOV-VIPs-862858ec61e9955b.yaml diff --git a/octavia/amphorae/backends/agent/api_server/rules_schema.py b/octavia/amphorae/backends/agent/api_server/rules_schema.py new file mode 100644 index 0000000000..57ad8fe02d --- /dev/null +++ b/octavia/amphorae/backends/agent/api_server/rules_schema.py @@ -0,0 +1,52 @@ +# Copyright 2024 Red Hat, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from octavia_lib.common import constants as lib_consts + +from octavia.common import constants as consts + +# This is a JSON schema validation dictionary +# https://json-schema.org/latest/json-schema-validation.html + +SUPPORTED_RULES_SCHEMA = { + '$schema': 'http://json-schema.org/draft-07/schema#', + 'title': 'Octavia Amphora NFTables Rules Schema', + 'description': 'This schema is used to validate an nftables rules JSON ' + 'document sent from a controller.', + 'type': 'array', + 'items': { + 'additionalProperties': False, + 'properties': { + consts.PROTOCOL: { + 'type': 'string', + 'description': 'The protocol for the rule. One of: ' + 'TCP, UDP, VRRP, SCTP', + 'enum': list((lib_consts.PROTOCOL_SCTP, + lib_consts.PROTOCOL_TCP, + lib_consts.PROTOCOL_UDP, + consts.VRRP)) + }, + consts.CIDR: { + 'type': ['string', 'null'], + 'description': 'The allowed source CIDR.' + }, + consts.PORT: { + 'type': 'number', + 'description': 'The protocol port number.', + 'minimum': 1, + 'maximum': 65535 + } + }, + 'required': [consts.PROTOCOL, consts.CIDR, consts.PORT] + } +} diff --git a/octavia/amphorae/backends/agent/api_server/server.py b/octavia/amphorae/backends/agent/api_server/server.py index c4c3d521f4..1d8686783c 100644 --- a/octavia/amphorae/backends/agent/api_server/server.py +++ b/octavia/amphorae/backends/agent/api_server/server.py @@ -16,6 +16,7 @@ import os import stat import flask +from jsonschema import validate from oslo_config import cfg from oslo_log import log as logging import webob @@ -29,7 +30,9 @@ from octavia.amphorae.backends.agent.api_server import keepalivedlvs from octavia.amphorae.backends.agent.api_server import loadbalancer from octavia.amphorae.backends.agent.api_server import osutils from octavia.amphorae.backends.agent.api_server import plug +from octavia.amphorae.backends.agent.api_server import rules_schema from octavia.amphorae.backends.agent.api_server import util +from octavia.amphorae.backends.utils import nftable_utils from octavia.common import constants as consts @@ -137,6 +140,9 @@ class Server(object): self.app.add_url_rule(rule=PATH_PREFIX + '/interface/', view_func=self.get_interface, methods=['GET']) + self.app.add_url_rule(rule=PATH_PREFIX + '/interface//rules', + view_func=self.set_interface_rules, + methods=['PUT']) def upload_haproxy_config(self, amphora_id, lb_id): return self._loadbalancer.upload_haproxy_config(amphora_id, lb_id) @@ -257,3 +263,23 @@ class Server(object): def version_discovery(self): return webob.Response(json={'api_version': api_server.VERSION}) + + def set_interface_rules(self, ip_addr): + interface_webob = self._amphora_info.get_interface(ip_addr) + + if interface_webob.status_code != 200: + return interface_webob + interface = interface_webob.json['interface'] + + try: + rules_info = flask.request.get_json() + validate(rules_info, rules_schema.SUPPORTED_RULES_SCHEMA) + except Exception as e: + raise exceptions.BadRequest( + description='Invalid rules information') from e + + nftable_utils.write_nftable_vip_rules_file(interface, rules_info) + + nftable_utils.load_nftables_file() + + return webob.Response(json={'message': 'OK'}, status=200) diff --git a/octavia/amphorae/backends/utils/interface.py b/octavia/amphorae/backends/utils/interface.py index 057ce545fb..f2fbf8d828 100644 --- a/octavia/amphorae/backends/utils/interface.py +++ b/octavia/amphorae/backends/utils/interface.py @@ -210,15 +210,7 @@ class InterfaceController(object): nftable_utils.write_nftable_vip_rules_file(interface.name, []) - cmd = [consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE] - try: - subprocess.check_output(cmd, stderr=subprocess.STDOUT) - except Exception as e: - if hasattr(e, 'output'): - LOG.error(e.output) - else: - LOG.error(e) - raise + nftable_utils.load_nftables_file() def up(self, interface): LOG.info("Setting interface %s up", interface.name) diff --git a/octavia/amphorae/backends/utils/nftable_utils.py b/octavia/amphorae/backends/utils/nftable_utils.py index 8a84eb55a1..384f7bc0c5 100644 --- a/octavia/amphorae/backends/utils/nftable_utils.py +++ b/octavia/amphorae/backends/utils/nftable_utils.py @@ -13,8 +13,17 @@ # under the License. import os import stat +import subprocess +from octavia_lib.common import constants as lib_consts +from oslo_log import log as logging +from webob import exc + +from octavia.amphorae.backends.utils import network_namespace from octavia.common import constants as consts +from octavia.common import utils + +LOG = logging.getLogger(__name__) def write_nftable_vip_rules_file(interface_name, rules): @@ -28,7 +37,17 @@ def write_nftable_vip_rules_file(interface_name, rules): hook_string = (f' type filter hook ingress device {interface_name} ' f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n') - # Check if an existing rules file exists or we if need to create an + # Allow ICMP destination unreachable for PMTUD + icmp_string = ' icmp type destination-unreachable accept\n' + # Allow the required neighbor solicitation/discovery PMTUD ICMPV6 + icmpv6_string = (' icmpv6 type { nd-neighbor-solicit, ' + 'nd-router-advert, nd-neighbor-advert, packet-too-big, ' + 'destination-unreachable } accept\n') + # Allow DHCP responses + dhcp_string = ' udp sport 67 udp dport 68 accept\n' + dhcpv6_string = ' udp sport 547 udp dport 546 accept\n' + + # Check if an existing rules file exists or we be need to create an # "drop all" file with no rules except for VRRP. If it exists, we should # not overwrite it here as it could be a reboot unless we were passed new # rules. @@ -40,15 +59,21 @@ def write_nftable_vip_rules_file(interface_name, rules): # Clear the existing rules in the kernel # Note: The "nft -f" method is atomic, so clearing the rules will # not leave the amphora exposed. - file.write(f'flush chain {consts.NFT_FAMILY} ' - f'{consts.NFT_VIP_TABLE} {consts.NFT_VIP_CHAIN}\n') + # Create and delete the table to not get errors if the table does + # not exist yet. + file.write(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} ' + '{}\n') + file.write(f'delete table {consts.NFT_FAMILY} ' + f'{consts.NFT_VIP_TABLE}\n') file.write(table_string) file.write(chain_string) file.write(hook_string) - # TODO(johnsom) Add peer ports here consts.HAPROXY_BASE_PEER_PORT - # and ip protocol 112 for VRRP. Need the peer address + file.write(icmp_string) + file.write(icmpv6_string) + file.write(dhcp_string) + file.write(dhcpv6_string) for rule in rules: - file.write(f' {rule}\n') + file.write(f' {_build_rule_cmd(rule)}\n') file.write(' }\n') # close the chain file.write('}\n') # close the table else: # No existing rules, create the "drop all" base rules @@ -57,7 +82,44 @@ def write_nftable_vip_rules_file(interface_name, rules): file.write(table_string) file.write(chain_string) file.write(hook_string) - # TODO(johnsom) Add peer ports here consts.HAPROXY_BASE_PEER_PORT - # and ip protocol 112 for VRRP. Need the peer address + file.write(icmp_string) + file.write(icmpv6_string) + file.write(dhcp_string) + file.write(dhcpv6_string) file.write(' }\n') # close the chain file.write('}\n') # close the table + + +def _build_rule_cmd(rule): + prefix_saddr = '' + if rule[consts.CIDR] and rule[consts.CIDR] != '0.0.0.0/0': + cidr_ip_version = utils.ip_version(rule[consts.CIDR].split('/')[0]) + if cidr_ip_version == 4: + prefix_saddr = f'ip saddr {rule[consts.CIDR]} ' + elif cidr_ip_version == 6: + prefix_saddr = f'ip6 saddr {rule[consts.CIDR]} ' + else: + raise exc.HTTPBadRequest(explanation='Unknown ip version') + + if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_SCTP: + return f'{prefix_saddr}sctp dport {rule[consts.PORT]} accept' + if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_TCP: + return f'{prefix_saddr}tcp dport {rule[consts.PORT]} accept' + if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_UDP: + return f'{prefix_saddr}udp dport {rule[consts.PORT]} accept' + if rule[consts.PROTOCOL] == consts.VRRP: + return f'{prefix_saddr}ip protocol 112 accept' + raise exc.HTTPBadRequest(explanation='Unknown protocol used in rules') + + +def load_nftables_file(): + cmd = [consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE] + try: + with network_namespace.NetworkNamespace(consts.AMPHORA_NAMESPACE): + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except Exception as e: + if hasattr(e, 'output'): + LOG.error(e.output) + else: + LOG.error(e) + raise diff --git a/octavia/amphorae/drivers/driver_base.py b/octavia/amphorae/drivers/driver_base.py index c4ac3f6d6b..6f6b4253e8 100644 --- a/octavia/amphorae/drivers/driver_base.py +++ b/octavia/amphorae/drivers/driver_base.py @@ -252,6 +252,17 @@ class AmphoraLoadBalancerDriver(object, metaclass=abc.ABCMeta): :raises TimeOutException: The amphora didn't reply """ + @abc.abstractmethod + def set_interface_rules(self, amphora: db_models.Amphora, ip_address, + rules): + """Sets interface firewall rules in the amphora + + :param amphora: The amphora to query. + :param ip_address: The IP address assigned to the interface the rules + will be applied on. + :param rules: The l1st of allow rules to apply. + """ + class VRRPDriverMixin(object, metaclass=abc.ABCMeta): """Abstract mixin class for VRRP support in loadbalancer amphorae diff --git a/octavia/amphorae/drivers/haproxy/rest_api_driver.py b/octavia/amphorae/drivers/haproxy/rest_api_driver.py index 676f2f1807..d0722d3b2c 100644 --- a/octavia/amphorae/drivers/haproxy/rest_api_driver.py +++ b/octavia/amphorae/drivers/haproxy/rest_api_driver.py @@ -598,6 +598,24 @@ class HaproxyAmphoraLoadBalancerDriver( amphora, ip_address, timeout_dict, log_error=False) return response_json.get('interface', None) + def set_interface_rules(self, amphora: db_models.Amphora, + ip_address, rules): + """Sets interface firewall rules in the amphora + + :param amphora: The amphora to query. + :param ip_address: The IP address assigned to the interface the rules + will be applied on. + :param rules: The l1st of allow rules to apply. + """ + try: + self._populate_amphora_api_version(amphora) + self.clients[amphora.api_version].set_interface_rules( + amphora, ip_address, rules) + except exc.NotFound as e: + LOG.debug('Amphora %s does not support the set_interface_rules ' + 'API.', amphora.id) + raise driver_except.AmpDriverNotImplementedError() from e + # Check a custom hostname class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter): @@ -867,3 +885,7 @@ class AmphoraAPIClient1_0(AmphoraAPIClientBase): def update_agent_config(self, amp, agent_config, timeout_dict=None): r = self.put(amp, 'config', timeout_dict, data=agent_config) return exc.check_exception(r) + + def set_interface_rules(self, amp, ip_address, rules): + r = self.put(amp, f'interface/{ip_address}/rules', json=rules) + return exc.check_exception(r) diff --git a/octavia/amphorae/drivers/noop_driver/driver.py b/octavia/amphorae/drivers/noop_driver/driver.py index df3648d38a..8b3cdff5b0 100644 --- a/octavia/amphorae/drivers/noop_driver/driver.py +++ b/octavia/amphorae/drivers/noop_driver/driver.py @@ -218,3 +218,6 @@ class NoopAmphoraLoadBalancerDriver( def check(self, amphora, timeout_dict=None): pass + + def set_interface_rules(self, amphora, ip_address, rules): + pass diff --git a/octavia/common/constants.py b/octavia/common/constants.py index 426377b8f4..ce87a8cc53 100644 --- a/octavia/common/constants.py +++ b/octavia/common/constants.py @@ -308,6 +308,7 @@ AMP_DATA = 'amp_data' AMP_VRRP_INT = 'amp_vrrp_int' AMPHORA = 'amphora' AMPHORA_DICT = 'amphora_dict' +AMPHORA_FIREWALL_RULES = 'amphora_firewall_rules' AMPHORA_ID = 'amphora_id' AMPHORA_INDEX = 'amphora_index' AMPHORA_NETWORK_CONFIG = 'amphora_network_config' @@ -460,6 +461,7 @@ VIP_VNIC_TYPE = 'vip_vnic_type' VNIC_TYPE = 'vnic_type' VNIC_TYPE_DIRECT = 'direct' VNIC_TYPE_NORMAL = 'normal' +VRRP = 'vrrp' VRRP_ID = 'vrrp_id' VRRP_IP = 'vrrp_ip' VRRP_GROUP = 'vrrp_group' @@ -468,6 +470,7 @@ VRRP_PORT_ID = 'vrrp_port_id' VRRP_PRIORITY = 'vrrp_priority' # Taskflow flow and task names +AMP_UPDATE_FW_SUBFLOW = 'amphora-update-firewall-subflow' CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow' CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow' CREATE_AMPHORA_RETRY_SUBFLOW = 'octavia-create-amphora-retry-subflow' @@ -496,6 +499,7 @@ DELETE_L7RULE_FLOW = 'octavia-delete-l7policy-flow' FAILOVER_AMPHORA_FLOW = 'octavia-failover-amphora-flow' FAILOVER_LOADBALANCER_FLOW = 'octavia-failover-loadbalancer-flow' FINALIZE_AMPHORA_FLOW = 'octavia-finalize-amphora-flow' +FIREWALL_RULES_SUBFLOW = 'firewall-rules-subflow' LOADBALANCER_NETWORKING_SUBFLOW = 'octavia-new-loadbalancer-net-subflow' UPDATE_HEALTH_MONITOR_FLOW = 'octavia-update-health-monitor-flow' UPDATE_LISTENER_FLOW = 'octavia-update-listener-flow' @@ -583,6 +587,7 @@ CREATE_VIP_BASE_PORT = 'create-vip-base-port' DELETE_AMPHORA = 'delete-amphora' DELETE_PORT = 'delete-port' DISABLE_AMP_HEALTH_MONITORING = 'disable-amphora-health-monitoring' +GET_AMPHORA_FIREWALL_RULES = 'get-amphora-firewall-rules' GET_AMPHORA_NETWORK_CONFIGS_BY_ID = 'get-amphora-network-configs-by-id' GET_AMPHORAE_FROM_LB = 'get-amphorae-from-lb' GET_SUBNET_FROM_VIP = 'get-subnet-from-vip' @@ -595,6 +600,7 @@ RELOAD_LB_AFTER_AMP_ASSOC = 'reload-lb-after-amp-assoc' RELOAD_LB_AFTER_AMP_ASSOC_FULL_GRAPH = 'reload-lb-after-amp-assoc-full-graph' RELOAD_LB_AFTER_PLUG_VIP = 'reload-lb-after-plug-vip' RELOAD_LB_BEFOR_ALLOCATE_VIP = 'reload-lb-before-allocate-vip' +SET_AMPHORA_FIREWALL_RULES = 'set-amphora-firewall-rules' UPDATE_AMP_FAILOVER_DETAILS = 'update-amp-failover-details' @@ -974,6 +980,7 @@ NFT_ADD = 'add' NFT_CMD = '/usr/sbin/nft' NFT_FAMILY = 'inet' NFT_VIP_RULES_FILE = '/var/lib/octavia/nftables-vip.rules' -NFT_VIP_TABLE = 'amphora-vip' -NFT_VIP_CHAIN = 'amphora-vip-chain' +NFT_VIP_TABLE = 'amphora_vip' +NFT_VIP_CHAIN = 'amphora_vip_chain' NFT_SRIOV_PRIORITY = '-310' +PROTOCOL = 'protocol' diff --git a/octavia/common/utils.py b/octavia/common/utils.py index 284e70ed1e..a81ee5284d 100644 --- a/octavia/common/utils.py +++ b/octavia/common/utils.py @@ -191,3 +191,9 @@ class exception_logger(object): self.logger(e) return None return call + + +def map_protocol_to_nftable_protocol(rule_dict): + rule_dict[constants.PROTOCOL] = ( + constants.L4_PROTOCOL_MAP[rule_dict[constants.PROTOCOL]]) + return rule_dict diff --git a/octavia/controller/worker/v2/controller_worker.py b/octavia/controller/worker/v2/controller_worker.py index 97c8937f7c..575f1c38c8 100644 --- a/octavia/controller/worker/v2/controller_worker.py +++ b/octavia/controller/worker/v2/controller_worker.py @@ -270,6 +270,14 @@ class ControllerWorker(object): raise db_exceptions.NoResultFound load_balancer = db_listener.load_balancer + flavor_dict = {} + if load_balancer.flavor_id: + with session.begin(): + flavor_dict = ( + self._flavor_repo.get_flavor_metadata_dict( + session, load_balancer.flavor_id)) + flavor_dict[constants.LOADBALANCER_TOPOLOGY] = load_balancer.topology + provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( load_balancer).to_dict(recurse=True) @@ -279,7 +287,7 @@ class ControllerWorker(object): self.run_flow( flow_utils.get_create_listener_flow, - store=store) + flavor_dict=flavor_dict, store=store) def delete_listener(self, listener): """Deletes a listener. @@ -288,12 +296,32 @@ class ControllerWorker(object): :returns: None :raises ListenerNotFound: The referenced listener was not found """ + try: + db_lb = self._get_db_obj_until_pending_update( + self._lb_repo, listener[constants.LOADBALANCER_ID]) + except tenacity.RetryError as e: + LOG.warning('Loadbalancer did not go into %s in 60 seconds. ' + 'This either due to an in-progress Octavia upgrade ' + 'or an overloaded and failing database. Assuming ' + 'an upgrade is in progress and continuing.', + constants.PENDING_UPDATE) + db_lb = e.last_attempt.result() + + flavor_dict = {} + if db_lb.flavor_id: + session = db_apis.get_session() + with session.begin(): + flavor_dict = ( + self._flavor_repo.get_flavor_metadata_dict( + session, db_lb.flavor_id)) + flavor_dict[constants.LOADBALANCER_TOPOLOGY] = db_lb.topology + store = {constants.LISTENER: listener, constants.LOADBALANCER_ID: listener[constants.LOADBALANCER_ID], constants.PROJECT_ID: listener[constants.PROJECT_ID]} self.run_flow( - flow_utils.get_delete_listener_flow, + flow_utils.get_delete_listener_flow, flavor_dict=flavor_dict, store=store) def update_listener(self, listener, listener_updates): @@ -315,12 +343,21 @@ class ControllerWorker(object): constants.PENDING_UPDATE) db_lb = e.last_attempt.result() + session = db_apis.get_session() + flavor_dict = {} + if db_lb.flavor_id: + with session.begin(): + flavor_dict = ( + self._flavor_repo.get_flavor_metadata_dict( + session, db_lb.flavor_id)) + flavor_dict[constants.LOADBALANCER_TOPOLOGY] = db_lb.topology + store = {constants.LISTENER: listener, constants.UPDATE_DICT: listener_updates, constants.LOADBALANCER_ID: db_lb.id, constants.LISTENERS: [listener]} self.run_flow( - flow_utils.get_update_listener_flow, + flow_utils.get_update_listener_flow, flavor_dict=flavor_dict, store=store) @tenacity.retry( @@ -998,16 +1035,14 @@ class ControllerWorker(object): lb_id = loadbalancer.id # Even if the LB doesn't have a flavor, create one and # pass through the topology. + flavor_dict = {} if loadbalancer.flavor_id: with session.begin(): flavor_dict = ( self._flavor_repo.get_flavor_metadata_dict( session, loadbalancer.flavor_id)) - flavor_dict[constants.LOADBALANCER_TOPOLOGY] = ( - loadbalancer.topology) - else: - flavor_dict = {constants.LOADBALANCER_TOPOLOGY: - loadbalancer.topology} + flavor_dict[constants.LOADBALANCER_TOPOLOGY] = ( + loadbalancer.topology) if loadbalancer.availability_zone: with session.begin(): az_metadata = ( @@ -1162,13 +1197,12 @@ class ControllerWorker(object): # We must provide a topology in the flavor definition # here for the amphora to be created with the correct # configuration. + flavor = {} if lb.flavor_id: with session.begin(): flavor = self._flavor_repo.get_flavor_metadata_dict( session, lb.flavor_id) - flavor[constants.LOADBALANCER_TOPOLOGY] = lb.topology - else: - flavor = {constants.LOADBALANCER_TOPOLOGY: lb.topology} + flavor[constants.LOADBALANCER_TOPOLOGY] = lb.topology if lb: provider_lb_dict = ( diff --git a/octavia/controller/worker/v2/flows/amphora_flows.py b/octavia/controller/worker/v2/flows/amphora_flows.py index 1e8149fbf4..8b8902da33 100644 --- a/octavia/controller/worker/v2/flows/amphora_flows.py +++ b/octavia/controller/worker/v2/flows/amphora_flows.py @@ -28,6 +28,7 @@ from octavia.controller.worker.v2.tasks import database_tasks from octavia.controller.worker.v2.tasks import lifecycle_tasks from octavia.controller.worker.v2.tasks import network_tasks from octavia.controller.worker.v2.tasks import retry_tasks +from octavia.controller.worker.v2.tasks import shim_tasks CONF = cfg.CONF LOG = logging.getLogger(__name__) @@ -227,7 +228,7 @@ class AmphoraFlows(object): def get_vrrp_subflow(self, prefix, timeout_dict=None, create_vrrp_group=True, - get_amphorae_status=True): + get_amphorae_status=True, flavor_dict=None): sf_name = prefix + '-' + constants.GET_VRRP_SUBFLOW vrrp_subflow = linear_flow.Flow(sf_name) @@ -259,7 +260,7 @@ class AmphoraFlows(object): # unordered subflow. update_amps_subflow = unordered_flow.Flow('VRRP-update-subflow') - # We have three tasks to run in order, per amphora + # We have tasks to run in order, per amphora amp_0_subflow = linear_flow.Flow('VRRP-amp-0-update-subflow') amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexUpdateVRRPInterface( @@ -279,6 +280,20 @@ class AmphoraFlows(object): inject={constants.AMPHORA_INDEX: 0, constants.TIMEOUT_DICT: timeout_dict})) + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + amp_0_subflow.add(database_tasks.GetAmphoraFirewallRules( + name=sf_name + '-0-' + constants.GET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORAE_NETWORK_CONFIG), + provides=constants.AMPHORA_FIREWALL_RULES, + inject={constants.AMPHORA_INDEX: 0})) + + amp_0_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules( + name=sf_name + '-0-' + constants.SET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORA_FIREWALL_RULES), + inject={constants.AMPHORA_INDEX: 0})) + amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart( name=sf_name + '-0-' + constants.AMP_VRRP_START, requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS), @@ -304,6 +319,21 @@ class AmphoraFlows(object): rebind={constants.NEW_AMPHORA_ID: constants.AMPHORA_ID}, inject={constants.AMPHORA_INDEX: 1, constants.TIMEOUT_DICT: timeout_dict})) + + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + amp_1_subflow.add(database_tasks.GetAmphoraFirewallRules( + name=sf_name + '-1-' + constants.GET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORAE_NETWORK_CONFIG), + provides=constants.AMPHORA_FIREWALL_RULES, + inject={constants.AMPHORA_INDEX: 1})) + + amp_1_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules( + name=sf_name + '-1-' + constants.SET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORA_FIREWALL_RULES), + inject={constants.AMPHORA_INDEX: 1})) + amp_1_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart( name=sf_name + '-1-' + constants.AMP_VRRP_START, requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS), @@ -443,6 +473,27 @@ class AmphoraFlows(object): requires=(constants.AMPHORA, constants.LOADBALANCER, constants.AMPHORAE_NETWORK_CONFIG))) + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + amp_for_failover_flow.add( + shim_tasks.AmphoraToAmphoraeWithVRRPIP( + name=prefix + '-' + constants.AMPHORA_TO_AMPHORAE_VRRP_IP, + requires=(constants.AMPHORA, constants.BASE_PORT), + provides=constants.NEW_AMPHORAE)) + amp_for_failover_flow.add(database_tasks.GetAmphoraFirewallRules( + name=prefix + '-' + constants.GET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORAE_NETWORK_CONFIG), + rebind={constants.AMPHORAE: constants.NEW_AMPHORAE}, + provides=constants.AMPHORA_FIREWALL_RULES, + inject={constants.AMPHORA_INDEX: 0})) + amp_for_failover_flow.add( + amphora_driver_tasks.SetAmphoraFirewallRules( + name=prefix + '-' + constants.SET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORA_FIREWALL_RULES), + rebind={constants.AMPHORAE: constants.NEW_AMPHORAE}, + inject={constants.AMPHORA_INDEX: 0})) + # Plug member ports amp_for_failover_flow.add(network_tasks.CalculateAmphoraDelta( name=prefix + '-' + constants.CALCULATE_AMPHORA_DELTA, @@ -601,7 +652,8 @@ class AmphoraFlows(object): failover_amp_flow.add( self.get_vrrp_subflow(constants.GET_VRRP_SUBFLOW, timeout_dict, create_vrrp_group=False, - get_amphorae_status=False)) + get_amphorae_status=False, + flavor_dict=flavor_dict)) # Reload the listener. This needs to be done here because # it will create the required haproxy check scripts for diff --git a/octavia/controller/worker/v2/flows/flow_utils.py b/octavia/controller/worker/v2/flows/flow_utils.py index 97d1c41754..bd5d56b3e0 100644 --- a/octavia/controller/worker/v2/flows/flow_utils.py +++ b/octavia/controller/worker/v2/flows/flow_utils.py @@ -139,20 +139,21 @@ def get_update_l7rule_flow(): return L7_RULES_FLOWS.get_update_l7rule_flow() -def get_create_listener_flow(): - return LISTENER_FLOWS.get_create_listener_flow() +def get_create_listener_flow(flavor_dict=None): + return LISTENER_FLOWS.get_create_listener_flow(flavor_dict=flavor_dict) -def get_create_all_listeners_flow(): - return LISTENER_FLOWS.get_create_all_listeners_flow() +def get_create_all_listeners_flow(flavor_dict=None): + return LISTENER_FLOWS.get_create_all_listeners_flow( + flavor_dict=flavor_dict) -def get_delete_listener_flow(): - return LISTENER_FLOWS.get_delete_listener_flow() +def get_delete_listener_flow(flavor_dict=None): + return LISTENER_FLOWS.get_delete_listener_flow(flavor_dict=flavor_dict) -def get_update_listener_flow(): - return LISTENER_FLOWS.get_update_listener_flow() +def get_update_listener_flow(flavor_dict=None): + return LISTENER_FLOWS.get_update_listener_flow(flavor_dict=flavor_dict) def get_create_member_flow(): diff --git a/octavia/controller/worker/v2/flows/listener_flows.py b/octavia/controller/worker/v2/flows/listener_flows.py index cc80d8484b..70af587e7a 100644 --- a/octavia/controller/worker/v2/flows/listener_flows.py +++ b/octavia/controller/worker/v2/flows/listener_flows.py @@ -14,6 +14,7 @@ # from taskflow.patterns import linear_flow +from taskflow.patterns import unordered_flow from octavia.common import constants from octavia.controller.worker.v2.tasks import amphora_driver_tasks @@ -24,7 +25,7 @@ from octavia.controller.worker.v2.tasks import network_tasks class ListenerFlows(object): - def get_create_listener_flow(self): + def get_create_listener_flow(self, flavor_dict=None): """Create a flow to create a listener :returns: The flow for creating a listener @@ -36,13 +37,18 @@ class ListenerFlows(object): requires=constants.LOADBALANCER_ID)) create_listener_flow.add(network_tasks.UpdateVIP( requires=constants.LISTENERS)) + + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + create_listener_flow.add(*self._get_firewall_rules_subflow( + flavor_dict)) + create_listener_flow.add(database_tasks. MarkLBAndListenersActiveInDB( requires=(constants.LOADBALANCER_ID, constants.LISTENERS))) return create_listener_flow - def get_create_all_listeners_flow(self): + def get_create_all_listeners_flow(self, flavor_dict=None): """Create a flow to create all listeners :returns: The flow for creating all listeners @@ -60,12 +66,17 @@ class ListenerFlows(object): requires=constants.LOADBALANCER_ID)) create_all_listeners_flow.add(network_tasks.UpdateVIP( requires=constants.LISTENERS)) + + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + create_all_listeners_flow.add(*self._get_firewall_rules_subflow( + flavor_dict)) + create_all_listeners_flow.add( database_tasks.MarkHealthMonitorsOnlineInDB( requires=constants.LOADBALANCER)) return create_all_listeners_flow - def get_delete_listener_flow(self): + def get_delete_listener_flow(self, flavor_dict=None): """Create a flow to delete a listener :returns: The flow for deleting a listener @@ -79,6 +90,11 @@ class ListenerFlows(object): requires=constants.LOADBALANCER_ID)) delete_listener_flow.add(database_tasks.DeleteListenerInDB( requires=constants.LISTENER)) + + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + delete_listener_flow.add(*self._get_firewall_rules_subflow( + flavor_dict)) + delete_listener_flow.add(database_tasks.DecrementListenerQuota( requires=constants.PROJECT_ID)) delete_listener_flow.add(database_tasks.MarkLBActiveInDBByListener( @@ -86,7 +102,7 @@ class ListenerFlows(object): return delete_listener_flow - def get_delete_listener_internal_flow(self, listener): + def get_delete_listener_internal_flow(self, listener, flavor_dict=None): """Create a flow to delete a listener and l7policies internally (will skip deletion on the amp and marking LB active) @@ -104,13 +120,22 @@ class ListenerFlows(object): name='delete_listener_in_db_' + listener_id, requires=constants.LISTENER, inject={constants.LISTENER: listener})) + + # Currently the flavor_dict will always be None since there is + # no point updating the firewall rules when deleting the LB. + # However, this may be used for additional flows in the future, so + # adding this code for completeness. + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + delete_listener_flow.add(*self._get_firewall_rules_subflow( + flavor_dict)) + delete_listener_flow.add(database_tasks.DecrementListenerQuota( name='decrement_listener_quota_' + listener_id, requires=constants.PROJECT_ID)) return delete_listener_flow - def get_update_listener_flow(self): + def get_update_listener_flow(self, flavor_dict=None): """Create a flow to update a listener :returns: The flow for updating a listener @@ -122,6 +147,11 @@ class ListenerFlows(object): requires=constants.LOADBALANCER_ID)) update_listener_flow.add(network_tasks.UpdateVIP( requires=constants.LISTENERS)) + + if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): + update_listener_flow.add(*self._get_firewall_rules_subflow( + flavor_dict)) + update_listener_flow.add(database_tasks.UpdateListenerInDB( requires=[constants.LISTENER, constants.UPDATE_DICT])) update_listener_flow.add(database_tasks. @@ -130,3 +160,63 @@ class ListenerFlows(object): constants.LISTENERS))) return update_listener_flow + + def _get_firewall_rules_subflow(self, flavor_dict): + """Creates a subflow that updates the firewall rules in the amphorae. + + :returns: The subflow for updating firewall rules in the amphorae. + """ + sf_name = constants.FIREWALL_RULES_SUBFLOW + fw_rules_subflow = linear_flow.Flow(sf_name) + + fw_rules_subflow.add(database_tasks.GetAmphoraeFromLoadbalancer( + name=sf_name + '-' + constants.GET_AMPHORAE_FROM_LB, + requires=constants.LOADBALANCER_ID, + provides=constants.AMPHORAE)) + + fw_rules_subflow.add(network_tasks.GetAmphoraeNetworkConfigs( + name=sf_name + '-' + constants.GET_AMP_NETWORK_CONFIG, + requires=constants.LOADBALANCER_ID, + provides=constants.AMPHORAE_NETWORK_CONFIG)) + + update_amps_subflow = unordered_flow.Flow( + constants.AMP_UPDATE_FW_SUBFLOW) + + amp_0_subflow = linear_flow.Flow('amp-0-fw-update') + + amp_0_subflow.add(database_tasks.GetAmphoraFirewallRules( + name=sf_name + '-0-' + constants.GET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, constants.AMPHORAE_NETWORK_CONFIG), + provides=constants.AMPHORA_FIREWALL_RULES, + inject={constants.AMPHORA_INDEX: 0})) + + amp_0_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules( + name=sf_name + '-0-' + constants.SET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, constants.AMPHORA_FIREWALL_RULES), + inject={constants.AMPHORA_INDEX: 0})) + + update_amps_subflow.add(amp_0_subflow) + + if (flavor_dict[constants.LOADBALANCER_TOPOLOGY] == + constants.TOPOLOGY_ACTIVE_STANDBY): + + amp_1_subflow = linear_flow.Flow('amp-1-fw-update') + + amp_1_subflow.add(database_tasks.GetAmphoraFirewallRules( + name=sf_name + '-1-' + constants.GET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORAE_NETWORK_CONFIG), + provides=constants.AMPHORA_FIREWALL_RULES, + inject={constants.AMPHORA_INDEX: 1})) + + amp_1_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules( + name=sf_name + '-1-' + constants.SET_AMPHORA_FIREWALL_RULES, + requires=(constants.AMPHORAE, + constants.AMPHORA_FIREWALL_RULES), + inject={constants.AMPHORA_INDEX: 1})) + + update_amps_subflow.add(amp_1_subflow) + + fw_rules_subflow.add(update_amps_subflow) + + return fw_rules_subflow diff --git a/octavia/controller/worker/v2/flows/load_balancer_flows.py b/octavia/controller/worker/v2/flows/load_balancer_flows.py index 29984ed9c8..cc857f6daf 100644 --- a/octavia/controller/worker/v2/flows/load_balancer_flows.py +++ b/octavia/controller/worker/v2/flows/load_balancer_flows.py @@ -93,10 +93,12 @@ class LoadBalancerFlows(object): post_amp_prefix = constants.POST_LB_AMP_ASSOCIATION_SUBFLOW lb_create_flow.add( - self.get_post_lb_amp_association_flow(post_amp_prefix, topology)) + self.get_post_lb_amp_association_flow(post_amp_prefix, topology, + flavor_dict=flavor_dict)) if listeners: - lb_create_flow.add(*self._create_listeners_flow()) + lb_create_flow.add( + *self._create_listeners_flow(flavor_dict=flavor_dict)) lb_create_flow.add( database_tasks.MarkLBActiveInDB( @@ -177,6 +179,7 @@ class LoadBalancerFlows(object): def _get_amp_net_subflow(self, sf_name, flavor_dict=None): flows = [] + # If we have an SRIOV VIP, we need to setup a firewall in the amp if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): flows.append(network_tasks.CreateSRIOVBasePort( name=sf_name + '-' + constants.PLUG_VIP_AMPHORA, @@ -192,7 +195,25 @@ class LoadBalancerFlows(object): requires=(constants.LOADBALANCER, constants.AMPHORA, constants.PORT_DATA), provides=constants.AMP_DATA)) - # TODO(johnsom) nftables need to be handled here in the SG patch + flows.append(network_tasks.ApplyQosAmphora( + name=sf_name + '-' + constants.APPLY_QOS_AMP, + requires=(constants.LOADBALANCER, constants.AMP_DATA, + constants.UPDATE_DICT))) + flows.append(database_tasks.UpdateAmphoraVIPData( + name=sf_name + '-' + constants.UPDATE_AMPHORA_VIP_DATA, + requires=constants.AMP_DATA)) + flows.append(network_tasks.GetAmphoraNetworkConfigs( + name=sf_name + '-' + constants.GET_AMP_NETWORK_CONFIG, + requires=(constants.LOADBALANCER, constants.AMPHORA), + provides=constants.AMPHORA_NETWORK_CONFIG)) + # SR-IOV firewall rules are handled in AmphoraPostVIPPlug + # interface.py up + flows.append(amphora_driver_tasks.AmphoraPostVIPPlug( + name=sf_name + '-' + constants.AMP_POST_VIP_PLUG, + rebind={constants.AMPHORAE_NETWORK_CONFIG: + constants.AMPHORA_NETWORK_CONFIG}, + requires=(constants.LOADBALANCER, + constants.AMPHORAE_NETWORK_CONFIG))) else: flows.append(network_tasks.PlugVIPAmphora( name=sf_name + '-' + constants.PLUG_VIP_AMPHORA, @@ -219,7 +240,7 @@ class LoadBalancerFlows(object): constants.AMPHORAE_NETWORK_CONFIG))) return flows - def _create_listeners_flow(self): + def _create_listeners_flow(self, flavor_dict=None): flows = [] flows.append( database_tasks.ReloadLoadBalancer( @@ -252,11 +273,13 @@ class LoadBalancerFlows(object): ) ) flows.append( - self.listener_flows.get_create_all_listeners_flow() + self.listener_flows.get_create_all_listeners_flow( + flavor_dict=flavor_dict) ) return flows - def get_post_lb_amp_association_flow(self, prefix, topology): + def get_post_lb_amp_association_flow(self, prefix, topology, + flavor_dict=None): """Reload the loadbalancer and create networking subflows for created/allocated amphorae. @@ -274,14 +297,15 @@ class LoadBalancerFlows(object): post_create_LB_flow.add(database_tasks.GetAmphoraeFromLoadbalancer( requires=constants.LOADBALANCER_ID, provides=constants.AMPHORAE)) - vrrp_subflow = self.amp_flows.get_vrrp_subflow(prefix) + vrrp_subflow = self.amp_flows.get_vrrp_subflow( + prefix, flavor_dict=flavor_dict) post_create_LB_flow.add(vrrp_subflow) post_create_LB_flow.add(database_tasks.UpdateLoadbalancerInDB( requires=[constants.LOADBALANCER, constants.UPDATE_DICT])) return post_create_LB_flow - def _get_delete_listeners_flow(self, listeners): + def _get_delete_listeners_flow(self, listeners, flavor_dict=None): """Sets up an internal delete flow :param listeners: A list of listener dicts @@ -291,7 +315,7 @@ class LoadBalancerFlows(object): for listener in listeners: listeners_delete_flow.add( self.listener_flows.get_delete_listener_internal_flow( - listener)) + listener, flavor_dict=flavor_dict)) return listeners_delete_flow def get_delete_load_balancer_flow(self, lb): @@ -705,7 +729,7 @@ class LoadBalancerFlows(object): failover_LB_flow.add(self.amp_flows.get_vrrp_subflow( new_amp_role + '-' + constants.GET_VRRP_SUBFLOW, timeout_dict, create_vrrp_group=False, - get_amphorae_status=False)) + get_amphorae_status=False, flavor_dict=lb[constants.FLAVOR])) # #### End of standby #### diff --git a/octavia/controller/worker/v2/tasks/amphora_driver_tasks.py b/octavia/controller/worker/v2/tasks/amphora_driver_tasks.py index 62d6805162..ce820deb23 100644 --- a/octavia/controller/worker/v2/tasks/amphora_driver_tasks.py +++ b/octavia/controller/worker/v2/tasks/amphora_driver_tasks.py @@ -760,3 +760,26 @@ class AmphoraeGetConnectivityStatus(BaseAmphoraTask): amphorae_status[amphora_id][constants.UNREACHABLE] = False return amphorae_status + + +class SetAmphoraFirewallRules(BaseAmphoraTask): + """Task to push updated firewall ruls to an amphora.""" + + def execute(self, amphorae: List[dict], amphora_index: int, + amphora_firewall_rules: List[dict]): + + if (amphora_firewall_rules and + amphora_firewall_rules[0].get('non-sriov-vip', False)): + # Not an SRIOV VIP, so skip setting firewall rules. + # This is already logged in GetAmphoraFirewallRules. + return + + session = db_apis.get_session() + with session.begin(): + db_amp = self.amphora_repo.get( + session, id=amphorae[amphora_index][constants.ID]) + + self.amphora_driver.set_interface_rules( + db_amp, + amphorae[amphora_index][constants.VRRP_IP], + amphora_firewall_rules) diff --git a/octavia/controller/worker/v2/tasks/database_tasks.py b/octavia/controller/worker/v2/tasks/database_tasks.py index 720be4f4ff..e78ef41d7f 100644 --- a/octavia/controller/worker/v2/tasks/database_tasks.py +++ b/octavia/controller/worker/v2/tasks/database_tasks.py @@ -14,6 +14,7 @@ # from cryptography import fernet +from octavia_lib.common import constants as lib_consts from oslo_config import cfg from oslo_db import exception as odb_exceptions from oslo_log import log as logging @@ -27,6 +28,7 @@ from taskflow.types import failure from octavia.api.drivers import utils as provider_utils from octavia.common import constants from octavia.common import data_models +from octavia.common import exceptions from octavia.common.tls_utils import cert_parser from octavia.common import utils from octavia.controller.worker import task_utils as task_utilities @@ -3073,3 +3075,47 @@ class UpdatePoolMembersOperatingStatusInDB(BaseDatabaseTask): with db_apis.session().begin() as session: self.member_repo.update_pool_members( session, pool_id, operating_status=operating_status) + + +class GetAmphoraFirewallRules(BaseDatabaseTask): + """Task to build firewall rules for the amphora.""" + + def execute(self, amphorae, amphora_index, amphorae_network_config): + this_amp_id = amphorae[amphora_index][constants.ID] + amp_net_config = amphorae_network_config[this_amp_id] + + lb_dict = amp_net_config[constants.AMPHORA]['load_balancer'] + vip_dict = lb_dict[constants.VIP] + + if vip_dict[constants.VNIC_TYPE] != constants.VNIC_TYPE_DIRECT: + LOG.debug('Load balancer VIP port is not SR-IOV enabled. Skipping ' + 'firewall rules update.') + return [{'non-sriov-vip': True}] + + session = db_apis.get_session() + with session.begin(): + rules = self.listener_repo.get_port_protocol_cidr_for_lb( + session, + amp_net_config[constants.AMPHORA][constants.LOAD_BALANCER_ID]) + + # If we are act/stdby, inject the VRRP firewall rule(s) + if lb_dict[constants.TOPOLOGY] == constants.TOPOLOGY_ACTIVE_STANDBY: + for amp_cfg in lb_dict[constants.AMPHORAE]: + if (amp_cfg[constants.ID] != this_amp_id and + amp_cfg[constants.STATUS] == + lib_consts.AMPHORA_ALLOCATED): + vrrp_ip = amp_cfg[constants.VRRP_IP] + vrrp_ip_ver = utils.ip_version(vrrp_ip) + + if vrrp_ip_ver == 4: + vrrp_ip_cidr = f'{vrrp_ip}/32' + elif vrrp_ip_ver == 6: + vrrp_ip_cidr = f'{vrrp_ip}/128' + else: + raise exceptions.InvalidIPAddress(ip_addr=vrrp_ip) + + rules.append({constants.PROTOCOL: constants.VRRP, + constants.CIDR: vrrp_ip_cidr, + constants.PORT: 112}) + LOG.debug('Amphora %s SR-IOV firewall rules: %s', this_amp_id, rules) + return rules diff --git a/octavia/controller/worker/v2/tasks/shim_tasks.py b/octavia/controller/worker/v2/tasks/shim_tasks.py new file mode 100644 index 0000000000..b6b587fd71 --- /dev/null +++ b/octavia/controller/worker/v2/tasks/shim_tasks.py @@ -0,0 +1,28 @@ +# Copyright 2024 Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +from taskflow import task + +from octavia.common import constants + + +class AmphoraToAmphoraeWithVRRPIP(task.Task): + """A shim class to convert a single Amphora instance to a list.""" + + def execute(self, amphora: dict, base_port: dict): + # The VRRP_IP has not been stamped on the Amphora at this point in the + # flow, so inject it from our port create call in a previous task. + amphora[constants.VRRP_IP] = ( + base_port[constants.FIXED_IPS][0][constants.IP_ADDRESS]) + return [amphora] diff --git a/octavia/db/repositories.py b/octavia/db/repositories.py index a53b533b37..b0ed8a6ab0 100644 --- a/octavia/db/repositories.py +++ b/octavia/db/repositories.py @@ -39,6 +39,7 @@ from sqlalchemy import update from octavia.common import constants as consts from octavia.common import data_models from octavia.common import exceptions +from octavia.common import utils from octavia.common import validate from octavia.db import api as db_api from octavia.db import models @@ -1085,6 +1086,23 @@ class ListenerRepository(BaseRepository): update({self.model_class.provisioning_status: consts.ACTIVE}, synchronize_session='fetch')) + def get_port_protocol_cidr_for_lb(self, session, loadbalancer_id): + # readability variables + Listener = self.model_class + ListenerCidr = models.ListenerCidr + + stmt = (select(Listener.protocol, + ListenerCidr.cidr, + Listener.protocol_port.label(consts.PORT)) + .select_from(Listener) + .join(models.ListenerCidr, + Listener.id == ListenerCidr.listener_id, isouter=True) + .where(Listener.load_balancer_id == loadbalancer_id)) + rows = session.execute(stmt) + + return [utils.map_protocol_to_nftable_protocol(u._asdict()) for u + in rows.all()] + class ListenerStatisticsRepository(BaseRepository): model_class = models.ListenerStatistics diff --git a/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py b/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py index 95c78f3bb7..23f63b62e8 100644 --- a/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py +++ b/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py @@ -24,6 +24,7 @@ from oslo_config import fixture as oslo_fixture from oslo_serialization import jsonutils from oslo_utils.secretutils import md5 from oslo_utils import uuidutils +import webob from octavia.amphorae.backends.agent import api_server from octavia.amphorae.backends.agent.api_server import certificate_update @@ -3055,3 +3056,39 @@ class TestServerTestCase(base.TestCase): self.assertEqual(200, rv.status_code) self.assertEqual(expected_dict, jsonutils.loads(rv.data.decode('utf-8'))) + + @mock.patch('octavia.amphorae.backends.utils.nftable_utils.' + 'load_nftables_file') + @mock.patch('octavia.amphorae.backends.utils.nftable_utils.' + 'write_nftable_vip_rules_file') + @mock.patch('octavia.amphorae.backends.agent.api_server.amphora_info.' + 'AmphoraInfo.get_interface') + def test_set_interface_rules(self, mock_get_int, mock_write_rules, + mock_load_rules): + mock_get_int.side_effect = [ + webob.Response(status=400), + webob.Response(status=200, json={'interface': 'fake1'}), + webob.Response(status=200, json={'interface': 'fake1'})] + + # Test can't find interface + rv = self.ubuntu_app.put('/' + api_server.VERSION + + '/interface/192.0.2.10/rules', data='fake') + self.assertEqual(400, rv.status_code) + mock_write_rules.assert_not_called() + + # Test schema validation failure + rv = self.ubuntu_app.put('/' + api_server.VERSION + + '/interface/192.0.2.10/rules', data='fake') + self.assertEqual('400 Bad Request', rv.status) + + # Test successful path + rules_json = ('[{"protocol":"TCP","cidr":"192.0.2.0/24","port":8080},' + '{"protocol":"UDP","cidr":null,"port":80}]') + rv = self.ubuntu_app.put('/' + api_server.VERSION + + '/interface/192.0.2.10/rules', + data=rules_json, + content_type='application/json') + self.assertEqual('200 OK', rv.status) + mock_write_rules.assert_called_once_with('fake1', + jsonutils.loads(rules_json)) + mock_load_rules.assert_called_once() diff --git a/octavia/tests/functional/db/test_repositories.py b/octavia/tests/functional/db/test_repositories.py index 080b37cdb0..15d3063a59 100644 --- a/octavia/tests/functional/db/test_repositories.py +++ b/octavia/tests/functional/db/test_repositories.py @@ -2762,6 +2762,14 @@ class TestListenerRepositoryTest(BaseRepositoryTest): self.assertEqual(constants.PENDING_UPDATE, new_listener.provisioning_status) + def test_get_port_protocol_cidr_for_lb(self): + self.create_listener(self.FAKE_UUID_1, 80, + provisioning_status=constants.ACTIVE) + rules = self.listener_repo.get_port_protocol_cidr_for_lb( + self.session, self.FAKE_UUID_1) + self.assertEqual([{'protocol': 'TCP', 'cidr': None, 'port': 80}], + rules) + class ListenerStatisticsRepositoryTest(BaseRepositoryTest): diff --git a/octavia/tests/unit/amphorae/backends/utils/test_interface.py b/octavia/tests/unit/amphorae/backends/utils/test_interface.py index 831826e65b..8e230e7e4e 100644 --- a/octavia/tests/unit/amphorae/backends/utils/test_interface.py +++ b/octavia/tests/unit/amphorae/backends/utils/test_interface.py @@ -15,6 +15,7 @@ import errno import os import socket +import subprocess from unittest import mock import pyroute2 @@ -448,6 +449,8 @@ class TestInterface(base.TestCase): mock.call(["post-up", "eth1"]) ]) + @mock.patch('octavia.amphorae.backends.utils.network_namespace.' + 'NetworkNamespace') @mock.patch('octavia.amphorae.backends.utils.nftable_utils.' 'write_nftable_vip_rules_file') @mock.patch('pyroute2.IPRoute.rule') @@ -459,7 +462,7 @@ class TestInterface(base.TestCase): @mock.patch('subprocess.check_output') def test_up_sriov(self, mock_check_output, mock_link_lookup, mock_get_links, mock_link, mock_addr, mock_route, - mock_rule, mock_nftable): + mock_rule, mock_nftable, mock_netns): iface = interface_file.InterfaceFile( name="fake-eth1", if_type="vip", @@ -1441,3 +1444,56 @@ class TestInterface(base.TestCase): addr = controller._normalize_ip_network(None) self.assertIsNone(addr) + + @mock.patch('octavia.amphorae.backends.utils.nftable_utils.' + 'load_nftables_file') + @mock.patch('octavia.amphorae.backends.utils.nftable_utils.' + 'write_nftable_vip_rules_file') + @mock.patch('subprocess.check_output') + def test__setup_nftables_chain(self, mock_check_output, mock_write_rules, + mock_load_rules): + + controller = interface.InterfaceController() + + mock_check_output.side_effect = [ + mock.DEFAULT, mock.DEFAULT, + subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1), + mock.DEFAULT, + subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1)] + + interface_mock = mock.MagicMock() + interface_mock.name = 'fake2' + + # Test succeessful path + controller._setup_nftables_chain(interface_mock) + + mock_write_rules.assert_called_once_with('fake2', []) + mock_load_rules.assert_called_once_with() + mock_check_output.assert_has_calls([ + mock.call([consts.NFT_CMD, 'add', 'table', consts.NFT_FAMILY, + consts.NFT_VIP_TABLE], stderr=subprocess.STDOUT), + mock.call([consts.NFT_CMD, 'add', 'chain', consts.NFT_FAMILY, + consts.NFT_VIP_TABLE, consts.NFT_VIP_CHAIN, '{', + 'type', 'filter', 'hook', 'ingress', 'device', + 'fake2', 'priority', consts.NFT_SRIOV_PRIORITY, ';', + 'policy', 'drop', ';', '}'], stderr=subprocess.STDOUT)]) + + # Test first nft call fails + mock_write_rules.reset_mock() + mock_load_rules.reset_mock() + mock_check_output.reset_mock() + + self.assertRaises(subprocess.CalledProcessError, + controller._setup_nftables_chain, interface_mock) + mock_check_output.assert_called_once() + mock_write_rules.assert_not_called() + + # Test second nft call fails + mock_write_rules.reset_mock() + mock_load_rules.reset_mock() + mock_check_output.reset_mock() + + self.assertRaises(subprocess.CalledProcessError, + controller._setup_nftables_chain, interface_mock) + self.assertEqual(2, mock_check_output.call_count) + mock_write_rules.assert_not_called() diff --git a/octavia/tests/unit/amphorae/backends/utils/test_nftable_utils.py b/octavia/tests/unit/amphorae/backends/utils/test_nftable_utils.py index f4fdaecd6f..141a9cbe98 100644 --- a/octavia/tests/unit/amphorae/backends/utils/test_nftable_utils.py +++ b/octavia/tests/unit/amphorae/backends/utils/test_nftable_utils.py @@ -13,10 +13,15 @@ # under the License. import os import stat +import subprocess from unittest import mock +from octavia_lib.common import constants as lib_consts +from webob import exc + from octavia.amphorae.backends.utils import nftable_utils from octavia.common import constants as consts +from octavia.common import exceptions import octavia.tests.unit.base as base @@ -47,10 +52,17 @@ class TestNFTableUtils(base.TestCase): mock_isfile.return_value = True mock_open.return_value = 'fake-fd' + test_rule_1 = {consts.CIDR: None, + consts.PROTOCOL: lib_consts.PROTOCOL_TCP, + consts.PORT: 1234} + test_rule_2 = {consts.CIDR: '192.0.2.0/24', + consts.PROTOCOL: consts.VRRP, + consts.PORT: 4321} + mocked_open = mock.mock_open() with mock.patch.object(os, 'fdopen', mocked_open): nftable_utils.write_nftable_vip_rules_file( - 'fake-eth2', ['test rule 1', 'test rule 2']) + 'fake-eth2', [test_rule_1, test_rule_2]) mocked_open.assert_called_once_with('fake-fd', 'w') mock_open.assert_called_once_with( @@ -60,15 +72,23 @@ class TestNFTableUtils(base.TestCase): handle = mocked_open() handle.write.assert_has_calls([ - mock.call(f'flush chain {consts.NFT_FAMILY} ' - f'{consts.NFT_VIP_TABLE} {consts.NFT_VIP_CHAIN}\n'), + mock.call(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} ' + '{}\n'), + mock.call(f'delete table {consts.NFT_FAMILY} ' + f'{consts.NFT_VIP_TABLE}\n'), mock.call(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} ' '{\n'), mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'), mock.call(' type filter hook ingress device fake-eth2 ' f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'), - mock.call(' test rule 1\n'), - mock.call(' test rule 2\n'), + mock.call(' icmp type destination-unreachable accept\n'), + mock.call(' icmpv6 type { nd-neighbor-solicit, ' + 'nd-router-advert, nd-neighbor-advert, packet-too-big, ' + 'destination-unreachable } accept\n'), + mock.call(' udp sport 67 udp dport 68 accept\n'), + mock.call(' udp sport 547 udp dport 546 accept\n'), + mock.call(' tcp dport 1234 accept\n'), + mock.call(' ip saddr 192.0.2.0/24 ip protocol 112 accept\n'), mock.call(' }\n'), mock.call('}\n') ]) @@ -101,6 +121,74 @@ class TestNFTableUtils(base.TestCase): mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'), mock.call(' type filter hook ingress device fake-eth2 ' f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'), + mock.call(' icmp type destination-unreachable accept\n'), + mock.call(' icmpv6 type { nd-neighbor-solicit, ' + 'nd-router-advert, nd-neighbor-advert, packet-too-big, ' + 'destination-unreachable } accept\n'), + mock.call(' udp sport 67 udp dport 68 accept\n'), + mock.call(' udp sport 547 udp dport 546 accept\n'), mock.call(' }\n'), mock.call('}\n') ]) + + @mock.patch('octavia.common.utils.ip_version') + def test__build_rule_cmd(self, mock_ip_version): + + mock_ip_version.side_effect = [4, 6, 99] + + cmd = nftable_utils._build_rule_cmd({ + consts.CIDR: '192.0.2.0/24', + consts.PROTOCOL: lib_consts.PROTOCOL_SCTP, + consts.PORT: 1234}) + self.assertEqual('ip saddr 192.0.2.0/24 sctp dport 1234 accept', cmd) + + cmd = nftable_utils._build_rule_cmd({ + consts.CIDR: '2001:db8::/32', + consts.PROTOCOL: lib_consts.PROTOCOL_TCP, + consts.PORT: 1235}) + self.assertEqual('ip6 saddr 2001:db8::/32 tcp dport 1235 accept', cmd) + + self.assertRaises(exc.HTTPBadRequest, nftable_utils._build_rule_cmd, + {consts.CIDR: '192/32', + consts.PROTOCOL: lib_consts.PROTOCOL_TCP, + consts.PORT: 1237}) + + cmd = nftable_utils._build_rule_cmd({ + consts.CIDR: None, + consts.PROTOCOL: lib_consts.PROTOCOL_UDP, + consts.PORT: 1236}) + self.assertEqual('udp dport 1236 accept', cmd) + + cmd = nftable_utils._build_rule_cmd({ + consts.CIDR: None, + consts.PROTOCOL: consts.VRRP, + consts.PORT: 1237}) + self.assertEqual('ip protocol 112 accept', cmd) + + self.assertRaises(exc.HTTPBadRequest, nftable_utils._build_rule_cmd, + {consts.CIDR: None, + consts.PROTOCOL: 'bad-protocol', + consts.PORT: 1237}) + + @mock.patch('octavia.amphorae.backends.utils.network_namespace.' + 'NetworkNamespace') + @mock.patch('subprocess.check_output') + def test_load_nftables_file(self, mock_check_output, mock_netns): + + mock_netns.side_effect = [ + mock.DEFAULT, + subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1), + exceptions.AmphoraNetworkConfigException] + + nftable_utils.load_nftables_file() + + mock_netns.assert_called_once_with(consts.AMPHORA_NAMESPACE) + mock_check_output.assert_called_once_with([ + consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE], + stderr=subprocess.STDOUT) + + self.assertRaises(subprocess.CalledProcessError, + nftable_utils.load_nftables_file) + + self.assertRaises(exceptions.AmphoraNetworkConfigException, + nftable_utils.load_nftables_file) diff --git a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py index 6cdd1360c0..1498c4685b 100644 --- a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py +++ b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py @@ -13,7 +13,7 @@ # under the License. from unittest import mock -from octavia.amphorae.driver_exceptions.exceptions import AmpVersionUnsupported +from octavia.amphorae.driver_exceptions import exceptions as driver_except from octavia.amphorae.drivers.haproxy import exceptions as exc from octavia.amphorae.drivers.haproxy import rest_api_driver import octavia.tests.unit.base as base @@ -87,6 +87,28 @@ class TestHAProxyAmphoraDriver(base.TestCase): mock_amp = mock.MagicMock() mock_amp.api_version = "0.5" - self.assertRaises(AmpVersionUnsupported, + self.assertRaises(driver_except.AmpVersionUnsupported, self.driver._populate_amphora_api_version, mock_amp) + + @mock.patch('octavia.amphorae.drivers.haproxy.rest_api_driver.' + 'HaproxyAmphoraLoadBalancerDriver.' + '_populate_amphora_api_version') + def test_set_interface_rules(self, mock_api_version): + + IP_ADDRESS = '203.0.113.44' + amphora_mock = mock.MagicMock() + amphora_mock.api_version = '0' + client_mock = mock.MagicMock() + client_mock.set_interface_rules.side_effect = [mock.DEFAULT, + exc.NotFound] + self.driver.clients['0'] = client_mock + + self.driver.set_interface_rules(amphora_mock, IP_ADDRESS, 'fake_rules') + mock_api_version.assert_called_once_with(amphora_mock) + client_mock.set_interface_rules.assert_called_once_with( + amphora_mock, IP_ADDRESS, 'fake_rules') + + self.assertRaises(driver_except.AmpDriverNotImplementedError, + self.driver.set_interface_rules, amphora_mock, + IP_ADDRESS, 'fake_rules') diff --git a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver_1_0.py b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver_1_0.py index 6613b6b837..13ab0f4bb8 100644 --- a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver_1_0.py +++ b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver_1_0.py @@ -1548,3 +1548,13 @@ class TestAmphoraAPIClientTest(base.TestCase): self.assertRaises(exc.InternalServerError, self.driver.update_agent_config, self.amp, "some_file") + + @requests_mock.mock() + def test_set_interface_rules(self, m): + ip_addr = '192.0.2.44' + rules = ('[{"protocol":"TCP","cidr":"192.0.2.0/24","port":8080},' + '{"protocol":"UDP","cidr":null,"port":80}]') + m.put(f'{self.base_url_ver}/interface/{ip_addr}/rules') + + self.driver.set_interface_rules(self.amp, ip_addr, rules) + self.assertTrue(m.called) diff --git a/octavia/tests/unit/common/test_utils.py b/octavia/tests/unit/common/test_utils.py index 45a9da243e..247ab336ad 100644 --- a/octavia/tests/unit/common/test_utils.py +++ b/octavia/tests/unit/common/test_utils.py @@ -13,6 +13,7 @@ # under the License. from unittest import mock +from octavia_lib.common import constants as lib_consts from oslo_utils import uuidutils from octavia.common import constants @@ -139,3 +140,41 @@ class TestConfig(base.TestCase): expected_sg_name = constants.VIP_SECURITY_GROUP_PREFIX + FAKE_LB_ID self.assertEqual(expected_sg_name, utils.get_vip_security_group_name(FAKE_LB_ID)) + + def test_map_protocol_to_nftable_protocol(self): + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_TCP}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_HTTP}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_HTTPS}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_TERMINATED_HTTPS}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_PROXY}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_PROXYV2}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_UDP}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_UDP}, result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_SCTP}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_SCTP}, + result) + + result = utils.map_protocol_to_nftable_protocol( + {constants.PROTOCOL: lib_consts.PROTOCOL_PROMETHEUS}) + self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result) diff --git a/octavia/tests/unit/controller/worker/v2/flows/test_listener_flows.py b/octavia/tests/unit/controller/worker/v2/flows/test_listener_flows.py index 45ae52f378..68b74c3737 100644 --- a/octavia/tests/unit/controller/worker/v2/flows/test_listener_flows.py +++ b/octavia/tests/unit/controller/worker/v2/flows/test_listener_flows.py @@ -34,19 +34,31 @@ class TestListenerFlows(base.TestCase): def test_get_create_listener_flow(self, mock_get_net_driver): - listener_flow = self.ListenerFlow.get_create_listener_flow() + flavor_dict = { + constants.SRIOV_VIP: True, + constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE} + listener_flow = self.ListenerFlow.get_create_listener_flow( + flavor_dict=flavor_dict) self.assertIsInstance(listener_flow, flow.Flow) self.assertIn(constants.LISTENERS, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) + self.assertIn(constants.AMPHORAE_NETWORK_CONFIG, + listener_flow.provides) + self.assertIn(constants.AMPHORAE, listener_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides) + self.assertEqual(2, len(listener_flow.requires)) - self.assertEqual(0, len(listener_flow.provides)) + self.assertEqual(3, len(listener_flow.provides)) def test_get_delete_listener_flow(self, mock_get_net_driver): - - listener_flow = self.ListenerFlow.get_delete_listener_flow() + flavor_dict = { + constants.SRIOV_VIP: True, + constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE} + listener_flow = self.ListenerFlow.get_delete_listener_flow( + flavor_dict=flavor_dict) self.assertIsInstance(listener_flow, flow.Flow) @@ -54,25 +66,42 @@ class TestListenerFlows(base.TestCase): self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.PROJECT_ID, listener_flow.requires) + self.assertIn(constants.AMPHORAE_NETWORK_CONFIG, + listener_flow.provides) + self.assertIn(constants.AMPHORAE, listener_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides) + self.assertEqual(3, len(listener_flow.requires)) - self.assertEqual(0, len(listener_flow.provides)) + self.assertEqual(3, len(listener_flow.provides)) def test_get_delete_listener_internal_flow(self, mock_get_net_driver): + flavor_dict = { + constants.SRIOV_VIP: True, + constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE} fake_listener = {constants.LISTENER_ID: uuidutils.generate_uuid()} listener_flow = self.ListenerFlow.get_delete_listener_internal_flow( - fake_listener) + fake_listener, flavor_dict=flavor_dict) self.assertIsInstance(listener_flow, flow.Flow) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.PROJECT_ID, listener_flow.requires) + self.assertIn(constants.AMPHORAE_NETWORK_CONFIG, + listener_flow.provides) + self.assertIn(constants.AMPHORAE, listener_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides) + self.assertEqual(2, len(listener_flow.requires)) - self.assertEqual(0, len(listener_flow.provides)) + self.assertEqual(3, len(listener_flow.provides)) def test_get_update_listener_flow(self, mock_get_net_driver): + flavor_dict = { + constants.SRIOV_VIP: True, + constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE} - listener_flow = self.ListenerFlow.get_update_listener_flow() + listener_flow = self.ListenerFlow.get_update_listener_flow( + flavor_dict=flavor_dict) self.assertIsInstance(listener_flow, flow.Flow) @@ -81,14 +110,28 @@ class TestListenerFlows(base.TestCase): self.assertIn(constants.LISTENERS, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) + self.assertIn(constants.AMPHORAE_NETWORK_CONFIG, + listener_flow.provides) + self.assertIn(constants.AMPHORAE, listener_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides) + self.assertEqual(4, len(listener_flow.requires)) - self.assertEqual(0, len(listener_flow.provides)) + self.assertEqual(3, len(listener_flow.provides)) def test_get_create_all_listeners_flow(self, mock_get_net_driver): - listeners_flow = self.ListenerFlow.get_create_all_listeners_flow() + flavor_dict = { + constants.SRIOV_VIP: True, + constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_ACTIVE_STANDBY} + listeners_flow = self.ListenerFlow.get_create_all_listeners_flow( + flavor_dict=flavor_dict) self.assertIsInstance(listeners_flow, flow.Flow) self.assertIn(constants.LOADBALANCER, listeners_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listeners_flow.requires) self.assertIn(constants.LOADBALANCER, listeners_flow.provides) + self.assertIn(constants.AMPHORAE_NETWORK_CONFIG, + listeners_flow.provides) + self.assertIn(constants.AMPHORAE, listeners_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, + listeners_flow.provides) self.assertEqual(2, len(listeners_flow.requires)) - self.assertEqual(2, len(listeners_flow.provides)) + self.assertEqual(5, len(listeners_flow.provides)) diff --git a/octavia/tests/unit/controller/worker/v2/flows/test_load_balancer_flows.py b/octavia/tests/unit/controller/worker/v2/flows/test_load_balancer_flows.py index 4b4c603174..7cd5678ee0 100644 --- a/octavia/tests/unit/controller/worker/v2/flows/test_load_balancer_flows.py +++ b/octavia/tests/unit/controller/worker/v2/flows/test_load_balancer_flows.py @@ -359,10 +359,13 @@ class TestLoadBalancerFlows(base.TestCase): self.assertIn(constants.VIP, failover_flow.provides) self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides) self.assertIn(constants.VIP_SG_ID, failover_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, failover_flow.provides) + self.assertIn(constants.SUBNET, failover_flow.provides) + self.assertIn(constants.NEW_AMPHORAE, failover_flow.provides) self.assertEqual(6, len(failover_flow.requires), failover_flow.requires) - self.assertEqual(14, len(failover_flow.provides), + self.assertEqual(16, len(failover_flow.provides), failover_flow.provides) @mock.patch('octavia.common.rpc.NOTIFIER', @@ -435,10 +438,14 @@ class TestLoadBalancerFlows(base.TestCase): self.assertIn(constants.VIP, failover_flow.provides) self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides) self.assertIn(constants.VIP_SG_ID, failover_flow.provides) + self.assertIn(constants.SUBNET, failover_flow.provides) + self.assertIn(constants.AMPHORA_FIREWALL_RULES, failover_flow.provides) + self.assertIn(constants.SUBNET, failover_flow.provides) + self.assertIn(constants.NEW_AMPHORAE, failover_flow.provides) self.assertEqual(6, len(failover_flow.requires), failover_flow.requires) - self.assertEqual(14, len(failover_flow.provides), + self.assertEqual(16, len(failover_flow.provides), failover_flow.provides) @mock.patch('octavia.common.rpc.NOTIFIER', diff --git a/octavia/tests/unit/controller/worker/v2/tasks/test_amphora_driver_tasks.py b/octavia/tests/unit/controller/worker/v2/tasks/test_amphora_driver_tasks.py index 21c7e1a941..147ed643ac 100644 --- a/octavia/tests/unit/controller/worker/v2/tasks/test_amphora_driver_tasks.py +++ b/octavia/tests/unit/controller/worker/v2/tasks/test_amphora_driver_tasks.py @@ -1246,3 +1246,31 @@ class TestAmphoraDriverTasks(base.TestCase): ret[amphora1_mock[constants.ID]][constants.UNREACHABLE]) self.assertTrue( ret[amphora2_mock[constants.ID]][constants.UNREACHABLE]) + + def test_set_amphora_firewall_rules(self, + mock_driver, + mock_generate_uuid, + mock_log, + mock_get_session, + mock_listener_repo_get, + mock_listener_repo_update, + mock_amphora_repo_get, + mock_amphora_repo_update): + amphora = {constants.ID: AMP_ID, constants.VRRP_IP: '192.0.2.88'} + mock_amphora_repo_get.return_value = _db_amphora_mock + + set_amp_fw_rules = amphora_driver_tasks.SetAmphoraFirewallRules() + + # Test non-SRIOV VIP path + set_amp_fw_rules.execute([amphora], 0, [{'non-sriov-vip': True}]) + + mock_get_session.assert_not_called() + mock_driver.set_interface_rules.assert_not_called() + + # Test SRIOV VIP path + set_amp_fw_rules.execute([amphora], 0, [{'fake_rule': True}]) + + mock_amphora_repo_get.assert_called_once_with(_session_mock, id=AMP_ID) + + mock_driver.set_interface_rules.assert_called_once_with( + _db_amphora_mock, '192.0.2.88', [{'fake_rule': True}]) diff --git a/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks.py b/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks.py index 48000554e4..8b471b8bcb 100644 --- a/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks.py +++ b/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. # +import copy import random from unittest import mock @@ -24,6 +25,7 @@ from taskflow.types import failure from octavia.api.drivers import utils as provider_utils from octavia.common import constants from octavia.common import data_models +from octavia.common import exceptions from octavia.common import utils from octavia.controller.worker.v2.tasks import database_tasks from octavia.db import repositories as repo @@ -31,6 +33,7 @@ import octavia.tests.unit.base as base AMP_ID = uuidutils.generate_uuid() +AMP2_ID = uuidutils.generate_uuid() COMPUTE_ID = uuidutils.generate_uuid() LB_ID = uuidutils.generate_uuid() SERVER_GROUP_ID = uuidutils.generate_uuid() @@ -2987,3 +2990,100 @@ class TestDatabaseTasks(base.TestCase): mock_session, POOL_ID, operating_status=constants.ONLINE) + + @mock.patch('octavia.common.utils.ip_version') + @mock.patch('octavia.db.api.get_session') + @mock.patch('octavia.db.repositories.ListenerRepository.' + 'get_port_protocol_cidr_for_lb') + def test_get_amphora_firewall_rules(self, + mock_get_port_for_lb, + mock_db_get_session, + mock_ip_version, + mock_generate_uuid, + mock_LOG, + mock_get_session, + mock_loadbalancer_repo_update, + mock_listener_repo_update, + mock_amphora_repo_update, + mock_amphora_repo_delete): + + amphora_dict = {constants.ID: AMP_ID} + rules = [{'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80}, + {'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80}] + vrrp_rules = [ + {'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80}, + {'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80}, + {'cidr': '203.0.113.5/32', 'port': 112, 'protocol': 'vrrp'}] + mock_get_port_for_lb.side_effect = [ + copy.deepcopy(rules), copy.deepcopy(rules), copy.deepcopy(rules), + copy.deepcopy(rules)] + mock_ip_version.side_effect = [4, 6, 55] + + get_amp_fw_rules = database_tasks.GetAmphoraFirewallRules() + + # Test non-SRIOV VIP + amphora_net_cfg_dict = { + AMP_ID: {constants.AMPHORA: { + 'load_balancer': {constants.VIP: { + constants.VNIC_TYPE: constants.VNIC_TYPE_NORMAL}}}}} + result = get_amp_fw_rules.execute([amphora_dict], 0, + amphora_net_cfg_dict) + self.assertEqual([{'non-sriov-vip': True}], result) + + # Test SRIOV VIP - Single + amphora_net_cfg_dict = { + AMP_ID: {constants.AMPHORA: { + 'load_balancer': {constants.VIP: { + constants.VNIC_TYPE: constants.VNIC_TYPE_DIRECT}, + constants.TOPOLOGY: constants.TOPOLOGY_SINGLE}, + constants.LOAD_BALANCER_ID: LB_ID}}} + result = get_amp_fw_rules.execute([amphora_dict], 0, + amphora_net_cfg_dict) + mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(), + LB_ID) + self.assertEqual(rules, result) + + mock_get_port_for_lb.reset_mock() + + # Test SRIOV VIP - Active/Standby + amphora_net_cfg_dict = { + AMP_ID: {constants.AMPHORA: { + 'load_balancer': {constants.VIP: { + constants.VNIC_TYPE: constants.VNIC_TYPE_DIRECT}, + constants.TOPOLOGY: constants.TOPOLOGY_ACTIVE_STANDBY, + constants.AMPHORAE: [{ + constants.ID: AMP_ID, + constants.STATUS: constants.AMPHORA_ALLOCATED}, + {constants.ID: AMP2_ID, + constants.STATUS: constants.AMPHORA_ALLOCATED, + constants.VRRP_IP: '203.0.113.5'}]}, + constants.LOAD_BALANCER_ID: LB_ID}}} + + # IPv4 path + mock_get_port_for_lb.reset_mock() + vrrp_rules = [ + {'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80}, + {'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80}, + {'cidr': '203.0.113.5/32', 'port': 112, 'protocol': 'vrrp'}] + result = get_amp_fw_rules.execute([amphora_dict], 0, + amphora_net_cfg_dict) + mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(), + LB_ID) + self.assertEqual(vrrp_rules, result) + + # IPv6 path + mock_get_port_for_lb.reset_mock() + vrrp_rules = [ + {'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80}, + {'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80}, + {'cidr': '203.0.113.5/128', 'port': 112, 'protocol': 'vrrp'}] + result = get_amp_fw_rules.execute([amphora_dict], 0, + amphora_net_cfg_dict) + mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(), + LB_ID) + self.assertEqual(vrrp_rules, result) + + # Bogus IP version path + self.assertRaises(exceptions.InvalidIPAddress, + get_amp_fw_rules.execute, [amphora_dict], 0, + amphora_net_cfg_dict) diff --git a/octavia/tests/unit/controller/worker/v2/tasks/test_shim_tasks.py b/octavia/tests/unit/controller/worker/v2/tasks/test_shim_tasks.py new file mode 100644 index 0000000000..58c446f83d --- /dev/null +++ b/octavia/tests/unit/controller/worker/v2/tasks/test_shim_tasks.py @@ -0,0 +1,33 @@ +# Copyright 2024 Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +from octavia.common import constants +from octavia.controller.worker.v2.tasks import shim_tasks +import octavia.tests.unit.base as base + + +class TestShimTasks(base.TestCase): + + def test_amphora_to_amphorae_with_vrrp_ip(self): + + amp_to_amps = shim_tasks.AmphoraToAmphoraeWithVRRPIP() + + base_port = {constants.FIXED_IPS: + [{constants.IP_ADDRESS: '192.0.2.43'}]} + amphora = {constants.ID: '123456'} + expected_amphora = [{constants.ID: '123456', + constants.VRRP_IP: '192.0.2.43'}] + + self.assertEqual(expected_amphora, + amp_to_amps.execute(amphora, base_port)) diff --git a/octavia/tests/unit/controller/worker/v2/test_controller_worker.py b/octavia/tests/unit/controller/worker/v2/test_controller_worker.py index 345c98cbfb..0b6c6d83b1 100644 --- a/octavia/tests/unit/controller/worker/v2/test_controller_worker.py +++ b/octavia/tests/unit/controller/worker/v2/test_controller_worker.py @@ -70,7 +70,7 @@ _db_load_balancer_mock = mock.MagicMock() _load_balancer_mock = { constants.LOADBALANCER_ID: LB_ID, constants.TOPOLOGY: constants.TOPOLOGY_SINGLE, - constants.FLAVOR_ID: None, + constants.FLAVOR_ID: 1, constants.AVAILABILITY_ZONE: None, constants.SERVER_GROUP_ID: None } @@ -133,7 +133,7 @@ class TestControllerWorker(base.TestCase): _db_load_balancer_mock.amphorae = _db_amphora_mock _db_load_balancer_mock.vip = _vip_mock _db_load_balancer_mock.id = LB_ID - _db_load_balancer_mock.flavor_id = None + _db_load_balancer_mock.flavor_id = 1 _db_load_balancer_mock.availability_zone = None _db_load_balancer_mock.server_group_id = None _db_load_balancer_mock.project_id = PROJECT_ID @@ -331,7 +331,10 @@ class TestControllerWorker(base.TestCase): cw.update_health_monitor(_health_mon_mock, HEALTH_UPDATE_DICT) + @mock.patch('octavia.db.repositories.FlavorRepository.' + 'get_flavor_metadata_dict', return_value={}) def test_create_listener(self, + mock_get_flavor_dict, mock_api_get_session, mock_dyn_log_listener, mock_taskflow_load, @@ -355,42 +358,19 @@ class TestControllerWorker(base.TestCase): provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( _db_load_balancer_mock).to_dict(recurse=True) + flavor_dict = {constants.LOADBALANCER_TOPOLOGY: + constants.TOPOLOGY_SINGLE} (cw.services_controller.run_poster. assert_called_once_with( - flow_utils.get_create_listener_flow, store={ - constants.LOADBALANCER: provider_lb, - constants.LOADBALANCER_ID: LB_ID, - constants.LISTENERS: [listener_dict]})) + flow_utils.get_create_listener_flow, flavor_dict=flavor_dict, + store={constants.LOADBALANCER: provider_lb, + constants.LOADBALANCER_ID: LB_ID, + constants.LISTENERS: [listener_dict]})) + @mock.patch('octavia.db.repositories.FlavorRepository.' + 'get_flavor_metadata_dict', return_value={}) def test_delete_listener(self, - mock_api_get_session, - mock_dyn_log_listener, - mock_taskflow_load, - mock_pool_repo_get, - mock_member_repo_get, - mock_l7rule_repo_get, - mock_l7policy_repo_get, - mock_listener_repo_get, - mock_lb_repo_get, - mock_health_mon_repo_get, - mock_amp_repo_get): - - _flow_mock.reset_mock() - - listener_dict = {constants.LISTENER_ID: LISTENER_ID, - constants.LOADBALANCER_ID: LB_ID, - constants.PROJECT_ID: PROJECT_ID} - cw = controller_worker.ControllerWorker() - cw.delete_listener(listener_dict) - - (cw.services_controller.run_poster. - assert_called_once_with( - flow_utils.get_delete_listener_flow, - store={constants.LISTENER: self.ref_listener_dict, - constants.LOADBALANCER_ID: LB_ID, - constants.PROJECT_ID: PROJECT_ID})) - - def test_update_listener(self, + mock_get_flavor_dict, mock_api_get_session, mock_dyn_log_listener, mock_taskflow_load, @@ -406,6 +386,48 @@ class TestControllerWorker(base.TestCase): load_balancer_mock = mock.MagicMock() load_balancer_mock.provisioning_status = constants.PENDING_UPDATE load_balancer_mock.id = LB_ID + load_balancer_mock.flavor_id = 1 + load_balancer_mock.topology = constants.TOPOLOGY_SINGLE + mock_lb_repo_get.return_value = load_balancer_mock + + _flow_mock.reset_mock() + + listener_dict = {constants.LISTENER_ID: LISTENER_ID, + constants.LOADBALANCER_ID: LB_ID, + constants.PROJECT_ID: PROJECT_ID} + cw = controller_worker.ControllerWorker() + cw.delete_listener(listener_dict) + + flavor_dict = {constants.LOADBALANCER_TOPOLOGY: + constants.TOPOLOGY_SINGLE} + (cw.services_controller.run_poster. + assert_called_once_with( + flow_utils.get_delete_listener_flow, flavor_dict=flavor_dict, + store={constants.LISTENER: self.ref_listener_dict, + constants.LOADBALANCER_ID: LB_ID, + constants.PROJECT_ID: PROJECT_ID})) + + @mock.patch('octavia.db.repositories.FlavorRepository.' + 'get_flavor_metadata_dict', return_value={}) + def test_update_listener(self, + mock_get_flavor_dict, + mock_api_get_session, + mock_dyn_log_listener, + mock_taskflow_load, + mock_pool_repo_get, + mock_member_repo_get, + mock_l7rule_repo_get, + mock_l7policy_repo_get, + mock_listener_repo_get, + mock_lb_repo_get, + mock_health_mon_repo_get, + mock_amp_repo_get): + + load_balancer_mock = mock.MagicMock() + load_balancer_mock.provisioning_status = constants.PENDING_UPDATE + load_balancer_mock.id = LB_ID + load_balancer_mock.flavor_id = None + load_balancer_mock.topology = constants.TOPOLOGY_SINGLE mock_lb_repo_get.return_value = load_balancer_mock _flow_mock.reset_mock() @@ -416,8 +438,11 @@ class TestControllerWorker(base.TestCase): cw = controller_worker.ControllerWorker() cw.update_listener(listener_dict, LISTENER_UPDATE_DICT) + flavor_dict = {constants.LOADBALANCER_TOPOLOGY: + constants.TOPOLOGY_SINGLE} (cw.services_controller.run_poster. assert_called_once_with(flow_utils.get_update_listener_flow, + flavor_dict=flavor_dict, store={constants.LISTENER: listener_dict, constants.UPDATE_DICT: LISTENER_UPDATE_DICT, @@ -425,10 +450,13 @@ class TestControllerWorker(base.TestCase): constants.LISTENERS: [listener_dict]})) + @mock.patch('octavia.db.repositories.FlavorRepository.' + 'get_flavor_metadata_dict', return_value={}) @mock.patch("octavia.controller.worker.v2.controller_worker." "ControllerWorker._get_db_obj_until_pending_update") def test_update_listener_timeout(self, mock__get_db_obj_until_pending, + mock_get_flavor_dict, mock_api_get_session, mock_dyn_log_listener, mock_taskflow_load, @@ -443,6 +471,7 @@ class TestControllerWorker(base.TestCase): load_balancer_mock = mock.MagicMock() load_balancer_mock.provisioning_status = constants.PENDING_UPDATE load_balancer_mock.id = LB_ID + load_balancer_mock.flavor_id = 1 _flow_mock.reset_mock() _listener_mock.provisioning_status = constants.PENDING_UPDATE last_attempt_mock = mock.MagicMock() @@ -2095,10 +2124,13 @@ class TestControllerWorker(base.TestCase): cw._get_amphorae_for_failover, load_balancer_mock) + @mock.patch('octavia.db.repositories.FlavorRepository.' + 'get_flavor_metadata_dict') @mock.patch('octavia.controller.worker.v2.controller_worker.' 'ControllerWorker._get_amphorae_for_failover') def test_failover_loadbalancer_single(self, mock_get_amps_for_failover, + mock_get_flavor_dict, mock_api_get_session, mock_dyn_log_listener, mock_taskflow_load, @@ -2113,6 +2145,7 @@ class TestControllerWorker(base.TestCase): _flow_mock.reset_mock() mock_lb_repo_get.return_value = _db_load_balancer_mock mock_get_amps_for_failover.return_value = [_amphora_mock] + mock_get_flavor_dict.return_value = {} provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( _db_load_balancer_mock).to_dict() diff --git a/releasenotes/notes/Add-support-for-SR-IOV-VIPs-862858ec61e9955b.yaml b/releasenotes/notes/Add-support-for-SR-IOV-VIPs-862858ec61e9955b.yaml new file mode 100644 index 0000000000..7a0c78c889 --- /dev/null +++ b/releasenotes/notes/Add-support-for-SR-IOV-VIPs-862858ec61e9955b.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + Octavia Amphora based load balancers now support using SR-IOV virtual + functions (VF) on the VIP port(s) of the load balancer. This is enabled + by using an Octavia Flavor that includes the 'sriov_vip': True setting. +upgrade: + - | + You must update the amphora image to support the SR-IOV VIP feature.