Add typing information to octavia.amphorae package

... and fix issues reported by mypy.

Apart from the benefit of finding possible bugs using mypy,
type annotations also have other benefits like for instance:

- IDEs like PyCharm and VS Code can use annotations to provide
  better
  code completion, highlight errors and to make coding more efficient
- It serves as documentation of the code and helps to understand
  code
  better. The annotation syntax is very easy to read.
- Sphinx generated module documentation can use this information
  (without the need for additional type info in comments/docstrings)

The syntax used here should be supported by Python 3.6, so it should
not create issues when backporting patches in the future.

Partial-Bug: #2017974
Change-Id: I5b0f084785f477c855218c54a21caecdecefb62f
This commit is contained in:
Tom Weininger 2024-01-10 09:30:49 +01:00
parent bc259c0bf0
commit 233554a8ee
13 changed files with 63 additions and 34 deletions

View File

@ -133,14 +133,14 @@ class AmphoraInfo:
return extend_data
def _get_meminfo(self):
re_parser = re.compile(r'^(?P<key>\S*):\s*(?P<value>\d*)\s*kB')
re_parser = re.compile(r'^(\S*):\s*(\d*)\s*kB')
result = {}
with open('/proc/meminfo', encoding='utf-8') as meminfo:
for line in meminfo:
match = re_parser.match(line)
if not match:
continue # skip lines that don't parse
key, value = match.groups(['key', 'value'])
key, value = match.groups()
result[key] = int(value)
return result

View File

@ -33,8 +33,12 @@ def get_haproxy_versions():
version_re = re.search(r'.*version (.+?)\.(.+?)(\.|-dev).*',
version.decode('utf-8'))
major_version = int(version_re.group(1))
minor_version = int(version_re.group(2))
if version_re:
major_version = int(version_re.group(1))
minor_version = int(version_re.group(2))
else:
LOG.warning("Could not read haproxy version")
major_version, minor_version = 0, 0
return major_version, minor_version

View File

