diff --git a/etc/glance-api.conf b/etc/glance-api.conf index 9aa8a5ee87..9df88dc342 100644 --- a/etc/glance-api.conf +++ b/etc/glance-api.conf @@ -125,6 +125,11 @@ workers = 1 # The default value for property_protection_rule_format is 'roles'. #property_protection_rule_format = roles +# Specifies how long (in hours) a task is supposed to live in the tasks DB +# after succeeding or failing before getting soft-deleted. +# The default value for task_time_to_live is 48 hours. +# task_time_to_live = 48 + # ================= Syslog Options ============================ # Send logs to syslog (/dev/log) instead of to file specified diff --git a/etc/policy.json b/etc/policy.json index 24709d2f69..310679a2de 100644 --- a/etc/policy.json +++ b/etc/policy.json @@ -22,5 +22,10 @@ "get_members": "", "modify_member": "", - "manage_image_cache": "role:admin" + "manage_image_cache": "role:admin", + + "get_task": "", + "get_tasks": "", + "add_task": "", + "modify_task": "" } diff --git a/glance/api/authorization.py b/glance/api/authorization.py index c659cc5986..bcc23008b3 100644 --- a/glance/api/authorization.py +++ b/glance/api/authorization.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -53,6 +54,24 @@ def proxy_member(context, member): return ImmutableMemberProxy(member) +def is_task_mutable(context, task): + """Return True if the task is mutable in this context.""" + if context.is_admin: + return True + + if context.owner is None: + return False + + return task.owner == context.owner + + +def proxy_task(context, task): + if is_task_mutable(context, task): + return task + else: + return ImmutableTaskProxy(task) + + class ImageRepoProxy(glance.domain.proxy.Repo): def __init__(self, image_repo, context): @@ -294,6 +313,36 @@ class ImmutableMemberProxy(object): updated_at = _immutable_attr('base', 'updated_at') +class ImmutableTaskProxy(object): + def __init__(self, base): + self.base = base + + task_id = _immutable_attr('base', 'task_id') + type = _immutable_attr('base', 'type') + status = _immutable_attr('base', 'status') + input = _immutable_attr('base', 'input') + owner = _immutable_attr('base', 'owner') + message = _immutable_attr('base', 'message') + expires_at = _immutable_attr('base', 'expires_at') + created_at = _immutable_attr('base', 'created_at') + updated_at = _immutable_attr('base', 'updated_at') + + def run(self, executor): + raise NotImplementedError() + + def begin_processing(self): + message = _("You are not permitted to set status on this task.") + raise exception.Forbidden(message) + + def succeed(self, result): + message = _("You are not permitted to set status on this task.") + raise exception.Forbidden(message) + + def fail(self, message): + message = _("You are not permitted to set status on this task.") + raise exception.Forbidden(message) + + class ImageProxy(glance.domain.proxy.Image): def __init__(self, image, context): @@ -308,3 +357,53 @@ class ImageProxy(glance.domain.proxy.Image): else: member_repo = self.image.get_member_repo(**kwargs) return ImageMemberRepoProxy(member_repo, self, self.context) + + +class TaskProxy(glance.domain.proxy.Task): + + def __init__(self, task): + self.task = task + super(TaskProxy, self).__init__(task) + + +class TaskFactoryProxy(glance.domain.proxy.TaskFactory): + + def __init__(self, task_factory, context): + self.task_factory = task_factory + self.context = context + super(TaskFactoryProxy, self).__init__( + task_factory, + proxy_class=TaskProxy, + proxy_kwargs=None + ) + + def new_task(self, task_type, task_input, owner): + #NOTE(nikhil): Unlike Images, Tasks are expected to have owner. + # We currently do not allow even admins to set the owner to None. + if owner is not None and (owner == self.context.owner + or self.context.is_admin): + return super(TaskFactoryProxy, self).new_task( + task_type, + task_input, + owner + ) + else: + message = _("You are not permitted to create this task with " + "owner as: %s") + raise exception.Forbidden(message % owner) + + +class TaskRepoProxy(glance.domain.proxy.Repo): + + def __init__(self, task_repo, context): + self.task_repo = task_repo + self.context = context + super(TaskRepoProxy, self).__init__(task_repo) + + def get(self, task_id): + task = self.task_repo.get(task_id) + return proxy_task(self.context, task) + + def list(self, *args, **kwargs): + tasks = self.task_repo.list(*args, **kwargs) + return [proxy_task(self.context, t) for t in tasks] diff --git a/glance/api/policy.py b/glance/api/policy.py index 0367a93b23..2ba532977a 100644 --- a/glance/api/policy.py +++ b/glance/api/policy.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright (c) 2011 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -353,3 +354,59 @@ class ImageLocationsProxy(object): __delslice__ = _get_checker('delete_image_location', '__delslice__') del _get_checker + + +class TaskProxy(glance.domain.proxy.Task): + + def __init__(self, task, context, policy): + self.task = task + self.context = context + self.policy = policy + super(TaskProxy, self).__init__(task) + + def run(self, executor): + self.base.run(executor) + + +class TaskRepoProxy(glance.domain.proxy.Repo): + + def __init__(self, task_repo, context, policy): + self.context = context + self.policy = policy + self.task_repo = task_repo + proxy_kwargs = {'context': self.context, 'policy': self.policy} + super(TaskRepoProxy, self).__init__( + task_repo, + item_proxy_class=TaskProxy, + item_proxy_kwargs=proxy_kwargs + ) + + def get(self, task_id): + self.policy.enforce(self.context, 'get_task', {}) + return super(TaskRepoProxy, self).get(task_id) + + def list(self, *args, **kwargs): + self.policy.enforce(self.context, 'get_tasks', {}) + return super(TaskRepoProxy, self).list(*args, **kwargs) + + def add(self, task): + self.policy.enforce(self.context, 'add_task', {}) + return super(TaskRepoProxy, self).add(task) + + def save(self, task): + self.policy.enforce(self.context, 'modify_task', {}) + return super(TaskRepoProxy, self).save(task) + + +class TaskFactoryProxy(glance.domain.proxy.TaskFactory): + + def __init__(self, task_factory, context, policy): + self.task_factory = task_factory + self.context = context + self.policy = policy + proxy_kwargs = {'context': self.context, 'policy': self.policy} + super(TaskFactoryProxy, self).__init__( + task_factory, + proxy_class=TaskProxy, + proxy_kwargs=proxy_kwargs + ) diff --git a/glance/common/exception.py b/glance/common/exception.py index 10666cdf2d..99f9fbc201 100644 --- a/glance/common/exception.py +++ b/glance/common/exception.py @@ -298,5 +298,22 @@ class RPCError(GlanceException): message = _("%(cls)s exception was raised in the last rpc call: %(val)s") -class TaskNotFound(GlanceException): +class TaskException(GlanceException): + message = _("An unknown task exception occurred") + + +class TaskNotFound(TaskException, NotFound): message = _("Task with the given id %(task_id)s was not found") + + +class InvalidTaskStatus(TaskException, Invalid): + message = _("Provided status of task is unsupported: %(status)s") + + +class InvalidTaskType(TaskException, Invalid): + message = _("Provided type of task is unsupported: %(type)s") + + +class InvalidTaskStatusTransition(TaskException, Invalid): + message = _("Status transition from %(cur_status)s to" + " %(new_status)s is not allowed") diff --git a/glance/db/__init__.py b/glance/db/__init__.py index 1cef15db9e..2eb1ef98f6 100644 --- a/glance/db/__init__.py +++ b/glance/db/__init__.py @@ -3,6 +3,7 @@ # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # Copyright 2010-2012 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -309,3 +310,87 @@ class ThreadPoolWrapper(object): def unwrap(self): return self.wrapped + + +class TaskRepo(object): + + def _format_task_from_db(self, db_task): + return glance.domain.Task( + task_id=db_task['id'], + type=db_task['type'], + status=db_task['status'], + input=db_task['input'], + result=db_task['result'], + owner=db_task['owner'], + message=db_task['message'], + expires_at=db_task['expires_at'], + created_at=db_task['created_at'], + updated_at=db_task['updated_at'], + ) + + def _format_task_to_db(self, task): + return {'id': task.task_id, + 'type': task.type, + 'status': task.status, + 'input': task.input, + 'result': task.result, + 'owner': task.owner, + 'message': task.message, + 'expires_at': task.expires_at, + 'created_at': task.created_at, + 'updated_at': task.updated_at} + + def __init__(self, context, db_api): + self.context = context + self.db_api = db_api + + def get(self, task_id): + try: + db_api_task = self.db_api.task_get(self.context, task_id) + except (exception.NotFound, exception.Forbidden): + msg = _('Could not find task %s') % task_id + raise exception.NotFound(msg) + return self._format_task_from_db(db_api_task) + + def list(self, + marker=None, + limit=None, + sort_key='created_at', + sort_dir='desc', + filters=None): + db_api_tasks = self.db_api.task_get_all(self.context, + filters=filters, + marker=marker, + limit=limit, + sort_key=sort_key, + sort_dir=sort_dir) + return [self._format_task_from_db(task) for task in db_api_tasks] + + def save(self, task): + task_values = self._format_task_to_db(task) + try: + updated_values = self.db_api.task_update(self.context, + task.task_id, + task_values) + except (exception.NotFound, exception.Forbidden): + msg = _('Could not find task %s') % task.task_id + raise exception.NotFound(msg) + task.updated_at = updated_values['updated_at'] + + def add(self, task): + task_values = self._format_task_to_db(task) + updated_values = self.db_api.task_create(self.context, task_values) + task.created_at = updated_values['created_at'] + task.updated_at = updated_values['updated_at'] + + def remove(self, task): + task_values = self._format_task_to_db(task) + try: + self.db_api.task_update(self.context, task.task_id, task_values) + updated_values = self.db_api.task_delete(self.context, + task.task_id) + except (exception.NotFound, exception.Forbidden): + msg = _('Could not find task %s') % task.task_id + raise exception.NotFound(msg) + task.updated_at = updated_values['updated_at'] + task.deleted_at = updated_values['deleted_at'] diff --git a/glance/db/sqlalchemy/api.py b/glance/db/sqlalchemy/api.py index 0a5042c237..1a6cc80fb5 100644 --- a/glance/db/sqlalchemy/api.py +++ b/glance/db/sqlalchemy/api.py @@ -1194,6 +1194,7 @@ def task_delete(context, task_id, session=None): raise exception.TaskNotFound(task_id=task_id) task_ref.delete(session=session) + return _task_format(task_ref) def task_get_all(context, filters=None, marker=None, limit=None, diff --git a/glance/domain/__init__.py b/glance/domain/__init__.py index 5947c1c3ff..ddbe70ed27 100644 --- a/glance/domain/__init__.py +++ b/glance/domain/__init__.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -14,13 +15,19 @@ # under the License. import collections +import datetime + from oslo.config import cfg from glance.common import exception +import glance.openstack.common.log as logging from glance.openstack.common import timeutils from glance.openstack.common import uuidutils +LOG = logging.getLogger(__name__) + + image_format_opts = [ cfg.ListOpt('container_formats', default=['ami', 'ari', 'aki', 'bare', 'ovf'], @@ -31,6 +38,10 @@ image_format_opts = [ 'vdi', 'iso'], help=_("Supported values for the 'disk_format' " "image attribute")), + cfg.IntOpt('task_time_to_live', + default=48, + help=_("Time in hours for which a task lives after, either " + "succeeding or failing")), ] @@ -216,3 +227,117 @@ class ImageMemberFactory(object): return ImageMembership(image_id=image.image_id, member_id=member_id, created_at=created_at, updated_at=updated_at, status='pending') + + +class Task(object): + _supported_task_type = ('import',) + + _supported_task_status = ('pending', 'processing', 'success', 'failure') + + def __init__(self, task_id, type, status, input, result, owner, message, + expires_at, created_at, updated_at): + + if type not in self._supported_task_type: + raise exception.InvalidTaskType(type) + + if status not in self._supported_task_status: + raise exception.InvalidTaskStatus(status) + + self.task_id = task_id + self._status = status + self.type = type + self.input = input + self.result = result + self.owner = owner + self.message = message + self.expires_at = expires_at + # NOTE(nikhil): We use '_time_to_live' to determine how long a + # task should live from the time it succeeds or fails. + self._time_to_live = datetime.timedelta(hours=CONF.task_time_to_live) + self.created_at = created_at + self.updated_at = updated_at + + @property + def status(self): + return self._status + + def run(self, executor): + # NOTE(flwang): The task status won't be set here but handled by the + # executor. + # NOTE(nikhil): Ideally, a task should always be instantiated with an + # executor. However, we need to make that a part of the framework + # and we are planning to add such logic when Controller would + # be introduced. + if executor: + executor.run(self.task_id) + + def _validate_task_status_transition(self, cur_status, new_status): + valid_transitions = { + 'pending': ['processing', 'failure'], + 'processing': ['success', 'failure'], + 'success': [], + 'failure': [], + } + + if new_status in valid_transitions[cur_status]: + return True + else: + return False + + def _set_task_status(self, new_status): + if self._validate_task_status_transition(self.status, new_status): + self._status = new_status + log_msg = (_("Task status changed from %(cur_status)s to " + "%(new_status)s") % {'cur_status': self.status, + 'new_status': new_status}) + LOG.info(log_msg) + else: + log_msg = (_("Task status failed to change from %(cur_status)s " + "to %(new_status)s") % {'cur_status': self.status, + 'new_status': new_status}) + LOG.error(log_msg) + raise exception.InvalidTaskStatusTransition( + cur_status=self.status, + new_status=new_status + ) + + def begin_processing(self): + new_status = 'processing' + self._set_task_status(new_status) + + def succeed(self, result): + new_status = 'success' + self.result = result + self._set_task_status(new_status) + self.expires_at = timeutils.utcnow() + self._time_to_live + + def fail(self, message): + new_status = 'failure' + self.message = message + self._set_task_status(new_status) + self.expires_at = timeutils.utcnow() + self._time_to_live + + +class TaskFactory(object): + def new_task(self, task_type, task_input, owner): + task_id = uuidutils.generate_uuid() + status = 'pending' + result = None + message = None + # Note(nikhil): expires_at would be set on the task, only when it + # succeeds or fails. + expires_at = None + created_at = timeutils.utcnow() + updated_at = created_at + return Task( + task_id, + task_type, + status, + task_input, + result, + owner, + message, + expires_at, + created_at, + updated_at + ) diff --git a/glance/domain/proxy.py b/glance/domain/proxy.py index ce6359147e..12ba814da3 100644 --- a/glance/domain/proxy.py +++ b/glance/domain/proxy.py @@ -1,4 +1,5 @@ # Copyright 2013 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -130,3 +131,41 @@ class Image(object): def get_member_repo(self): return self.helper.proxy(self.base.get_member_repo()) + + +class Task(object): + def __init__(self, base): + self.base = base + + task_id = _proxy('base', 'task_id') + type = _proxy('base', 'type') + status = _proxy('base', 'status') + input = _proxy('base', 'input') + result = _proxy('base', 'result') + owner = _proxy('base', 'owner') + message = _proxy('base', 'message') + expires_at = _proxy('base', 'expires_at') + created_at = _proxy('base', 'created_at') + updated_at = _proxy('base', 'updated_at') + + def run(self, executor): + raise NotImplementedError() + + def begin_processing(self): + self.base.begin_processing() + + def succeed(self, result): + self.base.succeed(result) + + def fail(self, message): + self.base.fail(message) + + +class TaskFactory(object): + def __init__(self, base, proxy_class=None, proxy_kwargs=None): + self.helper = Helper(proxy_class, proxy_kwargs) + self.base = base + + def new_task(self, task_type, task_input, owner): + t = self.base.new_task(task_type, task_input, owner) + return self.helper.proxy(t) diff --git a/glance/gateway.py b/glance/gateway.py index f550c1ec91..e29f51e991 100644 --- a/glance/gateway.py +++ b/glance/gateway.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -85,3 +86,23 @@ class Gateway(object): notifier_image_repo, context) return authorized_image_repo + + def get_task_factory(self, context): + task_factory = glance.domain.TaskFactory() + policy_task_factory = policy.TaskFactoryProxy( + task_factory, context, self.policy) + notifier_task_factory = glance.notifier.TaskFactoryProxy( + policy_task_factory, context, self.notifier) + authorized_task_factory = authorization.TaskFactoryProxy( + notifier_task_factory, context) + return authorized_task_factory + + def get_task_repo(self, context): + task_repo = glance.db.TaskRepo(context, self.db_api) + policy_task_repo = policy.TaskRepoProxy( + task_repo, context, self.policy) + notifier_task_repo = glance.notifier.TaskRepoProxy( + policy_task_repo, context, self.notifier) + authorized_task_repo = authorization.TaskRepoProxy( + notifier_task_repo, context) + return authorized_task_repo diff --git a/glance/notifier/__init__.py b/glance/notifier/__init__.py index c631f85873..126a5c461d 100644 --- a/glance/notifier/__init__.py +++ b/glance/notifier/__init__.py @@ -2,6 +2,7 @@ # Copyright 2011, OpenStack Foundation # Copyright 2012, Red Hat, Inc. +# Copyright 2013 IBM Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -125,6 +126,23 @@ def format_image_notification(image): } +def format_task_notification(task): + # NOTE(nikhil): input is not passed to the notifier payload as it may + # contain sensitive info. + return {'id': task.task_id, + 'type': task.type, + 'status': task.status, + 'result': None, + 'owner': task.owner, + 'message': None, + 'expires_at': timeutils.isotime(task.expires_at), + 'created_at': timeutils.isotime(task.created_at), + 'updated_at': timeutils.isotime(task.updated_at), + 'deleted': False, + 'deleted_at': None, + } + + class ImageRepoProxy(glance.domain.proxy.Repo): def __init__(self, image_repo, context, notifier): @@ -246,3 +264,64 @@ class ImageProxy(glance.domain.proxy.Image): payload = format_image_notification(self.image) self.notifier.info('image.upload', payload) self.notifier.info('image.activate', payload) + + +class TaskRepoProxy(glance.domain.proxy.Repo): + + def __init__(self, task_repo, context, notifier): + self.task_repo = task_repo + self.context = context + self.notifier = notifier + proxy_kwargs = {'context': self.context, 'notifier': self.notifier} + super(TaskRepoProxy, self).__init__(task_repo, + item_proxy_class=TaskProxy, + item_proxy_kwargs=proxy_kwargs) + + def add(self, task): + self.notifier.info('task.create', format_task_notification(task)) + return super(TaskRepoProxy, self).add(task) + + def remove(self, task): + payload = format_task_notification(task) + payload['deleted'] = True + payload['deleted_at'] = timeutils.isotime() + self.notifier.info('task.delete', payload) + return super(TaskRepoProxy, self).add(task) + + +class TaskFactoryProxy(glance.domain.proxy.TaskFactory): + def __init__(self, factory, context, notifier): + kwargs = {'context': context, 'notifier': notifier} + super(TaskFactoryProxy, self).__init__(factory, + proxy_class=TaskProxy, + proxy_kwargs=kwargs) + + +class TaskProxy(glance.domain.proxy.Task): + + def __init__(self, task, context, notifier): + self.task = task + self.context = context + self.notifier = notifier + super(TaskProxy, self).__init__(task) + + def run(self, executor): + self.notifier.info('task.run', format_task_notification(self.task)) + return super(TaskProxy, self).run(executor) + + def begin_processing(self): + self.notifier.info( + 'task.processing', + format_task_notification(self.task) + ) + return super(TaskProxy, self).begin_processing() + + def succeed(self, result): + self.notifier.info('task.success', + format_task_notification(self.task)) + return super(TaskProxy, self).succeed(result) + + def fail(self, message): + self.notifier.info('task.failure', + format_task_notification(self.task)) + return super(TaskProxy, self).fail(message) diff --git a/glance/tests/etc/policy.json b/glance/tests/etc/policy.json index fe4b2c317a..4d27cdc385 100644 --- a/glance/tests/etc/policy.json +++ b/glance/tests/etc/policy.json @@ -23,5 +23,10 @@ "get_members": "", "modify_member": "", - "manage_image_cache": "" + "manage_image_cache": "", + + "get_task": "", + "get_tasks": "", + "add_task": "", + "modify_task": "" } diff --git a/glance/tests/unit/test_auth.py b/glance/tests/unit/test_auth.py index 8206246b0e..f5b20730e6 100644 --- a/glance/tests/unit/test_auth.py +++ b/glance/tests/unit/test_auth.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2011 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -26,6 +27,7 @@ from glance.common import exception import glance.domain from glance.openstack.common import timeutils from glance.tests import utils +from glance.tests.unit import utils as unittest_utils TENANT1 = '6838eb7b-6ded-434a-882c-b344c77fe8df' @@ -857,3 +859,206 @@ class TestImageRepoProxy(utils.BaseTestCase): setattr, images[1], 'name', 'Wally') self.assertRaises(exception.Forbidden, setattr, images[2], 'name', 'Calvin') + + +class TestImmutableTask(utils.BaseTestCase): + def setUp(self): + super(TestImmutableTask, self).setUp() + task_factory = glance.domain.TaskFactory() + self.context = glance.context.RequestContext(tenant=TENANT2) + task_type = 'import' + task_input = '{"loc": "fake"}' + owner = TENANT2 + task = task_factory.new_task(task_type, task_input, owner) + self.task = authorization.ImmutableTaskProxy(task) + + def _test_change(self, attr, value): + self.assertRaises( + exception.Forbidden, + setattr, + self.task, + attr, + value + ) + self.assertRaises( + exception.Forbidden, + delattr, + self.task, + attr + ) + + def test_change_id(self): + self._test_change('task_id', UUID2) + + def test_change_type(self): + self._test_change('type', 'fake') + + def test_change_status(self): + self._test_change('status', 'success') + + def test_change_input(self): + self._test_change('input', {'foo': 'bar'}) + + def test_change_owner(self): + self._test_change('owner', 'fake') + + def test_change_message(self): + self._test_change('message', 'fake') + + def test_change_expires_at(self): + self._test_change('expires_at', 'fake') + + def test_change_created_at(self): + self._test_change('created_at', 'fake') + + def test_change_updated_at(self): + self._test_change('updated_at', 'fake') + + def test_run(self): + self.assertRaises( + NotImplementedError, + self.task.run, + 'executor' + ) + + def test_begin_processing(self): + self.assertRaises( + exception.Forbidden, + self.task.begin_processing + ) + + def test_succeed(self): + self.assertRaises( + exception.Forbidden, + self.task.succeed, + 'result' + ) + + def test_fail(self): + self.assertRaises( + exception.Forbidden, + self.task.fail, + 'message' + ) + + +class TestTaskFactoryProxy(utils.BaseTestCase): + def setUp(self): + super(TestTaskFactoryProxy, self).setUp() + factory = glance.domain.TaskFactory() + self.context = glance.context.RequestContext(tenant=TENANT1) + self.context_owner_is_none = glance.context.RequestContext() + self.task_factory = authorization.TaskFactoryProxy( + factory, + self.context + ) + self.task_type = 'import' + self.task_input = '{"loc": "fake"}' + self.owner = 'foo' + + self.request1 = unittest_utils.get_fake_request(tenant=TENANT1) + self.request2 = unittest_utils.get_fake_request(tenant=TENANT2) + + def test_task_create_default_owner(self): + owner = self.request1.context.owner + task = self.task_factory.new_task( + self.task_type, + self.task_input, + owner + ) + self.assertEqual(task.owner, TENANT1) + + def test_task_create_wrong_owner(self): + self.assertRaises( + exception.Forbidden, + self.task_factory.new_task, + self.task_type, + self.task_input, + self.owner + ) + + def test_task_create_owner_as_None(self): + self.assertRaises( + exception.Forbidden, + self.task_factory.new_task, + self.task_type, + self.task_input, + None + ) + + def test_task_create_admin_context_owner_as_None(self): + self.context.is_admin = True + self.assertRaises( + exception.Forbidden, + self.task_factory.new_task, + self.task_type, + self.task_input, + None + ) + + +class TestTaskRepoProxy(utils.BaseTestCase): + + class TaskRepoStub(object): + def __init__(self, fixtures): + self.fixtures = fixtures + + def get(self, task_id): + for f in self.fixtures: + if f.task_id == task_id: + return f + else: + raise ValueError(task_id) + + def list(self, *args, **kwargs): + return self.fixtures + + def setUp(self): + super(TestTaskRepoProxy, self).setUp() + task_factory = glance.domain.TaskFactory() + task_type = 'import' + task_input = '{"loc": "fake"}' + owner = None + self.fixtures = [ + task_factory.new_task(task_type, task_input, owner), + task_factory.new_task(task_type, task_input, owner), + task_factory.new_task(task_type, task_input, owner), + ] + self.context = glance.context.RequestContext(tenant=TENANT1) + task_repo = self.TaskRepoStub(self.fixtures) + self.task_repo = authorization.TaskRepoProxy( + task_repo, + self.context + ) + + def test_get_mutable_task(self): + task = self.task_repo.get(self.fixtures[0].task_id) + self.assertEqual(task.task_id, self.fixtures[0].task_id) + + def test_get_immutable_task(self): + task = self.task_repo.get(self.fixtures[1].task_id) + self.assertRaises( + exception.Forbidden, + setattr, + task, + 'input', + 'foo' + ) + + def test_list(self): + tasks = self.task_repo.list() + self.assertEqual(tasks[0].task_id, self.fixtures[0].task_id) + self.assertRaises( + exception.Forbidden, + setattr, + tasks[1], + 'input', + 'foo' + ) + self.assertRaises( + exception.Forbidden, + setattr, + tasks[2], + 'input', + 'foo' + ) diff --git a/glance/tests/unit/test_db.py b/glance/tests/unit/test_db.py index a84b9ba606..42284e0917 100644 --- a/glance/tests/unit/test_db.py +++ b/glance/tests/unit/test_db.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation. +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -19,7 +20,7 @@ from glance.common import crypt from glance.common import exception import glance.context import glance.db -from glance.openstack.common import uuidutils +from glance.openstack.common import uuidutils, timeutils import glance.tests.unit.utils as unit_test_utils import glance.tests.utils as test_utils @@ -79,6 +80,21 @@ def _db_image_member_fixture(image_id, member_id, **kwargs): return obj +def _db_task_fixture(task_id, type, status, **kwargs): + obj = { + 'id': task_id, + 'type': type, + 'status': status, + 'input': None, + 'result': None, + 'owner': None, + 'message': None, + 'deleted': False, + } + obj.update(kwargs) + return obj + + class TestImageRepo(test_utils.BaseTestCase): def setUp(self): @@ -494,3 +510,135 @@ class TestImageMemberRepo(test_utils.BaseTestCase): self.image_member_repo.remove, fake_member) self.assertTrue(fake_uuid in unicode(exc)) + + +class TestTaskRepo(test_utils.BaseTestCase): + + def setUp(self): + super(TestTaskRepo, self).setUp() + self.db = unit_test_utils.FakeDB() + self.db.reset() + self.context = glance.context.RequestContext(user=USER1, + tenant=TENANT1) + self.task_repo = glance.db.TaskRepo(self.context, self.db) + self.task_factory = glance.domain.TaskFactory() + self.fake_task_input = ('{"import_from": ' + '"swift://cloud.foo/account/mycontainer/path"' + ',"image_from_format": "qcow2"}') + self._create_tasks() + + def _create_tasks(self): + self.db.reset() + self.tasks = [ + _db_task_fixture(UUID1, type='import', status='pending', + input=self.fake_task_input, + result='', + owner=TENANT1, + message='', + ), + _db_task_fixture(UUID2, type='import', status='processing', + input=self.fake_task_input, + result='', + owner=TENANT1, + message='', + ), + _db_task_fixture(UUID3, type='import', status='failure', + input=self.fake_task_input, + result='', + owner=TENANT1, + message='', + ), + _db_task_fixture(UUID4, type='import', status='success', + input=self.fake_task_input, + result='', + owner=TENANT2, + message='', + ), + ] + [self.db.task_create(None, task) for task in self.tasks] + + def test_get(self): + task = self.task_repo.get(UUID1) + self.assertEqual(task.task_id, UUID1) + self.assertEqual(task.type, 'import') + self.assertEqual(task.status, 'pending') + self.assertEqual(task.input, self.fake_task_input) + self.assertEqual(task.result, '') + self.assertEqual(task.owner, TENANT1) + + def test_get_not_found(self): + self.assertRaises(exception.NotFound, self.task_repo.get, + uuidutils.generate_uuid()) + + def test_get_forbidden(self): + self.assertRaises(exception.NotFound, self.task_repo.get, UUID4) + + def test_list(self): + tasks = self.task_repo.list() + task_ids = set([i.task_id for i in tasks]) + self.assertEqual(set([UUID1, UUID2, UUID3]), task_ids) + + def test_list_with_type(self): + filters = {'type': 'import'} + tasks = self.task_repo.list(filters=filters) + task_ids = set([i.task_id for i in tasks]) + self.assertEqual(set([UUID1, UUID2, UUID3]), task_ids) + + def test_list_with_status(self): + filters = {'status': 'failure'} + tasks = self.task_repo.list(filters=filters) + task_ids = set([i.task_id for i in tasks]) + self.assertEqual(set([UUID3]), task_ids) + + def test_list_with_marker(self): + full_tasks = self.task_repo.list() + full_ids = [i.task_id for i in full_tasks] + marked_tasks = self.task_repo.list(marker=full_ids[0]) + actual_ids = [i.task_id for i in marked_tasks] + self.assertEqual(actual_ids, full_ids[1:]) + + def test_list_with_last_marker(self): + tasks = self.task_repo.list() + marked_tasks = self.task_repo.list(marker=tasks[-1].task_id) + self.assertEqual(len(marked_tasks), 0) + + def test_limited_list(self): + limited_tasks = self.task_repo.list(limit=2) + self.assertEqual(len(limited_tasks), 2) + + def test_list_with_marker_and_limit(self): + full_tasks = self.task_repo.list() + full_ids = [i.task_id for i in full_tasks] + marked_tasks = self.task_repo.list(marker=full_ids[0], limit=1) + actual_ids = [i.task_id for i in marked_tasks] + self.assertEqual(actual_ids, full_ids[1:2]) + + def test_sorted_list(self): + tasks = self.task_repo.list(sort_key='status', sort_dir='desc') + task_ids = [i.task_id for i in tasks] + self.assertEqual([UUID2, UUID1, UUID3], task_ids) + + def test_add_task(self): + task_type = 'import' + task = self.task_factory.new_task(task_type, self.fake_task_input, + None) + self.assertEqual(task.updated_at, task.created_at) + self.task_repo.add(task) + retrieved_task = self.task_repo.get(task.task_id) + self.assertEqual(retrieved_task.updated_at, task.updated_at) + + def test_save_task(self): + task = self.task_repo.get(UUID1) + original_update_time = task.updated_at + self.task_repo.save(task) + current_update_time = task.updated_at + self.assertTrue(current_update_time > original_update_time) + task = self.task_repo.get(UUID1) + self.assertEqual(task.updated_at, current_update_time) + + def test_remove_task(self): + task = self.task_repo.get(UUID1) + self.task_repo.remove(task) + self.assertRaises(exception.NotFound, + self.task_repo.get, + task.task_id) diff --git a/glance/tests/unit/test_domain.py b/glance/tests/unit/test_domain.py index 686e3560df..095819ff04 100644 --- a/glance/tests/unit/test_domain.py +++ b/glance/tests/unit/test_domain.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation. +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,9 +14,18 @@ # License for the specific language governing permissions and limitations # under the License. +import datetime + +from oslo.config import cfg + from glance.common import exception from glance import domain +from glance.openstack.common import uuidutils, timeutils import glance.tests.utils as test_utils +import glance.tests.unit.utils as unittest_utils + + +CONF = cfg.CONF UUID1 = 'c80a1a6c-bd1f-41c5-90ee-81afedb1d58d' @@ -275,3 +285,147 @@ class TestExtraProperties(test_utils.BaseTestCase): extra_properties = domain.ExtraProperties(a_dict) random_list = ['foo', 'bar'] self.assertFalse(extra_properties.__eq__(random_list)) + + +class TestTaskFactory(test_utils.BaseTestCase): + + def setUp(self): + super(TestTaskFactory, self).setUp() + self.task_factory = domain.TaskFactory() + + def test_new_task(self): + task_type = 'import' + task_input = '{"import_from": "fake"}' + owner = TENANT1 + task = self.task_factory.new_task(task_type, task_input, owner) + self.assertTrue(task.task_id is not None) + self.assertTrue(task.created_at is not None) + self.assertEqual(task.created_at, task.updated_at) + self.assertEqual(task.status, 'pending') + self.assertEqual(task.owner, TENANT1) + self.assertEqual(task.input, '{"import_from": "fake"}') + + def test_new_task_invalid_type(self): + task_type = 'blah' + task_input = '{"import_from": "fake"}' + owner = TENANT1 + self.assertRaises( + exception.InvalidTaskType, + self.task_factory.new_task, + task_type, + task_input, + owner, + ) + + +class TestTask(test_utils.BaseTestCase): + + def setUp(self): + super(TestTask, self).setUp() + self.task_factory = domain.TaskFactory() + task_type = 'import' + task_input = ('{"import_from": "file:///home/a.img",' + ' "import_from_format": "qcow2"}') + owner = TENANT1 + self.gateway = unittest_utils.FakeGateway() + self.task = self.task_factory.new_task(task_type, task_input, owner) + + def test_task_invalid_status(self): + task_id = uuidutils.generate_uuid() + status = 'blah' + self.assertRaises( + exception.InvalidTaskStatus, + domain.Task, + task_id, + type='import', + status=status, + input=None, + result=None, + owner=None, + message=None, + expires_at=None, + created_at=timeutils.utcnow(), + updated_at=timeutils.utcnow() + ) + + def test_validate_status_transition_from_pending(self): + self.task.begin_processing() + self.assertEqual(self.task.status, 'processing') + + def test_validate_status_transition_from_processing_to_success(self): + self.task.begin_processing() + self.task.succeed('') + self.assertEqual(self.task.status, 'success') + + def test_validate_status_transition_from_processing_to_failure(self): + self.task.begin_processing() + self.task.fail('') + self.assertEqual(self.task.status, 'failure') + + def test_invalid_status_transitions_from_pending(self): + #test do not allow transition from pending to success + self.assertRaises( + exception.InvalidTaskStatusTransition, + self.task.succeed, + '' + ) + + def test_invalid_status_transitions_from_success(self): + #test do not allow transition from success to processing + self.task.begin_processing() + self.task.succeed('') + self.assertRaises( + exception.InvalidTaskStatusTransition, + self.task.begin_processing + ) + #test do not allow transition from success to failure + self.assertRaises( + exception.InvalidTaskStatusTransition, + self.task.fail, + '' + ) + + def test_invalid_status_transitions_from_failure(self): + #test do not allow transition from failure to processing + self.task.begin_processing() + self.task.fail('') + self.assertRaises( + exception.InvalidTaskStatusTransition, + self.task.begin_processing + ) + #test do not allow transition from failure to success + self.assertRaises( + exception.InvalidTaskStatusTransition, + self.task.succeed, + '' + ) + + def test_begin_processing(self): + self.task.begin_processing() + self.assertEqual(self.task.status, 'processing') + + def test_succeed(self): + timeutils.set_time_override() + self.task.begin_processing() + self.task.succeed('{"location": "file://home"}') + self.assertEqual(self.task.status, 'success') + expected = (timeutils.utcnow() + + datetime.timedelta(hours=CONF.task_time_to_live)) + self.assertEqual( + self.task.expires_at, + expected + ) + timeutils.clear_time_override() + + def test_fail(self): + timeutils.set_time_override() + self.task.begin_processing() + self.task.fail('{"message": "connection failed"}') + self.assertEqual(self.task.status, 'failure') + expected = (timeutils.utcnow() + + datetime.timedelta(hours=CONF.task_time_to_live)) + self.assertEqual( + self.task.expires_at, + expected + ) + timeutils.clear_time_override() diff --git a/glance/tests/unit/test_domain_proxy.py b/glance/tests/unit/test_domain_proxy.py index af0ee632a2..677baef56f 100644 --- a/glance/tests/unit/test_domain_proxy.py +++ b/glance/tests/unit/test_domain_proxy.py @@ -1,4 +1,5 @@ # Copyright 2013 OpenStack Foundation. +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,6 +14,8 @@ # License for the specific language governing permissions and limitations # under the License. +import mock + from glance.domain import proxy import glance.tests.utils as test_utils @@ -279,3 +282,49 @@ class TestImage(test_utils.BaseTestCase): member_repo = proxy_image.get_member_repo() self.assertTrue(isinstance(member_repo, FakeProxy)) self.assertEqual(member_repo.base, 'corn') + + +class TestTaskFactory(test_utils.BaseTestCase): + def setUp(self): + super(TestTaskFactory, self).setUp() + self.factory = mock.Mock() + self.fake_type = 'import' + self.fake_input = "fake input" + self.fake_owner = "owner" + + def test_proxy_plain(self): + proxy_factory = proxy.TaskFactory(self.factory) + + proxy_factory.new_task( + self.fake_type, + self.fake_input, + self.fake_owner + ) + + self.factory.new_task.assert_called_once_with( + self.fake_type, + self.fake_input, + self.fake_owner + ) + + def test_proxy_wrapping(self): + proxy_factory = proxy.TaskFactory( + self.factory, + proxy_class=FakeProxy, + proxy_kwargs={'dog': 'bark'} + ) + self.factory.new_task.return_value = 'fake_task' + + task = proxy_factory.new_task( + self.fake_type, + self.fake_input, + self.fake_owner + ) + + self.factory.new_task.assert_called_once_with( + self.fake_type, + self.fake_input, + self.fake_owner + ) + self.assertTrue(isinstance(task, FakeProxy)) + self.assertEqual(task.base, 'fake_task') diff --git a/glance/tests/unit/test_notifier.py b/glance/tests/unit/test_notifier.py index 490682f365..2a02a4a40b 100644 --- a/glance/tests/unit/test_notifier.py +++ b/glance/tests/unit/test_notifier.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2011 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -30,7 +31,7 @@ from glance.common import exception import glance.context from glance import notifier from glance.notifier import notify_kombu -from glance.openstack.common import importutils +from glance.openstack.common import importutils, timeutils import glance.openstack.common.log as logging import glance.tests.unit.utils as unit_test_utils from glance.tests import utils @@ -71,6 +72,34 @@ class ImageRepoStub(object): return ['images_from_list'] +class TaskStub(glance.domain.Task): + def run(self): + pass + + def succeed(self, result): + pass + + def fail(self, message): + pass + + +class TaskRepoStub(object): + def remove(self, *args, **kwargs): + return 'task_from_remove' + + def save(self, *args, **kwargs): + return 'task_from_save' + + def add(self, *args, **kwargs): + return 'task_from_add' + + def get(self, *args, **kwargs): + return 'task_from_get' + + def list(self, *args, **kwargs): + return ['tasks_from_list'] + + class TestNotifier(utils.BaseTestCase): def test_invalid_strategy(self): @@ -848,3 +877,127 @@ class RabbitStrategyTestCase(utils.BaseTestCase): self.rabbit_strategy._send_message. \ assert_called_with('fake_msg', 'notifications.warn') self.rabbit_strategy.log_failure.assert_called_with('fake_msg', "WARN") + + +class TestTaskNotifications(utils.BaseTestCase): + """Test Task Notifications work""" + + def setUp(self): + super(TestTaskNotifications, self).setUp() + self.task = TaskStub( + task_id='aaa', + type='import', + status='pending', + input={"loc": "fake"}, + result='', + owner=TENANT2, + message='', + expires_at=None, + created_at=DATETIME, + updated_at=DATETIME + ) + self.context = glance.context.RequestContext( + tenant=TENANT2, + user=USER1 + ) + self.task_repo_stub = TaskRepoStub() + self.notifier = unit_test_utils.FakeNotifier() + self.task_repo_proxy = glance.notifier.TaskRepoProxy( + self.task_repo_stub, + self.context, + self.notifier + ) + self.task_proxy = glance.notifier.TaskProxy( + self.task, + self.context, + self.notifier + ) + timeutils.set_time_override() + + def tearDown(self): + super(TestTaskNotifications, self).tearDown() + timeutils.clear_time_override() + + def test_task_create_notification(self): + self.task_repo_proxy.add(self.task_proxy) + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 1) + output_log = output_logs[0] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.create') + self.assertEqual(output_log['payload']['id'], self.task.task_id) + self.assertEqual( + output_log['payload']['updated_at'], + timeutils.isotime(self.task.updated_at) + ) + self.assertEqual( + output_log['payload']['created_at'], + timeutils.isotime(self.task.created_at) + ) + if 'location' in output_log['payload']: + self.fail('Notification contained location field.') + + def test_task_delete_notification(self): + now = timeutils.isotime() + self.task_repo_proxy.remove(self.task_proxy) + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 1) + output_log = output_logs[0] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.delete') + self.assertEqual(output_log['payload']['id'], self.task.task_id) + self.assertEqual( + output_log['payload']['updated_at'], + timeutils.isotime(self.task.updated_at) + ) + self.assertEqual( + output_log['payload']['created_at'], + timeutils.isotime(self.task.created_at) + ) + self.assertEqual( + output_log['payload']['deleted_at'], + now + ) + if 'location' in output_log['payload']: + self.fail('Notification contained location field.') + + def test_task_run_notification(self): + self.assertRaises( + NotImplementedError, + self.task_proxy.run, + executor=None + ) + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 1) + output_log = output_logs[0] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.run') + self.assertEqual(output_log['payload']['id'], self.task.task_id) + + def test_task_processing_notification(self): + self.task_proxy.begin_processing() + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 1) + output_log = output_logs[0] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.processing') + self.assertEqual(output_log['payload']['id'], self.task.task_id) + + def test_task_success_notification(self): + self.task_proxy.begin_processing() + self.task_proxy.succeed(result=None) + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 2) + output_log = output_logs[1] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.success') + self.assertEqual(output_log['payload']['id'], self.task.task_id) + + def test_task_failure_notification(self): + self.task_proxy.fail(message=None) + output_logs = self.notifier.get_logs() + self.assertEqual(len(output_logs), 1) + output_log = output_logs[0] + self.assertEqual(output_log['notification_type'], 'INFO') + self.assertEqual(output_log['event_type'], 'task.failure') + self.assertEqual(output_log['payload']['id'], self.task.task_id) diff --git a/glance/tests/unit/test_policy.py b/glance/tests/unit/test_policy.py index c959b589fd..4a1102ee4e 100644 --- a/glance/tests/unit/test_policy.py +++ b/glance/tests/unit/test_policy.py @@ -1,4 +1,5 @@ # Copyright 2012 OpenStack Foundation +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -22,6 +23,7 @@ import glance.api.policy from glance.common import exception import glance.context from glance.tests.unit import base +import glance.tests.unit.utils as unit_test_utils from glance.tests import utils as test_utils UUID1 = 'c80a1a6c-bd1f-41c5-90ee-81afedb1d58d' @@ -77,6 +79,31 @@ class MemberRepoStub(object): return 'member_repo_remove' +class TaskRepoStub(object): + def get(self, *args, **kwargs): + return 'task_from_get' + + def add(self, *args, **kwargs): + return 'task_from_add' + + def list(self, *args, **kwargs): + return ['task_from_list_0', 'task_from_list_1'] + + +class TaskStub(object): + def __init__(self, task_id): + self.task_id = task_id + self.status = 'pending' + + def run(self): + self.status = 'processing' + + +class TaskFactoryStub(object): + def new_task(self, *args): + return 'new_task' + + class TestPolicyEnforcer(base.IsolatedUnitTest): def test_policy_file_default_rules_default_location(self): enforcer = glance.api.policy.Enforcer() @@ -334,6 +361,82 @@ class TestMemberPolicy(test_utils.BaseTestCase): self.policy.enforce.assert_called_once_with({}, "delete_member", {}) +class TestTaskPolicy(test_utils.BaseTestCase): + def setUp(self): + self.task_stub = TaskStub(UUID1) + self.task_repo_stub = TaskRepoStub() + self.task_factory_stub = TaskFactoryStub() + self.policy = unit_test_utils.FakePolicyEnforcer() + super(TestTaskPolicy, self).setUp() + + def test_get_task_not_allowed(self): + rules = {"get_task": False} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + self.assertRaises(exception.Forbidden, task_repo.get, UUID1) + + def test_get_task_allowed(self): + rules = {"get_task": True} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + output = task_repo.get(UUID1) + self.assertTrue(isinstance(output, glance.api.policy.TaskProxy)) + self.assertEqual(output.task, 'task_from_get') + + def test_get_tasks_not_allowed(self): + rules = {"get_tasks": False} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + self.assertRaises(exception.Forbidden, task_repo.list) + + def test_get_tasks_allowed(self): + rules = {"get_task": True} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + tasks = task_repo.list() + for i, task in enumerate(tasks): + self.assertTrue(isinstance(task, glance.api.policy.TaskProxy)) + self.assertEqual(task.task, 'task_from_list_%d' % i) + + def test_add_task_not_allowed(self): + rules = {"add_task": False} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + task = glance.api.policy.TaskProxy(self.task_stub, {}, self.policy) + self.assertRaises(exception.Forbidden, task_repo.add, task) + + def test_add_task_allowed(self): + rules = {"add_task": True} + self.policy.set_rules(rules) + task_repo = glance.api.policy.TaskRepoProxy( + self.task_repo_stub, + {}, + self.policy + ) + task = glance.api.policy.TaskProxy(self.task_stub, {}, self.policy) + task_repo.add(task) + + class TestContextPolicyEnforcer(base.IsolatedUnitTest): def _do_test_policy_influence_context_admin(self, policy_admin_role, diff --git a/glance/tests/unit/utils.py b/glance/tests/unit/utils.py index dff87b1b49..0c7b311d64 100644 --- a/glance/tests/unit/utils.py +++ b/glance/tests/unit/utils.py @@ -86,6 +86,7 @@ class FakeDB(object): 'members': [], 'tags': {}, 'locations': [], + 'tasks': {} } def __getattr__(self, key): @@ -204,3 +205,46 @@ class FakeNotifier(object): def get_logs(self): return self.log + + +class FakeGateway(object): + def __init__(self, image_factory=None, image_member_factory=None, + image_repo=None, task_factory=None, task_repo=None): + self.image_factory = image_factory + self.image_member_factory = image_member_factory + self.image_repo = image_repo + self.task_factory = task_factory + self.task_repo = task_repo + + def get_image_factory(self, context): + return self.image_factory + + def get_image_member_factory(self, context): + return self.image_member_factory + + def get_repo(self, context): + return self.image_repo + + def get_task_factory(self, context): + return self.task_factory + + def get_task_repo(self, context): + return self.task_repo + + +class FakeTask(object): + def __init__(self, task_id, type=None, status=None): + self.task_id = task_id + self.type = type + self.message = None + self.input = None + self._status = status + self._executor = None + + def success(self, result): + self.result = result + self._status = 'success' + + def fail(self, message): + self.message = message + self._status = 'failure'