Merge "Fixes to ip billing after integration testing"

This commit is contained in:
Jenkins 2016-07-15 19:03:27 +00:00 committed by Gerrit Code Review
commit a4e2fd5367
7 changed files with 217 additions and 43 deletions

View File

@ -93,7 +93,7 @@ def do_notify(context, event_type, payload):
@env.has_capability(env.Capabilities.IP_BILLING) @env.has_capability(env.Capabilities.IP_BILLING)
def notify(context, event_type, ipaddress, send_usage=False): def notify(context, event_type, ipaddress, send_usage=False, *args, **kwargs):
"""Method to send notifications. """Method to send notifications.
We must send USAGE when a public IPv4 address is deallocated or a FLIP is We must send USAGE when a public IPv4 address is deallocated or a FLIP is
@ -114,6 +114,12 @@ def notify(context, event_type, ipaddress, send_usage=False):
or (event_type == IP_EXISTS and not CONF.QUARK.notify_ip_exists): or (event_type == IP_EXISTS and not CONF.QUARK.notify_ip_exists):
LOG.debug('IP_BILL: notification {} is disabled by config'. LOG.debug('IP_BILL: notification {} is disabled by config'.
format(event_type)) format(event_type))
return
# Do not send notifications when we are undoing due to an error
if 'rollback' in kwargs and kwargs['rollback']:
LOG.debug('IP_BILL: not sending notification because we are in undo')
return
# ip.add needs the allocated_at time. # ip.add needs the allocated_at time.
# All other events need the current time. # All other events need the current time.
@ -292,10 +298,11 @@ def convert_timestamp(ts):
Examples of a good timestamp for startTime, endTime, and eventTime: Examples of a good timestamp for startTime, endTime, and eventTime:
'2016-05-20T00:00:00Z' '2016-05-20T00:00:00Z'
We must drop microseconds so that Yagi does not get upset.
Note the trailing 'Z'. Python does not add the 'Z' so we tack it on Note the trailing 'Z'. Python does not add the 'Z' so we tack it on
ourselves. ourselves.
""" """
return ts.isoformat() + 'Z' return ts.replace(microsecond=0).isoformat() + 'Z'
def _now(): def _now():

View File