@ -14,6 +14,7 @@
import abc
import typing as tp
from oslo_config import cfg
@ -25,7 +26,7 @@ class LvsListenerApiServerBase(metaclass=abc.ABCMeta):
"""
_SUBSCRIBED_AMP_COMPILE = []
_SUBSCRIBED_AMP_COMPILE: tp.List[str] = []
def get_subscribed_amp_compile_info(self):
return self._SUBSCRIBED_AMP_COMPILE

View File

@ -297,12 +297,12 @@ def parse_haproxy_file(lb_id):
with open(config_path(lb_id), encoding='utf-8') as file:
cfg = file.read()
listeners = {}
listeners: tp.Dict[str, dict] = {}
m = FRONTEND_BACKEND_PATTERN.split(cfg)
sections = FRONTEND_BACKEND_PATTERN.split(cfg)
last_token = None
last_id = None
for section in m:
for section in sections:
if last_token is None:
# We aren't in a section yet, see if this line starts one
if section == 'frontend':

View File

@ -65,7 +65,8 @@ def get_counters_file():
os.open(stats_file_path, flags, mode), 'r+')
except OSError:
LOG.info("Failed to open `%s`, ignoring...", stats_file_path)
COUNTERS_FILE.seek(0)
if COUNTERS_FILE:
COUNTERS_FILE.seek(0)
return COUNTERS_FILE

View File

@ -14,6 +14,7 @@
import csv
import socket
import typing as tp
from oslo_log import log as logging
@ -119,7 +120,7 @@ class HAProxyQuery:
results = self.show_stat(object_type=6) # servers + pool
final_results = {}
final_results: tp.Dict[str, dict] = {}
for line in results:
# pxname: pool, svname: server_name, status: status

View File

@ -14,6 +14,7 @@ import ipaddress
import os
import re
import subprocess
import typing as tp
from octavia_lib.common import constants as lib_consts
from oslo_log import log as logging
@ -45,17 +46,14 @@ def read_kernel_file(ns_name, file_path):
cmd = ("ip netns exec {ns} cat {lvs_stat_path}".format(
ns=ns_name, lvs_stat_path=file_path))
try:
output = subprocess.check_output(cmd.split(),
stderr=subprocess.STDOUT)
output = subprocess.check_output(
cmd.split(), stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
LOG.error("Failed to get kernel lvs status in ns %(ns_name)s "
"%(kernel_lvs_path)s: %(err)s %(out)s",
{'ns_name': ns_name, 'kernel_lvs_path': file_path,
'err': e, 'out': e.output})
raise e
# py3 treat the output as bytes type.
if isinstance(output, bytes):
output = output.decode('utf-8')
return output
@ -75,7 +73,7 @@ def get_listener_realserver_mapping(ns_name, listener_ip_ports,
ip_obj = ipaddress.ip_address(listener_ip.strip('[]'))
output = read_kernel_file(ns_name, KERNEL_LVS_PATH).split('\n')
if ip_obj.version == 4:
ip_to_hex_format = "%.8X" % ip_obj._ip
ip_to_hex_format = "%.8X" % ip_obj._ip # type: ignore
else:
ip_to_hex_format = r'\[' + ip_obj.exploded + r'\]'
port_hex_format = "%.4X" % int(listener_port)
@ -145,7 +143,7 @@ def get_lvs_listener_resource_ipports_nsname(listener_id):
# {'id': member-id-2,
# 'ipport': ipport}],
# 'HealthMonitor': {'id': healthmonitor-id}}
resource_ipport_mapping = {}
resource_ipport_mapping: tp.Dict[str, tp.Any] = {}
with open(util.keepalived_lvs_cfg_path(listener_id),
encoding='utf-8') as f:
cfg = f.read()
@ -190,7 +188,7 @@ def get_lvs_listener_resource_ipports_nsname(listener_id):
resource_type_ids = CONFIG_COMMENT_REGEX.findall(cfg)
for resource_type, resource_id in resource_type_ids:
value = {'id': resource_id}
value: tp.Union[dict, list] = {'id': resource_id}
if resource_type == 'Member':
resource_type = '%ss' % resource_type
if resource_type not in resource_ipport_mapping:
@ -320,10 +318,9 @@ def get_ipvsadm_info(ns_name, is_stats_cmd=False):
# use --exact to ensure output is integer only
if is_stats_cmd:
cmd_list += ['--stats', '--exact']
output = subprocess.check_output(cmd_list, stderr=subprocess.STDOUT)
if isinstance(output, bytes):
output = output.decode('utf-8')
output = output.split('\n')
output_str = subprocess.check_output(
cmd_list, stderr=subprocess.STDOUT, text=True)
output = output_str.split('\n')
fields = []
# mapping = {'listeneripport': {'Linstener': vs_values,
# 'members': [rs_values1, rs_values2]}}

View File

@ -459,8 +459,8 @@ class HaproxyAmphoraLoadBalancerDriver(
if certs:
# Build and upload the crt-list file for haproxy
crt_list = "\n".join(cert_filename_list)
crt_list = f'{crt_list}\n'.encode()
crt_list_b = "\n".join(cert_filename_list)
crt_list = f'{crt_list_b}\n'.encode()
md5sum = md5(crt_list,
usedforsecurity=False).hexdigest() # nosec
name = f'{listener.id}.pem'
@ -619,6 +619,10 @@ class HaproxyAmphoraLoadBalancerDriver(
# Check a custom hostname
class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.uuid: str = ""
def cert_verify(self, conn, url, verify, cert):
conn.assert_hostname = self.uuid
return super().cert_verify(conn, url, verify, cert)

View File

@ -17,6 +17,7 @@ import datetime
import socket
import time
import timeit
import typing as tp
from oslo_config import cfg
from oslo_log import log as logging
@ -91,6 +92,8 @@ class UDPStatusGetter:
:return: Returns the unwrapped payload and addr that sent the
heartbeat.
"""
if self.sock is None:
raise exceptions.NetworkConfig("unable to find suitable socket")
(data, srcaddr) = self.sock.recvfrom(UDP_MAX_SIZE)
LOG.debug('Received packet from %s', srcaddr)
try:
@ -432,8 +435,8 @@ class UpdateHealthDb:
if not db_lb:
return
processed_pools = []
potential_offline_pools = {}
processed_pools: tp.List[dict] = []
potential_offline_pools: tp.Dict[str, str] = {}
# We got a heartbeat so lb is healthy until proven otherwise
if db_lb[constants.ENABLED] is False:
@ -553,7 +556,7 @@ class UpdateHealthDb:
def _process_pool_status(
self, session, pool_id, db_pool_dict, pools, lb_status,
processed_pools, potential_offline_pools):
processed_pools: list, potential_offline_pools: dict):
pool_status = None
if pool_id not in pools:

