Add synchronized_domain decorator

As a temporary fix for bug 1392762, we serialize concurrent
modifications to a domain, preventing the issue described in
the bug. We choose a temporary fix, as pools will make a no
locking fix significantly easier.

Closes-Bug: 1392762
Change-Id: Ifb1bba170983023aedbc63ed3559fe8b28359efa
This commit is contained in:
Kiall Mac Innes 2014-11-17 13:41:59 +00:00
parent 06a7629812
commit 768ee18830
1 changed files with 106 additions and 15 deletions

View File

@ -17,10 +17,13 @@
import re
import contextlib
import functools
import threading
import itertools
from oslo.config import cfg
from oslo import messaging
from oslo.utils import excutils
from oslo.concurrency import lockutils
from designate.openstack.common import log as logging
from designate.i18n import _LI
@ -38,6 +41,7 @@ from designate import storage
LOG = logging.getLogger(__name__)
DOMAIN_LOCKS = threading.local()
@contextlib.contextmanager
@ -69,6 +73,79 @@ def transaction(f):
return wrapper
def synchronized_domain(domain_arg=1, new_domain=False):
"""Ensures only a single operation is in progress for each domain
A Decorator which ensures only a single operation can be happening
on a single domain at once, within the current designate-central instance
"""
def outer(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if not hasattr(DOMAIN_LOCKS, 'held'):
# Create the held set if necessary
DOMAIN_LOCKS.held = set()
domain_id = None
if 'domain_id' in kwargs:
domain_id = kwargs['domain_id']
elif 'domain' in kwargs:
domain_id = kwargs['domain'].id
elif 'recordset' in kwargs:
domain_id = kwargs['recordset'].domain_id
elif 'record' in kwargs:
domain_id = kwargs['record'].domain_id
# The various objects won't always have an ID set, we should
# attempt to locate an Object containing the ID.
if domain_id is None:
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, objects.Domain):
domain_id = arg.id
if domain_id is not None:
break
elif (isinstance(arg, objects.RecordSet) or
isinstance(arg, objects.Record)):
domain_id = arg.domain_id
if domain_id is not None:
break
# If we still don't have an ID, find the Nth argument as
# defined by the domain_arg decorator option.
if domain_id is None and len(args) > domain_arg:
domain_id = args[domain_arg]
if isinstance(domain_id, objects.Domain):
# If the value is a Domain object, extract it's ID.
domain_id = domain_id.id
if not new_domain and domain_id is None:
raise Exception('Failed to determine domain id for '
'synchronized operation')
if domain_id in DOMAIN_LOCKS.held:
# Call the wrapped function
return f(self, *args, **kwargs)
else:
with lockutils.lock('domain-%s' % domain_id):
DOMAIN_LOCKS.held.add(domain_id)
# Call the wrapped function
result = f(self, *args, **kwargs)
DOMAIN_LOCKS.held.remove(domain_id)
return result
return wrapper
return outer
class Service(service.RPCService):
RPC_API_VERSION = '4.2'
@ -335,8 +412,10 @@ class Service(service.RPCService):
'type': "SOA",
'records': recordlist
}
soa = self.create_recordset(context, zone['id'],
objects.RecordSet(**values), False)
soa = self.create_recordset(context,
domain_id=zone['id'],
recordset=objects.RecordSet(**values),
increment_serial=False)
return soa
def _update_soa(self, context, zone):
@ -368,8 +447,10 @@ class Service(service.RPCService):
'type': "NS",
'records': recordlist
}
ns = self.create_recordset(context, zone['id'],
objects.RecordSet(**values), False)
ns = self.create_recordset(context,
domain_id=zone['id'],
recordset=objects.RecordSet(**values),
increment_serial=False)
return ns
@ -684,6 +765,7 @@ class Service(service.RPCService):
return self.storage.count_tenants(context)
# Domain Methods
@synchronized_domain(new_domain=True)
@transaction
def create_domain(self, context, domain):
# TODO(kiall): Refactor this method into *MUCH* smaller chunks.
@ -829,6 +911,7 @@ class Service(service.RPCService):
return self.storage.find_domain(context, criterion)
@synchronized_domain()
@transaction
def update_domain(self, context, domain, increment_serial=True):
# TODO(kiall): Refactor this method into *MUCH* smaller chunks.
@ -880,6 +963,7 @@ class Service(service.RPCService):
return domain
@synchronized_domain()
@transaction
def delete_domain(self, context, domain_id):
domain = self.storage.get_domain(context, domain_id)
@ -940,6 +1024,7 @@ class Service(service.RPCService):
return reports
@synchronized_domain()
@transaction
def touch_domain(self, context, domain_id):
domain = self.storage.get_domain(context, domain_id)
@ -959,6 +1044,7 @@ class Service(service.RPCService):
return domain
# RecordSet Methods
@synchronized_domain()
@transaction
def create_recordset(self, context, domain_id, recordset,
increment_serial=True):
@ -1044,6 +1130,7 @@ class Service(service.RPCService):
return recordset
@synchronized_domain()
@transaction
def update_recordset(self, context, recordset, increment_serial=True):
domain_id = recordset.obj_get_original_value('domain_id')
@ -1099,6 +1186,7 @@ class Service(service.RPCService):
return recordset
@synchronized_domain()
@transaction
def delete_recordset(self, context, domain_id, recordset_id,
increment_serial=True):
@ -1144,6 +1232,7 @@ class Service(service.RPCService):
return self.storage.count_recordsets(context, criterion)
# Record Methods
@synchronized_domain()
@transaction
def create_record(self, context, domain_id, recordset_id, record,
increment_serial=True):
@ -1224,6 +1313,7 @@ class Service(service.RPCService):
return self.storage.find_record(context, criterion)
@synchronized_domain()
@transaction
def update_record(self, context, record, increment_serial=True):
domain_id = record.obj_get_original_value('domain_id')
@ -1272,6 +1362,7 @@ class Service(service.RPCService):
return record
@synchronized_domain()
@transaction
def delete_record(self, context, domain_id, recordset_id, record_id,
increment_serial=True):
@ -1626,9 +1717,9 @@ class Service(service.RPCService):
for record in records:
self.delete_record(
elevated_context,
rset['domain_id'],
rset['id'],
record['id'])
domain_id=rset['domain_id'],
recordset_id=rset['id'],
record_id=record['id'])
self.delete_recordset(elevated_context, zone['id'], rset['id'])
except exceptions.RecordSetNotFound:
pass
@ -1641,8 +1732,8 @@ class Service(service.RPCService):
recordset = self.create_recordset(
elevated_context,
zone['id'],
objects.RecordSet(**recordset_values))
domain_id=zone['id'],
recordset=objects.RecordSet(**recordset_values))
record_values = {
'data': values['ptrdname'],
@ -1657,9 +1748,9 @@ class Service(service.RPCService):
record = self.create_record(
elevated_context,
zone['id'],
recordset['id'],
objects.Record(**record_values))
domain_id=zone['id'],
recordset_id=recordset['id'],
record=objects.Record(**record_values))
mangled = self._format_floatingips(
context, {(region, floatingip_id): (fip, record)},
@ -1693,9 +1784,9 @@ class Service(service.RPCService):
self.delete_record(
elevated_context,
record['domain_id'],
record['recordset_id'],
record['id'])
domain_id=record['domain_id'],
recordset_id=record['recordset_id'],
record_id=record['id'])
@transaction
def update_floatingip(self, context, region, floatingip_id, values):