diff --git a/nodepool/driver/__init__.py b/nodepool/driver/__init__.py index 216bfe8b5..dde69c168 100644 --- a/nodepool/driver/__init__.py +++ b/nodepool/driver/__init__.py @@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta): return not self.__eq__(other) -class ConfigPool(ConfigValue): +class ConfigPool(ConfigValue, metaclass=abc.ABCMeta): + ''' + Base class for a single pool as defined in the configuration file. + ''' + def __init__(self): self.labels = {} self.max_servers = math.inf @@ -837,6 +841,40 @@ class ConfigPool(ConfigValue): self.node_attributes == other.node_attributes) return False + @classmethod + def getCommonSchemaDict(self): + ''' + Return the schema dict for common pool attributes. + + When a driver validates its own configuration schema, it should call + this class method to get and include the common pool attributes in + the schema. + + The `labels` attribute, though common, can vary its type across + drivers so it is not returned in the schema. + ''' + return { + 'max-servers': int, + 'node-attributes': dict, + } + + @abc.abstractmethod + def load(self, pool_config): + ''' + Load pool config options from the parsed configuration file. + + Subclasses are expected to call the parent method so that common + configuration values are loaded properly. + + Although `labels` is a common attribute, each driver may + define it differently, so we cannot parse that attribute here. + + :param dict pool_config: A single pool config section from which we + will load the values. + ''' + self.max_servers = pool_config.get('max-servers', math.inf) + self.node_attributes = pool_config.get('node-attributes') + class DriverConfig(ConfigValue): def __init__(self): diff --git a/nodepool/driver/kubernetes/config.py b/nodepool/driver/kubernetes/config.py index a3b60a97d..db303a850 100644 --- a/nodepool/driver/kubernetes/config.py +++ b/nodepool/driver/kubernetes/config.py @@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool): def __repr__(self): return "" % self.name + def load(self, pool_config, full_config): + super().load(pool_config) + self.name = pool_config['name'] + self.labels = {} + for label in pool_config.get('labels', []): + pl = KubernetesLabel() + pl.name = label['name'] + pl.type = label['type'] + pl.image = label.get('image') + pl.image_pull = label.get('image-pull', 'IfNotPresent') + pl.pool = self + self.labels[pl.name] = pl + full_config.labels[label['name']].pools.append(self) + class KubernetesProviderConfig(ProviderConfig): def __init__(self, driver, provider): @@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig): self.context = self.provider['context'] for pool in self.provider.get('pools', []): pp = KubernetesPool() - pp.name = pool['name'] + pp.load(pool, config) pp.provider = self self.pools[pp.name] = pp - pp.labels = {} - for label in pool.get('labels', []): - pl = KubernetesLabel() - pl.name = label['name'] - pl.type = label['type'] - pl.image = label.get('image') - pl.image_pull = label.get('image-pull', 'IfNotPresent') - pl.pool = pp - pp.labels[pl.name] = pl - config.labels[label['name']].pools.append(pp) def getSchema(self): k8s_label = { @@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig): 'image-pull': str, } - pool = { + pool = ConfigPool.getCommonSchemaDict() + pool.update({ v.Required('name'): str, v.Required('labels'): [k8s_label], - } + }) provider = { v.Required('pools'): [pool], diff --git a/nodepool/driver/openstack/config.py b/nodepool/driver/openstack/config.py index ea680c181..065ac9fcc 100644 --- a/nodepool/driver/openstack/config.py +++ b/nodepool/driver/openstack/config.py @@ -149,6 +149,64 @@ class ProviderPool(ConfigPool): def __repr__(self): return "" % self.name + def load(self, pool_config, full_config, provider): + ''' + Load pool configuration options. + + :param dict pool_config: A single pool config section from which we + will load the values. + :param dict full_config: The full nodepool config. + :param OpenStackProviderConfig: The calling provider object. + ''' + super().load(pool_config) + + self.provider = provider + self.name = pool_config['name'] + self.max_cores = pool_config.get('max-cores', math.inf) + self.max_ram = pool_config.get('max-ram', math.inf) + self.ignore_provider_quota = pool_config.get('ignore-provider-quota', + False) + self.azs = pool_config.get('availability-zones') + self.networks = pool_config.get('networks', []) + self.security_groups = pool_config.get('security-groups', []) + self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True)) + self.host_key_checking = bool(pool_config.get('host-key-checking', + True)) + + for label in pool_config.get('labels', []): + pl = ProviderLabel() + pl.name = label['name'] + pl.pool = self + self.labels[pl.name] = pl + diskimage = label.get('diskimage', None) + if diskimage: + pl.diskimage = full_config.diskimages[diskimage] + else: + pl.diskimage = None + cloud_image_name = label.get('cloud-image', None) + if cloud_image_name: + cloud_image = provider.cloud_images.get(cloud_image_name, None) + if not cloud_image: + raise ValueError( + "cloud-image %s does not exist in provider %s" + " but is referenced in label %s" % + (cloud_image_name, self.name, pl.name)) + else: + cloud_image = None + pl.cloud_image = cloud_image + pl.min_ram = label.get('min-ram', 0) + pl.flavor_name = label.get('flavor-name', None) + pl.key_name = label.get('key-name') + pl.console_log = label.get('console-log', False) + pl.boot_from_volume = bool(label.get('boot-from-volume', + False)) + pl.volume_size = label.get('volume-size', 50) + pl.instance_properties = label.get('instance-properties', + None) + + top_label = full_config.labels[pl.name] + top_label.pools.append(self) + class OpenStackProviderConfig(ProviderConfig): def __init__(self, driver, provider): @@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig): for pool in self.provider.get('pools', []): pp = ProviderPool() - pp.name = pool['name'] - pp.provider = self + pp.load(pool, config, self) self.pools[pp.name] = pp - pp.max_cores = pool.get('max-cores', math.inf) - pp.max_servers = pool.get('max-servers', math.inf) - pp.max_ram = pool.get('max-ram', math.inf) - pp.ignore_provider_quota = pool.get('ignore-provider-quota', False) - pp.azs = pool.get('availability-zones') - pp.networks = pool.get('networks', []) - pp.security_groups = pool.get('security-groups', []) - pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True)) - pp.host_key_checking = bool(pool.get('host-key-checking', True)) - pp.node_attributes = pool.get('node-attributes') - - for label in pool.get('labels', []): - pl = ProviderLabel() - pl.name = label['name'] - pl.pool = pp - pp.labels[pl.name] = pl - diskimage = label.get('diskimage', None) - if diskimage: - pl.diskimage = config.diskimages[diskimage] - else: - pl.diskimage = None - cloud_image_name = label.get('cloud-image', None) - if cloud_image_name: - cloud_image = self.cloud_images.get(cloud_image_name, None) - if not cloud_image: - raise ValueError( - "cloud-image %s does not exist in provider %s" - " but is referenced in label %s" % - (cloud_image_name, self.name, pl.name)) - else: - cloud_image = None - pl.cloud_image = cloud_image - pl.min_ram = label.get('min-ram', 0) - pl.flavor_name = label.get('flavor-name', None) - pl.key_name = label.get('key-name') - pl.console_log = label.get('console-log', False) - pl.boot_from_volume = bool(label.get('boot-from-volume', - False)) - pl.volume_size = label.get('volume-size', 50) - pl.instance_properties = label.get('instance-properties', - None) - - top_label = config.labels[pl.name] - top_label.pools.append(pp) def getSchema(self): provider_diskimage = { @@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig): v.Any(label_min_ram, label_flavor_name), v.Any(label_diskimage, label_cloud_image)) - pool = { + pool = ConfigPool.getCommonSchemaDict() + pool.update({ 'name': str, 'networks': [str], 'auto-floating-ip': bool, 'host-key-checking': bool, 'ignore-provider-quota': bool, 'max-cores': int, - 'max-servers': int, 'max-ram': int, 'labels': [pool_label], - 'node-attributes': dict, 'availability-zones': [str], 'security-groups': [str] - } + }) return v.Schema({ 'region-name': str, diff --git a/nodepool/driver/static/config.py b/nodepool/driver/static/config.py index 48cb74941..cef6a0994 100644 --- a/nodepool/driver/static/config.py +++ b/nodepool/driver/static/config.py @@ -41,6 +41,33 @@ class StaticPool(ConfigPool): def __repr__(self): return "" % self.name + def load(self, pool_config, full_config): + super().load(pool_config) + self.name = pool_config['name'] + # WARNING: This intentionally changes the type! + self.labels = set() + for node in pool_config.get('nodes', []): + self.nodes.append({ + 'name': node['name'], + 'labels': as_list(node['labels']), + 'host-key': as_list(node.get('host-key', [])), + 'timeout': int(node.get('timeout', 5)), + # Read ssh-port values for backward compat, but prefer port + 'connection-port': int( + node.get('connection-port', node.get('ssh-port', 22))), + 'connection-type': node.get('connection-type', 'ssh'), + 'username': node.get('username', 'zuul'), + 'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)), + }) + if isinstance(node['labels'], str): + for label in node['labels'].split(): + self.labels.add(label) + full_config.labels[label].pools.append(self) + elif isinstance(node['labels'], list): + for label in node['labels']: + self.labels.add(label) + full_config.labels[label].pools.append(self) + class StaticProviderConfig(ProviderConfig): def __init__(self, *args, **kwargs): @@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig): def load(self, config): for pool in self.provider.get('pools', []): pp = StaticPool() - pp.name = pool['name'] + pp.load(pool, config) pp.provider = self self.pools[pp.name] = pp - # WARNING: This intentionally changes the type! - pp.labels = set() - for node in pool.get('nodes', []): - pp.nodes.append({ - 'name': node['name'], - 'labels': as_list(node['labels']), - 'host-key': as_list(node.get('host-key', [])), - 'timeout': int(node.get('timeout', 5)), - # Read ssh-port values for backward compat, but prefer port - 'connection-port': int( - node.get('connection-port', node.get('ssh-port', 22))), - 'connection-type': node.get('connection-type', 'ssh'), - 'username': node.get('username', 'zuul'), - 'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)), - }) - if isinstance(node['labels'], str): - for label in node['labels'].split(): - pp.labels.add(label) - config.labels[label].pools.append(pp) - elif isinstance(node['labels'], list): - for label in node['labels']: - pp.labels.add(label) - config.labels[label].pools.append(pp) def getSchema(self): pool_node = { @@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig): 'connection-type': str, 'max-parallel-jobs': int, } - pool = { + pool = ConfigPool.getCommonSchemaDict() + pool.update({ 'name': str, 'nodes': [pool_node], - } + }) return v.Schema({'pools': [pool]}) def getSupportedLabels(self, pool_name=None): diff --git a/nodepool/driver/test/config.py b/nodepool/driver/test/config.py index 067c2934c..a16480485 100644 --- a/nodepool/driver/test/config.py +++ b/nodepool/driver/test/config.py @@ -12,7 +12,6 @@ # License for the specific language governing permissions and limitations # under the License. -import math import voluptuous as v from nodepool.driver import ConfigPool @@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig class TestPool(ConfigPool): - pass + def load(self, pool_config): + super().load(pool_config) + self.name = pool_config['name'] + self.labels = pool_config['labels'] class TestConfig(ProviderConfig): @@ -43,18 +45,19 @@ class TestConfig(ProviderConfig): self.labels = set() for pool in self.provider.get('pools', []): testpool = TestPool() - testpool.name = pool['name'] + testpool.load(pool) testpool.provider = self - testpool.max_servers = pool.get('max-servers', math.inf) - testpool.labels = pool['labels'] for label in pool['labels']: self.labels.add(label) newconfig.labels[label].pools.append(testpool) self.pools[pool['name']] = testpool def getSchema(self): - pool = {'name': str, - 'labels': [str]} + pool = ConfigPool.getCommonSchemaDict() + pool.update({ + 'name': str, + 'labels': [str] + }) return v.Schema({'pools': [pool]}) def getSupportedLabels(self, pool_name=None): diff --git a/nodepool/tests/unit/test_config_comparisons.py b/nodepool/tests/unit/test_config_comparisons.py index d7724e818..674df51b6 100644 --- a/nodepool/tests/unit/test_config_comparisons.py +++ b/nodepool/tests/unit/test_config_comparisons.py @@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool from nodepool.driver.static.config import StaticProviderConfig +class TempConfigPool(ConfigPool): + def load(self): + pass + + class TestConfigComparisons(tests.BaseTestCase): def test_ConfigPool(self): - a = ConfigPool() - b = ConfigPool() + + a = TempConfigPool() + b = TempConfigPool() self.assertEqual(a, b) a.max_servers = 5 self.assertNotEqual(a, b) @@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase): a.max_servers = 5 self.assertNotEqual(a, b) - c = ConfigPool() + c = TempConfigPool() d = ProviderPool() - self.assertNotEqual(c, d) + self.assertNotEqual(d, c) def test_OpenStackProviderConfig(self): provider = {'name': 'foo'} @@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase): # intentionally change an attribute of the base class a.max_servers = 5 self.assertNotEqual(a, b) - c = ConfigPool() + c = TempConfigPool() self.assertNotEqual(b, c) def test_StaticProviderConfig(self):