From d6ef934b708b1731a248e05f2aaf56024d5a826a Mon Sep 17 00:00:00 2001 From: David Shrewsbury Date: Thu, 13 Dec 2018 15:14:20 -0500 Subject: [PATCH] Extract common config parsing for ProviderConfig Adds a ProviderConfig class method that can be called to get the config schema for the common config options in a Provider. Drivers are modified to call this method. Change-Id: Ib67256dddc06d13eb7683226edaa8c8c10a73326 --- nodepool/cmd/config_validator.py | 7 ++----- nodepool/driver/__init__.py | 9 +++++++++ nodepool/driver/kubernetes/config.py | 5 ++++- nodepool/driver/openstack/config.py | 4 +++- nodepool/driver/static/config.py | 4 +++- nodepool/driver/test/config.py | 4 +++- 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/nodepool/cmd/config_validator.py b/nodepool/cmd/config_validator.py index c06b64d55..c3dde57d7 100644 --- a/nodepool/cmd/config_validator.py +++ b/nodepool/cmd/config_validator.py @@ -14,6 +14,7 @@ import logging import voluptuous as v import yaml +from nodepool.driver import ProviderConfig from nodepool.config import get_provider_config log = logging.getLogger(__name__) @@ -26,11 +27,7 @@ class ConfigValidator: self.config_file = config_file def validate(self): - provider = { - 'name': v.Required(str), - 'driver': str, - 'max-concurrency': int, - } + provider = ProviderConfig.getCommonSchemaDict() label = { 'name': str, diff --git a/nodepool/driver/__init__.py b/nodepool/driver/__init__.py index dde69c168..1b958a402 100644 --- a/nodepool/driver/__init__.py +++ b/nodepool/driver/__init__.py @@ -22,6 +22,7 @@ import importlib import logging import math import os +import voluptuous as v from nodepool import zk from nodepool import exceptions @@ -910,6 +911,14 @@ class ProviderConfig(ConfigValue, metaclass=abc.ABCMeta): def __repr__(self): return "" % self.name + @classmethod + def getCommonSchemaDict(self): + return { + v.Required('name'): str, + 'driver': str, + 'max-concurrency': int + } + @property @abc.abstractmethod def pools(self): diff --git a/nodepool/driver/kubernetes/config.py b/nodepool/driver/kubernetes/config.py index db303a850..d58ffb032 100644 --- a/nodepool/driver/kubernetes/config.py +++ b/nodepool/driver/kubernetes/config.py @@ -109,7 +109,10 @@ class KubernetesProviderConfig(ProviderConfig): v.Required('context'): str, 'launch-retries': int, } - return v.Schema(provider) + + schema = ProviderConfig.getCommonSchemaDict() + schema.update(provider) + return v.Schema(schema) def getSupportedLabels(self, pool_name=None): labels = set() diff --git a/nodepool/driver/openstack/config.py b/nodepool/driver/openstack/config.py index 065ac9fcc..2f259978d 100644 --- a/nodepool/driver/openstack/config.py +++ b/nodepool/driver/openstack/config.py @@ -385,7 +385,8 @@ class OpenStackProviderConfig(ProviderConfig): 'security-groups': [str] }) - return v.Schema({ + schema = ProviderConfig.getCommonSchemaDict() + schema.update({ 'region-name': str, v.Required('cloud'): str, 'boot-timeout': int, @@ -400,6 +401,7 @@ class OpenStackProviderConfig(ProviderConfig): 'diskimages': [provider_diskimage], 'cloud-images': [provider_cloud_images], }) + return v.Schema(schema) def getSupportedLabels(self, pool_name=None): labels = set() diff --git a/nodepool/driver/static/config.py b/nodepool/driver/static/config.py index cef6a0994..b35b96539 100644 --- a/nodepool/driver/static/config.py +++ b/nodepool/driver/static/config.py @@ -112,7 +112,9 @@ class StaticProviderConfig(ProviderConfig): 'name': str, 'nodes': [pool_node], }) - return v.Schema({'pools': [pool]}) + schema = ProviderConfig.getCommonSchemaDict() + schema.update({'pools': [pool]}) + return v.Schema(schema) def getSupportedLabels(self, pool_name=None): labels = set() diff --git a/nodepool/driver/test/config.py b/nodepool/driver/test/config.py index a16480485..b90f268fe 100644 --- a/nodepool/driver/test/config.py +++ b/nodepool/driver/test/config.py @@ -58,7 +58,9 @@ class TestConfig(ProviderConfig): 'name': str, 'labels': [str] }) - return v.Schema({'pools': [pool]}) + schema = ProviderConfig.getCommonSchemaDict() + schema.update({'pools': [pool]}) + return v.Schema(schema) def getSupportedLabels(self, pool_name=None): return self.labels