diff --git a/ovsdbapp/api.py b/ovsdbapp/api.py index b0777ced..dd0b9243 100644 --- a/ovsdbapp/api.py +++ b/ovsdbapp/api.py @@ -16,6 +16,11 @@ import abc import contextlib import six +try: + # Python 3 no longer has thread module + import thread # noqa +except ImportError: + import threading as thread @six.add_metaclass(abc.ABCMeta) @@ -67,7 +72,8 @@ class Transaction(object): @six.add_metaclass(abc.ABCMeta) class API(object): def __init__(self): - self._nested_txn = None + # Mapping between a (green)thread and its transaction. + self._nested_txns_map = {} @abc.abstractmethod def create_transaction(self, check_error=False, log_errors=True, **kwargs): @@ -92,16 +98,18 @@ class API(object): :returns: Either a new transaction or an existing one. :rtype: :class:`Transaction` """ - if self._nested_txn: - yield self._nested_txn - else: + cur_thread_id = thread.get_ident() + + try: + yield self._nested_txns_map[cur_thread_id] + except KeyError: with self.create_transaction( check_error, log_errors, **kwargs) as txn: - self._nested_txn = txn + self._nested_txns_map[cur_thread_id] = txn try: yield txn finally: - self._nested_txn = None + del self._nested_txns_map[cur_thread_id] @abc.abstractmethod def db_create(self, table, **col_values): diff --git a/ovsdbapp/tests/unit/test_api.py b/ovsdbapp/tests/unit/test_api.py index 213c4299..09b5e775 100644 --- a/ovsdbapp/tests/unit/test_api.py +++ b/ovsdbapp/tests/unit/test_api.py @@ -14,10 +14,24 @@ import mock import testtools +import time from ovsdbapp import api from ovsdbapp.tests import base +try: + import eventlet + + def create_thread(executable): + eventlet.spawn_n(executable) + +except ImportError: + import threading + + def create_thread(executable): + thread = threading.Thread(target=executable) + thread.start() + class FakeTransaction(object): def __enter__(self): @@ -60,3 +74,30 @@ class TransactionTestCase(base.TestCase): with self.api.transaction() as txn2: self.assertIsNot(txn1, txn2) + + def test_transaction_nested_multiple_threads(self): + shared_resource = [] + + def thread1(): + with self.api.transaction() as txn: + shared_resource.append(txn) + while len(shared_resource) == 1: + time.sleep(0.1) + shared_resource.append(0) + + def thread2(): + while len(shared_resource) != 1: + time.sleep(0.1) + with self.api.transaction() as txn: + shared_resource.append(txn) + shared_resource.append(0) + + create_thread(thread1) + create_thread(thread2) + + while len(shared_resource) != 4: + time.sleep(0.1) + + txn1, txn2 = shared_resource[:2] + + self.assertNotEqual(txn1, txn2)