Separate NIC validation

Change-Id: Icfc15cb4780d4682ed20aa2e5796578e81496c87
This commit is contained in:
Dmitry Tantsur 2018-06-19 10:28:31 +02:00
parent 4ce8342ea5
commit 4e16c2f6f7
3 changed files with 25 additions and 13 deletions

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import random
import sys
@ -258,15 +257,10 @@ class Provisioner(object):
def _get_nics(self, nics):
"""Validate and get the NICs."""
_utils.validate_nics(nics)
result = []
if not isinstance(nics, collections.Sequence):
raise TypeError("NICs must be a list of dicts")
for nic in nics:
if not isinstance(nic, collections.Mapping) or len(nic) != 1:
raise TypeError("Each NIC must be a dict with one item, "
"got %s" % nic)
nic_type, nic_id = next(iter(nic.items()))
if nic_type == 'network':
try:
@ -277,7 +271,7 @@ class Provisioner(object):
{'net': nic_id, 'error': exc})
else:
result.append((nic_type, network))
elif nic_type == 'port':
else:
try:
port = self._api.get_port(nic_id)
except Exception as exc:
@ -286,9 +280,6 @@ class Provisioner(object):
{'port': nic_id, 'error': exc})
else:
result.append((nic_type, port))
else:
raise ValueError("Unexpected NIC type %s, supported values: "
"'port', 'network'" % nic_type)
return result

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import json
import os
@ -121,3 +122,23 @@ def is_hostname_safe(hostname):
return False
return _HOSTNAME_RE.match(hostname) is not None
def validate_nics(nics):
"""Validate NICs."""
if not isinstance(nics, collections.Sequence):
raise TypeError("NICs must be a list of dicts")
unknown_nic_types = set()
for nic in nics:
if not isinstance(nic, collections.Mapping) or len(nic) != 1:
raise TypeError("Each NIC must be a dict with one item, "
"got %s" % nic)
nic_type = next(iter(nic))
if nic_type not in ('port', 'network'):
unknown_nic_types.add(nic_type)
if unknown_nic_types:
raise ValueError("Unexpected NIC type(s) %s, supported values are "
"'port' and 'network'" % ', '.join(unknown_nic_types))

View File

@ -488,7 +488,7 @@ class TestProvisionNode(Base):
self.api.release_node.assert_called_with(self.node)
def test_invalid_nic_type(self):
self.assertRaisesRegex(ValueError, 'Unexpected NIC type foo',
self.assertRaisesRegex(ValueError, r'Unexpected NIC type\(s\) foo',
self.pr.provision_node,
self.node, 'image', [{'foo': 'bar'}])
self.assertFalse(self.api.create_port.called)