# -*- coding: utf-8 -*- # Copyright 2015 Mirantis, Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License along # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. import copy import multiprocessing import os import shutil import tempfile import createrepo import lxml.etree as etree import six from packetary.drivers.base import RepositoryDriverBase from packetary.library.checksum import composite as checksum_composite from packetary.library.streams import GzipDecompress from packetary.library import utils from packetary.objects import FileChecksum from packetary.objects import Package from packetary.objects import PackageRelation from packetary.objects import PackageVersion from packetary.objects import Repository from packetary.objects import VersionRange from packetary.schemas import RPM_REPO_SCHEMA urljoin = six.moves.urllib.parse.urljoin # TODO(configurable option for drivers) _CORE_GROUPS = ("core", "base") _MANDATORY_TYPES = ("mandatory", "default") # The namespaces are used in metadata xml of repository _NAMESPACES = { "main": "http://linux.duke.edu/metadata/common", "md": "http://linux.duke.edu/metadata/repo", "rpm": "http://linux.duke.edu/metadata/rpm" } _checksum_collector = checksum_composite('md5', 'sha1', 'sha256') _OPERATORS_MAPPING = { None: None, 'GT': '>', 'LT': '<', 'EQ': '=', 'GE': '>=', 'LE': '<=', } _DEFAULT_PRIORITY = 10 class CreaterepoCallBack(object): """Callback object for createrepo""" def __init__(self, logger): self.logger = logger def errorlog(self, msg): """Error log output.""" self.logger.error(msg) def log(self, msg): """Logs message.""" self.logger.info(msg) def progress(self, item, current, total): """"Progress bar.""" pass class RpmRepositoryDriver(RepositoryDriverBase): def get_repository_data_schema(self): return RPM_REPO_SCHEMA def priority_sort(self, repo_data): # DEB repository expects general values from 0 to 1000. 0 # to have lowest priority and 1000 -- the highest. Note that a # priority above 1000 will allow even downgrades no matter the version # of the prioritary package priority = repo_data.get('priority') if priority is None: priority = _DEFAULT_PRIORITY return priority def get_repository(self, connection, repository_data, arch, consumer): consumer(Repository( name=repository_data['name'], url=utils.normalize_repository_url(repository_data["uri"]), architecture=arch, path=repository_data.get('path'), origin="" )) def get_packages(self, connection, repository, consumer): baseurl = repository.url repomd = urljoin(baseurl, "repodata/repomd.xml") self.logger.debug("repomd: %s", repomd) repomd_tree = etree.parse(connection.open_stream(repomd)) mandatory = self._get_mandatory_packages( self._load_db( connection, baseurl, repomd_tree, "group_gz", "group" ) ) primary_db = self._load_db(connection, baseurl, repomd_tree, "primary") if primary_db is None: raise ValueError("Malformed repository: {0}".format(repository)) counter = 0 for tag in primary_db.iterfind("./main:package", _NAMESPACES): try: name = tag.find("./main:name", _NAMESPACES).text consumer(Package( repository=repository, name=tag.find("./main:name", _NAMESPACES).text, version=self._unparse_version_attrs( tag.find("./main:version", _NAMESPACES).attrib ), filesize=int( tag.find("./main:size", _NAMESPACES) .attrib.get("package", -1) ), filename=tag.find( "./main:location", _NAMESPACES ).attrib["href"], checksum=self._get_checksum(tag), mandatory=name in mandatory, requires=self._get_relations(tag, "requires"), obsoletes=self._get_relations(tag, "obsoletes"), provides=self._get_relations(tag, "provides") )) except (ValueError, KeyError) as e: self.logger.error( "Malformed tag %s - %s: %s", repository, etree.tostring(tag), six.text_type(e) ) raise counter += 1 self.logger.info("loaded: %d packages from %s.", counter, repository) def add_packages(self, connection, repository, packages): groupstree = self._load_groups(connection, repository) self._rebuild_repository(connection, repository, packages, groupstree) def fork_repository(self, connection, repository, destination, source=False, locale=False): # TODO(download gpk) # TODO(sources and locales) new_repo = copy.copy(repository) new_repo.url = utils.normalize_repository_url(destination) utils.ensure_dir_exist(destination) groupstree = self._load_groups(connection, repository) self._rebuild_repository(connection, new_repo, set(), groupstree) return new_repo def create_repository(self, connection, repository_data, arch): repository = Repository( name=repository_data['name'], url=utils.normalize_repository_url(repository_data["uri"]), architecture=arch, path=repository_data.get('path'), origin=repository_data.get('origin') ) utils.ensure_dir_exist(utils.get_path_from_url(repository.url)) self._rebuild_repository(connection, repository, None, None) return repository def load_package_from_file(self, repository, filepath): fullpath = utils.get_path_from_url(repository.url + filepath) _, size, checksum = next(iter(utils.get_size_and_checksum_for_files( [fullpath], _checksum_collector) )) pkg = createrepo.yumbased.YumLocalPackage(filename=fullpath) hdr = pkg.returnLocalHeader() return Package( repository=repository, name=hdr["name"], version=PackageVersion( hdr['epoch'], hdr['version'], hdr['release'] ), filesize=int(hdr['size']), filename=filepath, checksum=FileChecksum(*checksum), mandatory=False, requires=self._parse_package_relations(pkg.requires), obsoletes=self._parse_package_relations(pkg.obsoletes), provides=self._parse_package_relations(pkg.provides), ) def get_relative_path(self, repository, filename): return "packages/" + filename def _rebuild_repository(self, conn, repo, packages, groupstree=None): basepath = utils.get_path_from_url(repo.url) self.logger.info("rebuild repository in %s", basepath) md_config = createrepo.MetaDataConfig() mdfile_path = os.path.join( basepath, md_config.finaldir, md_config.repomdfile ) update = packages is not None and os.path.exists(mdfile_path) groupsfile = None if groupstree is None and update: # The createrepo lose the groups information on update # to prevent this set group info manually groupstree = self._load_groups(conn, repo) if groupstree is not None: groupsfile = os.path.join(tempfile.gettempdir(), "groups.xml") with open(groupsfile, "w") as fd: groupstree.write(fd) try: md_config.workers = multiprocessing.cpu_count() md_config.directory = str(basepath) md_config.groupfile = groupsfile md_config.update = update if not packages: # only generate meta-files, without packages info md_config.excludes = ["*"] mdgen = createrepo.MetaDataGenerator( config_obj=md_config, callback=CreaterepoCallBack(self.logger) ) mdgen.doPkgMetadata() mdgen.doRepoMetadata() mdgen.doFinalMove() except createrepo.MDError as e: err_msg = six.text_type(e) self.logger.exception( "failed to create yum repository in %s: %s", basepath, err_msg ) shutil.rmtree( os.path.join(md_config.outputdir, md_config.tempdir), ignore_errors=True ) raise RuntimeError( "Failed to create yum repository in {0}." .format(err_msg)) finally: if groupsfile is not None: os.unlink(groupsfile) def _load_groups(self, connection, repository): repomd = urljoin(repository.url, "repodata/repomd.xml") self.logger.debug("load repomd: %s", repomd) repomd_tree = etree.parse(connection.open_stream(repomd)) return self._load_db( connection, repository.url, repomd_tree, "group_gz", "group" ) def _load_db(self, connection, baseurl, repomd, *aliases): """Loads database. :param connection: the connection object :param baseurl: the base repository URL :param repomd: the parsed metadata of repository :param aliases: the aliases of database name :return: parsed database file or None if db does not exist """ for dbname in aliases: self.logger.debug("loading %s database...", dbname) node = repomd.find( "./md:data[@type='{0}']".format(dbname), _NAMESPACES ) if node is not None: break else: return url = urljoin( baseurl, node.find("./md:location", _NAMESPACES).attrib["href"] ) self.logger.debug("loading %s - %s...", dbname, url) stream = connection.open_stream(url) if url.endswith(".gz"): stream = GzipDecompress(stream) return etree.parse(stream) def _get_mandatory_packages(self, groups_db): """Get the set of mandatory package names. :param groups_db: the parsed groups database """ package_names = set() if groups_db is None: return package_names count = 0 for name in _CORE_GROUPS: result = groups_db.xpath("./group/id[text()='{0}']".format(name)) if len(result) == 0: self.logger.warning("the group '%s' is not found.", name) continue group = result[0].getparent() for t in _MANDATORY_TYPES: xpath = "./packagelist/packagereq[@type='{0}']".format(t) for tag in group.iterfind(xpath): package_names.add(tag.text) count += 1 self.logger.info("detected %d mandatory packages.", count) return package_names def _get_relations(self, pkg_tag, name): """Gets package relations by name from package tag. :param pkg_tag: the xml-tag with package description :param name: the relations name :return: list of PackageRelation objects """ relations = list() append = relations.append tags_iter = pkg_tag.iterfind( "./main:format/rpm:%s/rpm:entry" % name, _NAMESPACES ) for elem in tags_iter: append(PackageRelation.from_args( self._unparse_relation_attrs(elem.attrib) )) return relations @staticmethod def _parse_package_relations(relations): """Parses yum package relations. :param relations: list of tuples (name, flags, (epoch, version, release)) :return: list of PackageRelation objects """ return [ PackageRelation( x[0], VersionRange( _OPERATORS_MAPPING[x[1]], x[1] and PackageVersion(*x[2]) ) ) for x in relations ] def _get_checksum(self, pkg_tag): """Gets checksum from package tag.""" checksum = dict.fromkeys(("md5", "sha1", "sha256"), None) checksum_tag = pkg_tag.find("./main:checksum", _NAMESPACES) checksum[checksum_tag.attrib["type"]] = checksum_tag.text return FileChecksum(**checksum) def _unparse_relation_attrs(self, attrs): """Gets the package relation from attributes. :param attrs: the relation tag attributes :return tuple(name, version_op, version_edge) """ if "flags" not in attrs: return attrs['name'], None return ( attrs['name'], _OPERATORS_MAPPING[attrs["flags"]], self._unparse_version_attrs(attrs) ) @staticmethod def _unparse_version_attrs(attrs): """Gets the package version from attributes. :param attrs: the relation tag attributes :return: the PackageVersion object """ return PackageVersion( attrs.get("epoch", 0), attrs.get("ver", "0.0"), attrs.get("rel", "0") )