@ -215,13 +215,13 @@ class QuarkIpam(object):
@synchronized(named("allocate_mac_address")) @synchronized(named("allocate_mac_address"))
def allocate_mac_address(self, context, net_id, port_id, reuse_after, def allocate_mac_address(self, context, net_id, port_id, reuse_after,
mac_address=None, mac_address=None,
use_forbidden_mac_range=False): use_forbidden_mac_range=False, **kwargs):
if mac_address: if mac_address:
mac_address = netaddr.EUI(mac_address).value mac_address = netaddr.EUI(mac_address).value
kwargs = {"network_id": net_id, "port_id": port_id, kwargs.update({"network_id": net_id, "port_id": port_id,
"mac_address": mac_address, "mac_address": mac_address,
"use_forbidden_mac_range": use_forbidden_mac_range} "use_forbidden_mac_range": use_forbidden_mac_range})
LOG.info(("Attempting to allocate a new MAC address " LOG.info(("Attempting to allocate a new MAC address "
"[{0}]").format(utils.pretty_kwargs(**kwargs))) "[{0}]").format(utils.pretty_kwargs(**kwargs)))
@ -556,9 +556,6 @@ class QuarkIpam(object):
version=subnet["ip_version"], network_id=net_id, version=subnet["ip_version"], network_id=net_id,
address_type=kwargs.get('address_type', address_type=kwargs.get('address_type',
ip_types.FIXED)) ip_types.FIXED))
# alexm: need to notify from here because this code
# does not go through the _allocate_from_subnet() path.
billing.notify(context, billing.IP_ADD, address)
return address return address
except db_exception.DBDuplicateEntry: except db_exception.DBDuplicateEntry:
# This shouldn't ever happen, since we hold a unique MAC # This shouldn't ever happen, since we hold a unique MAC
@ -701,7 +698,7 @@ class QuarkIpam(object):
if self.is_strategy_satisfied(new_addresses, allocate_complete=True): if self.is_strategy_satisfied(new_addresses, allocate_complete=True):
# Only notify when all went well # Only notify when all went well
for address in new_addresses: for address in new_addresses:
billing.notify(context, billing.IP_ADD, address) billing.notify(context, billing.IP_ADD, address, **kwargs)
LOG.info("IPAM for port ID {0} completed with addresses " LOG.info("IPAM for port ID {0} completed with addresses "
"{1}".format(port_id, "{1}".format(port_id,
[a["address_readable"] [a["address_readable"]
@ -711,14 +708,15 @@ class QuarkIpam(object):
raise ip_address_failure(net_id) raise ip_address_failure(net_id)
def deallocate_ip_address(self, context, address): def deallocate_ip_address(self, context, address, **kwargs):
if address["version"] == 6: if address["version"] == 6:
db_api.ip_address_delete(context, address) db_api.ip_address_delete(context, address)
else: else:
address["deallocated"] = 1 address["deallocated"] = 1
address["address_type"] = None address["address_type"] = None
billing.notify(context, billing.IP_DEL, address, send_usage=True) billing.notify(context, billing.IP_DEL, address, send_usage=True,
**kwargs)
def deallocate_ips_by_port(self, context, port=None, **kwargs): def deallocate_ips_by_port(self, context, port=None, **kwargs):
ips_to_remove = [] ips_to_remove = []
@ -768,7 +766,7 @@ class QuarkIpam(object):
# SQLAlchemy caching. # SQLAlchemy caching.
context.session.add(flip) context.session.add(flip)
context.session.flush() context.session.flush()
billing.notify(context, billing.IP_DISASSOC, flip) billing.notify(context, billing.IP_DISASSOC, flip, **kwargs)
driver = registry.DRIVER_REGISTRY.get_driver() driver = registry.DRIVER_REGISTRY.get_driver()
driver.remove_floating_ip(flip) driver.remove_floating_ip(flip)
elif len(flip.fixed_ips) > 1: elif len(flip.fixed_ips) > 1:
@ -782,7 +780,8 @@ class QuarkIpam(object):
context, flip, fix_ip) context, flip, fix_ip)
context.session.add(flip) context.session.add(flip)
context.session.flush() context.session.flush()
billing.notify(context, billing.IP_DISASSOC, flip) billing.notify(context, billing.IP_DISASSOC, flip,
**kwargs)
else: else:
remaining_fixed_ips.append(fix_ip) remaining_fixed_ips.append(fix_ip)
port_fixed_ips = {} port_fixed_ips = {}
@ -800,7 +799,7 @@ class QuarkIpam(object):
# NCP-1509(roaet): # NCP-1509(roaet):
# - started using admin_context due to tenant not claiming when realloc # - started using admin_context due to tenant not claiming when realloc
def deallocate_mac_address(self, context, address): def deallocate_mac_address(self, context, address, **kwargs):
admin_context = context.elevated() admin_context = context.elevated()
mac = db_api.mac_address_find(admin_context, address=address, mac = db_api.mac_address_find(admin_context, address=address,
scope=db_api.ONE) scope=db_api.ONE)

View File

@ -233,8 +233,8 @@ def create_port(context, port):
with utils.CommandManager().execute() as cmd_mgr: with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do @cmd_mgr.do
def _allocate_ips(fixed_ips, net, port_id, segment_id, mac): def _allocate_ips(fixed_ips, net, port_id, segment_id, mac,
fixed_ip_kwargs = {} **kwargs):
if fixed_ips: if fixed_ips:
if (STRATEGY.is_provider_network(net_id) and if (STRATEGY.is_provider_network(net_id) and
not context.is_admin): not context.is_admin):
@ -244,36 +244,38 @@ def create_port(context, port):
net_id, net_id,
segment_id, segment_id,
fixed_ips) fixed_ips)
fixed_ip_kwargs["ip_addresses"] = ips kwargs["ip_addresses"] = ips
fixed_ip_kwargs["subnets"] = subnets kwargs["subnets"] = subnets
ipam_driver.allocate_ip_address( ipam_driver.allocate_ip_address(
context, addresses, net["id"], port_id, context, addresses, net["id"], port_id,
CONF.QUARK.ipam_reuse_after, segment_id=segment_id, CONF.QUARK.ipam_reuse_after, segment_id=segment_id,
mac_address=mac, **fixed_ip_kwargs) mac_address=mac, **kwargs)
@cmd_mgr.undo @cmd_mgr.undo
def _allocate_ips_undo(addr): def _allocate_ips_undo(addr, **kwargs):
LOG.info("Rolling back IP addresses...") LOG.info("Rolling back IP addresses...")
if addresses: if addresses:
for address in addresses: for address in addresses:
try: try:
with context.session.begin(): with context.session.begin():
ipam_driver.deallocate_ip_address(context, address) ipam_driver.deallocate_ip_address(context, address,
**kwargs)
except Exception: except Exception:
LOG.exception("Couldn't release IP %s" % address) LOG.exception("Couldn't release IP %s" % address)
@cmd_mgr.do @cmd_mgr.do
def _allocate_mac(net, port_id, mac_address, def _allocate_mac(net, port_id, mac_address,
use_forbidden_mac_range=False): use_forbidden_mac_range=False,
**kwargs):
mac = ipam_driver.allocate_mac_address( mac = ipam_driver.allocate_mac_address(
context, net["id"], port_id, CONF.QUARK.ipam_reuse_after, context, net["id"], port_id, CONF.QUARK.ipam_reuse_after,
mac_address=mac_address, mac_address=mac_address,
use_forbidden_mac_range=use_forbidden_mac_range) use_forbidden_mac_range=use_forbidden_mac_range, **kwargs)
return mac return mac
@cmd_mgr.undo @cmd_mgr.undo
def _allocate_mac_undo(mac): def _allocate_mac_undo(mac, **kwargs):
LOG.info("Rolling back MAC address...") LOG.info("Rolling back MAC address...")
if mac: if mac:
try: try:
@ -284,7 +286,7 @@ def create_port(context, port):
LOG.exception("Couldn't release MAC %s" % mac) LOG.exception("Couldn't release MAC %s" % mac)
@cmd_mgr.do @cmd_mgr.do
def _allocate_backend_port(mac, addresses, net, port_id): def _allocate_backend_port(mac, addresses, net, port_id, **kwargs):
backend_port = net_driver.create_port( backend_port = net_driver.create_port(
context, net["id"], context, net["id"],
port_id=port_id, port_id=port_id,
@ -298,7 +300,8 @@ def create_port(context, port):
return backend_port return backend_port
@cmd_mgr.undo @cmd_mgr.undo
def _allocate_back_port_undo(backend_port): def _allocate_back_port_undo(backend_port,
**kwargs):
LOG.info("Rolling back backend port...") LOG.info("Rolling back backend port...")
try: try:
backend_port_uuid = None backend_port_uuid = None
@ -310,7 +313,8 @@ def create_port(context, port):
"Couldn't rollback backend port %s" % backend_port) "Couldn't rollback backend port %s" % backend_port)
@cmd_mgr.do @cmd_mgr.do
def _allocate_db_port(port_attrs, backend_port, addresses, mac): def _allocate_db_port(port_attrs, backend_port, addresses, mac,
**kwargs):
port_attrs["network_id"] = net["id"] port_attrs["network_id"] = net["id"]
port_attrs["id"] = port_id port_attrs["id"] = port_id
port_attrs["security_groups"] = security_groups port_attrs["security_groups"] = security_groups
@ -325,7 +329,8 @@ def create_port(context, port):
return new_port return new_port
@cmd_mgr.undo @cmd_mgr.undo
def _allocate_db_port_undo(new_port): def _allocate_db_port_undo(new_port,
**kwargs):
LOG.info("Rolling back database port...") LOG.info("Rolling back database port...")
if not new_port: if not new_port:
return return

View File

@ -1445,12 +1445,13 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["FOO"].allocate_ip_address.assert_called_once_with( ipam["FOO"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["FOO"].allocate_mac_address.assert_called_once_with( ipam["FOO"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac, rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar",
@ -1497,12 +1498,14 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["FOO"].allocate_ip_address.assert_called_once_with( ipam["FOO"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["FOO"].allocate_mac_address.assert_called_once_with( ipam["FOO"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac,
rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar",
@ -1551,12 +1554,13 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["BAR"].allocate_ip_address.assert_called_once_with( ipam["BAR"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["BAR"].allocate_mac_address.assert_called_once_with( ipam["BAR"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac, rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar",
@ -1601,12 +1605,13 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["FOO"].allocate_ip_address.assert_called_once_with( ipam["FOO"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["FOO"].allocate_mac_address.assert_called_once_with( ipam["FOO"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac, rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar",
@ -1651,12 +1656,14 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["BAR"].allocate_ip_address.assert_called_once_with( ipam["BAR"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["BAR"].allocate_mac_address.assert_called_once_with( ipam["BAR"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac,
rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=self.expected_bridge, uuid=1, name="foobar",
@ -1736,12 +1743,13 @@ class TestPortDriverSelection(test_quark_plugin.TestQuarkPlugin):
ipam["FOO"].allocate_ip_address.assert_called_once_with( ipam["FOO"].allocate_ip_address.assert_called_once_with(
admin_ctx, [], network["id"], 1, admin_ctx, [], network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
segment_id=None, mac_address=mac) segment_id=None, rollback=False, mac_address=mac)
ipam["FOO"].allocate_mac_address.assert_called_once_with( ipam["FOO"].allocate_mac_address.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac, rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=expected_bridge, uuid=5, name="foobar", admin_ctx, bridge=expected_bridge, uuid=5, name="foobar",
@ -1800,7 +1808,8 @@ class TestQuarkPortCreateFiltering(test_quark_plugin.TestQuarkPlugin):
alloc_mac.assert_called_once_with( alloc_mac.assert_called_once_with(
self.context, network["id"], 1, self.context, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=None, use_forbidden_mac_range=False) mac_address=None, rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
self.context, addresses=[], network_id=network["id"], self.context, addresses=[], network_id=network["id"],
tenant_id="fake", uuid=1, name="foobar", tenant_id="fake", uuid=1, name="foobar",
@ -1841,7 +1850,9 @@ class TestQuarkPortCreateFiltering(test_quark_plugin.TestQuarkPlugin):
alloc_mac.assert_called_once_with( alloc_mac.assert_called_once_with(
admin_ctx, network["id"], 1, admin_ctx, network["id"], 1,
cfg.CONF.QUARK.ipam_reuse_after, cfg.CONF.QUARK.ipam_reuse_after,
mac_address=expected_mac, use_forbidden_mac_range=False) mac_address=expected_mac,
rollback=False,
use_forbidden_mac_range=False)
port_create.assert_called_once_with( port_create.assert_called_once_with(
admin_ctx, bridge=expected_bridge, uuid=1, name="foobar", admin_ctx, bridge=expected_bridge, uuid=1, name="foobar",

View File

@ -21,6 +21,7 @@ from quark import billing
from quark.db.models import IPAddress from quark.db.models import IPAddress
from quark import network_strategy from quark import network_strategy
from quark.tests import test_base from quark.tests import test_base
from quark import utils
class QuarkBillingBaseTest(test_base.TestBase): class QuarkBillingBaseTest(test_base.TestBase):
@ -185,3 +186,30 @@ class QuarkBillingEnvironmentCapabilityTest(QuarkBillingBaseTest):
billing.notify(self.context, billing.IP_ADD, ipaddress) billing.notify(self.context, billing.IP_ADD, ipaddress)
self.assertFalse(notifier.called) self.assertFalse(notifier.called)
cfg.CONF.clear_override('environment_capabilities', 'QUARK') cfg.CONF.clear_override('environment_capabilities', 'QUARK')
@mock.patch('neutron.common.rpc.get_notifier')
def test_do_not_notify_in_undo_cmd_mgr(self, notifier):
"""Wraps a call to notify in CommandManager's rollback()"""
cfg.CONF.set_override('environment_capabilities',
'security_groups,ip_billing',
'QUARK')
ipaddress = get_fake_fixed_address()
ipaddress.allocated_at = datetime.datetime.utcnow()
try:
with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do
def f():
raise Exception
@cmd_mgr.undo
def f_undo(*args, **kwargs):
billing.notify(self.context, billing.IP_ADD, ipaddress,
*args, **kwargs)
f()
except Exception:
pass
# notifier should NOT have been called
self.assertFalse(notifier.called)
cfg.CONF.clear_override('environment_capabilities', 'QUARK')

View File

@ -0,0 +1,121 @@
# Copyright (c) 2016 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
from quark.tests import test_base
from quark import utils
def func_do(**kwargs):
if 'rollback' in kwargs and kwargs['rollback']:
return True
else:
return False
def func_undo(**kwargs):
if 'rollback' in kwargs and kwargs['rollback']:
return True
else:
return False
class QuarkCommandManagerTest(test_base.TestBase):
def setUp(self):
super(QuarkCommandManagerTest, self).setUp()
# Using globals here because of callbacks
self.is_rollback_do = False
self.is_rollback_undo = True
@mock.patch('quark.tests.test_command_manager.func_undo')
@mock.patch('quark.tests.test_command_manager.func_do')
def test_command_manager_no_undo(self,
func_do_notifier,
func_undo_notifier):
"""Test that undo is not called when everything is good"""
try:
with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do
def f(**kwargs):
func_do(**kwargs)
@cmd_mgr.undo
def f_undo(*args, **kwargs):
func_undo(**kwargs)
f()
except Exception:
pass
self.assertTrue(func_do_notifier.called)
self.assertFalse(func_undo_notifier.called)
@mock.patch('quark.tests.test_command_manager.func_undo')
def test_command_manager_undo(self, func_undo_notifier):
"""Test that undo is called when the do function raises"""
try:
with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do
def f(**kwargs):
raise Exception
@cmd_mgr.undo
def f_undo(*args, **kwargs):
func_undo(**kwargs)
f()
except Exception:
pass
self.assertTrue(func_undo_notifier.called)
def test_rollback_is_passed_to_do(self):
"""Tests that the do function has rollback set to False"""
self.is_rollback_do = True
try:
with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do
def f(**kwargs):
return func_do(**kwargs)
@cmd_mgr.undo
def f_undo(*args, **kwargs):
func_undo(**kwargs)
self.is_rollback_do = f()
except Exception:
pass
self.assertFalse(self.is_rollback_do)
def test_rollback_is_passed_to_undo(self):
"""Tests that the undo function has rollback set to True"""
self.is_rollback_undo = False
try:
with utils.CommandManager().execute() as cmd_mgr:
@cmd_mgr.do
def f(**kwargs):
raise Exception
@cmd_mgr.undo
def f_undo(*args, **kwargs):
self.is_rollback_undo = func_undo(**kwargs)
f()
except Exception:
pass
self.assertTrue(self.is_rollback_undo)

View File

@ -130,9 +130,11 @@ class Command(object):
self.func = func self.func = func
self.result = None self.result = None
self.called = False self.called = False
self.is_rollback = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self.called = True self.called = True
kwargs['rollback'] = self.is_rollback
self.result = self.func(*args, **kwargs) self.result = self.func(*args, **kwargs)
return self.result return self.result
@ -159,6 +161,7 @@ class CommandManager(object):
def undo(self, func): def undo(self, func):
cmd = Command(func) cmd = Command(func)
cmd.is_rollback = True
self.undo_commands.append(cmd) self.undo_commands.append(cmd)
return cmd return cmd