diff --git a/packetary/drivers/rpm_driver.py b/packetary/drivers/rpm_driver.py
new file mode 100644
index 0000000..a00aef7
--- /dev/null
+++ b/packetary/drivers/rpm_driver.py
@@ -0,0 +1,289 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2015 Mirantis, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+import multiprocessing
+import os
+import shutil
+
+import createrepo
+import lxml.etree as etree
+import six
+
+from packetary.drivers.base import RepositoryDriverBase
+from packetary.library.streams import GzipDecompress
+from packetary.library import utils
+from packetary.objects import FileChecksum
+from packetary.objects import Package
+from packetary.objects import PackageRelation
+from packetary.objects import PackageVersion
+from packetary.objects import Repository
+
+
+urljoin = six.moves.urllib.parse.urljoin
+
+# TODO(configurable option for drivers)
+_CORE_GROUPS = ("core", "base")
+
+_MANDATORY_TYPES = ("mandatory", "default")
+
+# The namespaces are used in metadata xml of repository
+_NAMESPACES = {
+ "main": "http://linux.duke.edu/metadata/common",
+ "md": "http://linux.duke.edu/metadata/repo",
+ "rpm": "http://linux.duke.edu/metadata/rpm"
+}
+
+
+class CreaterepoCallBack(object):
+ """Callback object for createrepo"""
+ def __init__(self, logger):
+ self.logger = logger
+
+ def errorlog(self, msg):
+ """Error log output."""
+ self.logger.error(msg)
+
+ def log(self, msg):
+ """Logs message."""
+ self.logger.info(msg)
+
+ def progress(self, item, current, total):
+ """"Progress bar."""
+ pass
+
+
+class RpmRepositoryDriver(RepositoryDriverBase):
+ def parse_urls(self, urls):
+ """Overrides method of superclass."""
+ return (url.rstrip("/") for url in urls)
+
+ def get_repository(self, connection, url, arch, consumer):
+ """Overrides method of superclass."""
+ # Currently supported repositories, that has URL in following format:
+ # baseurl/{name}/{architecture}
+ # because the architecture is sentetic part of rpm repository URL
+ name = url.rsplit("/", 1)[-1]
+ baseurl = "/".join((url, arch, ""))
+ consumer(Repository(
+ name=name,
+ url=baseurl,
+ architecture=arch,
+ origin=""
+ ))
+
+ def get_packages(self, connection, repository, consumer):
+ """Overrides method of superclass."""
+ baseurl = repository.url
+ repomd = urljoin(baseurl, "repodata/repomd.xml")
+ self.logger.debug("repomd: %s", repomd)
+
+ repomd_tree = etree.parse(connection.open_stream(repomd))
+ mandatory = self._get_mandatory_packages(
+ self._load_db(
+ connection, baseurl, repomd_tree, "group_gz", "group"
+ )
+ )
+ primary_db = self._load_db(connection, baseurl, repomd_tree, "primary")
+ if primary_db is None:
+ raise ValueError("Malformed repository: {0}".format(repository))
+
+ counter = 0
+ for tag in primary_db.iterfind("./main:package", _NAMESPACES):
+ try:
+ name = tag.find("./main:name", _NAMESPACES).text
+ consumer(Package(
+ repository=repository,
+ name=tag.find("./main:name", _NAMESPACES).text,
+ version=self._unparse_version_attrs(
+ tag.find("./main:version", _NAMESPACES).attrib
+ ),
+ filesize=int(
+ tag.find("./main:size", _NAMESPACES)
+ .attrib.get("package", -1)
+ ),
+ filename=tag.find(
+ "./main:location", _NAMESPACES
+ ).attrib["href"],
+ checksum=self._get_checksum(tag),
+ mandatory=name in mandatory,
+ requires=self._get_relations(tag, "requires"),
+ obsoletes=self._get_relations(tag, "obsoletes"),
+ provides=self._get_relations(tag, "provides")
+ ))
+ except (ValueError, KeyError) as e:
+ self.logger.error(
+ "Malformed tag %s - %s: %s",
+ repository, etree.tostring(tag), six.text_type(e)
+ )
+ raise
+ counter += 1
+ self.logger.info("loaded: %d packages from %s.", counter, repository)
+
+ def rebuild_repository(self, repository, packages):
+ """Overrides method of superclass."""
+ basepath = utils.get_path_from_url(repository.url)
+ self.logger.info("rebuild repository in %s", basepath)
+ md_config = createrepo.MetaDataConfig()
+ try:
+ md_config.workers = multiprocessing.cpu_count()
+ md_config.directory = str(basepath)
+ md_config.update = True
+ mdgen = createrepo.MetaDataGenerator(
+ config_obj=md_config, callback=CreaterepoCallBack(self.logger)
+ )
+ mdgen.doPkgMetadata()
+ mdgen.doRepoMetadata()
+ mdgen.doFinalMove()
+ except createrepo.MDError as e:
+ err_msg = six.text_type(e)
+ self.logger.exception(
+ "failed to create yum repository in %s: %s",
+ basepath,
+ err_msg
+ )
+ shutil.rmtree(
+ os.path.join(md_config.outputdir, md_config.tempdir),
+ ignore_errors=True
+ )
+ raise RuntimeError(
+ "Failed to create yum repository in {0}."
+ .format(err_msg))
+
+ def fork_repository(self, connection, repository, destination,
+ source=False, locale=False):
+ """Overrides method of superclass."""
+ # TODO(download gpk)
+ # TODO(sources and locales)
+ destination = os.path.join(
+ destination, repository.name,
+ repository.architecture, ""
+ )
+ new_repo = copy.copy(repository)
+ new_repo.url = destination
+ self.logger.info("clone repository %s to %s", repository, destination)
+ utils.ensure_dir_exist(destination)
+ self.rebuild_repository(new_repo, set())
+ return new_repo
+
+ def _load_db(self, connection, baseurl, repomd, *aliases):
+ """Loads database.
+
+ :param connection: the connection object
+ :param baseurl: the base repository URL
+ :param repomd: the parsed metadata of repository
+ :param aliases: the aliases of database name
+ :return: parsed database file or None if db does not exist
+ """
+
+ for dbname in aliases:
+ self.logger.debug("loading %s database...", dbname)
+ node = repomd.find(
+ "./md:data[@type='{0}']".format(dbname), _NAMESPACES
+ )
+ if node is not None:
+ break
+ else:
+ return
+
+ url = urljoin(
+ baseurl,
+ node.find("./md:location", _NAMESPACES).attrib["href"]
+ )
+ self.logger.debug("loading %s - %s...", dbname, url)
+ stream = connection.open_stream(url)
+ if url.endswith(".gz"):
+ stream = GzipDecompress(stream)
+ return etree.parse(stream)
+
+ def _get_mandatory_packages(self, groups_db):
+ """Get the set of mandatory package names.
+
+ :param groups_db: the parsed groups database
+ """
+ package_names = set()
+ if groups_db is None:
+ return package_names
+ count = 0
+ for name in _CORE_GROUPS:
+ result = groups_db.xpath("./group/id[text()='{0}']".format(name))
+ if len(result) == 0:
+ self.logger.warning("the group '%s' is not found.", name)
+ continue
+ group = result[0].getparent()
+ for t in _MANDATORY_TYPES:
+ xpath = "./packagelist/packagereq[@type='{0}']".format(t)
+ for tag in group.iterfind(xpath):
+ package_names.add(tag.text)
+ count += 1
+ self.logger.info("detected %d mandatory packages.", count)
+ return package_names
+
+ def _get_relations(self, pkg_tag, name):
+ """Gets package relations by name from package tag.
+
+ :param pkg_tag: the xml-tag with package description
+ :param name: the relations name
+ :return: list of PackageRelation objects
+ """
+ relations = list()
+ append = relations.append
+ tags_iter = pkg_tag.iterfind(
+ "./main:format/rpm:%s/rpm:entry" % name,
+ _NAMESPACES
+ )
+ for elem in tags_iter:
+ append(PackageRelation.from_args(
+ self._unparse_relation_attrs(elem.attrib)
+ ))
+
+ return relations
+
+ def _get_checksum(self, pkg_tag):
+ """Gets checksum from package tag."""
+ checksum = dict.fromkeys(("md5", "sha1", "sha256"), None)
+ checksum_tag = pkg_tag.find("./main:checksum", _NAMESPACES)
+ checksum[checksum_tag.attrib["type"]] = checksum_tag.text
+ return FileChecksum(**checksum)
+
+ def _unparse_relation_attrs(self, attrs):
+ """Gets the package relation from attributes.
+
+ :param attrs: the relation tag attributes
+ :return tuple(name, version_op, version_edge)
+ """
+ if "flags" not in attrs:
+ return attrs['name'], None
+
+ return (
+ attrs['name'],
+ attrs["flags"].lower(),
+ self._unparse_version_attrs(attrs)
+ )
+
+ @staticmethod
+ def _unparse_version_attrs(attrs):
+ """Gets the package version from attributes.
+
+ :param attrs: the relation tag attributes
+ :return: the PackageVersion object
+ """
+
+ return PackageVersion(
+ int(attrs.get("epoch", 0)),
+ attrs.get("ver", "0.0").split("."),
+ attrs.get("rel", "0").split(".")
+ )
diff --git a/packetary/library/utils.py b/packetary/library/utils.py
index 440ab97..a3df3c1 100644
--- a/packetary/library/utils.py
+++ b/packetary/library/utils.py
@@ -21,7 +21,6 @@ import os
import six
-
urlparse = six.moves.urllib.parse.urlparse
@@ -72,19 +71,22 @@ def get_size_and_checksum_for_files(files, checksum_algo):
yield filename, size, checksum
-def get_path_from_url(url):
+def get_path_from_url(url, ensure_file=True):
"""Get the path from the URL.
:param url: the URL
- :return: the filepath
+ :param ensure_file: If True, ensure that scheme is "file"
+ :return: the path component from URL
:raises ValueError
"""
comps = urlparse(url, scheme="file")
- if comps.scheme != "file":
+ if ensure_file and comps.scheme != "file":
raise ValueError(
"The absolute path is expected, actual have: {0}.".format(url)
)
+ if os.sep != "/":
+ return comps.path.replace("/", os.sep)
return comps.path
diff --git a/packetary/objects/__init__.py b/packetary/objects/__init__.py
index 234dc3e..96bf21f 100644
--- a/packetary/objects/__init__.py
+++ b/packetary/objects/__init__.py
@@ -20,6 +20,7 @@ from packetary.objects.package import FileChecksum
from packetary.objects.package import Package
from packetary.objects.package_relation import PackageRelation
from packetary.objects.package_relation import VersionRange
+from packetary.objects.package_version import PackageVersion
from packetary.objects.packages_tree import PackagesTree
from packetary.objects.repository import Repository
@@ -30,6 +31,7 @@ __all__ = [
"Package",
"PackageRelation",
"PackagesTree",
+ "PackageVersion",
"Repository",
"VersionRange",
]
diff --git a/packetary/objects/package_relation.py b/packetary/objects/package_relation.py
index fe7ba86..31cdfd0 100644
--- a/packetary/objects/package_relation.py
+++ b/packetary/objects/package_relation.py
@@ -24,6 +24,9 @@ class VersionRange(object):
the compare operation can be one of:
equal, greater, less, greater or equal, less or equal.
"""
+
+ __slots__ = ["op", "edge"]
+
def __init__(self, op=None, edge=None):
"""Initialises.
@@ -96,6 +99,8 @@ class PackageRelation(object):
and range of versions that satisfies requirement.
"""
+ __slots__ = ["name", "version", "alternative"]
+
def __init__(self, name, version=None, alternative=None):
"""Initialises.
diff --git a/packetary/objects/package_version.py b/packetary/objects/package_version.py
new file mode 100644
index 0000000..03057f6
--- /dev/null
+++ b/packetary/objects/package_version.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2015 Mirantis, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from packetary.objects.base import ComparableObject
+
+
+class PackageVersion(ComparableObject):
+ """The Package version."""
+
+ __slots__ = ["epoch", "version", "release"]
+
+ def __init__(self, epoch, version, release):
+ self.epoch = int(epoch)
+ self.version = tuple(version)
+ self.release = tuple(release)
+
+ @classmethod
+ def from_string(cls, text):
+ """Constructs from string.
+
+ :param text: the version in format '[{epoch-}]-{version}-{release}'
+ """
+ components = text.split("-")
+ if len(components) > 2:
+ epoch = components[0]
+ components = components[1:]
+ else:
+ epoch = 0
+ return cls(epoch, components[0].split("."), components[1].split("."))
+
+ def cmp(self, other):
+ if not isinstance(other, PackageVersion):
+ other = PackageVersion.from_string(str(other))
+
+ if not isinstance(other, PackageVersion):
+ raise TypeError
+ if self.epoch < other.epoch:
+ return -1
+ if self.epoch > other.epoch:
+ return 1
+
+ res = self._cmp_version_part(self.version, other.version)
+ if res != 0:
+ return res
+ return self._cmp_version_part(self.release, other.release)
+
+ def __eq__(self, other):
+ if other is self:
+ return True
+ return self.cmp(other) == 0
+
+ def __str__(self):
+ return "{0}-{1}-{2}".format(
+ self.epoch,
+ ".".join(str(x) for x in self.version),
+ ".".join(str(x) for x in self.release)
+ )
+
+ @classmethod
+ def _order(cls, x):
+ """Return an integer value for character x"""
+ if x.isdigit():
+ return int(x) + 1
+ if x.isalpha():
+ return ord(x)
+ return ord(x) + 256
+
+ @classmethod
+ def _cmp_version_string(cls, version1, version2):
+ """Compares two versions as string."""
+ la = [cls._order(x) for x in version1]
+ lb = [cls._order(x) for x in version2]
+ while la or lb:
+ a = 0
+ b = 0
+ if la:
+ a = la.pop(0)
+ if lb:
+ b = lb.pop(0)
+ if a < b:
+ return -1
+ elif a > b:
+ return 1
+ return 0
+
+ @classmethod
+ def _cmp_version_part(cls, version1, version2):
+ """Compares two versions."""
+ ver1_it = iter(version1)
+ ver2_it = iter(version2)
+ while True:
+ v1 = next(ver1_it, None)
+ v2 = next(ver2_it, None)
+
+ if v1 is None or v2 is None:
+ if v1 is not None:
+ return 1
+ if v2 is not None:
+ return -1
+ return 0
+
+ if v1.isdigit() and v2.isdigit():
+ a = int(v1)
+ b = int(v2)
+ if a < b:
+ return -1
+ if a > b:
+ return 1
+ else:
+ r = cls._cmp_version_string(v1, v2)
+ if r != 0:
+ return r
diff --git a/packetary/objects/repository.py b/packetary/objects/repository.py
index 5ba604b..551e8ef 100644
--- a/packetary/objects/repository.py
+++ b/packetary/objects/repository.py
@@ -32,12 +32,12 @@ class Repository(object):
def __str__(self):
if isinstance(self.name, tuple):
return ".".join(self.name)
- return str(self.name)
+ return self.name or self.url
def __unicode__(self):
if isinstance(self.name, tuple):
return u".".join(self.name)
- return unicode(self.name, "utf8")
+ return self.name or self.url
def __copy__(self):
"""Creates shallow copy of package."""
diff --git a/packetary/tests/data/groups.xml b/packetary/tests/data/groups.xml
new file mode 100644
index 0000000..4fadf5e
--- /dev/null
+++ b/packetary/tests/data/groups.xml
@@ -0,0 +1,17 @@
+
+
+
+
+ core
+
+ test1
+ test2
+
+
+
+ other
+
+ test1
+
+
+
diff --git a/packetary/tests/data/primary.xml b/packetary/tests/data/primary.xml
new file mode 100644
index 0000000..3be4161
--- /dev/null
+++ b/packetary/tests/data/primary.xml
@@ -0,0 +1,62 @@
+
+
+ test1
+ x86_64
+
+ e8ed9e0612e813491ed5e7c10502a39e43ec665afd1321541dea211202707a65
+ Test package
+ Test package
+ Test
+ http://localhost/
+
+
+
+
+ GPLv2 with exceptions
+ CentOS
+ System Environment/Daemons
+ worker1.bsys.centos.org
+ 389-ds-base-1.3.3.1-13.el7.src.rpm
+
+
+
+
+
+
+
+
+
+
+
+
+
+ test2
+ x86_64
+
+ e8ed9e0612e813491ed5e7c10502a39e43ec665afd1321541dea211202707a65
+ Test package
+ Test package
+ Test
+ http://localhost/
+
+
+
+
+ GPLv2 with exceptions
+ CentOS
+ System Environment/Daemons
+ worker1.bsys.centos.org
+ 389-ds-base-1.3.3.1-13.el7.src.rpm
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/packetary/tests/data/repomd.xml b/packetary/tests/data/repomd.xml
new file mode 100644
index 0000000..990ff06
--- /dev/null
+++ b/packetary/tests/data/repomd.xml
@@ -0,0 +1,20 @@
+
+
+ 1427842153
+
+ 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c
+ 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c
+
+ 1427842225
+ 2528031
+ 23175717
+
+
+ 689a2ef671fe1c2245539e9c7b90e9dcd1236f4a0dd376512cfd509531a2b70d
+ bb7a4b6a6ccc8b4875b569359aedf67f9678cd56da7f372c134200265e276951
+
+ 1427842225
+ 2528031
+ 23175717
+
+
diff --git a/packetary/tests/data/repomd2.xml b/packetary/tests/data/repomd2.xml
new file mode 100644
index 0000000..198ef47
--- /dev/null
+++ b/packetary/tests/data/repomd2.xml
@@ -0,0 +1,20 @@
+
+
+ 1427842153
+
+ 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c
+ 32fa7089953ace14f4a6e722bd33c353fcb94d9678d8a062a3b028e54042319c
+
+ 1427842225
+ 2528031
+ 23175717
+
+
+ 689a2ef671fe1c2245539e9c7b90e9dcd1236f4a0dd376512cfd509531a2b70d
+ bb7a4b6a6ccc8b4875b569359aedf67f9678cd56da7f372c134200265e276951
+
+ 1427842225
+ 2528031
+ 23175717
+
+
diff --git a/packetary/tests/test_library_utils.py b/packetary/tests/test_library_utils.py
index cdfce1a..b6fd5b3 100644
--- a/packetary/tests/test_library_utils.py
+++ b/packetary/tests/test_library_utils.py
@@ -87,6 +87,11 @@ class TestLibraryUtils(base.TestCase):
with self.assertRaises(ValueError):
utils.get_path_from_url("http:///a/f.txt")
+ self.assertEqual(
+ "/f.txt",
+ utils.get_path_from_url("http://host/f.txt", False)
+ )
+
@mock.patch("packetary.library.utils.os")
def test_ensure_dir_exist(self, os):
os.makedirs.side_effect = [
diff --git a/packetary/tests/test_objects.py b/packetary/tests/test_objects.py
index fa6a45b..7eb90a1 100644
--- a/packetary/tests/test_objects.py
+++ b/packetary/tests/test_objects.py
@@ -18,8 +18,10 @@ import copy
import six
from packetary.objects import PackageRelation
+from packetary.objects import PackageVersion
from packetary.objects import VersionRange
+
from packetary.tests import base
from packetary.tests.stubs import generator
@@ -94,6 +96,20 @@ class TestRepositoryObject(base.TestCase):
self.assertEqual(clone.name, origin.name)
self.assertEqual(clone.architecture, origin.architecture)
+ def test_str(self):
+ self.assertEqual(
+ "a.b",
+ str(generator.gen_repository(name=("a", "b")))
+ )
+ self.assertEqual(
+ "/a/b/",
+ str(generator.gen_repository(name="", url="/a/b/"))
+ )
+ self.assertEqual(
+ "a",
+ str(generator.gen_repository(name="a", url="/a/b/"))
+ )
+
class TestRelationObject(TestObjectBase):
def test_equal(self):
@@ -181,3 +197,25 @@ class TestVersionRange(TestObjectBase):
def test_intersection_is_typesafe(self):
with self.assertRaises(TypeError):
VersionRange("eq", 1).has_intersection(("eq", 1))
+
+
+class TestPackageVersion(base.TestCase):
+ def test_get_from_string(self):
+ ver = PackageVersion.from_string("1.0-22")
+ self.assertEqual(0, ver.epoch)
+ self.assertEqual(('1', '0'), ver.version)
+ self.assertEqual(('22',), ver.release)
+
+ ver2 = PackageVersion.from_string("1-11.0-2")
+ self.assertEqual(1, ver2.epoch)
+ self.assertEqual(('11', '0'), ver2.version)
+ self.assertEqual(('2',), ver2.release)
+
+ def test_compare(self):
+ ver1 = PackageVersion.from_string("6.3-31.5")
+ ver2 = PackageVersion.from_string("13.9-16.12")
+ self.assertLess(ver1, ver2)
+ self.assertGreater(ver2, ver1)
+ self.assertEqual(ver1, ver1)
+ self.assertLess(ver1, "6.3-40")
+ self.assertGreater(ver1, "6.3-31.4a")
diff --git a/packetary/tests/test_rpm_driver.py b/packetary/tests/test_rpm_driver.py
new file mode 100644
index 0000000..a4caec4
--- /dev/null
+++ b/packetary/tests/test_rpm_driver.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2015 Mirantis, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import mock
+import os.path as path
+import sys
+
+import six
+
+from packetary.objects import FileChecksum
+from packetary.tests import base
+from packetary.tests.stubs.generator import gen_repository
+from packetary.tests.stubs.helpers import get_compressed
+
+
+REPOMD = path.join(path.dirname(__file__), "data", "repomd.xml")
+
+REPOMD2 = path.join(path.dirname(__file__), "data", "repomd2.xml")
+
+PRIMARY_DB = path.join(path.dirname(__file__), "data", "primary.xml")
+
+GROUPS_DB = path.join(path.dirname(__file__), "data", "groups.xml")
+
+
+class TestRpmDriver(base.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.createrepo = sys.modules["createrepo"] = mock.MagicMock()
+ # import driver class after patching sys.modules
+ from packetary.drivers import rpm_driver
+
+ super(TestRpmDriver, cls).setUpClass()
+ cls.driver = rpm_driver.RpmRepositoryDriver()
+
+ def setUp(self):
+ self.createrepo.reset_mock()
+ self.connection = mock.MagicMock()
+
+ def test_parse_urls(self):
+ self.assertItemsEqual(
+ [
+ "http://host/centos/os",
+ "http://host/centos/updates"
+ ],
+ self.driver.parse_urls([
+ "http://host/centos/os",
+ "http://host/centos/updates/",
+ ])
+ )
+
+ def test_get_repository(self):
+ repos = []
+
+ self.driver.get_repository(
+ self.connection, "http://host/centos/os", "x86_64", repos.append
+ )
+
+ self.assertEqual(1, len(repos))
+ repo = repos[0]
+ self.assertEqual("os", repo.name)
+ self.assertEqual("", repo.origin)
+ self.assertEqual("x86_64", repo.architecture)
+ self.assertEqual("http://host/centos/os/x86_64/", repo.url)
+
+ def test_get_packages(self):
+ streams = []
+ for conv, fname in zip(
+ (lambda x: six.BytesIO(x.read()),
+ get_compressed, get_compressed),
+ (REPOMD, GROUPS_DB, PRIMARY_DB)
+ ):
+ with open(fname, "rb") as s:
+ streams.append(conv(s))
+
+ packages = []
+ self.connection.open_stream.side_effect = streams
+ self.driver.get_packages(
+ self.connection,
+ gen_repository("test", url="http://host/centos/os/x86_64/"),
+ packages.append
+ )
+ self.connection.open_stream.assert_any_call(
+ "http://host/centos/os/x86_64/repodata/repomd.xml"
+ )
+ self.connection.open_stream.assert_any_call(
+ "http://host/centos/os/x86_64/repodata/groups.xml.gz"
+ )
+ self.connection.open_stream.assert_any_call(
+ "http://host/centos/os/x86_64/repodata/primary.xml.gz"
+ )
+ self.assertEqual(2, len(packages))
+ package = packages[0]
+ self.assertEqual("test1", package.name)
+ self.assertEqual("1.1.1.1-1.el7", package.version)
+ self.assertEqual(100, package.filesize)
+ self.assertEqual(
+ FileChecksum(
+ None,
+ None,
+ 'e8ed9e0612e813491ed5e7c10502a39e'
+ '43ec665afd1321541dea211202707a65'),
+ package.checksum
+ )
+ self.assertEqual(
+ "Packages/test1.rpm", package.filename
+ )
+ self.assertItemsEqual(
+ ['test2 (eq 0-1.1.1.1-1.el7)'],
+ (str(x) for x in package.requires)
+ )
+ self.assertItemsEqual(
+ ["file (any)"],
+ (str(x) for x in package.provides)
+ )
+ self.assertItemsEqual(
+ ["test-old (any)"],
+ (str(x) for x in package.obsoletes)
+ )
+ self.assertTrue(package.mandatory)
+ self.assertFalse(packages[1].mandatory)
+
+ def test_get_packages_if_group_not_gzipped(self):
+ streams = []
+ for conv, fname in zip(
+ (lambda x: six.BytesIO(x.read()),
+ lambda x: six.BytesIO(x.read()),
+ get_compressed),
+ (REPOMD2, GROUPS_DB, PRIMARY_DB)
+ ):
+ with open(fname, "rb") as s:
+ streams.append(conv(s))
+
+ packages = []
+ self.connection.open_stream.side_effect = streams
+ self.driver.get_packages(
+ self.connection,
+ gen_repository("test", url="http://host/centos/os/x86_64/"),
+ packages.append
+ )
+ self.connection.open_stream.assert_any_call(
+ "http://host/centos/os/x86_64/repodata/groups.xml"
+ )
+ self.assertEqual(2, len(packages))
+ package = packages[0]
+ self.assertTrue(package.mandatory)
+
+ @mock.patch("packetary.drivers.rpm_driver.shutil")
+ def test_rebuild_repository(self, shutil):
+ self.createrepo.MDError = ValueError
+ self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [
+ None, self.createrepo.MDError()
+ ]
+ repo = gen_repository("test", url="file:///repo/os/x86_64")
+ self.createrepo.MetaDataConfig().outputdir = "/repo/os/x86_64"
+ self.createrepo.MetaDataConfig().tempdir = "tmp"
+
+ self.driver.rebuild_repository(repo, set())
+
+ self.assertEqual(
+ "/repo/os/x86_64",
+ self.createrepo.MetaDataConfig().directory
+ )
+ self.assertTrue(self.createrepo.MetaDataConfig().update)
+ self.createrepo.MetaDataGenerator()\
+ .doPkgMetadata.assert_called_once_with()
+ self.createrepo.MetaDataGenerator()\
+ .doRepoMetadata.assert_called_once_with()
+ self.createrepo.MetaDataGenerator()\
+ .doFinalMove.assert_called_once_with()
+
+ with self.assertRaises(RuntimeError):
+ self.driver.rebuild_repository(repo, set())
+ shutil.rmtree.assert_called_once_with(
+ "/repo/os/x86_64/tmp", ignore_errors=True
+ )
+
+ @mock.patch("packetary.drivers.rpm_driver.utils")
+ def test_fork_repository(self, utils):
+ repo = gen_repository("os", url="http://localhost/os/x86_64")
+ clone = self.driver.fork_repository(
+ self.connection,
+ repo,
+ "/repo"
+ )
+
+ utils.ensure_dir_exist.assert_called_once_with("/repo/os/x86_64/")
+ self.assertEqual(repo.name, clone.name)
+ self.assertEqual(repo.architecture, clone.architecture)
+ self.assertEqual("/repo/os/x86_64/", clone.url)
+ self.createrepo.MetaDataGenerator()\
+ .doFinalMove.assert_called_once_with()
diff --git a/requirements.txt b/requirements.txt
index 3af0b9e..26592e7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,3 +10,4 @@ chardet>=2.3.0
stevedore>=1.1.0
six>=1.5.2
python-debian>=0.1.23
+lxml>=3.2
diff --git a/setup.cfg b/setup.cfg
index b030539..5aaa750 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,6 +31,7 @@ console_scripts =
packetary.drivers =
deb=packetary.drivers.deb_driver:DebRepositoryDriver
+ rpm=packetary.drivers.rpm_driver:RpmRepositoryDriver
[build_sphinx]
source-dir = doc/source