From 1ce69b4feff844959ce2f881227af19f74a32b00 Mon Sep 17 00:00:00 2001 From: Bulat Gaifullin Date: Fri, 22 Jan 2016 17:22:44 +0300 Subject: [PATCH] Unify input data format The input data is described in YAML or JSON format Also implemented support of priorities for repositories Change-Id: I02f11714ba8880dd06c3ceeadf230c1d812ff0be Implements: blueprint unify-input-data --- packetary/__init__.py | 10 +- packetary/api.py | 174 +++++----- packetary/cli/commands/base.py | 48 ++- packetary/cli/commands/clone.py | 47 +-- packetary/cli/commands/packages.py | 50 +-- packetary/cli/commands/unresolved.py | 25 +- packetary/cli/commands/utils.py | 33 +- packetary/controllers/repository.py | 127 +++----- packetary/drivers/base.py | 25 +- packetary/drivers/deb_driver.py | 99 +++--- packetary/drivers/rpm_driver.py | 42 ++- packetary/library/utils.py | 27 +- packetary/objects/__init__.py | 2 + packetary/objects/index.py | 100 ++---- packetary/objects/package.py | 4 +- packetary/objects/package_relation.py | 57 ++-- packetary/objects/packages_forest.py | 90 ++++++ packetary/objects/packages_tree.py | 161 +++++----- packetary/objects/repository.py | 32 +- packetary/tests/stubs/generator.py | 6 +- packetary/tests/test_cli_commands.py | 135 ++++---- packetary/tests/test_command_utils.py | 28 +- packetary/tests/test_deb_driver.py | 143 ++++----- packetary/tests/test_index.py | 154 +++------ packetary/tests/test_library_utils.py | 23 ++ packetary/tests/test_objects.py | 84 +++-- packetary/tests/test_packages_forest.py | 104 ++++++ packetary/tests/test_packages_tree.py | 162 ++++------ packetary/tests/test_repository_api.py | 314 +++++++++---------- packetary/tests/test_repository_contoller.py | 96 +++--- packetary/tests/test_rpm_driver.py | 50 +-- tox.ini | 2 +- 32 files changed, 1223 insertions(+), 1231 deletions(-) create mode 100644 packetary/objects/packages_forest.py create mode 100644 packetary/tests/test_packages_forest.py diff --git a/packetary/__init__.py b/packetary/__init__.py index 9ee18e2..bb8ecae 100644 --- a/packetary/__init__.py +++ b/packetary/__init__.py @@ -29,5 +29,11 @@ __all__ = [ "RepositoryApi", ] -__version__ = pbr.version.VersionInfo( - 'packetary').version_string() +try: + __version__ = pbr.version.VersionInfo( + 'packetary').version_string() +except Exception as e: + # when run tests without installing package + # pbr may raise exception. + print("ERROR:", e) + __version__ = "0.0.0" diff --git a/packetary/api.py b/packetary/api.py index d6d65fc..ea21448 100644 --- a/packetary/api.py +++ b/packetary/api.py @@ -16,6 +16,7 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +from collections import defaultdict import logging import six @@ -23,8 +24,8 @@ import six from packetary.controllers import RepositoryController from packetary.library.connections import ConnectionsManager from packetary.library.executor import AsynchronousSection -from packetary.objects import Index from packetary.objects import PackageRelation +from packetary.objects import PackagesForest from packetary.objects import PackagesTree from packetary.objects.statistics import CopyStatistics @@ -111,120 +112,111 @@ class RepositoryApi(object): context = config if isinstance(config, Context) else Context(config) return cls(RepositoryController.load(context, repotype, repoarch)) - def get_packages(self, origin, debs=None, requirements=None): + def get_packages(self, repos_data, requirements_data=None, + include_mandatory=False): """Gets the list of packages from repository(es). - :param origin: The list of repository`s URLs - :param debs: the list of repository`s URL to calculate list of - dependencies, that will be used to filter packages. - :param requirements: the list of package relations, - to resolve the list of mandatory packages. + :param repos_data: The list of repository descriptions + :param requirements_data: The list of package`s requirements + that should be included + :param include_mandatory: if True, all mandatory packages will be :return: the set of packages """ - repositories = self._get_repositories(origin) - return self._get_packages(repositories, debs, requirements) + repos = self._load_repositories(repos_data) + requirements = self._load_requirements(requirements_data) + return self._get_packages(repos, requirements, include_mandatory) - def clone_repositories(self, origin, destination, debs=None, - requirements=None, keep_existing=True, - include_source=False, include_locale=False): + def clone_repositories(self, repos_data, requirements_data, destination, + include_source=False, include_locale=False, + include_mandatory=False): """Creates the clones of specified repositories in local folder. - :param origin: The list of repository`s URLs + :param repos_data: The list of repository descriptions + :param requirements_data: The list of package`s requirements + that should be included :param destination: the destination folder path - :param debs: the list of repository`s URL to calculate list of - dependencies, that will be used to filter packages. - :param requirements: the list of package relations, - to resolve the list of mandatory packages. - :param keep_existing: If False - local packages that does not exist - in original repo will be removed. :param include_source: if True, the source packages will be copied as well. - :param include_locale: if True, the locales - will be copied as well. + :param include_locale: if True, the locales will be copied as well. + :param include_mandatory: if True, all mandatory packages will be + included :return: count of copied and total packages. """ - repositories = self._get_repositories(origin) - packages = self._get_packages(repositories, debs, requirements) - mirrors = self.controller.clone_repositories( - repositories, destination, include_source, include_locale - ) - package_groups = dict((x, set()) for x in repositories) - for pkg in packages: + repos = self._load_repositories(repos_data) + reqs = self._load_requirements(requirements_data) + all_packages = self._get_packages(repos, reqs, include_mandatory) + package_groups = defaultdict(set) + for pkg in all_packages: package_groups[pkg.repository].add(pkg) stat = CopyStatistics() + mirrors = defaultdict(set) + # group packages by mirror for repo, packages in six.iteritems(package_groups): - mirror = mirrors[repo] - logger.info("copy packages from - %s", repo) - self.controller.copy_packages( - mirror, packages, keep_existing, stat.on_package_copied + mirror = self.controller.fork_repository( + repo, destination, include_source, include_locale + ) + mirrors[mirror].update(packages) + + # add new packages to mirrors + for mirror, packages in six.iteritems(mirrors): + self.controller.assign_packages( + mirror, packages, stat.on_package_copied ) return stat - def get_unresolved_dependencies(self, origin, main=None): + def get_unresolved_dependencies(self, repos_data): """Gets list of unresolved dependencies for repository(es). - :param origin: The list of repository`s URLs - :param main: The main repository(es) URL + :param repos_data: The list of repository descriptions :return: list of unresolved dependencies """ packages = PackagesTree() - self.controller.load_packages( - self._get_repositories(origin), - packages.add - ) + self._load_packages(self._load_repositories(repos_data), packages.add) + return packages.get_unresolved_dependencies() - if main is not None: - base = Index() - self.controller.load_packages( - self._get_repositories(main), - base.add - ) - else: - base = None - - return packages.get_unresolved_dependencies(base) - - def _get_repositories(self, urls): - """Gets the set of repositories by url.""" - repositories = set() - self.controller.load_repositories(urls, repositories.add) - return repositories - - def _get_packages(self, repositories, master, requirements): - """Gets the list of packages according to master and requirements.""" - if master is None and requirements is None: - packages = set() - self.controller.load_packages(repositories, packages.add) - return packages - - packages = PackagesTree() - self.controller.load_packages(repositories, packages.add) - if master is not None: - main_index = Index() - self.controller.load_packages( - self._get_repositories(master), - main_index.add - ) - else: - main_index = None - - return packages.get_minimal_subset( - main_index, - self._parse_requirements(requirements) - ) - - @staticmethod - def _parse_requirements(requirements): - """Gets the list of relations from requirements. - - :param requirements: the list of requirement in next format: - 'name [cmp version]|[alt [cmp version]]' - """ + def _get_packages(self, repos, requirements, include_mandatory): if requirements is not None: - return set( - PackageRelation.from_args( - *(x.split() for x in r.split("|"))) for r in requirements - ) - return set() + forest = PackagesForest() + for repo in repos: + self.controller.load_packages(repo, forest.add_tree().add) + return forest.get_packages(requirements, include_mandatory) + + packages = set() + self._load_packages(repos, packages.add) + return packages + + def _load_packages(self, repos, consumer): + for repo in repos: + self.controller.load_packages(repo, consumer) + + def _load_repositories(self, repos_data): + self._validate_repos_data(repos_data) + return self.controller.load_repositories(repos_data) + + def _load_requirements(self, requirements_data): + if requirements_data is None: + return + + self._validate_requirements_data(requirements_data) + result = [] + for r in requirements_data: + self._validate_requirements_data(r) + versions = r.get('versions', None) + if versions is None: + result.append(PackageRelation.from_args((r['name'],))) + else: + for version in versions: + result.append(PackageRelation.from_args( + ([r['name']] + version.split(None, 1)) + )) + return result + + def _validate_repos_data(self, repos_data): + # TODO(bgaifullin) implement me + pass + + def _validate_requirements_data(self, requirements_data): + # TODO(bgaifullin) implement me + pass diff --git a/packetary/cli/commands/base.py b/packetary/cli/commands/base.py index bb672cf..a466623 100644 --- a/packetary/cli/commands/base.py +++ b/packetary/cli/commands/base.py @@ -22,7 +22,7 @@ from cliff import command import six from packetary.cli.commands.utils import make_display_attr_getter -from packetary.cli.commands.utils import read_lines_from_file +from packetary.cli.commands.utils import read_from_file from packetary import RepositoryApi @@ -56,21 +56,15 @@ class BaseRepoCommand(command.Command): default="x86_64", help='The target architecture.') - origin_gr = parser.add_mutually_exclusive_group(required=True) - origin_gr.add_argument( - '-o', '--origin-url', - nargs="+", - dest='origins', - type=six.text_type, - metavar='URL', - help='Space separated list of URLs of origin repositories.') - - origin_gr.add_argument( - '-O', '--origin-file', - type=read_lines_from_file, - dest='origins', + parser.add_argument( + '-r', '--repositories', + dest='repositories', + type=read_from_file, metavar='FILENAME', - help='The path to file with URLs of origin repositories.') + required=True, + help="The path to file with list of repositories." + "See documentation about format." + ) return parser @@ -98,6 +92,30 @@ class BaseRepoCommand(command.Command): """ +class PackagesMixin(object): + """Added arguments to declare list of packages.""" + + def get_parser(self, prog_name): + parser = super(PackagesMixin, self).get_parser(prog_name) + parser.add_argument( + "--skip-mandatory", + dest='include_mandatory', + action='store_false', + default=True, + help="Do not copy mandatory packages." + ) + + parser.add_argument( + "-p", "--packages", + dest='requirements', + type=read_from_file, + metavar='FILENAME', + help="The path to file with list of packages." + "See documentation about format." + ) + return parser + + class BaseProduceOutputCommand(BaseRepoCommand): columns = None diff --git a/packetary/cli/commands/clone.py b/packetary/cli/commands/clone.py index fb9c3dc..1e0d22e 100644 --- a/packetary/cli/commands/clone.py +++ b/packetary/cli/commands/clone.py @@ -17,10 +17,10 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. from packetary.cli.commands.base import BaseRepoCommand -from packetary.cli.commands.utils import read_lines_from_file +from packetary.cli.commands.base import PackagesMixin -class CloneCommand(BaseRepoCommand): +class CloneCommand(PackagesMixin, BaseRepoCommand): """Clones the specified repository to local folder.""" def get_parser(self, prog_name): @@ -53,52 +53,17 @@ class CloneCommand(BaseRepoCommand): help="Also copy localisation files." ) - bootstrap_group = parser.add_mutually_exclusive_group(required=False) - bootstrap_group.add_argument( - "-b", "--bootstrap", - nargs='+', - dest='bootstrap', - metavar='PACKAGE [OP VERSION]', - help="Space separated list of package relations, " - "to resolve the list of mandatory packages." - ) - bootstrap_group.add_argument( - "-B", "--bootstrap-file", - type=read_lines_from_file, - dest='bootstrap', - metavar='FILENAME', - help="Path to the file with list of package relations, " - "to resolve the list of mandatory packages." - ) - - requires_group = parser.add_mutually_exclusive_group(required=False) - requires_group.add_argument( - '-r', '--requires-url', - nargs="+", - dest='requires', - metavar='URL', - help="Space separated list of repository`s URL to calculate list " - "of dependencies, that will be used to filter packages") - - requires_group.add_argument( - '-R', '--requires-file', - type=read_lines_from_file, - dest='requires', - metavar='FILENAME', - help="The path to the file with list of repository`s URL " - "to calculate list of dependencies, " - "that will be used to filter packages") return parser def take_repo_action(self, api, parsed_args): stat = api.clone_repositories( - parsed_args.origins, + parsed_args.repositories, + parsed_args.requirements, parsed_args.destination, - parsed_args.requires, - parsed_args.bootstrap, parsed_args.keep_existing, parsed_args.sources, - parsed_args.locales + parsed_args.locales, + parsed_args.include_mandatory ) self.stdout.write( "Packages copied: {0.copied}/{0.total}.\n".format(stat) diff --git a/packetary/cli/commands/packages.py b/packetary/cli/commands/packages.py index 7f02bd7..ce4780d 100644 --- a/packetary/cli/commands/packages.py +++ b/packetary/cli/commands/packages.py @@ -17,10 +17,10 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. from packetary.cli.commands.base import BaseProduceOutputCommand -from packetary.cli.commands.utils import read_lines_from_file +from packetary.cli.commands.base import PackagesMixin -class ListOfPackages(BaseProduceOutputCommand): +class ListOfPackages(PackagesMixin, BaseProduceOutputCommand): """Gets the list of packages from repository(es).""" columns = ( @@ -35,51 +35,11 @@ class ListOfPackages(BaseProduceOutputCommand): "requires", ) - def get_parser(self, prog_name): - parser = super(ListOfPackages, self).get_parser(prog_name) - - bootstrap_group = parser.add_mutually_exclusive_group(required=False) - bootstrap_group.add_argument( - "-b", "--bootstrap", - nargs='+', - dest='bootstrap', - metavar='PACKAGE [OP VERSION]', - help="Space separated list of package relations, " - "to resolve the list of mandatory packages." - ) - bootstrap_group.add_argument( - "-B", "--bootstrap-file", - type=read_lines_from_file, - dest='bootstrap', - metavar='FILENAME', - help="Path to the file with list of package relations, " - "to resolve the list of mandatory packages." - ) - - requires_group = parser.add_mutually_exclusive_group(required=False) - requires_group.add_argument( - '-r', '--requires-url', - nargs="+", - dest='requires', - metavar='URL', - help="Space separated list of repository`s URL to calculate list " - "of dependencies, that will be used to filter packages") - - requires_group.add_argument( - '-R', '--requires-file', - type=read_lines_from_file, - dest='requires', - metavar='FILENAME', - help="The path to the file with list of repository`s URL " - "to calculate list of dependencies, " - "that will be used to filter packages") - return parser - def take_repo_action(self, api, parsed_args): return api.get_packages( - parsed_args.origins, - parsed_args.requires, - parsed_args.bootstrap, + parsed_args.repositories, + parsed_args.requirements, + parsed_args.include_mandatory ) diff --git a/packetary/cli/commands/unresolved.py b/packetary/cli/commands/unresolved.py index a4ce6f3..a4b3669 100644 --- a/packetary/cli/commands/unresolved.py +++ b/packetary/cli/commands/unresolved.py @@ -17,7 +17,6 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. from packetary.cli.commands.base import BaseProduceOutputCommand -from packetary.cli.commands.utils import read_lines_from_file class ListOfUnresolved(BaseProduceOutputCommand): @@ -29,31 +28,9 @@ class ListOfUnresolved(BaseProduceOutputCommand): "alternative", ) - def get_parser(self, prog_name): - parser = super(ListOfUnresolved, self).get_parser(prog_name) - main_group = parser.add_mutually_exclusive_group(required=False) - main_group.add_argument( - '-m', '--main-url', - nargs="+", - dest='main', - metavar='URL', - help='Space separated list of URLs of repository(es) ' - ' that are used to resolve dependencies.') - - main_group.add_argument( - '-M', '--main-file', - type=read_lines_from_file, - dest='main', - metavar='FILENAME', - help='The path to the file, that contains ' - 'list of URLs of repository(es) ' - ' that are used to resolve dependencies.') - return parser - def take_repo_action(self, api, parsed_args): return api.get_unresolved_dependencies( - parsed_args.origins, - parsed_args.main, + parsed_args.repositories ) diff --git a/packetary/cli/commands/utils.py b/packetary/cli/commands/utils.py index de89158..fc09aee 100644 --- a/packetary/cli/commands/utils.py +++ b/packetary/cli/commands/utils.py @@ -16,25 +16,44 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -import operator +import json +import os + +import yaml import six -def read_lines_from_file(filename): +_PARSERS = { + "": yaml.safe_load, + ".json": json.load, + ".yaml": yaml.safe_load, + ".yml": yaml.safe_load, +} + + +def read_from_file(filename): """Reads lines from file. Note: the line starts with '#' will be skipped. :param filename: the path of target file :return: the list of lines from file + :raise ValuerError: when file-ext is unknown. """ + if filename is None: + return + + file_ext = os.path.splitext(filename)[-1].lower() + try: + parser = _PARSERS[file_ext] + except KeyError: + raise ValueError("Unsupported file format: {0}.\n" + "Please use '.json' or '.yaml' file extension" + .format(file_ext)) + with open(filename, 'r') as f: - return [ - x - for x in six.moves.map(operator.methodcaller("strip"), f) - if x and not x.startswith("#") - ] + return parser(f) def get_object_attrs(obj, attrs): diff --git a/packetary/controllers/repository.py b/packetary/controllers/repository.py index da67e2a..3f33776 100644 --- a/packetary/controllers/repository.py +++ b/packetary/controllers/repository.py @@ -22,6 +22,7 @@ import os import six import stevedore +from packetary.library import utils logger = logging.getLogger(__package__) @@ -58,114 +59,86 @@ class RepositoryController(object): ) return cls(context, driver, repoarch) - def load_repositories(self, urls, consumer): + def load_repositories(self, repositories_data): """Loads the repository objects from url. - :param urls: the list of repository urls. - :param consumer: the callback to consume objects + :param repositories_data: the list of repository`s descriptions + :return: the list of repositories sorted according to priority """ - if isinstance(urls, six.string_types): - urls = [urls] connection = self.context.connection - for parsed_url in self.driver.parse_urls(urls): + repositories_data.sort(key=self.driver.priority_sort) + repos = [] + for repo_data in repositories_data: self.driver.get_repository( - connection, parsed_url, self.arch, consumer + connection, repo_data, self.arch, repos.append ) + return repos - def load_packages(self, repositories, consumer): + def load_packages(self, repository, consumer): """Loads packages from repository. - :param repositories: the repository object + :param repository: the repository object :param consumer: the callback to consume objects """ connection = self.context.connection - for r in repositories: - self.driver.get_packages(connection, r, consumer) + self.driver.get_packages(connection, repository, consumer) - def assign_packages(self, repository, packages, keep_existing=True): + def fork_repository(self, repository, destination, source, locale): + """Creates copy of repositories. + + :param repository: the origin repository + :param destination: the target folder + :param source: If True, the source packages will be copied too. + :param locale: If True, the localisation will be copied too. + :return: the mapping origin to cloned repository. + """ + new_path = os.path.join( + destination, + repository.path or utils.get_path_from_url(repository.url, False) + ) + return self.driver.fork_repository( + self.context.connection, repository, new_path, source, locale + ) + + def assign_packages(self, repository, packages, observer=None): """Assigns new packages to the repository. It replaces the current repository`s packages. :param repository: the target repository :param packages: the set of new packages - :param keep_existing: - if True, all existing packages will be kept as is. - if False, all existing packages, that are not included - to new packages will be removed. + :param observer: the package copying process observer """ - if not isinstance(packages, set): packages = set(packages) else: packages = packages.copy() - if keep_existing: - consume_exist = packages.add - else: - def consume_exist(package): - if package not in packages: - filepath = os.path.join( - package.repository.url, package.filename - ) - logger.info("remove package - %s.", filepath) - os.remove(filepath) - - self.driver.get_packages( - self.context.connection, repository, consume_exist + self._copy_packages(repository, packages, observer) + self.driver.add_packages( + self.context.connection, repository, packages ) - self.driver.rebuild_repository(repository, packages) - def copy_packages(self, repository, packages, keep_existing, observer): - """Copies packages to repository. - - :param repository: the target repository - :param packages: the set of packages - :param keep_existing: see assign_packages for more details - :param observer: the package copying process observer - """ + def _copy_packages(self, target, packages, observer): with self.context.async_section() as section: for package in packages: section.execute( - self._copy_package, repository, package, observer + self._copy_package, target, package, observer ) - self.assign_packages(repository, packages, keep_existing) - - def clone_repositories(self, repositories, destination, - source=False, locale=False): - """Creates copy of repositories. - - :param repositories: the origin repositories - :param destination: the target folder - :param source: If True, the source packages will be copied too. - :param locale: If True, the localisation will be copied too. - :return: the mapping origin to cloned repository. - """ - mirros = dict() - destination = os.path.abspath(destination) - with self.context.async_section(0) as section: - for r in repositories: - section.execute( - self._fork_repository, - r, destination, source, locale, mirros - ) - return mirros - - def _fork_repository(self, r, destination, source, locale, mirrors): - """Creates clone of repository and stores it in mirrors.""" - new_repository = self.driver.fork_repository( - self.context.connection, r, destination, source, locale - ) - mirrors[r] = new_repository def _copy_package(self, target, package, observer): - """Synchronises remote file to local fs.""" - dst_path = os.path.join(target.url, package.filename) - src_path = urljoin(package.repository.url, package.filename) - bytes_copied = self.context.connection.retrieve( - src_path, dst_path, size=package.filesize - ) - if package.filesize < 0: - package.filesize = bytes_copied - observer(bytes_copied) + bytes_copied = 0 + if target.url != package.repository.url: + dst_path = os.path.join( + utils.get_path_from_url(target.url), package.filename + ) + src_path = urljoin(package.repository.url, package.filename) + bytes_copied = self.context.connection.retrieve( + src_path, dst_path, size=package.filesize + ) + if package.filesize < 0: + package.filesize = bytes_copied + + if observer: + observer(bytes_copied) diff --git a/packetary/drivers/base.py b/packetary/drivers/base.py index 3175acb..62fbb0d 100644 --- a/packetary/drivers/base.py +++ b/packetary/drivers/base.py @@ -35,18 +35,11 @@ class RepositoryDriverBase(object): self.logger = logging.getLogger(__package__) @abc.abstractmethod - def parse_urls(self, urls): - """Parses the repository url. - - :return: the sequence of parsed urls - """ - - @abc.abstractmethod - def get_repository(self, connection, url, arch, consumer): + def get_repository(self, connection, repository_data, arch, consumer): """Loads the repository meta information from URL. :param connection: the connection manager instance - :param url: the repository`s url + :param repository_data: the repository`s url :param arch: the repository`s architecture :param consumer: the callback to consume result """ @@ -74,9 +67,19 @@ class RepositoryDriverBase(object): """ @abc.abstractmethod - def rebuild_repository(self, repository, packages): - """Re-builds the repository. + def add_packages(self, connection, repository, packages): + """Adds new packages to the repository. + :param connection: the connection manager instance :param repository: the target repository :param packages: the set of packages """ + + @abc.abstractmethod + def priority_sort(self, repo_data): + """Key method to sort repositories data by priority. + + :param repo_data: the repository`s description + :return: the integer value that is relevant repository`s priority + less number means greater priority + """ diff --git a/packetary/drivers/deb_driver.py b/packetary/drivers/deb_driver.py index dd054cd..05272ca 100644 --- a/packetary/drivers/deb_driver.py +++ b/packetary/drivers/deb_driver.py @@ -39,11 +39,11 @@ from packetary.objects import Repository _OPERATORS_MAPPING = { - '>>': 'gt', - '<<': 'lt', - '=': 'eq', - '>=': 'ge', - '<=': 'le', + '>>': '>', + '<<': '<', + '=': '=', + '>=': '>=', + '<=': '<=', } _ARCHITECTURES = { @@ -77,46 +77,49 @@ _CHECKSUM_METHODS = ( "SHA256" ) +_DEFAULT_PRIORITY = 500 + _checksum_collector = checksum_composite('md5', 'sha1', 'sha256') class DebRepositoryDriver(RepositoryDriverBase): - def parse_urls(self, urls): - """Overrides method of superclass.""" - for url in urls: - try: - tokens = iter(x for x in url.split(" ") if x) - base, suite = next(tokens), next(tokens) - components = list(tokens) - except StopIteration: - raise ValueError("Invalid url: {0}".format(url)) + def priority_sort(self, repo_data): + # DEB repository expects general values from 0 to 1000. 0 + # to have lowest priority and 1000 -- the highest. Note that a + # priority above 1000 will allow even downgrades no matter the version + # of the prioritary package + priority = repo_data.get('priority') + if priority is None: + priority = _DEFAULT_PRIORITY + return -priority - base = base.rstrip("/") - if base.endswith("/dists"): - base = base[:-6] + def get_repository(self, connection, repository_data, arch, consumer): + url = utils.normalize_repository_url(repository_data['url']) + suite = repository_data['suite'] + components = repository_data.get('section') + path = repository_data.get('path') + name = repository_data.get('name') - # TODO(Flat Repository Format[1]) - # [1] https://wiki.debian.org/RepositoryFormat - for component in components: - yield (base, suite, component) + # TODO(bgaifullin) implement support for flat repisotory format [1] + # [1] https://wiki.debian.org/RepositoryFormat#Flat_Repository_Format + if components is None: + raise ValueError("The flat format does not supported.") - def get_repository(self, connection, url, arch, consumer): - """Overrides method of superclass.""" - - base, suite, component = url - release = self._get_url_of_metafile( - (base, suite, component, arch), "Release" - ) - deb_release = deb822.Release(connection.open_stream(release)) - consumer(Repository( - name=(deb_release["Archive"], deb_release["Component"]), - architecture=arch, - origin=deb_release["origin"], - url=base + "/" - )) + for component in components: + release = self._get_url_of_metafile( + (url, suite, component, arch), "Release" + ) + deb_release = deb822.Release(connection.open_stream(release)) + consumer(Repository( + name=name, + architecture=arch, + origin=deb_release["origin"], + url=url, + section=(suite, component), + path=path + )) def get_packages(self, connection, repository, consumer): - """Overrides method of superclass.""" index = self._get_url_of_metafile(repository, "Packages.gz") stream = GzipDecompress(connection.open_stream(index)) self.logger.info("loading packages from %s ...", repository) @@ -140,7 +143,8 @@ class DebRepositoryDriver(RepositoryDriverBase): requires=self._get_relations( dpkg, "depends", "pre-depends", "recommends" ), - obsoletes=self._get_relations(dpkg, "replaces"), + # The deb does not have obsoletes section + obsoletes=[], provides=self._get_relations(dpkg, "provides"), )) except KeyError as e: @@ -153,8 +157,7 @@ class DebRepositoryDriver(RepositoryDriverBase): self.logger.info("loaded: %d packages from %s.", counter, repository) - def rebuild_repository(self, repository, packages): - """Overrides method of superclass.""" + def add_packages(self, connection, repository, packages): basedir = utils.get_path_from_url(repository.url) index_file = utils.get_path_from_url( self._get_url_of_metafile(repository, "Packages") @@ -162,6 +165,8 @@ class DebRepositoryDriver(RepositoryDriverBase): utils.ensure_dir_exist(os.path.dirname(index_file)) index_gz = index_file + ".gz" count = 0 + # load existing packages + self.get_packages(connection, repository, packages.add) with open(index_file, "wb") as fd1: with closing(gzip.open(index_gz, "wb")) as fd2: writer = utils.composite_writer(fd1, fd2) @@ -185,7 +190,7 @@ class DebRepositoryDriver(RepositoryDriverBase): # TODO(download gpk) # TODO(sources and locales) new_repo = copy.copy(repository) - new_repo.url = utils.localize_repo_url(destination, repository.url) + new_repo.url = utils.normalize_repository_url(destination) packages_file = utils.get_path_from_url( self._get_url_of_metafile(new_repo, "Packages") ) @@ -200,8 +205,8 @@ class DebRepositoryDriver(RepositoryDriverBase): release = deb822.Release() release["Origin"] = repository.origin release["Label"] = repository.origin - release["Archive"] = repository.name[0] - release["Component"] = repository.name[1] + release["Archive"] = repository.section[0] + release["Component"] = repository.section[1] release["Architecture"] = _ARCHITECTURES[repository.architecture] with open(release_file, "wb") as fd: release.dump(fd) @@ -214,7 +219,7 @@ class DebRepositoryDriver(RepositoryDriverBase): """Updates the Release file in the suite.""" path = os.path.join( utils.get_path_from_url(repository.url), - "dists", repository.name[0] + "dists", repository.section[0] ) release_path = os.path.join(path, "Release") self.logger.info( @@ -304,7 +309,7 @@ class DebRepositoryDriver(RepositoryDriverBase): """ if isinstance(repo_or_comps, Repository): baseurl = repo_or_comps.url - suite, component = repo_or_comps.name + suite, component = repo_or_comps.section arch = repo_or_comps.architecture else: baseurl, suite, component, arch = repo_or_comps @@ -329,12 +334,12 @@ class DebRepositoryDriver(RepositoryDriverBase): ) release.setdefault("Origin", repository.origin) release.setdefault("Label", repository.origin) - release.setdefault("Suite", repository.name[0]) - release.setdefault("Codename", repository.name[0].split("-", 1)[0]) + release.setdefault("Suite", repository.section[0]) + release.setdefault("Codename", repository.section[0].split("-", 1)[0]) release.setdefault("Description", "The packages repository.") keys = ("Architectures", "Components") - values = (repository.architecture, repository.name[1]) + values = (repository.architecture, repository.section[1]) for key, value in six.moves.zip(keys, values): if key in release: release[key] = utils.append_token_to_string( diff --git a/packetary/drivers/rpm_driver.py b/packetary/drivers/rpm_driver.py index 8453e6e..18eec2a 100644 --- a/packetary/drivers/rpm_driver.py +++ b/packetary/drivers/rpm_driver.py @@ -49,6 +49,17 @@ _NAMESPACES = { "rpm": "http://linux.duke.edu/metadata/rpm" } +_OPERATORS_MAPPING = { + 'GT': '>', + 'LT': '<', + 'EQ': '=', + 'GE': '>=', + 'LE': '<=', +} + + +_DEFAULT_PRIORITY = 10 + class CreaterepoCallBack(object): """Callback object for createrepo""" @@ -69,21 +80,25 @@ class CreaterepoCallBack(object): class RpmRepositoryDriver(RepositoryDriverBase): - def parse_urls(self, urls): - """Overrides method of superclass.""" - return (url.rstrip("/") for url in urls) + def priority_sort(self, repo_data): + # DEB repository expects general values from 0 to 1000. 0 + # to have lowest priority and 1000 -- the highest. Note that a + # priority above 1000 will allow even downgrades no matter the version + # of the prioritary package + priority = repo_data.get('priority') + if priority is None: + priority = _DEFAULT_PRIORITY + return priority - def get_repository(self, connection, url, arch, consumer): - name = utils.get_path_from_url(url, False) + def get_repository(self, connection, repository_data, arch, consumer): consumer(Repository( - name=name, - url=url + "/", + name=repository_data['name'], + url=utils.normalize_repository_url(repository_data["url"]), architecture=arch, origin="" )) def get_packages(self, connection, repository, consumer): - """Overrides method of superclass.""" baseurl = repository.url repomd = urljoin(baseurl, "repodata/repomd.xml") self.logger.debug("repomd: %s", repomd) @@ -130,8 +145,7 @@ class RpmRepositoryDriver(RepositoryDriverBase): counter += 1 self.logger.info("loaded: %d packages from %s.", counter, repository) - def rebuild_repository(self, repository, packages): - """Overrides method of superclass.""" + def add_packages(self, connection, repository, packages): basepath = utils.get_path_from_url(repository.url) self.logger.info("rebuild repository in %s", basepath) md_config = createrepo.MetaDataConfig() @@ -165,12 +179,12 @@ class RpmRepositoryDriver(RepositoryDriverBase): # TODO(download gpk) # TODO(sources and locales) new_repo = copy.copy(repository) - new_repo.url = utils.localize_repo_url(destination, repository.url) + new_repo.url = utils.normalize_repository_url(destination) self.logger.info( "clone repository %s to %s", repository, new_repo.url ) - utils.ensure_dir_exist(new_repo.url) - self.rebuild_repository(new_repo, set()) + utils.ensure_dir_exist(destination) + self.add_packages(connection, new_repo, set()) return new_repo def _load_db(self, connection, baseurl, repomd, *aliases): @@ -264,7 +278,7 @@ class RpmRepositoryDriver(RepositoryDriverBase): return ( attrs['name'], - attrs["flags"].lower(), + _OPERATORS_MAPPING[attrs["flags"]], self._unparse_version_attrs(attrs) ) diff --git a/packetary/library/utils.py b/packetary/library/utils.py index 350892f..284c3c8 100644 --- a/packetary/library/utils.py +++ b/packetary/library/utils.py @@ -79,7 +79,7 @@ def get_path_from_url(url, ensure_file=True): :param url: the URL :param ensure_file: If True, ensure that scheme is "file" :return: the path component from URL - :raises ValueError + :raise ValueError: if expected local path and schema of URL is not file """ comps = urlparse(url, scheme="file") @@ -92,14 +92,27 @@ def get_path_from_url(url, ensure_file=True): return comps.path -def localize_repo_url(localurl, repo_url): - """Gets local repository url. +def get_url_from_path(path): + """Get the URL from local path. - :param localurl: the base local URL - :param repo_url: the origin URL of repository - :return: localurl + get_path_from_url(repo_url) + :param path: the local path + :return: the URL """ - return localurl.rstrip("/") + urlparse(repo_url).path + path = os.path.abspath(path) + if os.sep != "/": + path = path.replace(os.sep, "/") + return "file://" + path + + +def normalize_repository_url(url): + """Convert URL of repository to normal form. + + :param url: the origin URL + :return: normalized URL + """ + if url and url[0] in ("/", "."): + url = get_url_from_path(url) + return url.rstrip("/") + "/" def ensure_dir_exist(path): diff --git a/packetary/objects/__init__.py b/packetary/objects/__init__.py index e1b4e5b..dd1e147 100644 --- a/packetary/objects/__init__.py +++ b/packetary/objects/__init__.py @@ -22,6 +22,7 @@ from packetary.objects.package import Package from packetary.objects.package_relation import PackageRelation from packetary.objects.package_relation import VersionRange from packetary.objects.package_version import PackageVersion +from packetary.objects.packages_forest import PackagesForest from packetary.objects.packages_tree import PackagesTree from packetary.objects.repository import Repository @@ -31,6 +32,7 @@ __all__ = [ "Index", "Package", "PackageRelation", + "PackagesForest", "PackagesTree", "PackageVersion", "Repository", diff --git a/packetary/objects/index.py b/packetary/objects/index.py index fc43fcd..aa2e7ae 100644 --- a/packetary/objects/index.py +++ b/packetary/objects/index.py @@ -66,16 +66,22 @@ def _lowerbound_end(versions, version, condition): return result -def _equal(tree, version): - """Gets the package with specified version.""" - if version in tree: - return [tree[version]] - return [] +def _equal(versions, version): + """Gets the package with specified version. + + :param versions: the tree of versions. + :param version: the required version + """ + value = versions.get(version, None) + return [] if value is None else [value] -def _any(tree, _): - """Gets the package with max version.""" - return list(tree.values()) +def _any(versions, _): + """Gets the package with max version. + + :param versions: the tree of versions. + """ + return list(versions.values()) class Index(object): @@ -91,17 +97,15 @@ class Index(object): operators = { None: _any, - "lt": _make_operator(_start_upperbound, operator.lt), - "le": _make_operator(_start_upperbound, operator.le), - "gt": _make_operator(_lowerbound_end, operator.gt), - "ge": _make_operator(_lowerbound_end, operator.ge), - "eq": _equal, + "<": _make_operator(_start_upperbound, operator.lt), + "<=": _make_operator(_start_upperbound, operator.le), + ">": _make_operator(_lowerbound_end, operator.gt), + ">=": _make_operator(_lowerbound_end, operator.ge), + "=": _equal, } def __init__(self): self.packages = defaultdict(FastRBTree) - self.obsoletes = defaultdict(FastRBTree) - self.provides = defaultdict(FastRBTree) def __iter__(self): """Iterates over all packages including versions.""" @@ -115,6 +119,10 @@ class Index(object): 0 ) + def __contains__(self, name): + """Checks that index contains any package with such name.""" + return name in self.packages + def get_all(self): """Gets sequence from all of packages including versions.""" @@ -122,42 +130,15 @@ class Index(object): for version in versions.values(): yield version - def find(self, name, version): - """Finds the package by name and range of versions. - - :param name: the package`s name. - :param version: the range of versions. - :return: the package if it is found, otherwise None - """ - candidates = self.find_all(name, version) - if len(candidates) > 0: - return candidates[-1] - return None - - def find_all(self, name, version): + def find_all(self, name, version_range): """Finds the packages by name and range of versions. :param name: the package`s name. - :param version: the range of versions. + :param version_range: the range of versions. :return: the list of suitable packages """ - if name in self.packages: - candidates = self._find_versions( - self.packages[name], version - ) - if len(candidates) > 0: - return candidates - - if name in self.obsoletes: - return self._resolve_relation( - self.obsoletes[name], version - ) - - if name in self.provides: - return self._resolve_relation( - self.provides[name], version - ) + return self._find_versions(self.packages[name], version_range) return [] def add(self, package): @@ -166,43 +147,24 @@ class Index(object): :param package: the package object. """ self.packages[package.name][package.version] = package - key = package.name, package.version - - for obsolete in package.obsoletes: - self.obsoletes[obsolete.name][key] = obsolete - - for provide in package.provides: - self.provides[provide.name][key] = provide - - def _resolve_relation(self, relations, version): - """Resolve relation according to relations index. - - :param relations: the index of relations - :param version: the range of versions - :return: package if found, otherwise None - """ - for key, candidate in relations.iter_items(reverse=True): - if candidate.version.has_intersection(version): - return [self.packages[key[0]][key[1]]] - return [] @staticmethod - def _find_versions(versions, version): + def _find_versions(versions, version_range): """Searches accurate version. Search for the highest version out of intersection of existing and required range of versions. :param versions: the existing versions - :param version: the required range of versions + :param version_range: the required range of versions :return: package if found, otherwise None """ try: - op = Index.operators[version.op] + op = Index.operators[version_range.op] except KeyError: raise ValueError( "Unsupported operation: {0}" - .format(version.op) + .format(version_range.op) ) - return op(versions, version.edge) + return op(versions, version_range.edge) diff --git a/packetary/objects/package.py b/packetary/objects/package.py index cb5a7d7..145b5e5 100644 --- a/packetary/objects/package.py +++ b/packetary/objects/package.py @@ -59,10 +59,10 @@ class Package(ComparableObject): return Package(**self.__dict__) def __str__(self): - return "{0} {1}".format(self.name, self.version) + return "{0} ({1})".format(self.name, self.version) def __unicode__(self): - return u"{0} {1}".format(self.name, self.version) + return u"{0} ({1})".format(self.name, self.version) def __hash__(self): return hash((self.name, self.version)) diff --git a/packetary/objects/package_relation.py b/packetary/objects/package_relation.py index 65f3b73..0900f94 100644 --- a/packetary/objects/package_relation.py +++ b/packetary/objects/package_relation.py @@ -19,6 +19,16 @@ import operator +_OPERATORS = { + None: lambda x: True, + '=': operator.eq, + '>': operator.gt, + '<': operator.lt, + '>=': operator.ge, + '<=': operator.le, +} + + class VersionRange(object): """Describes the range of versions. @@ -27,17 +37,24 @@ class VersionRange(object): equal, greater, less, greater or equal, less or equal. """ - __slots__ = ["op", "edge"] + __slots__ = ("op", "edge") def __init__(self, op=None, edge=None): """Initialises. :param op: the name of operator to compare. :param edge: the edge of versions. + :raise ValueError: if comparison operator is invalid """ + if op not in _OPERATORS: + raise ValueError("Invalid comparison operator: '{0}'".format(op)) + self.op = op self.edge = edge + def __contains__(self, point): + return _OPERATORS[self.op](point, self.edge) + def __hash__(self): return hash((self.op, self.edge)) @@ -59,7 +76,11 @@ class VersionRange(object): return u"any" def has_intersection(self, other): - """Checks that 2 ranges has intersection.""" + """Checks that 2 ranges has intersection. + + :param other: the candidate to check + :return: True if intersection exists, otherwise False + """ if not isinstance(other, VersionRange): raise TypeError( @@ -70,28 +91,16 @@ class VersionRange(object): if self.op is None or other.op is None: return True - my_op = getattr(operator, self.op) - other_op = getattr(operator, other.op) if self.op[0] == other.op[0]: - if self.op[0] == 'l': - if self.edge < other.edge: - return my_op(self.edge, other.edge) - return other_op(other.edge, self.edge) - elif self.op[0] == 'g': - if self.edge > other.edge: - return my_op(self.edge, other.edge) - return other_op(other.edge, self.edge) - - if self.op == 'eq': - return other_op(self.edge, other.edge) - - if other.op == 'eq': - return my_op(other.edge, self.edge) - - return ( - my_op(other.edge, self.edge) and - other_op(self.edge, other.edge) - ) + if self.op == '=': + return self.edge == other.edge + # the intersection is -inf or +inf + return True + if self.edge == other.edge: + # need to cover case < a and >= a + return self.edge in other and other.edge in self + # all other cases + return self.edge in other or other.edge in self class PackageRelation(object): @@ -101,7 +110,7 @@ class PackageRelation(object): and range of versions that satisfies requirement. """ - __slots__ = ["name", "version", "alternative"] + __slots__ = ("name", "version", "alternative") def __init__(self, name, version=None, alternative=None): """Initialises. diff --git a/packetary/objects/packages_forest.py b/packetary/objects/packages_forest.py new file mode 100644 index 0000000..c8051e0 --- /dev/null +++ b/packetary/objects/packages_forest.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +# Copyright 2016 Mirantis, Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +import logging + +from packetary.objects.packages_tree import PackagesTree + + +logger = logging.getLogger(__package__) + + +class PackagesForest(object): + """Helper class to deal with dependency graph.""" + + def __init__(self): + self.trees = [] + + def add_tree(self): + """Add new tree to end of forest. + + :return: The added tree + """ + tree = PackagesTree() + self.trees.append(tree) + return tree + + def get_packages(self, requirements, include_mandatory=False): + """Get the packages according requirements. + + :param requirements: the list of requirements + :param include_mandatory: if true, the mandatory packages will be + included to result + :return list of packages to copy + """ + + # TODO(bgaifullin): use versions intersection instead of union + # now the all versions that fit requirements are selected + # need to select only one version that fits all requirements + + resolved = set() + unresolved = set() + stack = [requirements] + + if include_mandatory: + for tree in self.trees: + for mandatory in tree.mandatory_packages: + resolved.add(mandatory) + stack.append(mandatory.requires) + + while stack: + requirements = stack.pop() + for required in requirements: + for rel in required: + if rel not in unresolved: + candidate = self.find(rel) + if candidate is not None: + if candidate not in resolved: + stack.append(candidate.requires) + resolved.add(candidate) + break + else: + unresolved.add(required) + logger.warning("Unresolved relation: %s", required) + return resolved + + def find(self, relation): + """Finds package in forest. + + :param relation: the package relation + :return: the packages from first tree if found otherwise empty list + """ + for tree in self.trees: + candidate = tree.find(relation.name, relation.version) + if candidate is not None: + return candidate diff --git a/packetary/objects/packages_tree.py b/packetary/objects/packages_tree.py index 94f7b89..3da14f8 100644 --- a/packetary/objects/packages_tree.py +++ b/packetary/objects/packages_tree.py @@ -16,119 +16,98 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -import warnings +from collections import defaultdict + +import six from packetary.objects.index import Index +from packetary.objects.package_relation import VersionRange -class UnresolvedWarning(UserWarning): - """Warning about unresolved depends.""" - pass - - -class PackagesTree(Index): +class PackagesTree(object): """Helper class to deal with dependency graph.""" def __init__(self): super(PackagesTree, self).__init__() self.mandatory_packages = [] + self.packages = Index() + self.provides = defaultdict(dict) + self.obsoletes = defaultdict(dict) def add(self, package): - super(PackagesTree, self).add(package) # store all mandatory packages in separated list for quick access if package.mandatory: self.mandatory_packages.append(package) - def get_unresolved_dependencies(self, base=None): + self.packages.add(package) + key = package.name, package.version + + for obsolete in package.obsoletes: + self.obsoletes[obsolete.name][key] = obsolete + + for provide in package.provides: + self.provides[provide.name][key] = provide + + def find(self, name, version_range): + """Finds the package by name and range of versions. + + :param name: the package`s name. + :param version_range: the range of versions. + :return: the package if it is found, otherwise None + """ + candidates = self.find_all(name, version_range) + if len(candidates) > 0: + return candidates[-1] + return None + + def find_all(self, name, version_range): + """Finds the packages by name and range of versions. + + :param name: the package`s name. + :param version_range: the range of versions. + :return: the list of suitable packages + """ + if name in self.packages: + candidates = self.packages.find_all(name, version_range) + if len(candidates) > 0: + return candidates + + if name in self.obsoletes: + return self._resolve_relation(self.obsoletes[name], version_range) + + if name in self.provides: + return self._resolve_relation(self.provides[name], version_range) + return [] + + def get_unresolved_dependencies(self): """Gets the set of unresolved dependencies. - :param base: the base index to resolve dependencies :return: the set of unresolved depends. """ - external = self.__get_unresolved_dependencies(self) - if base is None: - return external - unresolved = set() - for relation in external: - for rel in relation: - if base.find(rel.name, rel.version) is not None: - break - else: - unresolved.add(relation) - return unresolved - def get_minimal_subset(self, main, requirements): - """Gets the minimal work subset. - - :param main: the main index, to complete requirements. - :param requirements: additional requirements. - :return: The set of resolved depends. - """ - - unresolved = set() - resolved = set() - if main is None: - def pkg_filter(*_): - pass - else: - pkg_filter = main.find - self.__get_unresolved_dependencies(main, requirements) - - stack = list() - stack.append((None, requirements)) - - # add all mandatory packages - for pkg in self.mandatory_packages: - stack.append((pkg, pkg.requires)) - - while len(stack) > 0: - pkg, required = stack.pop() - resolved.add(pkg) - for require in required: - for rel in require: + for pkg in self.packages: + for required in pkg.requires: + for rel in required: if rel not in unresolved: - if pkg_filter(rel.name, rel.version) is not None: - break - # use all packages that meets depends - candidates = self.find_all(rel.name, rel.version) - found = False - for cand in candidates: - if cand == pkg: - continue - found = True - if cand not in resolved: - stack.append((cand, cand.requires)) - - if found: + if self.find(rel.name, rel.version) is not None: break else: - unresolved.add(require) - msg = "Unresolved depends: {0}".format(require) - warnings.warn(UnresolvedWarning(msg)) - - resolved.remove(None) - return resolved - - @staticmethod - def __get_unresolved_dependencies(index, unresolved=None): - """Gets the set of unresolved dependencies. - - :param index: the search index. - :param unresolved: the known list of unresolved packages. - :return: the set of unresolved depends. - """ - - if unresolved is None: - unresolved = set() - - for pkg in index: - for require in pkg.requires: - for rel in require: - if rel not in unresolved: - candidate = index.find(rel.name, rel.version) - if candidate is not None and candidate != pkg: - break - else: - unresolved.add(require) + unresolved.add(required) return unresolved + + def _resolve_relation(self, relations, version_range): + """Resolve relation according to relations index. + + :param relations: the index of relations + :param version_range: the range of versions + :return: package if found, otherwise None + """ + result = [] + for key, candidate in six.iteritems(relations): + if version_range.has_intersection(candidate.version): + result.extend( + self.packages.find_all(key[0], VersionRange('=', key[1])) + ) + result.sort(key=lambda x: x.version) + return result diff --git a/packetary/objects/repository.py b/packetary/objects/repository.py index f2302e7..4b3ba859 100644 --- a/packetary/objects/repository.py +++ b/packetary/objects/repository.py @@ -20,29 +20,37 @@ class Repository(object): """Structure to describe repository object.""" - def __init__(self, name, url, architecture, origin): + def __init__(self, name, url, architecture, origin=None, + path=None, section=None): """Initialises. :param name: the repository`s name, may be tuple of strings :param url: the repository`s URL :param architecture: the repository`s architecture - :param origin: the repository`s origin + :param origin: optional, the repository`s origin + :param path: the repository relative path, used for mirroring + :param section: the repository section """ - self.name = name - self.url = url self.architecture = architecture - self.origin = origin + self.name = name + self.origin = origin or "" + self.url = url + self.section = section + self.path = path def __str__(self): - if isinstance(self.name, tuple): - return ".".join(self.name) - return self.name or self.url + if not self.section: + return self.url - def __unicode__(self): - if isinstance(self.name, tuple): - return u".".join(self.name) - return self.name or self.url + if isinstance(self.section, tuple): + section_str = " ".join(self.section) + else: + section_str = self.section + return " ".join((self.url, section_str)) def __copy__(self): """Creates shallow copy of package.""" return Repository(**self.__dict__) + + def __hash__(self): + return hash((self.url, self.section)) diff --git a/packetary/tests/stubs/generator.py b/packetary/tests/stubs/generator.py index 5c196bd..28535fb 100644 --- a/packetary/tests/stubs/generator.py +++ b/packetary/tests/stubs/generator.py @@ -20,9 +20,9 @@ from packetary import objects def gen_repository(name="test", url="file:///test", - architecture="x86_64", origin="Test"): + architecture="x86_64", origin="Test", **kwargs): """Helper to create Repository object with default attributes.""" - return objects.Repository(name, url, architecture, origin) + return objects.Repository(name, url, architecture, origin, **kwargs) def gen_relation(name="test", version=None, alternative=None): @@ -46,7 +46,7 @@ def gen_package(idx=1, **kwargs): for relation in ("requires", "provides", "obsoletes"): if relation not in kwargs: kwargs[relation] = [gen_relation( - "{0}{1}".format(relation, idx), ["le", idx + 1] + "{0}{1}".format(relation, idx), ["<=", idx + 1] )] return objects.Package(**kwargs) diff --git a/packetary/tests/test_cli_commands.py b/packetary/tests/test_cli_commands.py index 69e30d5..d93dd0d 100644 --- a/packetary/tests/test_cli_commands.py +++ b/packetary/tests/test_cli_commands.py @@ -24,25 +24,19 @@ import subprocess # that was removed in 3.5 subprocess.mswindows = False +from packetary.api import RepositoryApi from packetary.cli.commands import clone from packetary.cli.commands import packages from packetary.cli.commands import unresolved +from packetary.objects.statistics import CopyStatistics from packetary.tests import base from packetary.tests.stubs.generator import gen_package from packetary.tests.stubs.generator import gen_relation -from packetary.tests.stubs.generator import gen_repository -from packetary.tests.stubs.helpers import CallbacksAdapter -@mock.patch.multiple( - "packetary.api", - RepositoryController=mock.DEFAULT, - ConnectionsManager=mock.DEFAULT, - AsynchronousSection=mock.MagicMock() -) -@mock.patch( - "packetary.cli.commands.base.BaseRepoCommand.stdout" -) +@mock.patch("packetary.cli.commands.base.BaseRepoCommand.stdout") +@mock.patch("packetary.cli.commands.base.read_from_file") +@mock.patch("packetary.cli.commands.base.RepositoryApi") class TestCliCommands(base.TestCase): common_argv = [ "--ignore-errors-num=3", @@ -53,23 +47,24 @@ class TestCliCommands(base.TestCase): ] clone_argv = [ - "-o", "http://localhost/origin", - "-d", ".", - "-r", "http://localhost/requires", - "-b", "test-package", + "-r", "repositories.yaml", + "-p", "packages.yaml", + "-d", "/root", "-t", "deb", "-a", "x86_64", "--clean", + "--skip-mandatory" ] packages_argv = [ - "-o", "http://localhost/origin", + "-r", "repositories.yaml", "-t", "deb", - "-a", "x86_64" + "-a", "x86_64", + "-c", "name", "filename" ] unresolved_argv = [ - "-o", "http://localhost/origin", + "-r", "repositories.yaml", "-t", "deb", "-a", "x86_64" ] @@ -77,76 +72,76 @@ class TestCliCommands(base.TestCase): def start_cmd(self, cmd, argv): cmd.debug(argv + self.common_argv) - def check_context(self, context, ConnectionsManager): - self.assertEqual(3, context._ignore_errors_num) - self.assertEqual(8, context._threads_num) - self.assertIs(context._connection, ConnectionsManager.return_value) - ConnectionsManager.assert_called_once_with( - proxy="http://proxy", - secure_proxy="https://proxy", - retries_num=10 - ) + def check_common_config(self, config): + self.assertEqual("http://proxy", config.http_proxy) + self.assertEqual("https://proxy", config.https_proxy) + self.assertEqual(3, config.ignore_errors_num) + self.assertEqual(8, config.threads_num) + self.assertEqual(10, config.retries_num) - def test_clone_cmd(self, stdout, RepositoryController, **kwargs): - ctrl = RepositoryController.load() - ctrl.copy_packages = CallbacksAdapter() - ctrl.load_repositories = CallbacksAdapter() - ctrl.load_packages = CallbacksAdapter() - ctrl.copy_packages.return_value = [1, 0] - repo = gen_repository() - ctrl.load_repositories.side_effect = [repo, gen_repository()] - ctrl.load_packages.side_effect = [ - gen_package(repository=repo), - gen_package() + def test_clone_cmd(self, api_mock, read_file_mock, stdout_mock): + read_file_mock.side_effect = [ + [{"name": "repo"}], + [{"name": "package"}], ] + api_instance = mock.MagicMock(spec=RepositoryApi) + api_mock.create.return_value = api_instance + api_instance.clone_repositories.return_value = CopyStatistics() self.start_cmd(clone, self.clone_argv) - RepositoryController.load.assert_called_with( + api_mock.create.assert_called_once_with( mock.ANY, "deb", "x86_64" ) - self.check_context( - RepositoryController.load.call_args[0][0], **kwargs + self.check_common_config(api_mock.create.call_args[0][0]) + read_file_mock.assert_any_call("repositories.yaml") + read_file_mock.assert_any_call("packages.yaml") + api_instance.clone_repositories.assert_called_once_with( + [{"name": "repo"}], [{"name": "package"}], "/root", + False, False, False, False ) - stdout.write.assert_called_once_with( - "Packages copied: 1/2.\n" + stdout_mock.write.assert_called_once_with( + "Packages copied: 0/0.\n" ) - def test_get_packages_cmd(self, stdout, RepositoryController, **kwargs): - ctrl = RepositoryController.load() - ctrl.load_packages = CallbacksAdapter() - ctrl.load_packages.return_value = gen_package( - name="test1", - filesize=1, - requires=None, - obsoletes=None, - provides=None - ) + def test_get_packages_cmd(self, api_mock, read_file_mock, stdout_mock): + read_file_mock.return_value = [{"name": "repo"}] + api_instance = mock.MagicMock(spec=RepositoryApi) + api_mock.create.return_value = api_instance + api_instance.get_packages.return_value = [ + gen_package(name="test1", filesize=1, requires=None, + obsoletes=None, provides=None) + ] + self.start_cmd(packages, self.packages_argv) - RepositoryController.load.assert_called_with( + read_file_mock.assert_called_with("repositories.yaml") + api_mock.create.assert_called_once_with( mock.ANY, "deb", "x86_64" ) - self.check_context( - RepositoryController.load.call_args[0][0], **kwargs + self.check_common_config(api_mock.create.call_args[0][0]) + api_instance.get_packages.assert_called_once_with( + [{"name": "repo"}], None, True ) self.assertIn( - "test1; test; 1; test1.pkg; 1;", - stdout.write.call_args_list[3][0][0] + "test1; test1.pkg", + stdout_mock.write.call_args_list[3][0][0] ) - def test_get_unresolved_cmd(self, stdout, RepositoryController, **kwargs): - ctrl = RepositoryController.load() - ctrl.load_packages = CallbacksAdapter() - ctrl.load_packages.return_value = gen_package( - name="test1", - requires=[gen_relation("test2")] - ) + def test_get_unresolved_cmd(self, api_mock, read_file_mock, stdout_mock): + read_file_mock.return_value = [{"name": "repo"}] + api_instance = mock.MagicMock(spec=RepositoryApi) + api_mock.create.return_value = api_instance + api_instance.get_unresolved_dependencies.return_value = [ + gen_relation(name="test") + ] + self.start_cmd(unresolved, self.unresolved_argv) - RepositoryController.load.assert_called_with( + api_mock.create.assert_called_once_with( mock.ANY, "deb", "x86_64" ) - self.check_context( - RepositoryController.load.call_args[0][0], **kwargs + self.check_common_config(api_mock.create.call_args[0][0]) + api_instance.get_unresolved_dependencies.assert_called_once_with( + [{"name": "repo"}] ) self.assertIn( - "test2; any; -", - stdout.write.call_args_list[3][0][0] + "test; any; -", + stdout_mock.write.call_args_list[3][0][0] ) diff --git a/packetary/tests/test_command_utils.py b/packetary/tests/test_command_utils.py index 0022208..c704ed9 100644 --- a/packetary/tests/test_command_utils.py +++ b/packetary/tests/test_command_utils.py @@ -28,18 +28,28 @@ class Dummy(object): class TestCommandUtils(base.TestCase): @mock.patch("packetary.cli.commands.utils.open") - def test_read_lines_from_file(self, open_mock): - open_mock().__enter__.return_value = [ - "line1\n", - " # comment\n", - "line2 \n" - ] - + def test_read_from_json_file(self, open_mock): + mock.mock_open(open_mock, read_data='{"key": "value"}') self.assertEqual( - ["line1", "line2"], - utils.read_lines_from_file("test.txt") + {"key": "value"}, + utils.read_from_file("test.json") ) + @mock.patch("packetary.cli.commands.utils.open") + def test_read_from_yaml_file(self, open_mock): + mock.mock_open(open_mock, read_data='key: value') + self.assertEqual( + {"key": "value"}, + utils.read_from_file("test.YAML") + ) + + def test_read_from_from_file_if_none(self): + self.assertIsNone(utils.read_from_file(None)) + + def test_read_from_from_file_fails_if_unknown_extension(self): + with self.assertRaisesRegexp(ValueError, "txt"): + utils.read_from_file("test.txt") + def test_get_object_attrs(self): obj = Dummy() obj.attr_int = 0 diff --git a/packetary/tests/test_deb_driver.py b/packetary/tests/test_deb_driver.py index d0f831f..44e3dea 100644 --- a/packetary/tests/test_deb_driver.py +++ b/packetary/tests/test_deb_driver.py @@ -20,9 +20,7 @@ import mock import os.path as path import six - from packetary.drivers import deb_driver -from packetary.library.utils import localize_repo_url from packetary.tests import base from packetary.tests.stubs.generator import gen_package from packetary.tests.stubs.generator import gen_repository @@ -30,7 +28,6 @@ from packetary.tests.stubs.helpers import get_compressed PACKAGES = path.join(path.dirname(__file__), "data", "Packages") -RELEASE = path.join(path.dirname(__file__), "data", "Release") class TestDebDriver(base.TestCase): @@ -42,75 +39,79 @@ class TestDebDriver(base.TestCase): def setUp(self): self.connection = mock.MagicMock() + self.repo = gen_repository( + name="trusty", section=("trusty", "main"), url="file:///repo" + ) - def test_parse_urls(self): - self.assertItemsEqual( - [ - ("http://host", "trusty", "main"), - ("http://host", "trusty", "restricted"), - ], - self.driver.parse_urls( - ["http://host/dists/ trusty main restricted"] - ) - ) - self.assertItemsEqual( - [("http://host", "trusty", "main")], - self.driver.parse_urls( - ["http://host/dists trusty main"] - ) - ) - self.assertItemsEqual( - [("http://host", "trusty", "main")], - self.driver.parse_urls( - ["http://host/ trusty main"] - ) - ) - self.assertItemsEqual( - [ - ("http://host", "trusty", "main"), - ("http://host2", "trusty", "main"), - ], - self.driver.parse_urls( - [ - "http://host/ trusty main", - "http://host2/dists/ trusty main", - ] - ) + def test_priority_sort(self): + repos = [ + {"name": "repo0"}, + {"name": "repo1", "priority": 0}, + {"name": "repo2", "priority": 1000}, + {"name": "repo3", "priority": None} + ] + repos.sort(key=self.driver.priority_sort) + + self.assertEqual( + ["repo2", "repo0", "repo3", "repo1"], + [x['name'] for x in repos] ) def test_get_repository(self): repos = [] - with open(RELEASE, "rb") as stream: - self.connection.open_stream.return_value = stream - self.driver.get_repository( - self.connection, - ("http://host", "trusty", "main"), - "x86_64", - repos.append - ) - self.connection.open_stream.assert_called_once_with( + repo_data = { + "name": "repo1", "url": "http://host", "suite": "trusty", + "section": ["main", "universe"], "path": "my_path" + } + self.connection.open_stream.return_value = {"Origin": "Ubuntu"} + self.driver.get_repository( + self.connection, + repo_data, + "x86_64", + repos.append + ) + self.connection.open_stream.assert_any_call( "http://host/dists/trusty/main/binary-amd64/Release" ) - self.assertEqual(1, len(repos)) + self.connection.open_stream.assert_any_call( + "http://host/dists/trusty/universe/binary-amd64/Release" + ) + self.assertEqual(2, len(repos)) repo = repos[0] - self.assertEqual(("trusty", "main"), repo.name) + self.assertEqual("repo1", repo.name) + self.assertEqual(("trusty", "main"), repo.section) + self.assertEqual("Ubuntu", repo.origin) + self.assertEqual("x86_64", repo.architecture) + self.assertEqual("http://host/", repo.url) + self.assertEqual("my_path", repo.path) + repo = repos[1] + self.assertEqual("repo1", repo.name) + self.assertEqual(("trusty", "universe"), repo.section) self.assertEqual("Ubuntu", repo.origin) self.assertEqual("x86_64", repo.architecture) self.assertEqual("http://host/", repo.url) + def test_get_flat_repository(self): + with self.assertRaisesRegexp(ValueError, "does not supported"): + self.driver.get_repository( + self.connection, + {"url": "http://host", "suite": "trusty"}, + "x86_64", + lambda x: None + ) + def test_get_packages(self): packages = [] - repo = gen_repository(name=("trusty", "main"), url="http://host/") with open(PACKAGES, "rb") as s: self.connection.open_stream.return_value = get_compressed(s) self.driver.get_packages( self.connection, - repo, + self.repo, packages.append ) self.connection.open_stream.assert_called_once_with( - "http://host/dists/trusty/main/binary-amd64/Packages.gz", + "file:///repo/dists/trusty/main/binary-amd64/Packages.gz", ) self.assertEqual(1, len(packages)) package = packages[0] @@ -132,7 +133,7 @@ class TestDebDriver(base.TestCase): self.assertItemsEqual( [ 'test-main (any)', - 'test2 (ge 0.8.16~exp9) | tes2-old (any)', + 'test2 (>= 0.8.16~exp9) | tes2-old (any)', 'test3 (any)' ], (str(x) for x in package.requires) @@ -142,7 +143,7 @@ class TestDebDriver(base.TestCase): (str(x) for x in package.provides) ) self.assertItemsEqual( - ["test-old (any)"], + [], (str(x) for x in package.obsoletes) ) @@ -156,10 +157,8 @@ class TestDebDriver(base.TestCase): os=mock.DEFAULT, open=mock.DEFAULT ) - def test_rebuild_repository(self, os, debfile, deb822, fcntl, - gzip, utils, open): - repo = gen_repository(name=("trusty", "main"), url="file:///repo") - package = gen_package(name="test", repository=repo) + def test_add_packages(self, os, debfile, deb822, fcntl, gzip, utils, open): + package = gen_package(name="test", repository=self.repo) os.path.join = lambda *x: "/".join(x) utils.get_path_from_url = lambda x: x[7:] @@ -171,7 +170,7 @@ class TestDebDriver(base.TestCase): mock.MagicMock() # Packages.gz, rb ] open.side_effect = files - self.driver.rebuild_repository(repo, [package]) + self.driver.add_packages(self.connection, self.repo, {package}) open.assert_any_call( "/repo/dists/trusty/main/binary-amd64/Packages", "wb" ) @@ -186,27 +185,24 @@ class TestDebDriver(base.TestCase): gzip=mock.DEFAULT, open=mock.DEFAULT, os=mock.DEFAULT, - utils=mock.DEFAULT ) - def test_fork_repository(self, deb822, gzip, open, os, utils): + @mock.patch("packetary.drivers.deb_driver.utils.ensure_dir_exist") + def test_fork_repository(self, mkdir_mock, deb822, gzip, open, os): os.path.sep = "/" os.path.join = lambda *x: "/".join(x) - utils.get_path_from_url = lambda x: x - utils.localize_repo_url = localize_repo_url - repo = gen_repository( - name=("trusty", "main"), url="http://localhost/test/" - ) files = [ mock.MagicMock(), mock.MagicMock() ] open.side_effect = files - new_repo = self.driver.fork_repository(self.connection, repo, "/root") - self.assertEqual(repo.name, new_repo.name) - self.assertEqual(repo.architecture, new_repo.architecture) - self.assertEqual(repo.origin, new_repo.origin) - self.assertEqual("/root/test/", new_repo.url) - utils.ensure_dir_exist.assert_called_once_with(os.path.dirname()) + new_repo = self.driver.fork_repository( + self.connection, self.repo, "/root/test" + ) + self.assertEqual(self.repo.name, new_repo.name) + self.assertEqual(self.repo.architecture, new_repo.architecture) + self.assertEqual(self.repo.origin, new_repo.origin) + self.assertEqual("file:///root/test/", new_repo.url) + mkdir_mock.assert_called_once_with(os.path.dirname()) open.assert_any_call( "/root/test/dists/trusty/main/binary-amd64/Release", "wb" ) @@ -225,9 +221,7 @@ class TestDebDriver(base.TestCase): os=mock.DEFAULT, utils=mock.DEFAULT ) - def test_update_suite_index( - self, os, fcntl, gzip, open, utils): - repo = gen_repository(name=("trusty", "main"), url="/repo") + def test_update_suite_index(self, os, fcntl, gzip, open, utils): files = [ mock.MagicMock(), # Release, a+b mock.MagicMock(), # Packages, rb @@ -254,7 +248,7 @@ class TestDebDriver(base.TestCase): ) for name in deb_driver._REPOSITORY_FILES ) - self.driver._update_suite_index(repo) + self.driver._update_suite_index(self.repo) open.assert_any_call("/root/dists/trusty/Release", "a+b") files[0].seek.assert_called_once_with(0) files[0].truncate.assert_called_once_with(0) @@ -269,6 +263,5 @@ class TestDebDriver(base.TestCase): .format(k, k + "_value") )) open.assert_any_call("/root/dists/trusty/Release", "a+b") - print([x.fileno() for x in files]) fcntl.flock.assert_any_call(files[0].fileno(), fcntl.LOCK_EX) fcntl.flock.assert_any_call(files[0].fileno(), fcntl.LOCK_UN) diff --git a/packetary/tests/test_index.py b/packetary/tests/test_index.py index 8739ca0..a7d665e 100644 --- a/packetary/tests/test_index.py +++ b/packetary/tests/test_index.py @@ -23,79 +23,25 @@ from packetary.objects.index import Index from packetary import objects from packetary.tests import base from packetary.tests.stubs.generator import gen_package -from packetary.tests.stubs.generator import gen_relation class TestIndex(base.TestCase): def test_add(self): index = Index() - index.add(gen_package(version=1)) - self.assertIn("package1", index.packages) - self.assertIn(1, index.packages["package1"]) - self.assertIn("obsoletes1", index.obsoletes) - self.assertIn("provides1", index.provides) + package1 = gen_package(version=1) + index.add(package1) + self.assertIn(package1.name, index.packages) + self.assertEqual( + [(1, package1)], + list(index.packages[package1.name].items()) + ) - index.add(gen_package(version=2)) + package2 = gen_package(version=2) + index.add(package2) self.assertEqual(1, len(index.packages)) - self.assertIn(1, index.packages["package1"]) - self.assertIn(2, index.packages["package1"]) - self.assertEqual(1, len(index.obsoletes)) - self.assertEqual(1, len(index.provides)) - - def test_find(self): - index = Index() - p1 = gen_package(version=1) - p2 = gen_package(version=2) - index.add(p1) - index.add(p2) - - self.assertIs( - p1, - index.find("package1", objects.VersionRange("eq", 1)) - ) - self.assertIs( - p2, - index.find("package1", objects.VersionRange()) - ) - self.assertIsNone( - index.find("package1", objects.VersionRange("gt", 2)) - ) - - def test_find_all(self): - index = Index() - p11 = gen_package(idx=1, version=1) - p12 = gen_package(idx=1, version=2) - p21 = gen_package(idx=2, version=1) - p22 = gen_package(idx=2, version=2) - index.add(p11) - index.add(p12) - index.add(p21) - index.add(p22) - - self.assertItemsEqual( - [p11, p12], - index.find_all("package1", objects.VersionRange()) - ) - self.assertItemsEqual( - [p21, p22], - index.find_all("package2", objects.VersionRange("le", 2)) - ) - - def test_find_newest_package(self): - index = Index() - p1 = gen_package(idx=1, version=2) - p2 = gen_package(idx=2, version=2) - p2.obsoletes.append( - gen_relation(p1.name, ["lt", p1.version]) - ) - index.add(p1) - index.add(p2) - - self.assertIs( - p1, index.find(p1.name, objects.VersionRange("eq", p1.version)) - ) - self.assertIs( - p2, index.find(p1.name, objects.VersionRange("eq", 1)) + self.assertEqual( + [(1, package1), (2, package2)], + list(index.packages[package1.name].items()) ) def test_find_top_down(self): @@ -104,16 +50,17 @@ class TestIndex(base.TestCase): p2 = gen_package(version=2) index.add(p1) index.add(p2) - self.assertIs( - p2, - index.find("package1", objects.VersionRange("le", 2)) + self.assertEqual( + [p1, p2], + index.find_all(p1.name, objects.VersionRange("<=", 2)) ) - self.assertIs( - p1, - index.find("package1", objects.VersionRange("lt", 2)) + self.assertEqual( + [p1], + index.find_all(p1.name, objects.VersionRange("<", 2)) ) - self.assertIsNone( - index.find("package1", objects.VersionRange("lt", 1)) + self.assertEqual( + [], + index.find_all(p1.name, objects.VersionRange("<", 1)) ) def test_find_down_up(self): @@ -122,56 +69,33 @@ class TestIndex(base.TestCase): p2 = gen_package(version=2) index.add(p1) index.add(p2) - self.assertIs( - p2, - index.find("package1", objects.VersionRange("ge", 2)) + self.assertEqual( + [p2], + index.find_all(p1.name, objects.VersionRange(">=", 2)) ) - self.assertIs( - p2, - index.find("package1", objects.VersionRange("gt", 1)) + self.assertEqual( + [p2], + index.find_all(p1.name, objects.VersionRange(">", 1)) ) - self.assertIsNone( - index.find("package1", objects.VersionRange("gt", 2)) + self.assertEqual( + [], + index.find_all(p1.name, objects.VersionRange(">", 2)) ) - def test_find_accurate(self): + def test_find_with_specified_version(self): index = Index() - p1 = gen_package(version=1) - p2 = gen_package(version=2) - index.add(p1) - index.add(p2) - self.assertIs( - p1, - index.find("package1", objects.VersionRange("eq", 1)) - ) - self.assertIsNone( - index.find("package1", objects.VersionRange("eq", 3)) - ) - - def test_find_obsolete(self): - index = Index() - p1 = gen_package(version=1) - index.add(p1) - - self.assertIs( - p1, index.find("obsoletes1", objects.VersionRange("le", 2)) - ) - self.assertIsNone( - index.find("obsoletes1", objects.VersionRange("gt", 2)) - ) - - def test_find_provides(self): - index = Index() - p1 = gen_package(version=1) - p2 = gen_package(version=2) + p1 = gen_package(idx=1, version=1) + p2 = gen_package(idx=1, version=2) index.add(p1) index.add(p2) - self.assertIs( - p2, index.find("provides1", objects.VersionRange("ge", 2)) + self.assertItemsEqual( + [p1], + index.find_all(p1.name, objects.VersionRange("=", p1.version)) ) - self.assertIsNone( - index.find("provides1", objects.VersionRange("lt", 2)) + self.assertItemsEqual( + [p2], + index.find_all(p2.name, objects.VersionRange("=", p2.version)) ) def test_len(self): diff --git a/packetary/tests/test_library_utils.py b/packetary/tests/test_library_utils.py index b3a8d0b..6928d4f 100644 --- a/packetary/tests/test_library_utils.py +++ b/packetary/tests/test_library_utils.py @@ -94,6 +94,29 @@ class TestLibraryUtils(base.TestCase): utils.get_path_from_url("http://host/f.txt", False) ) + @mock.patch("packetary.library.utils.os") + def test_normalize_repository_url(self, os_mock): + def abs_patch_mock(p): + if p.startswith("/"): + return p + return "/root/" + p[2:] + + os_mock.sep = "/" + os_mock.path.abspath.side_effect = abs_patch_mock + + cases = [ + ("file:///repo/", "/repo"), + ("file:///root/repo/", "./repo"), + ("http://localhost/repo/", "http://localhost/repo"), + ("http://localhost/repo/", "http://localhost/repo/"), + ] + + for expected, url in cases: + self.assertEqual( + expected, utils.normalize_repository_url(url), + "URL: {0}".format(url) + ) + @mock.patch("packetary.library.utils.os") def test_ensure_dir_exist(self, os): os.makedirs.side_effect = [ diff --git a/packetary/tests/test_objects.py b/packetary/tests/test_objects.py index 3f0e027..255fd54 100644 --- a/packetary/tests/test_objects.py +++ b/packetary/tests/test_objects.py @@ -32,7 +32,7 @@ class TestObjectBase(base.TestCase): def check_copy(self, origin): clone = copy.copy(origin) self.assertIsNot(origin, clone) - self.assertEqual(origin, clone) + self.assertEqual(origin.name, clone.name) origin_name = origin.name origin.name += "1" self.assertEqual( @@ -91,25 +91,30 @@ class TestPackageObject(TestObjectBase): ) -class TestRepositoryObject(base.TestCase): +class TestRepositoryObject(TestObjectBase): def test_copy(self): - origin = generator.gen_repository() - clone = copy.copy(origin) - self.assertEqual(clone.name, origin.name) - self.assertEqual(clone.architecture, origin.architecture) + self.check_copy(generator.gen_repository()) + + def test_hashable(self): + self.check_hashable( + generator.gen_repository(name="test1", url="file:///repo"), + generator.gen_repository(name="test1", url="file:///repo", + section=("a", "b")), + ) def test_str(self): - self.assertEqual( - "a.b", - str(generator.gen_repository(name=("a", "b"))) - ) self.assertEqual( "/a/b/", - str(generator.gen_repository(name="", url="/a/b/")) + str(generator.gen_repository(name="a", url="/a/b/")) ) self.assertEqual( - "a", - str(generator.gen_repository(name="a", url="/a/b/")) + "/a/b/ c", + str(generator.gen_repository(name="a", url="/a/b/", section="c")) + ) + self.assertEqual( + "/a/b/ c d", + str(generator.gen_repository( + name="a", url="/a/b/", section=("c", "d"))) ) @@ -124,15 +129,15 @@ class TestRelationObject(TestObjectBase): def test_hashable(self): self.check_hashable( generator.gen_relation(name="test1"), - generator.gen_relation(name="test1", version=["le", 1]) + generator.gen_relation(name="test1", version=["<=", 1]) ) def test_from_args(self): r = PackageRelation.from_args( - ("test", "le", 2), ("test2",), ("test3",) + ("test", "<=", 2), ("test2",), ("test3",) ) self.assertEqual("test", r.name) - self.assertEqual("le", r.version.op) + self.assertEqual("<=", r.version.op) self.assertEqual(2, r.version.edge) self.assertEqual("test2", r.alternative.name) self.assertEqual(VersionRange(), r.alternative.version) @@ -142,7 +147,7 @@ class TestRelationObject(TestObjectBase): def test_iter(self): it = iter(PackageRelation.from_args( - ("test", "le", 2), ("test2", "ge", 3)) + ("test", "<=", 2), ("test2", ">=", 3)) ) self.assertEqual("test", next(it).name) self.assertEqual("test2", next(it).name) @@ -153,15 +158,15 @@ class TestRelationObject(TestObjectBase): class TestVersionRange(TestObjectBase): def test_equal(self): self.check_equal( - VersionRange("eq", 1), - VersionRange("eq", 1), - VersionRange("le", 1) + VersionRange("=", 1), + VersionRange("=", 1), + VersionRange("<=", 1) ) def test_hashable(self): self.check_hashable( - VersionRange(op="le"), - VersionRange(op="le", edge=3) + VersionRange(op="<="), + VersionRange(op="<=", edge=3) ) def __check_intersection(self, assertion, cases): @@ -177,28 +182,39 @@ class TestVersionRange(TestObjectBase): def test_have_intersection(self): cases = [ - (("lt", 2), ("gt", 1)), - (("lt", 3), ("lt", 4)), - (("gt", 3), ("gt", 4)), - (("eq", 1), ("eq", 1)), - (("ge", 1), ("le", 1)), - (("eq", 1), ("lt", 2)), - ((None, None), ("le", 10)), + (("=", 2), ("=", 2)), + (("=", 2), ("<", 3)), + (("=", 2), (">", 1)), + (("<", 2), (">", 1)), + (("<", 2), ("<", 3)), + (("<", 2), ("<", 2)), + (("<", 2), ("<=", 2)), + ((">", 2), (">", 1)), + ((">", 2), ("<", 3)), + ((">", 2), (">=", 2)), + ((">", 2), (">", 2)), + ((">=", 2), ("<=", 2)), + ((None, None), ("=", 2)), ] self.__check_intersection(self.assertTrue, cases) def test_does_not_have_intersection(self): cases = [ - (("lt", 2), ("gt", 2)), - (("ge", 2), ("lt", 2)), - (("gt", 2), ("le", 2)), - (("gt", 1), ("lt", 1)), + (("=", 2), ("=", 1)), + (("=", 2), ("<", 2)), + (("=", 2), (">", 2)), + (("=", 2), (">", 3)), + (("=", 2), ("<", 1)), + (("<", 2), (">=", 2)), + (("<", 2), (">", 3)), + ((">", 2), ("<=", 2)), + ((">", 2), ("<", 1)), ] self.__check_intersection(self.assertFalse, cases) def test_intersection_is_typesafe(self): with self.assertRaises(TypeError): - VersionRange("eq", 1).has_intersection(("eq", 1)) + VersionRange("=", 1).has_intersection(("=", 1)) class TestPackageVersion(base.TestCase): diff --git a/packetary/tests/test_packages_forest.py b/packetary/tests/test_packages_forest.py new file mode 100644 index 0000000..594560e --- /dev/null +++ b/packetary/tests/test_packages_forest.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +# Copyright 2016 Mirantis, Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +from packetary.objects import PackagesForest +from packetary.tests import base +from packetary.tests.stubs import generator + + +class TestPackagesForest(base.TestCase): + def setUp(self): + super(TestPackagesForest, self).setUp() + + def _add_packages(self, tree, packages): + for pkg in packages: + tree.add(pkg) + + def _generate_packages(self, forest): + packages1 = [ + generator.gen_package( + name="package1", version=1, mandatory=True, + requires=None + ), + generator.gen_package( + name="package2", version=1, + requires=None + ), + generator.gen_package( + name="package3", version=1, + requires=[generator.gen_relation("package5")] + ) + ] + packages2 = [ + generator.gen_package( + name="package4", version=1, mandatory=True, + requires=None + ), + generator.gen_package( + name="package5", version=1, + requires=[generator.gen_relation("package2")] + ), + ] + self._add_packages(forest.add_tree(), packages1) + self._add_packages(forest.add_tree(), packages2) + + def test_add_tree(self): + forest = PackagesForest() + tree = forest.add_tree() + self.assertIs(tree, forest.trees[-1]) + + def test_find(self): + forest = PackagesForest() + p11 = generator.gen_package(name="package1", version=1) + p12 = generator.gen_package(name="package1", version=2) + p21 = generator.gen_package(name="package2", version=1) + p22 = generator.gen_package(name="package2", version=2) + self._add_packages(forest.add_tree(), [p11, p22]) + self._add_packages(forest.add_tree(), [p12, p21]) + self.assertEqual( + p11, forest.find(generator.gen_relation("package1", [">=", 1])) + ) + self.assertEqual( + p12, forest.find(generator.gen_relation("package1", [">", 1])) + ) + self.assertEqual(p22, forest.find(generator.gen_relation("package2"))) + self.assertEqual( + p21, forest.find(generator.gen_relation("package2", ["<", 2])) + ) + + def test_get_packages_with_mandatory(self): + forest = PackagesForest() + self._generate_packages(forest) + packages = forest.get_packages( + [generator.gen_relation("package3")], True + ) + self.assertItemsEqual( + ["package1", "package2", "package3", "package4", "package5"], + (x.name for x in packages) + ) + + def test_get_packages_without_mandatory(self): + forest = PackagesForest() + self._generate_packages(forest) + packages = forest.get_packages( + [generator.gen_relation("package3")], False + ) + self.assertItemsEqual( + ["package2", "package3", "package5"], + (x.name for x in packages) + ) diff --git a/packetary/tests/test_packages_tree.py b/packetary/tests/test_packages_tree.py index f7e936e..2d183b0 100644 --- a/packetary/tests/test_packages_tree.py +++ b/packetary/tests/test_packages_tree.py @@ -16,120 +16,82 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -import warnings - -from packetary.objects import Index from packetary.objects import PackagesTree +from packetary.objects import VersionRange from packetary.tests import base from packetary.tests.stubs import generator class TestPackagesTree(base.TestCase): - def setUp(self): - super(TestPackagesTree, self).setUp() + def test_add(self): + tree = PackagesTree() + pkg = generator.gen_package(version=1, mandatory=True) + tree.add(pkg) + self.assertIs(pkg, tree.find(pkg.name, VersionRange('=', pkg.version))) + self.assertIs( + pkg.obsoletes[0], + tree.obsoletes[pkg.obsoletes[0].name][(pkg.name, pkg.version)] + ) + self.assertIs( + pkg.provides[0], + tree.provides[pkg.provides[0].name][(pkg.name, pkg.version)] + ) + tree.add(generator.gen_package(version=1, mandatory=False)) + self.assertItemsEqual([pkg], tree.mandatory_packages) - def test_get_unresolved_dependencies(self): - ptree = PackagesTree() - ptree.add(generator.gen_package( - 1, requires=[generator.gen_relation("unresolved")])) - ptree.add(generator.gen_package(2, requires=None)) - ptree.add(generator.gen_package( - 3, requires=[generator.gen_relation("package1")] - )) - ptree.add(generator.gen_package( - 4, - requires=[generator.gen_relation("loop")], - obsoletes=[generator.gen_relation("loop", ["le", 1])] - )) + def test_find_package(self): + tree = PackagesTree() + p1 = generator.gen_package(idx=1, version=1) + p2 = generator.gen_package(idx=1, version=2) + tree.add(p1) + tree.add(p2) - unresolved = ptree.get_unresolved_dependencies() - self.assertItemsEqual( - ["loop", "unresolved"], - (x.name for x in unresolved) + self.assertIs(p1, tree.find(p1.name, VersionRange("<", p2.version))) + self.assertIs(p2, tree.find(p1.name, VersionRange(">=", p1.version))) + self.assertIsNone(tree.find(p1.name, VersionRange(">", p2.version))) + + def test_find_obsolete(self): + tree = PackagesTree() + p1 = generator.gen_package( + version=1, obsoletes=[generator.gen_relation('obsolete', ('<', 2))] + ) + p2 = generator.gen_package( + version=2, obsoletes=[generator.gen_relation('obsolete', ('<', 2))] + ) + tree.add(p1) + tree.add(p2) + + self.assertEqual( + [p1, p2], tree.find_all("obsolete", VersionRange("<=", 2)) + ) + self.assertIsNone( + tree.find("obsolete", VersionRange(">", 2)) ) - def test_get_unresolved_dependencies_with_main(self): - ptree = PackagesTree() - ptree.add(generator.gen_package( + def test_find_provides(self): + tree = PackagesTree() + p1 = generator.gen_package( + version=1, obsoletes=[generator.gen_relation('provide', ('<', 2))] + ) + tree.add(p1) + + self.assertIs( + p1, tree.find("provide", VersionRange("<=", 2)) + ) + self.assertIsNone( + tree.find("provide", VersionRange(">", 2)) + ) + + def test_get_unresolved_dependencies(self): + tree = PackagesTree() + tree.add(generator.gen_package( 1, requires=[generator.gen_relation("unresolved")])) - ptree.add(generator.gen_package(2, requires=None)) - ptree.add(generator.gen_package( + tree.add(generator.gen_package(2, requires=None)) + tree.add(generator.gen_package( 3, requires=[generator.gen_relation("package1")] )) - ptree.add(generator.gen_package( - 4, - requires=[generator.gen_relation("package5")] - )) - main = Index() - main.add(generator.gen_package(5, requires=[ - generator.gen_relation("package6") - ])) - - unresolved = ptree.get_unresolved_dependencies(main) + unresolved = tree.get_unresolved_dependencies() self.assertItemsEqual( ["unresolved"], (x.name for x in unresolved) ) - - def test_get_minimal_subset_with_master(self): - ptree = PackagesTree() - ptree.add(generator.gen_package(1, requires=None)) - ptree.add(generator.gen_package(2, requires=None)) - ptree.add(generator.gen_package(3, requires=None)) - ptree.add(generator.gen_package( - 4, requires=[generator.gen_relation("package1")] - )) - - master = Index() - master.add(generator.gen_package(1, requires=None)) - master.add(generator.gen_package( - 5, - requires=[generator.gen_relation( - "package10", - alternative=generator.gen_relation("package4") - )] - )) - - unresolved = set([generator.gen_relation("package3")]) - resolved = ptree.get_minimal_subset(master, unresolved) - self.assertItemsEqual( - ["package3", "package4"], - (x.name for x in resolved) - ) - - def test_get_minimal_subset_without_master(self): - ptree = PackagesTree() - ptree.add(generator.gen_package(1, requires=None)) - ptree.add(generator.gen_package(2, requires=None)) - ptree.add(generator.gen_package( - 3, requires=[generator.gen_relation("package1")] - )) - unresolved = set([generator.gen_relation("package3")]) - resolved = ptree.get_minimal_subset(None, unresolved) - self.assertItemsEqual( - ["package3", "package1"], - (x.name for x in resolved) - ) - - def test_mandatory_packages_always_included(self): - ptree = PackagesTree() - ptree.add(generator.gen_package(1, requires=None, mandatory=True)) - ptree.add(generator.gen_package(2, requires=None)) - ptree.add(generator.gen_package(3, requires=None)) - unresolved = set([generator.gen_relation("package3")]) - resolved = ptree.get_minimal_subset(None, unresolved) - self.assertItemsEqual( - ["package3", "package1"], - (x.name for x in resolved) - ) - - def test_warning_if_unresolved(self): - ptree = PackagesTree() - ptree.add(generator.gen_package( - 1, requires=None)) - - with warnings.catch_warnings(record=True) as log: - ptree.get_minimal_subset( - None, [generator.gen_relation("package2")] - ) - self.assertIn("package2", str(log[0])) diff --git a/packetary/tests/test_repository_api.py b/packetary/tests/test_repository_api.py index 730662c..dad47ac 100644 --- a/packetary/tests/test_repository_api.py +++ b/packetary/tests/test_repository_api.py @@ -16,6 +16,7 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +import copy import mock from packetary.api import Configuration @@ -27,197 +28,174 @@ from packetary.tests.stubs.helpers import CallbacksAdapter class TestRepositoryApi(base.TestCase): - def test_get_packages_as_is(self): - controller = CallbacksAdapter() - pkg = generator.gen_package(name="test") - controller.load_packages.side_effect = [ - pkg - ] - api = RepositoryApi(controller) - packages = api.get_packages("file:///repo1") - self.assertEqual(1, len(packages)) - package = packages.pop() - self.assertIs(pkg, package) + def setUp(self): + self.controller = CallbacksAdapter() + self.api = RepositoryApi(self.controller) + self.repo_data = {"name": "repo1", "url": "file:///repo1"} + self.repo = generator.gen_repository(**self.repo_data) + self.controller.load_repositories.return_value = [self.repo] + self._generate_packages() - def test_get_packages_with_depends_resolving(self): - controller = CallbacksAdapter() - controller.load_packages.side_effect = [ - [ - generator.gen_package(idx=1, requires=None), - generator.gen_package( - idx=2, requires=[generator.gen_relation("package1")] - ), - generator.gen_package( - idx=3, requires=[generator.gen_relation("package1")] - ), - generator.gen_package(idx=4, requires=None), - generator.gen_package(idx=5, requires=None), - ], + def _generate_packages(self): + self.packages = [ + generator.gen_package(idx=1, repository=self.repo, requires=None), + generator.gen_package(idx=2, repository=self.repo, requires=None), generator.gen_package( - idx=6, requires=[generator.gen_relation("package2")] + idx=3, repository=self.repo, mandatory=True, + requires=[generator.gen_relation("package2")] ), + generator.gen_package( + idx=4, repository=self.repo, mandatory=False, + requires=[generator.gen_relation("package1")] + ), + generator.gen_package( + idx=5, repository=self.repo, + requires=[generator.gen_relation("package6")]) ] + self.controller.load_packages.return_value = self.packages - api = RepositoryApi(controller) - packages = api.get_packages([ - "file:///repo1", "file:///repo2" - ], - "file:///repo3", ["package4"] + @mock.patch("packetary.api.RepositoryController") + @mock.patch("packetary.api.ConnectionsManager") + def test_create_with_config(self, connection_mock, controller_mock): + config = Configuration( + http_proxy="http://localhost", https_proxy="https://localhost", + retries_num=10, threads_num=8, ignore_errors_num=6 + ) + RepositoryApi.create(config, "deb", "x86_64") + connection_mock.assert_called_once_with( + proxy="http://localhost", + secure_proxy="https://localhost", + retries_num=10 + ) + controller_mock.load.assert_called_once_with( + mock.ANY, "deb", "x86_64" ) + @mock.patch("packetary.api.RepositoryController") + @mock.patch("packetary.api.ConnectionsManager") + def test_create_with_context(self, connection_mock, controller_mock): + config = Configuration( + http_proxy="http://localhost", https_proxy="https://localhost", + retries_num=10, threads_num=8, ignore_errors_num=6 + ) + context = Context(config) + RepositoryApi.create(context, "deb", "x86_64") + connection_mock.assert_called_once_with( + proxy="http://localhost", + secure_proxy="https://localhost", + retries_num=10 + ) + controller_mock.load.assert_called_once_with( + context, "deb", "x86_64" + ) + + def test_get_packages_as_is(self): + packages = self.api.get_packages([self.repo_data], None) + self.assertEqual(5, len(packages)) + self.assertItemsEqual( + self.packages, + packages + ) + + def test_get_packages_by_requirements_with_mandatory(self): + packages = self.api.get_packages( + [self.repo_data], [{"name": "package1"}], True + ) self.assertEqual(3, len(packages)) self.assertItemsEqual( - ["package1", "package4", "package2"], + ["package1", "package2", "package3"], (x.name for x in packages) ) - controller.load_repositories.assert_any_call( - ["file:///repo1", "file:///repo2"] + + def test_get_packages_by_requirements_without_mandatory(self): + packages = self.api.get_packages( + [self.repo_data], [{"name": "package4"}], False ) - controller.load_repositories.assert_any_call( - "file:///repo3" + self.assertEqual(2, len(packages)) + self.assertItemsEqual( + ["package1", "package4"], + (x.name for x in packages) ) def test_clone_repositories_as_is(self): - controller = CallbacksAdapter() - repo = generator.gen_repository(name="repo1") - packages = [ - generator.gen_package(name="test1", repository=repo), - generator.gen_package(name="test2", repository=repo) - ] - mirror = generator.gen_repository(name="mirror") - controller.load_repositories.return_value = repo - controller.load_packages.return_value = packages - controller.clone_repositories.return_value = {repo: mirror} - controller.copy_packages.return_value = [0, 1] - api = RepositoryApi(controller) - stats = api.clone_repositories( - ["file:///repo1"], "/mirror", keep_existing=True + # return value is used as statistics + mirror = copy.copy(self.repo) + mirror.url = "file:///mirror/repo" + self.controller.fork_repository.return_value = mirror + self.controller.assign_packages.return_value = [0, 1, 1, 1, 0, 6] + stats = self.api.clone_repositories([self.repo_data], None, "/mirror") + self.controller.fork_repository.assert_called_once_with( + self.repo, '/mirror', False, False + ) + self.controller.assign_packages.assert_called_once_with( + mirror, set(self.packages) + ) + self.assertEqual(6, stats.total) + self.assertEqual(4, stats.copied) + + def test_clone_by_requirements_with_mandatory(self): + # return value is used as statistics + mirror = copy.copy(self.repo) + mirror.url = "file:///mirror/repo" + self.controller.fork_repository.return_value = mirror + self.controller.assign_packages.return_value = [0, 1, 1] + stats = self.api.clone_repositories( + [self.repo_data], [{"name": "package1"}], + "/mirror", include_mandatory=True + ) + packages = {self.packages[0], self.packages[1], self.packages[2]} + self.controller.fork_repository.assert_called_once_with( + self.repo, '/mirror', False, False + ) + self.controller.assign_packages.assert_called_once_with( + mirror, packages + ) + self.assertEqual(3, stats.total) + self.assertEqual(2, stats.copied) + + def test_clone_by_requirements_without_mandatory(self): + # return value is used as statistics + mirror = copy.copy(self.repo) + mirror.url = "file:///mirror/repo" + self.controller.fork_repository.return_value = mirror + self.controller.assign_packages.return_value = [0, 4] + stats = self.api.clone_repositories( + [self.repo_data], [{"name": "package4"}], + "/mirror", include_mandatory=False + ) + packages = {self.packages[0], self.packages[3]} + self.controller.fork_repository.assert_called_once_with( + self.repo, '/mirror', False, False + ) + self.controller.assign_packages.assert_called_once_with( + mirror, packages ) self.assertEqual(2, stats.total) self.assertEqual(1, stats.copied) - controller.copy_packages.assert_called_once_with( - mirror, set(packages), True - ) - - def test_copy_minimal_subset_of_repository(self): - controller = CallbacksAdapter() - repo1 = generator.gen_repository(name="repo1") - repo2 = generator.gen_repository(name="repo2") - repo3 = generator.gen_repository(name="repo3") - mirror1 = generator.gen_repository(name="mirror1") - mirror2 = generator.gen_repository(name="mirror2") - pkg_group1 = [ - generator.gen_package( - idx=1, requires=None, repository=repo1 - ), - generator.gen_package( - idx=1, version=2, requires=None, repository=repo1 - ), - generator.gen_package( - idx=2, requires=None, repository=repo1 - ) - ] - pkg_group2 = [ - generator.gen_package( - idx=4, - requires=[generator.gen_relation("package1")], - repository=repo2, - mandatory=True, - ) - ] - pkg_group3 = [ - generator.gen_package( - idx=3, requires=None, repository=repo1 - ) - ] - controller.load_repositories.side_effect = [[repo1, repo2], repo3] - controller.load_packages.side_effect = [ - pkg_group1 + pkg_group2 + pkg_group3, - generator.gen_package( - idx=6, - repository=repo3, - requires=[generator.gen_relation("package2")] - ) - ] - controller.clone_repositories.return_value = { - repo1: mirror1, repo2: mirror2 - } - controller.copy_packages.return_value = 1 - api = RepositoryApi(controller) - api.clone_repositories( - ["file:///repo1", "file:///repo2"], "/mirror", - ["file:///repo3"], - keep_existing=True - ) - controller.copy_packages.assert_any_call( - mirror1, set(pkg_group1), True - ) - controller.copy_packages.assert_any_call( - mirror2, set(pkg_group2), True - ) - self.assertEqual(2, controller.copy_packages.call_count) def test_get_unresolved(self): - controller = CallbacksAdapter() - pkg = generator.gen_package( - name="test", requires=[generator.gen_relation("test2")] - ) - controller.load_packages.side_effect = [ - pkg - ] - api = RepositoryApi(controller) - r = api.get_unresolved_dependencies("file:///repo1") - controller.load_repositories.assert_called_once_with("file:///repo1") - self.assertItemsEqual( - ["test2"], - (x.name for x in r) - ) + unresolved = self.api.get_unresolved_dependencies([self.repo_data]) + self.assertItemsEqual(["package6"], (x.name for x in unresolved)) - def test_get_unresolved_with_main(self): - controller = CallbacksAdapter() - pkg1 = generator.gen_package( - name="test1", requires=[ - generator.gen_relation("test2"), - generator.gen_relation("test3") - ] - ) - pkg2 = generator.gen_package( - name="test2", requires=[generator.gen_relation("test4")] - ) - controller.load_packages.side_effect = [ - pkg1, pkg2 - ] - api = RepositoryApi(controller) - r = api.get_unresolved_dependencies("file:///repo1", "file:///repo2") - controller.load_repositories.assert_any_call("file:///repo1") - controller.load_repositories.assert_any_call("file:///repo2") - self.assertItemsEqual( - ["test3"], - (x.name for x in r) - ) + def test_load_requirements(self): + expected = { + generator.gen_relation("test1"), + generator.gen_relation("test2", ["<", "3"]), + generator.gen_relation("test2", [">", "1"]), + } + actual = set(self.api._load_requirements( + [{"name": "test1"}, {"name": "test2", "versions": ["< 3", "> 1"]}] + )) + self.assertEqual(expected, actual) + self.assertIsNone(self.api._load_requirements(None)) - def test_parse_requirements(self): - requirements = RepositoryApi._parse_requirements( - ["p1 le 2 | p2 | p3 ge 2"] - ) + def test_validate_repos_data(self): + # TODO(bgaifullin) implement me + pass - expected = generator.gen_relation( - "p1", - ["le", '2'], - generator.gen_relation( - "p2", - None, - generator.gen_relation( - "p3", - ["ge", '2'] - ) - ) - ) - self.assertEqual(1, len(requirements)) - self.assertEqual( - list(expected), - list(requirements.pop()) - ) + def test_validate_requirements_data(self): + # TODO(bgaifullin) implement me + pass class TestContext(base.TestCase): diff --git a/packetary/tests/test_repository_contoller.py b/packetary/tests/test_repository_contoller.py index 9eb6cba..32768fa 100644 --- a/packetary/tests/test_repository_contoller.py +++ b/packetary/tests/test_repository_contoller.py @@ -18,9 +18,9 @@ import copy import mock -import six from packetary.controllers import RepositoryController +from packetary.drivers.base import RepositoryDriverBase from packetary.tests import base from packetary.tests.stubs.executor import Executor from packetary.tests.stubs.generator import gen_package @@ -30,7 +30,7 @@ from packetary.tests.stubs.helpers import CallbacksAdapter class TestRepositoryController(base.TestCase): def setUp(self): - self.driver = mock.MagicMock() + self.driver = mock.MagicMock(spec=RepositoryDriverBase) self.context = mock.MagicMock() self.context.async_section.return_value = Executor() self.ctrl = RepositoryController(self.context, self.driver, "x86_64") @@ -53,24 +53,21 @@ class TestRepositoryController(base.TestCase): self.assertIs(self.driver, controller.driver) def test_load_repositories(self): - self.driver.parse_urls.return_value = ["test1"] - consumer = mock.MagicMock() - self.ctrl.load_repositories("file:///test1", consumer) - self.driver.parse_urls.assert_called_once_with(["file:///test1"]) + repo_data = {"name": "test", "url": "file:///test1"} + repo = gen_repository(**repo_data) + self.driver.get_repository = CallbacksAdapter() + self.driver.get_repository.side_effect = [repo] + + repos = self.ctrl.load_repositories([repo_data]) self.driver.get_repository.assert_called_once_with( - self.context.connection, "test1", "x86_64", consumer + self.context.connection, repo_data, self.ctrl.arch ) - for url in [six.u("file:///test1"), ["file:///test1"]]: - self.driver.reset_mock() - self.ctrl.load_repositories(url, consumer) - if not isinstance(url, list): - url = [url] - self.driver.parse_urls.assert_called_once_with(url) + self.assertEqual([repo], repos) def test_load_packages(self): repo = mock.MagicMock() consumer = mock.MagicMock() - self.ctrl.load_packages([repo], consumer) + self.ctrl.load_packages(repo, consumer) self.driver.get_packages.assert_called_once_with( self.context.connection, repo, consumer ) @@ -78,30 +75,33 @@ class TestRepositoryController(base.TestCase): @mock.patch("packetary.controllers.repository.os") def test_assign_packages(self, os): repo = gen_repository(url="/test/repo") - packages = [ + packages = { gen_package(name="test1", repository=repo), gen_package(name="test2", repository=repo) - ] - existed_packages = [ - gen_package(name="test3", repository=repo), - gen_package(name="test2", repository=repo) - ] - + } os.path.join = lambda *x: "/".join(x) - self.driver.get_packages = CallbacksAdapter() - self.driver.get_packages.return_value = existed_packages - self.ctrl.assign_packages(repo, packages, True) - os.remove.assert_not_called() - all_packages = set(packages + existed_packages) - self.driver.rebuild_repository.assert_called_once_with( - repo, all_packages + self.ctrl.assign_packages(repo, packages) + self.driver.add_packages.assert_called_once_with( + self.ctrl.context.connection, repo, packages ) - self.driver.rebuild_repository.reset_mock() - self.ctrl.assign_packages(repo, packages, False) - self.driver.rebuild_repository.assert_called_once_with( - repo, set(packages) + + @mock.patch("packetary.controllers.repository.os") + def test_fork_repository(self, os): + os.path.join.side_effect = lambda *args: "".join(args) + repo = gen_repository(name="test1", url="file:///test") + clone = copy.copy(repo) + clone.url = "/root/repo" + self.driver.fork_repository.return_value = clone + self.context.connection.retrieve.side_effect = [0, 10] + self.ctrl.fork_repository(repo, "./repo", False, False) + self.driver.fork_repository.assert_called_once_with( + self.context.connection, repo, "./repo/test", False, False + ) + repo.path = "os" + self.ctrl.fork_repository(repo, "./repo/", False, False) + self.driver.fork_repository.assert_called_with( + self.context.connection, repo, "./repo/os", False, False ) - os.remove.assert_called_once_with("/test/repo/test3.pkg") def test_copy_packages(self): repo = gen_repository(url="file:///repo/") @@ -112,8 +112,9 @@ class TestRepositoryController(base.TestCase): target = gen_repository(url="/test/repo") self.context.connection.retrieve.side_effect = [0, 10] observer = mock.MagicMock() - self.ctrl.copy_packages(target, packages, True, observer) - observer.assert_has_calls([mock.call(0), mock.call(10)]) + self.ctrl._copy_packages(target, packages, observer) + observer.assert_any_call(0) + observer.assert_any_call(10) self.context.connection.retrieve.assert_any_call( "file:///repo/test1.pkg", "/test/repo/test1.pkg", @@ -124,22 +125,13 @@ class TestRepositoryController(base.TestCase): "/test/repo/test2.pkg", size=-1 ) - self.driver.rebuild_repository.assert_called_once_with( - target, set(packages) - ) - @mock.patch("packetary.controllers.repository.os") - def test_clone_repository(self, os): - os.path.abspath.return_value = "/root/repo" - repos = [ - gen_repository(name="test1"), - gen_repository(name="test2") + def test_copy_packages_does_not_affect_packages_in_same_repo(self): + repo = gen_repository(url="file:///repo/") + packages = [ + gen_package(name="test1", repository=repo, filesize=10), + gen_package(name="test2", repository=repo, filesize=-1) ] - clones = [copy.copy(x) for x in repos] - self.driver.fork_repository.side_effect = clones - mirrors = self.ctrl.clone_repositories(repos, "./repo") - for r in repos: - self.driver.fork_repository.assert_any_call( - self.context.connection, r, "/root/repo", False, False - ) - self.assertEqual(mirrors, dict(zip(repos, clones))) + observer = mock.MagicMock() + self.ctrl._copy_packages(repo, packages, observer) + self.assertFalse(self.context.connection.retrieve.called) diff --git a/packetary/tests/test_rpm_driver.py b/packetary/tests/test_rpm_driver.py index efb9bde..7207380 100644 --- a/packetary/tests/test_rpm_driver.py +++ b/packetary/tests/test_rpm_driver.py @@ -22,7 +22,6 @@ import sys import six -from packetary.library.utils import localize_repo_url from packetary.objects import FileChecksum from packetary.tests import base from packetary.tests.stubs.generator import gen_repository @@ -53,31 +52,33 @@ class TestRpmDriver(base.TestCase): self.createrepo.reset_mock() self.connection = mock.MagicMock() - def test_parse_urls(self): - self.assertItemsEqual( - [ - "http://host/centos/os", - "http://host/centos/updates" - ], - self.driver.parse_urls([ - "http://host/centos/os", - "http://host/centos/updates/", - ]) + def test_priority_sort(self): + repos = [ + {"name": "repo0"}, + {"name": "repo1", "priority": 1}, + {"name": "repo2", "priority": 99}, + {"name": "repo3", "priority": None} + ] + repos.sort(key=self.driver.priority_sort) + + self.assertEqual( + ["repo1", "repo0", "repo3", "repo2"], + [x['name'] for x in repos] ) def test_get_repository(self): repos = [] - + repo_data = {"name": "os", "url": "http://host/centos/os/x86_64/"} self.driver.get_repository( self.connection, - "http://host/centos/os/x86_64", + repo_data, "x86_64", repos.append ) self.assertEqual(1, len(repos)) repo = repos[0] - self.assertEqual("/centos/os/x86_64", repo.name) + self.assertEqual("os", repo.name) self.assertEqual("", repo.origin) self.assertEqual("x86_64", repo.architecture) self.assertEqual("http://host/centos/os/x86_64/", repo.url) @@ -125,7 +126,7 @@ class TestRpmDriver(base.TestCase): "Packages/test1.rpm", package.filename ) self.assertItemsEqual( - ['test2 (eq 0-1.1.1.1-1.el7)'], + ['test2 (= 0-1.1.1.1-1.el7)'], (str(x) for x in package.requires) ) self.assertItemsEqual( @@ -165,7 +166,7 @@ class TestRpmDriver(base.TestCase): self.assertTrue(package.mandatory) @mock.patch("packetary.drivers.rpm_driver.shutil") - def test_rebuild_repository(self, shutil): + def test_add_packages(self, shutil): self.createrepo.MDError = ValueError self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [ None, self.createrepo.MDError() @@ -174,7 +175,7 @@ class TestRpmDriver(base.TestCase): self.createrepo.MetaDataConfig().outputdir = "/repo/os/x86_64" self.createrepo.MetaDataConfig().tempdir = "tmp" - self.driver.rebuild_repository(repo, set()) + self.driver.add_packages(self.connection, repo, set()) self.assertEqual( "/repo/os/x86_64", @@ -189,24 +190,23 @@ class TestRpmDriver(base.TestCase): .doFinalMove.assert_called_once_with() with self.assertRaises(RuntimeError): - self.driver.rebuild_repository(repo, set()) + self.driver.add_packages(self.connection, repo, set()) shutil.rmtree.assert_called_once_with( "/repo/os/x86_64/tmp", ignore_errors=True ) - @mock.patch("packetary.drivers.rpm_driver.utils") - def test_fork_repository(self, utils): + @mock.patch("packetary.drivers.rpm_driver.utils.ensure_dir_exist") + def test_fork_repository(self, ensure_dir_exists_mock): repo = gen_repository("os", url="http://localhost/os/x86_64/") - utils.localize_repo_url = localize_repo_url + self.createrepo.MetaDataGenerator().doFinalMove.side_effect = [None] new_repo = self.driver.fork_repository( self.connection, repo, - "/repo" + "/repo/os/x86_64" ) - - utils.ensure_dir_exist.assert_called_once_with("/repo/os/x86_64/") + ensure_dir_exists_mock.assert_called_once_with("/repo/os/x86_64") self.assertEqual(repo.name, new_repo.name) self.assertEqual(repo.architecture, new_repo.architecture) - self.assertEqual("/repo/os/x86_64/", new_repo.url) + self.assertEqual("file:///repo/os/x86_64/", new_repo.url) self.createrepo.MetaDataGenerator()\ .doFinalMove.assert_called_once_with() diff --git a/tox.ini b/tox.ini index 07bdb03..3e02f93 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] minversion = 1.6 -envlist = py34,py27,py26,pep8 +envlist = py34,py27,pep8 skipsdist = True [testenv]