sqlalchemy: use nested transaction when getting/creating event types

If adding the database fails, not using a nested transaction will make the
whole transaction passed from the caller fail. The code does not handle that at
all.

This switches to using a nested transaction, so only the insert is rolled-back
if it fails.

(cherry-picked from e4021dbeac9c3d72497f6de90a748c6b9c8165fc from panko)

Change-Id: I937b844cdaf543b128cd00a7916efd829df26ce0
This commit is contained in:
Julien Danjou 2017-06-19 16:57:05 +02:00
parent 1d2d6106d2
commit ffc9e56c3c
2 changed files with 16 additions and 8 deletions

View File

@ -163,15 +163,13 @@ class Connection(base.Connection):
engine.execute(table.delete())
engine.dispose()
def _get_or_create_event_type(self, event_type, session=None):
def _get_or_create_event_type(self, event_type, session):
"""Check if an event type with the supplied name is already exists.
If not, we create it and return the record. This may result in a flush.
"""
try:
if session is None:
session = self._engine_facade.get_session()
with session.begin(subtransactions=True):
with session.begin(nested=True):
et = session.query(models.EventType).filter(
models.EventType.desc == event_type).first()
if not et:

View File

@ -63,22 +63,32 @@ class EventTypeTest(tests_db.TestBase):
# EventType is a construct specific to sqlalchemy
# Not applicable to other drivers.
def setUp(self):
super(EventTypeTest, self).setUp()
self.session = self.conn._engine_facade.get_session()
self.session.begin()
def test_event_type_exists(self):
et1 = self.event_conn._get_or_create_event_type("foo")
et1 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertTrue(et1.id >= 0)
et2 = self.event_conn._get_or_create_event_type("foo")
et2 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertEqual(et2.id, et1.id)
self.assertEqual(et2.desc, et1.desc)
def test_event_type_unique(self):
et1 = self.event_conn._get_or_create_event_type("foo")
et1 = self.event_conn._get_or_create_event_type("foo", self.session)
self.assertTrue(et1.id >= 0)
et2 = self.event_conn._get_or_create_event_type("blah")
et2 = self.event_conn._get_or_create_event_type("blah", self.session)
self.assertNotEqual(et1.id, et2.id)
self.assertNotEqual(et1.desc, et2.desc)
# Test the method __repr__ returns a string
self.assertTrue(reprlib.repr(et2))
def tearDown(self):
self.session.rollback()
self.session.close()
super(EventTypeTest, self).tearDown()
@tests_db.run_with('sqlite', 'mysql', 'pgsql')
class EventTest(tests_db.TestBase):