View File

@ -12,6 +12,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
from oslo_log import log as logging
from octavia.amphorae.drivers import driver_base
@ -21,12 +23,20 @@ from octavia.common import constants
LOG = logging.getLogger(__name__)
class KeepalivedAmphoraDriverMixin(driver_base.VRRPDriverMixin):
class KeepalivedAmphoraDriverMixin(driver_base.VRRPDriverMixin,
metaclass=abc.ABCMeta):
def __init__(self):
super().__init__()
# The Mixed class must define a self.client object for the
# AmphoraApiClient
self.clients: dict
@abc.abstractmethod
def _populate_amphora_api_version(self, amphora,
timeout_dict=None,
raise_retry_exception=False):
"""Populate the amphora object with the api_version"""
def update_vrrp_conf(self, loadbalancer, amphorae_network_config, amphora,
timeout_dict=None):

View File

@ -18,6 +18,7 @@ import os
import signal
import sys
import time
import typing as tp
from oslo_config import cfg
from oslo_log import log as logging
@ -31,7 +32,7 @@ from octavia import version
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
PROVIDER_AGENT_PROCESSES = []
PROVIDER_AGENT_PROCESSES: tp.List[multiprocessing.Process] = []
def _mutate_config(*args, **kwargs):
@ -143,9 +144,7 @@ def main():
try:
proc.join(CONF.driver_agent.provider_agent_shutdown_timeout)
if proc.exitcode is None:
# TODO(johnsom) Change to proc.kill() once
# python 3.7 or newer only
os.kill(proc.pid, signal.SIGKILL)
proc.kill()
LOG.warning(
'Forcefully killed "%s" provider agent because it '
'failed to shutdown in %s seconds.', proc.name,

View File

@ -25,10 +25,19 @@ import octavia.tests.unit.base as base
API_VERSION = '1.0'
class KeepalivedAmpDriverMixinImpl(
vrrp_rest_driver.KeepalivedAmphoraDriverMixin):
"""The base class is abstract"""
def _populate_amphora_api_version(self, amphora,
timeout_dict=None,
raise_retry_exception=False):
pass
class TestVRRPRestDriver(base.TestCase):
def setUp(self):
self.keepalived_mixin = vrrp_rest_driver.KeepalivedAmphoraDriverMixin()
self.keepalived_mixin = KeepalivedAmpDriverMixinImpl()
self.keepalived_mixin.clients = {
'base': mock.MagicMock(),
API_VERSION: mock.MagicMock()}

View File

@ -207,7 +207,7 @@ class TestDriverAgentCMD(base.TestCase):
mock_exit_event.set.assert_called_once()
mock_provider_proc.join.assert_called_once_with(
CONF.driver_agent.provider_agent_shutdown_timeout)
mock_kill.assert_called_once_with('not-valid-pid', signal.SIGKILL)
mock_provider_proc.kill.assert_called_once()
# Test keyboard interrupt with provider agents join exception
mock_exit_event.reset_mock()