sqlalchemy: use nested transaction when getting/creating event types

If adding the database fails, not using a nested transaction will make the
whole transaction passed from the caller fail. The code does not handle that at
all.

This switches to using a nested transaction, so only the insert is rolled-back
if it fails.

(cherry-picked from e4021dbeac9c3d72497f6de90a748c6b9c8165fc from panko)

Change-Id: I937b844cdaf543b128cd00a7916efd829df26ce0
This commit is contained in:
Julien Danjou 2017-06-19 16:57:05 +02:00
parent 1d2d6106d2
commit ffc9e56c3c
2 changed files with 16 additions and 8 deletions

View File

@ -163,15 +163,13 @@ class Connection(base.Connection):
engine.execute(table.delete()) engine.execute(table.delete())
engine.dispose() engine.dispose()
def _get_or_create_event_type(self, event_type, session=None): def _get_or_create_event_type(self, event_type, session):
"""Check if an event type with the supplied name is already exists. """Check if an event type with the supplied name is already exists.
If not, we create it and return the record. This may result in a flush. If not, we create it and return the record. This may result in a flush.
""" """
try: try:
if session is None: with session.begin(nested=True):
session = self._engine_facade.get_session()
with session.begin(subtransactions=True):
et = session.query(models.EventType).filter( et = session.query(models.EventType).filter(
models.EventType.desc == event_type).first() models.EventType.desc == event_type).first()
if not et: if not et:

View File

@ -63,22 +63,32 @@ class EventTypeTest(tests_db.TestBase):
# EventType is a construct specific to sqlalchemy # EventType is a construct specific to sqlalchemy
# Not applicable to other drivers. # Not applicable to other drivers.
def setUp(self):
super(EventTypeTest, self).setUp()
self.session = self.conn._engine_facade.get_session()
self.session.begin()
def test_event_type_exists(self): def test_event_type_exists(self):
et1 = self.event_conn._get_or_create_event_type("foo") et1 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertTrue(et1.id >= 0) self.assertTrue(et1.id >= 0)
et2 = self.event_conn._get_or_create_event_type("foo") et2 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertEqual(et2.id, et1.id) self.assertEqual(et2.id, et1.id)
self.assertEqual(et2.desc, et1.desc) self.assertEqual(et2.desc, et1.desc)
def test_event_type_unique(self): def test_event_type_unique(self):
et1 = self.event_conn._get_or_create_event_type("foo") et1 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertTrue(et1.id >= 0) self.assertTrue(et1.id >= 0)
et2 = self.event_conn._get_or_create_event_type("blah") et2 = self.event_conn._get_or_create_event_type("blah", self.session)
self.assertNotEqual(et1.id, et2.id) self.assertNotEqual(et1.id, et2.id)
self.assertNotEqual(et1.desc, et2.desc) self.assertNotEqual(et1.desc, et2.desc)
# Test the method __repr__ returns a string # Test the method __repr__ returns a string
self.assertTrue(reprlib.repr(et2)) self.assertTrue(reprlib.repr(et2))
def tearDown(self):
self.session.rollback()
self.session.close()
super(EventTypeTest, self).tearDown()
@tests_db.run_with('sqlite', 'mysql', 'pgsql') @tests_db.run_with('sqlite', 'mysql', 'pgsql')
class EventTest(tests_db.TestBase): class EventTest(tests_db.TestBase):