diff --git a/panko/storage/impl_sqlalchemy.py b/panko/storage/impl_sqlalchemy.py index e5bbeb4b..521e268e 100644 --- a/panko/storage/impl_sqlalchemy.py +++ b/panko/storage/impl_sqlalchemy.py @@ -146,15 +146,13 @@ class Connection(base.Connection): engine.execute(table.delete()) engine.dispose() - def _get_or_create_event_type(self, event_type, session=None): + def _get_or_create_event_type(self, event_type, session): """Check if an event type with the supplied name is already exists. If not, we create it and return the record. This may result in a flush. """ try: - if session is None: - session = self._engine_facade.get_session() - with session.begin(subtransactions=True): + with session.begin(nested=True): et = session.query(models.EventType).filter( models.EventType.desc == event_type).first() if not et: diff --git a/panko/tests/functional/storage/test_impl_sqlalchemy.py b/panko/tests/functional/storage/test_impl_sqlalchemy.py index 5bc07417..dd64b844 100644 --- a/panko/tests/functional/storage/test_impl_sqlalchemy.py +++ b/panko/tests/functional/storage/test_impl_sqlalchemy.py @@ -43,22 +43,32 @@ class EventTypeTest(tests_db.TestBase): # EventType is a construct specific to sqlalchemy # Not applicable to other drivers. + def setUp(self): + super(EventTypeTest, self).setUp() + self.session = self.conn._engine_facade.get_session() + self.session.begin() + def test_event_type_exists(self): - et1 = self.conn._get_or_create_event_type("foo") + et1 = self.conn._get_or_create_event_type("foo", self.session) self.assertTrue(et1.id >= 0) - et2 = self.conn._get_or_create_event_type("foo") + et2 = self.conn._get_or_create_event_type("foo", self.session) self.assertEqual(et2.id, et1.id) self.assertEqual(et2.desc, et1.desc) def test_event_type_unique(self): - et1 = self.conn._get_or_create_event_type("foo") + et1 = self.conn._get_or_create_event_type("foo", self.session) self.assertTrue(et1.id >= 0) - et2 = self.conn._get_or_create_event_type("blah") + et2 = self.conn._get_or_create_event_type("blah", self.session) self.assertNotEqual(et1.id, et2.id) self.assertNotEqual(et1.desc, et2.desc) # Test the method __repr__ returns a string self.assertTrue(reprlib.repr(et2)) + def tearDown(self): + self.session.rollback() + self.session.close() + super(EventTypeTest, self).tearDown() + @tests_db.run_with('sqlite', 'mysql', 'pgsql') class EventTest(tests_db.TestBase):