diff --git a/cloudinit/plugin_finder.py b/cloudinit/plugin_finder.py new file mode 100644 index 00000000..862da66a --- /dev/null +++ b/cloudinit/plugin_finder.py @@ -0,0 +1,54 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +"""Various base classes and implementations for finding *plugins*.""" + +import abc +import pkgutil + +import six + +from cloudinit import logging + + +LOG = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class BaseModuleIterator(object): + """Base class for describing a *module iterator* + + A module iterator is a class that's capable of listing + modules or packages from a specific location, which are + already loaded. + """ + + def __init__(self, search_paths): + self._search_paths = search_paths + + @abc.abstractmethod + def list_modules(self): + """List all the modules that this finder knows about.""" + + +class PkgutilModuleIterator(BaseModuleIterator): + """A class based on the *pkgutil* module for discovering modules.""" + + @staticmethod + def _find_module(finder, module): + """Delegate to the *finder* for finding the given module.""" + return finder.find_module(module).load_module(module) + + def list_modules(self): + """List all modules that this class knows about.""" + for finder, name, _ in pkgutil.walk_packages(self._search_paths): + try: + module = self._find_module(finder, name) + except ImportError: + LOG.debug('Could not import the module %r using the ' + 'search path %r', name, finder.path) + continue + + yield module diff --git a/cloudinit/sources/base.py b/cloudinit/sources/base.py index 3de21900..8453f5c4 100644 --- a/cloudinit/sources/base.py +++ b/cloudinit/sources/base.py @@ -4,9 +4,18 @@ # vi: ts=4 expandtab import abc +import itertools import six +from cloudinit import exceptions +from cloudinit import logging +from cloudinit import sources +from cloudinit.sources import strategy + + +LOG = logging.getLogger(__name__) + class APIResponse(object): """Holds API response content @@ -35,6 +44,64 @@ class APIResponse(object): return self.decoded_buffer +class DataSourceLoader(object): + """Class for retrieving an available data source instance + + :param names: + A list of possible data source names, from which the loader + should pick. This can be used to filter the data sources + that can be found from outside of cloudinit control. + + :param module_iterator: + An instance of :class:`cloudinit.plugin_finder.BaseModuleIterator`, + which is used to find possible modules where the data sources + can be found. + + :param strategies: + An iterator of search strategy classes, where each strategy is capable + of filtering the data sources that can be used by cloudinit. + Possible strategies includes serial data source search or + parallel data source or filtering data sources according to + some criteria (only network data sources) + + """ + + def __init__(self, names, module_iterator, strategies): + self._names = names + self._module_iterator = module_iterator + self._strategies = strategies + + @staticmethod + def _implements_source_api(module): + """Check if the given module implements the data source API.""" + return hasattr(module, 'data_sources') + + def _valid_modules(self): + """Return all the modules that are *valid* + + Valid modules are those that implements a particular API + for declaring the data sources it exports. + """ + modules = self._module_iterator.list_modules() + return filter(self._implements_source_api, modules) + + def all_data_sources(self): + """Get all the data source classes that this finder knows about.""" + return itertools.chain.from_iterable( + module.data_sources() + for module in self._valid_modules()) + + def valid_data_sources(self): + """Get the data sources that are valid for this run.""" + data_sources = self.all_data_sources() + # Instantiate them before passing to the strategies. + data_sources = (data_source() for data_source in data_sources) + + for strategy_instance in self._strategies: + data_sources = strategy_instance.search_data_sources(data_sources) + return data_sources + + @six.add_metaclass(abc.ABCMeta) class BaseDataSource(object): """Base class for the data sources.""" @@ -106,3 +173,41 @@ class BaseDataSource(object): def is_password_set(self): """Check if the password was already posted to the metadata service.""" + + +def get_data_source(names, module_iterator, strategies=None): + """Get an instance of any data source available. + + :param names: + A list of possible data source names, from which the loader + should pick. This can be used to filter the data sources + that can be found from outside of cloudinit control. + + :param module_iterator: + A subclass of :class:`cloudinit.plugin_finder.BaseModuleIterator`, + which is used to find possible modules where the data sources + can be found. + + :param strategies: + An iterator of search strategy classes, where each strategy is capable + of filtering the data sources that can be used by cloudinit. + """ + if names: + default_strategies = [strategy.FilterNameStrategy(names)] + else: + default_strategies = [] + if strategies is None: + strategies = [] + + strategy_instances = [strategy_cls() for strategy_cls in strategies] + strategies = default_strategies + strategy_instances + + iterator = module_iterator(sources.__path__) + loader = DataSourceLoader(names, iterator, strategies) + valid_sources = loader.valid_data_sources() + + data_source = next(valid_sources, None) + if not data_source: + raise exceptions.CloudInitError('No available data source found') + + return data_source diff --git a/cloudinit/sources/openstack/httpopenstack.py b/cloudinit/sources/openstack/httpopenstack.py index d2f49acc..618a62df 100644 --- a/cloudinit/sources/openstack/httpopenstack.py +++ b/cloudinit/sources/openstack/httpopenstack.py @@ -125,3 +125,8 @@ class HttpOpenStackSource(baseopenstack.BaseOpenStackSource): return False else: raise + + +def data_sources(): + """Get the data sources exported in this module.""" + return (HttpOpenStackSource,) diff --git a/cloudinit/sources/strategy.py b/cloudinit/sources/strategy.py new file mode 100644 index 00000000..32a9b66c --- /dev/null +++ b/cloudinit/sources/strategy.py @@ -0,0 +1,79 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +import abc + +import six + +from cloudinit import logging + + +LOG = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class BaseSearchStrategy(object): + """Declare search strategies for data sources + + A *search strategy* represents a decoupled way of choosing + one or more data sources from a list of data sources. + Each strategy can be used interchangeably and they can + be composed. For instance, once can apply a filtering strategy + over a parallel search strategy, which looks for the available + data sources. + """ + + @abc.abstractmethod + def search_data_sources(self, data_sources): + """Search the possible data sources for this strategy + + The method should filter the data sources that can be + considered *valid* for the given strategy. + + :param data_sources: + An iterator of data source instances, where the lookup + will be done. + """ + + @staticmethod + def is_datasource_available(data_source): + """Check if the given *data_source* is considered *available* + + A data source is considered available if it can be loaded, + but other strategies could implement their own behaviour. + """ + try: + if data_source.load(): + return True + except Exception: + LOG.error("Failed to load data source %r", data_source) + return False + + +class FilterNameStrategy(BaseSearchStrategy): + """A strategy for filtering data sources by name + + :param names: + A list of strings, where each string is a name for a possible + data source. Only the data sources that are in this list will + be loaded and filtered. + """ + + def __init__(self, names=None): + self._names = names + super(FilterNameStrategy, self).__init__() + + def search_data_sources(self, data_sources): + return (source for source in data_sources + if source.__class__.__name__ in self._names) + + +class SerialSearchStrategy(BaseSearchStrategy): + """A strategy that chooses a data source in serial.""" + + def search_data_sources(self, data_sources): + for data_source in data_sources: + if self.is_datasource_available(data_source): + yield data_source diff --git a/cloudinit/tests/sources/test_base.py b/cloudinit/tests/sources/test_base.py new file mode 100644 index 00000000..a39aa729 --- /dev/null +++ b/cloudinit/tests/sources/test_base.py @@ -0,0 +1,115 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +import functools +import string +import types + +from cloudinit import exceptions +from cloudinit import plugin_finder +from cloudinit.sources import base +from cloudinit.sources import strategy +from cloudinit import tests + + +class TestDataSourceDiscovery(tests.TestCase): + + def setUp(self): + super(TestDataSourceDiscovery, self).setUp() + self._modules = None + + @property + def modules(self): + if self._modules: + return self._modules + + class Module(types.ModuleType): + def data_sources(self): + return (self, ) + + def __call__(self): + return self + + @property + def __class__(self): + return self + + modules = self._modules = list(map(Module, string.ascii_letters)) + return modules + + @property + def module_iterator(self): + modules = self.modules + + class ModuleIterator(plugin_finder.BaseModuleIterator): + def list_modules(self): + return modules + [None, "", 42] + + return ModuleIterator(None) + + def test_loader_api(self): + # Test that the API of DataSourceLoader is sane + loader = base.DataSourceLoader( + names=[], module_iterator=self.module_iterator, + strategies=[]) + + all_data_sources = list(loader.all_data_sources()) + valid_data_sources = list(loader.valid_data_sources()) + + self.assertEqual(all_data_sources, self.modules) + self.assertEqual(valid_data_sources, self.modules) + + def test_loader_strategies(self): + class OrdStrategy(strategy.BaseSearchStrategy): + def search_data_sources(self, data_sources): + return filter(lambda source: ord(source.__name__) < 100, + data_sources) + + class NameStrategy(strategy.BaseSearchStrategy): + def search_data_sources(self, data_sources): + return (source for source in data_sources + if source.__name__ in ('a', 'b', 'c')) + + loader = base.DataSourceLoader( + names=[], module_iterator=self.module_iterator, + strategies=(OrdStrategy(), NameStrategy(), )) + valid_data_sources = list(loader.valid_data_sources()) + + self.assertEqual(len(valid_data_sources), 3) + self.assertEqual([source.__name__ for source in valid_data_sources], + ['a', 'b', 'c']) + + def test_get_data_source_filtered_by_name(self): + source = base.get_data_source( + names=['a', 'c'], + module_iterator=self.module_iterator.__class__) + self.assertEqual(source.__name__, 'a') + + def test_get_data_source_multiple_strategies(self): + class ReversedStrategy(strategy.BaseSearchStrategy): + def search_data_sources(self, data_sources): + return reversed(list(data_sources)) + + source = base.get_data_source( + names=['a', 'b', 'c'], + module_iterator=self.module_iterator.__class__, + strategies=(ReversedStrategy, )) + + self.assertEqual(source.__name__, 'c') + + def test_get_data_source_no_data_source(self): + get_data_source = functools.partial( + base.get_data_source, + names=['totallymissing'], + module_iterator=self.module_iterator.__class__) + + exc = self.assertRaises(exceptions.CloudInitError, + get_data_source) + self.assertEqual(str(exc), 'No available data source found') + + def test_get_data_source_no_name_filtering(self): + source = base.get_data_source( + names=[], module_iterator=self.module_iterator.__class__) + self.assertEqual(source.__name__, 'a') diff --git a/cloudinit/tests/sources/test_strategy.py b/cloudinit/tests/sources/test_strategy.py new file mode 100644 index 00000000..fa1b0314 --- /dev/null +++ b/cloudinit/tests/sources/test_strategy.py @@ -0,0 +1,65 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +from cloudinit.sources import strategy +from cloudinit import tests +from cloudinit.tests.util import mock + + +class TestStrategy(tests.TestCase): + + def test_custom_strategy(self): + class CustomStrategy(strategy.BaseSearchStrategy): + + def search_data_sources(self, data_sources): + # Return them in reverse order + return list(reversed(data_sources)) + + data_sources = [mock.sentinel.first, mock.sentinel.second] + instance = CustomStrategy() + sources = instance.search_data_sources(data_sources) + + self.assertEqual(sources, [mock.sentinel.second, mock.sentinel.first]) + + def test_is_datasource_available(self): + class CustomStrategy(strategy.BaseSearchStrategy): + def search_data_sources(self, _): + pass + + instance = CustomStrategy() + good_source = mock.Mock() + good_source.load.return_value = True + bad_source = mock.Mock() + bad_source.load.return_value = False + + self.assertTrue(instance.is_datasource_available(good_source)) + self.assertFalse(instance.is_datasource_available(bad_source)) + + def test_filter_name_strategy(self): + names = ['first', 'second', 'third'] + full_names = names + ['fourth', 'fifth'] + sources = [type(name, (object, ), {})() for name in full_names] + instance = strategy.FilterNameStrategy(names) + + sources = list(instance.search_data_sources(sources)) + + self.assertEqual(len(sources), 3) + self.assertEqual([source.__class__.__name__ for source in sources], + names) + + def test_serial_search_strategy(self): + def is_available(self, data_source): + return data_source in available_sources + + sources = [mock.sentinel.first, mock.sentinel.second, + mock.sentinel.third, mock.sentinel.fourth] + available_sources = [mock.sentinel.second, mock.sentinel.fourth] + + with mock.patch('cloudinit.sources.strategy.BaseSearchStrategy.' + 'is_datasource_available', new=is_available): + instance = strategy.SerialSearchStrategy() + valid_sources = list(instance.search_data_sources(sources)) + + self.assertEqual(available_sources, valid_sources) diff --git a/cloudinit/tests/test_plugin_finder.py b/cloudinit/tests/test_plugin_finder.py new file mode 100644 index 00000000..2cb244d8 --- /dev/null +++ b/cloudinit/tests/test_plugin_finder.py @@ -0,0 +1,55 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +import contextlib +import os +import shutil +import tempfile + +from cloudinit import plugin_finder +from cloudinit.tests import TestCase +from cloudinit.tests import util + + +class TestPkgutilModuleIterator(TestCase): + + @staticmethod + @contextlib.contextmanager + def _create_tmpdir(): + tmpdir = tempfile.mkdtemp() + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + @contextlib.contextmanager + def _create_package(self): + with self._create_tmpdir() as tmpdir: + path = os.path.join(tmpdir, 'good.py') + with open(path, 'w') as stream: + stream.write('name = 42') + + # Make sure this fails. + bad = os.path.join(tmpdir, 'bad.py') + with open(bad, 'w') as stream: + stream.write('import missingmodule') + + yield tmpdir + + def test_pkgutil_module_iterator(self): + logging_format = ("Could not import the module 'bad' " + "using the search path %r") + + with util.LogSnatcher('cloudinit.plugin_finder') as snatcher: + with self._create_package() as tmpdir: + expected_logging = logging_format % tmpdir + iterator = plugin_finder.PkgutilModuleIterator([tmpdir]) + modules = list(iterator.list_modules()) + + self.assertEqual(len(modules), 1) + module = modules[0] + self.assertEqual(module.name, 42) + self.assertEqual(len(snatcher.output), 1) + self.assertEqual(snatcher.output[0], expected_logging)