diff --git a/packetary/drivers/rpm_driver.py b/packetary/drivers/rpm_driver.py new file mode 100644 index 0000000..a00aef7 --- /dev/null +++ b/packetary/drivers/rpm_driver.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- + +# Copyright 2015 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import multiprocessing +import os +import shutil + +import createrepo +import lxml.etree as etree +import six + +from packetary.drivers.base import RepositoryDriverBase +from packetary.library.streams import GzipDecompress +from packetary.library import utils +from packetary.objects import FileChecksum +from packetary.objects import Package +from packetary.objects import PackageRelation +from packetary.objects import PackageVersion +from packetary.objects import Repository + + +urljoin = six.moves.urllib.parse.urljoin + +# TODO(configurable option for drivers) +_CORE_GROUPS = ("core", "base") + +_MANDATORY_TYPES = ("mandatory", "default") + +# The namespaces are used in metadata xml of repository +_NAMESPACES = { + "main": "http://linux.duke.edu/metadata/common", + "md": "http://linux.duke.edu/metadata/repo", + "rpm": "http://linux.duke.edu/metadata/rpm" +} + + +class CreaterepoCallBack(object): + """Callback object for createrepo""" + def __init__(self, logger): + self.logger = logger + + def errorlog(self, msg): + """Error log output.""" + self.logger.error(msg) + + def log(self, msg): + """Logs message.""" + self.logger.info(msg) + + def progress(self, item, current, total): + """"Progress bar.""" + pass + + +class RpmRepositoryDriver(RepositoryDriverBase): + def parse_urls(self, urls): + """Overrides method of superclass.""" + return (url.rstrip("/") for url in urls) + + def get_repository(self, connection, url, arch, consumer): + """Overrides method of superclass.""" + # Currently supported repositories, that has URL in following format: + # baseurl/{name}/{architecture} + # because the architecture is sentetic part of rpm repository URL + name = url.rsplit("/", 1)[-1] + baseurl = "/".join((url, arch, "")) + consumer(Repository( + name=name, + url=baseurl, + architecture=arch, + origin="" + )) + + def get_packages(self, connection, repository, consumer): + """Overrides method of superclass.""" + baseurl = repository.url + repomd = urljoin(baseurl, "repodata/repomd.xml") + self.logger.debug("repomd: %s", repomd) + + repomd_tree = etree.parse(connection.open_stream(repomd)) + mandatory = self._get_mandatory_packages( + self._load_db( + connection, baseurl, repomd_tree, "group_gz", "group" + ) + ) + primary_db = self._load_db(connection, baseurl, repomd_tree, "primary") + if primary_db is None: + raise ValueError("Malformed repository: {0}".format(repository)) + + counter = 0 + for tag in primary_db.iterfind("./main:package", _NAMESPACES): + try: + name = tag.find("./main:name", _NAMESPACES).text + consumer(Package( + repository=repository, + name=tag.find("./main:name", _NAMESPACES).text, + version=self._unparse_version_attrs( + tag.find("./main:version", _NAMESPACES).attrib + ), + filesize=int( + tag.find("./main:size", _NAMESPACES) + .attrib.get("package", -1) + ), + filename=tag.find( + "./main:location", _NAMESPACES + ).attrib["href"], + checksum=self._get_checksum(tag), + mandatory=name in mandatory, + requires=self._get_relations(tag, "requires"), + obsoletes=self._get_relations(tag, "obsoletes"), + provides=self._get_relations(tag, "provides") + )) + except (ValueError, KeyError) as e: + self.logger.error( + "Malformed tag %s - %s: %s", + repository, etree.tostring(tag), six.text_type(e) + ) + raise + counter += 1 + self.logger.info("loaded: %d packages from %s.", counter, repository) + + def rebuild_repository(self, repository, packages): + """Overrides method of superclass.""" + basepath = utils.get_path_from_url(repository.url) + self.logger.info("rebuild repository in %s", basepath) + md_config = createrepo.MetaDataConfig() + try: + md_config.workers = multiprocessing.cpu_count() + md_config.directory = str(basepath) + md_config.update = True + mdgen = createrepo.MetaDataGenerator( + config_obj=md_config, callback=CreaterepoCallBack(self.logger) + ) + mdgen.doPkgMetadata() + mdgen.doRepoMetadata() + mdgen.doFinalMove() + except createrepo.MDError as e: + err_msg = six.text_type(e) + self.logger.exception( + "failed to create yum repository in %s: %s", + basepath, + err_msg + ) + shutil.rmtree( + os.path.join(md_config.outputdir, md_config.tempdir), + ignore_errors=True + ) + raise RuntimeError( + "Failed to create yum repository in {0}." + .format(err_msg)) + + def fork_repository(self, connection, repository, destination, + source=False, locale=False): + """Overrides method of superclass.""" + # TODO(download gpk) + # TODO(sources and locales) + destination = os.path.join( + destination, repository.name, + repository.architecture, "" + ) + new_repo = copy.copy(repository) + new_repo.url = destination + self.logger.info("clone repository %s to %s", repository, destination) + utils.ensure_dir_exist(destination) + self.rebuild_repository(new_repo, set()) + return new_repo + + def _load_db(self, connection, baseurl, repomd, *aliases): + """Loads database. + + :param connection: the connection object + :param baseurl: the base repository URL + :param repomd: the parsed metadata of repository + :param aliases: the aliases of database name + :return: parsed database file or None if db does not exist + """ + + for dbname in aliases: + self.logger.debug("loading %s database...", dbname) + node = repomd.find( + "./md:data[@type='{0}']".format(dbname), _NAMESPACES + ) + if node is not None: + break + else: + return + + url = urljoin( + baseurl, + node.find("./md:location", _NAMESPACES).attrib["href"] + ) + self.logger.debug("loading %s - %s...", dbname, url) + stream = connection.open_stream(url) + if url.endswith(".gz"): + stream = GzipDecompress(stream) + return etree.parse(stream) + + def _get_mandatory_packages(self, groups_db): + """Get the set of mandatory package names. + + :param groups_db: the parsed groups database + """ + package_names = set() + if groups_db is None: + return package_names + count = 0 + for name in _CORE_GROUPS: + result = groups_db.xpath("./group/id[text()='{0}']".format(name)) + if len(result) == 0: + self.logger.warning("the group '%s' is not found.", name) + continue + group = result[0].getparent() + for t in _MANDATORY_TYPES: + xpath = "./packagelist/packagereq[@type='{0}']".format(t) + for tag in group.iterfind(xpath): + package_names.add(tag.text) + count += 1 + self.logger.info("detected %d mandatory packages.", count) + return package_names + + def _get_relations(self, pkg_tag, name): + """Gets package relations by name from package tag. + + :param pkg_tag: the xml-tag with package description + :param name: the relations name + :return: list of PackageRelation objects + """ + relations = list() + append = relations.append + tags_iter = pkg_tag.iterfind( + "./main:format/rpm:%s/rpm:entry" % name, + _NAMESPACES + ) + for elem in tags_iter: + append(PackageRelation.from_args( + self._unparse_relation_attrs(elem.attrib) + )) + + return relations + + def _get_checksum(self, pkg_tag): + """Gets checksum from package tag.""" + checksum = dict.fromkeys(("md5", "sha1", "sha256"), None) + checksum_tag = pkg_tag.find("./main:checksum", _NAMESPACES) + checksum[checksum_tag.attrib["type"]] = checksum_tag.text + return FileChecksum(**checksum) + + def _unparse_relation_attrs(self, attrs): + """Gets the package relation from attributes. + + :param attrs: the relation tag attributes + :return tuple(name, version_op, version_edge) + """ + if "flags" not in attrs: + return attrs['name'], None + + return ( + attrs['name'], + attrs["flags"].lower(), + self._unparse_version_attrs(attrs) + ) + + @staticmethod + def _unparse_version_attrs(attrs): + """Gets the package version from attributes. + + :param attrs: the relation tag attributes + :return: the PackageVersion object + """ + + return PackageVersion( + int(attrs.get("epoch", 0)), + attrs.get("ver", "0.0").split("."), + attrs.get("rel", "0").split(".") + ) diff --git a/packetary/library/utils.py b/packetary/library/utils.py index 440ab97..a3df3c1 100644 --- a/packetary/library/utils.py +++ b/packetary/library/utils.py @@ -21,7 +21,6 @@ import os import six - urlparse = six.moves.urllib.parse.urlparse @@ -72,19 +71,22 @@ def get_size_and_checksum_for_files(files, checksum_algo): yield filename, size, checksum -def get_path_from_url(url): +def get_path_from_url(url, ensure_file=True): """Get the path from the URL. :param url: the URL - :return: the filepath + :param ensure_file: If True, ensure that scheme is "file" + :return: the path component from URL :raises ValueError """ comps = urlparse(url, scheme="file") - if comps.scheme != "file": + if ensure_file and comps.scheme != "file": raise ValueError( "The absolute path is expected, actual have: {0}.".format(url) ) + if os.sep != "/": + return comps.path.replace("/", os.sep) return comps.path diff --git a/packetary/objects/__init__.py b/packetary/objects/__init__.py index 234dc3e..96bf21f 100644 --- a/packetary/objects/__init__.py +++ b/packetary/objects/__init__.py @@ -20,6 +20,7 @@ from packetary.objects.package import FileChecksum from packetary.objects.package import Package from packetary.objects.package_relation import PackageRelation from packetary.objects.package_relation import VersionRange +from packetary.objects.package_version import PackageVersion from packetary.objects.packages_tree import PackagesTree from packetary.objects.repository import Repository @@ -30,6 +31,7 @@ __all__ = [ "Package", "PackageRelation", "PackagesTree", + "PackageVersion", "Repository", "VersionRange", ] diff --git a/packetary/objects/package_relation.py b/packetary/objects/package_relation.py index fe7ba86..31cdfd0 100644 --- a/packetary/objects/package_relation.py +++ b/packetary/objects/package_relation.py @@ -24,6 +24,9 @@ class VersionRange(object): the compare operation can be one of: equal, greater, less, greater or equal, less or equal. """ + + __slots__ = ["op", "edge"] + def __init__(self, op=None, edge=None): """Initialises. @@ -96,6 +99,8 @@ class PackageRelation(object): and range of versions that satisfies requirement. """ + __slots__ = ["name", "version", "alternative"] + def __init__(self, name, version=None, alternative=None): """Initialises. diff --git a/packetary/objects/package_version.py b/packetary/objects/package_version.py new file mode 100644 index 0000000..03057f6 --- /dev/null +++ b/packetary/objects/package_version.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +# Copyright 2015 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from packetary.objects.base import ComparableObject + + +class PackageVersion(ComparableObject): + """The Package version.""" + + __slots__ = ["epoch", "version", "release"] + + def __init__(self, epoch, version, release): + self.epoch = int(epoch) + self.version = tuple(version) + self.release = tuple(release) + + @classmethod + def from_string(cls, text): + """Constructs from string. + + :param text: the version in format '[{epoch-}]-{version}-{release}' + """ + components = text.split("-") + if len(components) > 2: + epoch = components[0] + components = components[1:] + else: + epoch = 0 + return cls(epoch, components[0].split("."), components[1].split(".")) + + def cmp(self, other): + if not isinstance(other, PackageVersion): + other = PackageVersion.from_string(str(other)) + + if not isinstance(other, PackageVersion): + raise TypeError + if self.epoch < other.epoch: + return -1 + if self.epoch > other.epoch: + return 1 + + res = self._cmp_version_part(self.version, other.version) + if res != 0: + return res + return self._cmp_version_part(self.release, other.release) + + def __eq__(self, other): + if other is self: + return True + return self.cmp(other) == 0 + + def __str__(self): + return "{0}-{1}-{2}".format( + self.epoch, + ".".join(str(x) for x in self.version), + ".".join(str(x) for x in self.release) + ) + + @classmethod + def _order(cls, x): + """Return an integer value for character x""" + if x.isdigit(): + return int(x) + 1 + if x.isalpha(): + return ord(x) + return ord(x) + 256 + + @classmethod + def _cmp_version_string(cls, version1, version2): + """Compares two versions as string.""" + la = [cls._order(x) for x in version1] + lb = [cls._order(x) for x in version2] + while la or lb: + a = 0 + b = 0 + if la: + a = la.pop(0) + if lb: + b = lb.pop(0) + if a < b: + return -1 + elif a > b: + return 1 + return 0 + + @classmethod + def _cmp_version_part(cls, version1, version2): + """Compares two versions.""" + ver1_it = iter(version1) + ver2_it = iter(version2) + while True: + v1 = next(ver1_it, None) + v2 = next(ver2_it, None) + + if v1 is None or v2 is None: + if v1 is not None: + return 1 + if v2 is not None: + return -1 + return 0 + + if v1.isdigit() and v2.isdigit(): + a = int(v1) + b = int(v2) + if a < b: + return -1 + if a > b: + return 1 + else: + r = cls._cmp_version_string(v1, v2) + if r != 0: + return r diff --git a/packetary/objects/repository.py b/packetary/objects/repository.py index 5ba604b..551e8ef 100644 --- a/packetary/objects/repository.py +++ b/packetary/objects/repository.py @@ -32,12 +32,12 @@ class Repository(object): def __str__(self): if isinstance(self.name, tuple): return ".".join(self.name) - return str(self.name) + return self.name or self.url def __unicode__(self): if isinstance(self.name, tuple): return u".".join(self.name) - return unicode(self.name, "utf8") + return self.name or self.url def __copy__(self): """Creates shallow copy of package.""" diff --git a/packetary/tests/data/groups.xml b/packetary/tests/data/groups.xml new file mode 100644 index 0000000..4fadf5e --- /dev/null +++ b/packetary/tests/data/groups.xml @@ -0,0 +1,17 @@ + + + + + core + + test1 + test2 + + + + other + + test1 + + + diff --git a/packetary/tests/data/primary.xml b/packetary/tests/data/primary.xml new file mode 100644 index 0000000..3be4161 --- /dev/null +++ b/packetary/tests/data/primary.xml @@ -0,0 +1,62 @@ + + + test1 + x86_64 + + e8ed9e0612e813491ed5e7c10502a39e43ec665afd1321541dea211202707a65 + Test package + Test package + Test + http://localhost/ + + + test2 + x86_64 + + e8ed9e0612e813491ed5e7c10502a39e43ec665afd1321541dea211202707a65 + Test package + Test package + Test + http://localhost/ + + diff --git a/packetary/tests/data/repomd.xml b/packetary/tests/data/repomd.xml new file mode 100644 index 0000000..990ff06 --- /dev/null +++ b/packetary/tests/data/repomd.xml @@ -0,0 +1,20 @@ + + + 1427842153 + + 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c + 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c + + 1427842225 + 2528031 + 23175717 + + + 689a2ef671fe1c2245539e9c7b90e9dcd1236f4a0dd376512cfd509531a2b70d + bb7a4b6a6ccc8b4875b569359aedf67f9678cd56da7f372c134200265e276951 + + 1427842225 + 2528031 + 23175717 + + diff --git a/packetary/tests/data/repomd2.xml b/packetary/tests/data/repomd2.xml new file mode 100644 index 0000000..198ef47 --- /dev/null +++ b/packetary/tests/data/repomd2.xml @@ -0,0 +1,20 @@ + + + 1427842153 + + 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c + 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c + + 1427842225 + 2528031 + 23175717 + + + 689a2ef671fe1c2245539e9c7b90e9dcd1236f4a0dd376512cfd509531a2b70d + bb7a4b6a6ccc8b4875b569359aedf67f9678cd56da7f372c134200265e276951 + + 1427842225 + 2528031 + 23175717 + + diff --git a/packetary/tests/test_library_utils.py b/packetary/tests/test_library_utils.py index cdfce1a..b6fd5b3 100644 --- a/packetary/tests/test_library_utils.py +++ b/packetary/tests/test_library_utils.py @@ -87,6 +87,11 @@ class TestLibraryUtils(base.TestCase): with self.assertRaises(ValueError): utils.get_path_from_url("http:///a/f.txt") + self.assertEqual( + "/f.txt", + utils.get_path_from_url("http://host/f.txt", False) + ) + @mock.patch("packetary.library.utils.os") def test_ensure_dir_exist(self, os): os.makedirs.side_effect = [ diff --git a/packetary/tests/test_objects.py b/packetary/tests/test_objects.py index fa6a45b..7eb90a1 100644 --- a/packetary/tests/test_objects.py +++ b/packetary/tests/test_objects.py @@ -18,8 +18,10 @@ import copy import six from packetary.objects import PackageRelation +from packetary.objects import PackageVersion from packetary.objects import VersionRange + from packetary.tests import base from packetary.tests.stubs import generator @@ -94,6 +96,20 @@ class TestRepositoryObject(base.TestCase): self.assertEqual(clone.name, origin.name) self.assertEqual(clone.architecture, origin.architecture) + def test_str(self): + self.assertEqual( + "a.b", + str(generator.gen_repository(name=("a", "b"))) + ) + self.assertEqual( + "/a/b/", + str(generator.gen_repository(name="", url="/a/b/")) + ) + self.assertEqual( + "a", + str(generator.gen_repository(name="a", url="/a/b/")) + ) + class TestRelationObject(TestObjectBase): def test_equal(self): @@ -181,3 +197,25 @@ class TestVersionRange(TestObjectBase): def test_intersection_is_typesafe(self): with self.assertRaises(TypeError): VersionRange("eq", 1).has_intersection(("eq", 1)) + + +class TestPackageVersion(base.TestCase): + def test_get_from_string(self): + ver = PackageVersion.from_string("1.0-22") + self.assertEqual(0, ver.epoch) + self.assertEqual(('1', '0'), ver.version) + self.assertEqual(('22',), ver.release) + + ver2 = PackageVersion.from_string("1-11.0-2") + self.assertEqual(1, ver2.epoch) + self.assertEqual(('11', '0'), ver2.version) + self.assertEqual(('2',), ver2.release) + + def test_compare(self): + ver1 = PackageVersion.from_string("6.3-31.5") + ver2 = PackageVersion.from_string("13.9-16.12") + self.assertLess(ver1, ver2) + self.assertGreater(ver2, ver1) + self.assertEqual(ver1, ver1) + self.assertLess(ver1, "6.3-40") + self.assertGreater(ver1, "6.3-31.4a") diff --git a/packetary/tests/test_rpm_driver.py b/packetary/tests/test_rpm_driver.py new file mode 100644 index 0000000..a4caec4 --- /dev/null +++ b/packetary/tests/test_rpm_driver.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- + +# Copyright 2015 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +import os.path as path +import sys + +import six + +from packetary.objects import FileChecksum +from packetary.tests import base +from packetary.tests.stubs.generator import gen_repository +from packetary.tests.stubs.helpers import get_compressed + + +REPOMD = path.join(path.dirname(__file__), "data", "repomd.xml") + +REPOMD2 = path.join(path.dirname(__file__), "data", "repomd2.xml") + +PRIMARY_DB = path.join(path.dirname(__file__), "data", "primary.xml") + +GROUPS_DB = path.join(path.dirname(__file__), "data", "groups.xml") + + +class TestRpmDriver(base.TestCase): + @classmethod + def setUpClass(cls): + cls.createrepo = sys.modules["createrepo"] = mock.MagicMock() + # import driver class after patching sys.modules + from packetary.drivers import rpm_driver + + super(TestRpmDriver, cls).setUpClass() + cls.driver = rpm_driver.RpmRepositoryDriver() + + def setUp(self): + self.createrepo.reset_mock() + self.connection = mock.MagicMock() + + def test_parse_urls(self): + self.assertItemsEqual( + [ + "http://host/centos/os", + "http://host/centos/updates" + ], + self.driver.parse_urls([ + "http://host/centos/os", + "http://host/centos/updates/", + ]) + ) + + def test_get_repository(self): + repos = [] + + self.driver.get_repository( + self.connection, "http://host/centos/os", "x86_64", repos.append + ) + + self.assertEqual(1, len(repos)) + repo = repos[0] + self.assertEqual("os", repo.name) + self.assertEqual("", repo.origin) + self.assertEqual("x86_64", repo.architecture) + self.assertEqual("http://host/centos/os/x86_64/", repo.url) + + def test_get_packages(self): + streams = [] + for conv, fname in zip( + (lambda x: six.BytesIO(x.read()), + get_compressed, get_compressed), + (REPOMD, GROUPS_DB, PRIMARY_DB) + ): + with open(fname, "rb") as s: + streams.append(conv(s)) + + packages = [] + self.connection.open_stream.side_effect = streams + self.driver.get_packages( + self.connection, + gen_repository("test", url="http://host/centos/os/x86_64/"), + packages.append + ) + self.connection.open_stream.assert_any_call( + "http://host/centos/os/x86_64/repodata/repomd.xml" + ) + self.connection.open_stream.assert_any_call( + "http://host/centos/os/x86_64/repodata/groups.xml.gz" + ) + self.connection.open_stream.assert_any_call( + "http://host/centos/os/x86_64/repodata/primary.xml.gz" + ) + self.assertEqual(2, len(packages)) + package = packages[0] + self.assertEqual("test1", package.name) + self.assertEqual("1.1.1.1-1.el7", package.version) + self.assertEqual(100, package.filesize) + self.assertEqual( + FileChecksum( + None, + None, + 'e8ed9e0612e813491ed5e7c10502a39e' + '43ec665afd1321541dea211202707a65'), + package.checksum + ) + self.assertEqual( + "Packages/test1.rpm", package.filename + ) + self.assertItemsEqual( + ['test2 (eq 0-1.1.1.1-1.el7)'], + (str(x) for x in package.requires) + ) + self.assertItemsEqual( + ["file (any)"], + (str(x) for x in package.provides) + ) + self.assertItemsEqual( + ["test-old (any)"], + (str(x) for x in package.obsoletes) + ) + self.assertTrue(package.mandatory) + self.assertFalse(packages[1].mandatory) + + def test_get_packages_if_group_not_gzipped(self): + streams = [] + for conv, fname in zip( + (lambda x: six.BytesIO(x.read()), + lambda x: six.BytesIO(x.read()), + get_compressed), + (REPOMD2, GROUPS_DB, PRIMARY_DB) + ): + with open(fname, "rb") as s: + streams.append(conv(s)) + + packages = [] + self.connection.open_stream.side_effect = streams + self.driver.get_packages( + self.connection, + gen_repository("test", url="http://host/centos/os/x86_64/"), + packages.append + ) + self.connection.open_stream.assert_any_call( + "http://host/centos/os/x86_64/repodata/groups.xml" + ) + self.assertEqual(2, len(packages)) + package = packages[0] + self.assertTrue(package.mandatory) + + @mock.patch("packetary.drivers.rpm_driver.shutil") + def test_rebuild_repository(self, shutil): + self.createrepo.MDError = ValueError + self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [ + None, self.createrepo.MDError() + ] + repo = gen_repository("test", url="file:///repo/os/x86_64") + self.createrepo.MetaDataConfig().outputdir = "/repo/os/x86_64" + self.createrepo.MetaDataConfig().tempdir = "tmp" + + self.driver.rebuild_repository(repo, set()) + + self.assertEqual( + "/repo/os/x86_64", + self.createrepo.MetaDataConfig().directory + ) + self.assertTrue(self.createrepo.MetaDataConfig().update) + self.createrepo.MetaDataGenerator()\ + .doPkgMetadata.assert_called_once_with() + self.createrepo.MetaDataGenerator()\ + .doRepoMetadata.assert_called_once_with() + self.createrepo.MetaDataGenerator()\ + .doFinalMove.assert_called_once_with() + + with self.assertRaises(RuntimeError): + self.driver.rebuild_repository(repo, set()) + shutil.rmtree.assert_called_once_with( + "/repo/os/x86_64/tmp", ignore_errors=True + ) + + @mock.patch("packetary.drivers.rpm_driver.utils") + def test_fork_repository(self, utils): + repo = gen_repository("os", url="http://localhost/os/x86_64") + clone = self.driver.fork_repository( + self.connection, + repo, + "/repo" + ) + + utils.ensure_dir_exist.assert_called_once_with("/repo/os/x86_64/") + self.assertEqual(repo.name, clone.name) + self.assertEqual(repo.architecture, clone.architecture) + self.assertEqual("/repo/os/x86_64/", clone.url) + self.createrepo.MetaDataGenerator()\ + .doFinalMove.assert_called_once_with() diff --git a/requirements.txt b/requirements.txt index 3af0b9e..26592e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ chardet>=2.3.0 stevedore>=1.1.0 six>=1.5.2 python-debian>=0.1.23 +lxml>=3.2 diff --git a/setup.cfg b/setup.cfg index b030539..5aaa750 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ console_scripts = packetary.drivers = deb=packetary.drivers.deb_driver:DebRepositoryDriver + rpm=packetary.drivers.rpm_driver:RpmRepositoryDriver [build_sphinx] source-dir = doc/source