From efacfce94c52c91b17677188af55cd1c97d0088d Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Wed, 22 Mar 2017 13:09:29 +0100 Subject: [PATCH] sqlalchemy: use nested transaction when getting/creating event types If adding the database fails, not using a nested transaction will make the whole transaction passed from the caller fail. The code does not handle that at all. This switches to using a nested transaction, so only the insert is rolled-back if it fails. Change-Id: I5196147524e5fdd0d46d4c1995a8afb964ce3f6f (cherry picked from commit e4021dbeac9c3d72497f6de90a748c6b9c8165fc) --- panko/storage/impl_sqlalchemy.py | 6 ++---- .../functional/storage/test_impl_sqlalchemy.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/panko/storage/impl_sqlalchemy.py b/panko/storage/impl_sqlalchemy.py index e5bbeb4b..521e268e 100644 --- a/panko/storage/impl_sqlalchemy.py +++ b/panko/storage/impl_sqlalchemy.py @@ -146,15 +146,13 @@ class Connection(base.Connection): engine.execute(table.delete()) engine.dispose() - def _get_or_create_event_type(self, event_type, session=None): + def _get_or_create_event_type(self, event_type, session): """Check if an event type with the supplied name is already exists. If not, we create it and return the record. This may result in a flush. """ try: - if session is None: - session = self._engine_facade.get_session() - with session.begin(subtransactions=True): + with session.begin(nested=True): et = session.query(models.EventType).filter( models.EventType.desc == event_type).first() if not et: diff --git a/panko/tests/functional/storage/test_impl_sqlalchemy.py b/panko/tests/functional/storage/test_impl_sqlalchemy.py index 5bc07417..dd64b844 100644 --- a/panko/tests/functional/storage/test_impl_sqlalchemy.py +++ b/panko/tests/functional/storage/test_impl_sqlalchemy.py @@ -43,22 +43,32 @@ class EventTypeTest(tests_db.TestBase): # EventType is a construct specific to sqlalchemy # Not applicable to other drivers. + def setUp(self): + super(EventTypeTest, self).setUp() + self.session = self.conn._engine_facade.get_session() + self.session.begin() + def test_event_type_exists(self): - et1 = self.conn._get_or_create_event_type("foo") + et1 = self.conn._get_or_create_event_type("foo", self.session) self.assertTrue(et1.id >= 0) - et2 = self.conn._get_or_create_event_type("foo") + et2 = self.conn._get_or_create_event_type("foo", self.session) self.assertEqual(et2.id, et1.id) self.assertEqual(et2.desc, et1.desc) def test_event_type_unique(self): - et1 = self.conn._get_or_create_event_type("foo") + et1 = self.conn._get_or_create_event_type("foo", self.session) self.assertTrue(et1.id >= 0) - et2 = self.conn._get_or_create_event_type("blah") + et2 = self.conn._get_or_create_event_type("blah", self.session) self.assertNotEqual(et1.id, et2.id) self.assertNotEqual(et1.desc, et2.desc) # Test the method __repr__ returns a string self.assertTrue(reprlib.repr(et2)) + def tearDown(self): + self.session.rollback() + self.session.close() + super(EventTypeTest, self).tearDown() + @tests_db.run_with('sqlite', 'mysql', 'pgsql') class EventTest(tests_db.TestBase):