Separate NIC validation
Change-Id: Icfc15cb4780d4682ed20aa2e5796578e81496c87
This commit is contained in:
parent
4ce8342ea5
commit
4e16c2f6f7
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
|
@ -258,15 +257,10 @@ class Provisioner(object):
|
|||
|
||||
def _get_nics(self, nics):
|
||||
"""Validate and get the NICs."""
|
||||
_utils.validate_nics(nics)
|
||||
|
||||
result = []
|
||||
if not isinstance(nics, collections.Sequence):
|
||||
raise TypeError("NICs must be a list of dicts")
|
||||
|
||||
for nic in nics:
|
||||
if not isinstance(nic, collections.Mapping) or len(nic) != 1:
|
||||
raise TypeError("Each NIC must be a dict with one item, "
|
||||
"got %s" % nic)
|
||||
|
||||
nic_type, nic_id = next(iter(nic.items()))
|
||||
if nic_type == 'network':
|
||||
try:
|
||||
|
@ -277,7 +271,7 @@ class Provisioner(object):
|
|||
{'net': nic_id, 'error': exc})
|
||||
else:
|
||||
result.append((nic_type, network))
|
||||
elif nic_type == 'port':
|
||||
else:
|
||||
try:
|
||||
port = self._api.get_port(nic_id)
|
||||
except Exception as exc:
|
||||
|
@ -286,9 +280,6 @@ class Provisioner(object):
|
|||
{'port': nic_id, 'error': exc})
|
||||
else:
|
||||
result.append((nic_type, port))
|
||||
else:
|
||||
raise ValueError("Unexpected NIC type %s, supported values: "
|
||||
"'port', 'network'" % nic_type)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
|
@ -121,3 +122,23 @@ def is_hostname_safe(hostname):
|
|||
return False
|
||||
|
||||
return _HOSTNAME_RE.match(hostname) is not None
|
||||
|
||||
|
||||
def validate_nics(nics):
|
||||
"""Validate NICs."""
|
||||
if not isinstance(nics, collections.Sequence):
|
||||
raise TypeError("NICs must be a list of dicts")
|
||||
|
||||
unknown_nic_types = set()
|
||||
for nic in nics:
|
||||
if not isinstance(nic, collections.Mapping) or len(nic) != 1:
|
||||
raise TypeError("Each NIC must be a dict with one item, "
|
||||
"got %s" % nic)
|
||||
|
||||
nic_type = next(iter(nic))
|
||||
if nic_type not in ('port', 'network'):
|
||||
unknown_nic_types.add(nic_type)
|
||||
|
||||
if unknown_nic_types:
|
||||
raise ValueError("Unexpected NIC type(s) %s, supported values are "
|
||||
"'port' and 'network'" % ', '.join(unknown_nic_types))
|
||||
|
|
|
@ -488,7 +488,7 @@ class TestProvisionNode(Base):
|
|||
self.api.release_node.assert_called_with(self.node)
|
||||
|
||||
def test_invalid_nic_type(self):
|
||||
self.assertRaisesRegex(ValueError, 'Unexpected NIC type foo',
|
||||
self.assertRaisesRegex(ValueError, r'Unexpected NIC type\(s\) foo',
|
||||
self.pr.provision_node,
|
||||
self.node, 'image', [{'foo': 'bar'}])
|
||||
self.assertFalse(self.api.create_port.called)
|
||||
|
|
Loading…
Reference in New Issue