From 4e16c2f6f78c87b98ce0b25c13b7d5b63b0e7320 Mon Sep 17 00:00:00 2001 From: Dmitry Tantsur Date: Tue, 19 Jun 2018 10:28:31 +0200 Subject: [PATCH] Separate NIC validation Change-Id: Icfc15cb4780d4682ed20aa2e5796578e81496c87 --- metalsmith/_provisioner.py | 15 +++------------ metalsmith/_utils.py | 21 +++++++++++++++++++++ metalsmith/test/test_provisioner.py | 2 +- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/metalsmith/_provisioner.py b/metalsmith/_provisioner.py index db93bc9..d6b3398 100644 --- a/metalsmith/_provisioner.py +++ b/metalsmith/_provisioner.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging import random import sys @@ -258,15 +257,10 @@ class Provisioner(object): def _get_nics(self, nics): """Validate and get the NICs.""" + _utils.validate_nics(nics) + result = [] - if not isinstance(nics, collections.Sequence): - raise TypeError("NICs must be a list of dicts") - for nic in nics: - if not isinstance(nic, collections.Mapping) or len(nic) != 1: - raise TypeError("Each NIC must be a dict with one item, " - "got %s" % nic) - nic_type, nic_id = next(iter(nic.items())) if nic_type == 'network': try: @@ -277,7 +271,7 @@ class Provisioner(object): {'net': nic_id, 'error': exc}) else: result.append((nic_type, network)) - elif nic_type == 'port': + else: try: port = self._api.get_port(nic_id) except Exception as exc: @@ -286,9 +280,6 @@ class Provisioner(object): {'port': nic_id, 'error': exc}) else: result.append((nic_type, port)) - else: - raise ValueError("Unexpected NIC type %s, supported values: " - "'port', 'network'" % nic_type) return result diff --git a/metalsmith/_utils.py b/metalsmith/_utils.py index 6dde641..d485ffa 100644 --- a/metalsmith/_utils.py +++ b/metalsmith/_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import contextlib import json import os @@ -121,3 +122,23 @@ def is_hostname_safe(hostname): return False return _HOSTNAME_RE.match(hostname) is not None + + +def validate_nics(nics): + """Validate NICs.""" + if not isinstance(nics, collections.Sequence): + raise TypeError("NICs must be a list of dicts") + + unknown_nic_types = set() + for nic in nics: + if not isinstance(nic, collections.Mapping) or len(nic) != 1: + raise TypeError("Each NIC must be a dict with one item, " + "got %s" % nic) + + nic_type = next(iter(nic)) + if nic_type not in ('port', 'network'): + unknown_nic_types.add(nic_type) + + if unknown_nic_types: + raise ValueError("Unexpected NIC type(s) %s, supported values are " + "'port' and 'network'" % ', '.join(unknown_nic_types)) diff --git a/metalsmith/test/test_provisioner.py b/metalsmith/test/test_provisioner.py index 96c9311..11da438 100644 --- a/metalsmith/test/test_provisioner.py +++ b/metalsmith/test/test_provisioner.py @@ -488,7 +488,7 @@ class TestProvisionNode(Base): self.api.release_node.assert_called_with(self.node) def test_invalid_nic_type(self): - self.assertRaisesRegex(ValueError, 'Unexpected NIC type foo', + self.assertRaisesRegex(ValueError, r'Unexpected NIC type\(s\) foo', self.pr.provision_node, self.node, 'image', [{'foo': 'bar'}]) self.assertFalse(self.api.create_port.called)