diff --git a/cloudinit/sources/base.py b/cloudinit/sources/base.py index 8453f5c4..65f58538 100644 --- a/cloudinit/sources/base.py +++ b/cloudinit/sources/base.py @@ -140,6 +140,10 @@ class BaseDataSource(object): This should return an instance of :class:`APIResponse`. """ + @abc.abstractmethod + def version(self): + """Get the version of the current data source.""" + def instance_id(self): """Get this instance's id.""" diff --git a/cloudinit/sources/openstack/base.py b/cloudinit/sources/openstack/base.py index 6b44f5b7..01f59a38 100644 --- a/cloudinit/sources/openstack/base.py +++ b/cloudinit/sources/openstack/base.py @@ -50,6 +50,10 @@ class BaseOpenStackSource(base.BaseDataSource): def _path_join(self, path, *addons): """Join one or more components together.""" + def version(self): + """Get the underlying data source version.""" + return self._version + def _working_version(self): versions = self._available_versions() # OS_VERSIONS is stored in chronological order, so diff --git a/cloudinit/sources/strategy.py b/cloudinit/sources/strategy.py index 32a9b66c..547f8b61 100644 --- a/cloudinit/sources/strategy.py +++ b/cloudinit/sources/strategy.py @@ -77,3 +77,22 @@ class SerialSearchStrategy(BaseSearchStrategy): for data_source in data_sources: if self.is_datasource_available(data_source): yield data_source + + +class FilterVersionStrategy(BaseSearchStrategy): + """A strategy for filtering data sources by their version + + :param versions: + A list of strings, where each strings is a possible + version that a data source can have. + """ + + def __init__(self, versions=None): + if versions is None: + versions = [] + self._versions = versions + super(FilterVersionStrategy, self).__init__() + + def search_data_sources(self, data_sources): + return (source for source in data_sources + if source.version() in self._versions) diff --git a/cloudinit/tests/sources/test_strategy.py b/cloudinit/tests/sources/test_strategy.py index fa1b0314..9d84a0bc 100644 --- a/cloudinit/tests/sources/test_strategy.py +++ b/cloudinit/tests/sources/test_strategy.py @@ -63,3 +63,38 @@ class TestStrategy(tests.TestCase): valid_sources = list(instance.search_data_sources(sources)) self.assertEqual(available_sources, valid_sources) + + def test_filter_version_strategy(self): + class SourceV1(object): + def version(self): + return 'first' + + class SourceV2(SourceV1): + def version(self): + return 'second' + + class SourceV3(object): + def version(self): + return 'third' + + sources = [SourceV1(), SourceV2(), SourceV3()] + instance = strategy.FilterVersionStrategy(['third', 'first']) + + filtered_sources = sorted( + source.version() + for source in instance.search_data_sources(sources)) + + self.assertEqual(len(filtered_sources), 2) + self.assertEqual(filtered_sources, ['first', 'third']) + + def test_filter_version_strategy_no_versions_given(self): + class SourceV1(object): + def version(self): + return 'first' + + sources = [SourceV1()] + instance = strategy.FilterVersionStrategy() + + filtered_sources = list(instance.search_data_sources(sources)) + + self.assertEqual(len(filtered_sources), 0)