diff --git a/ironic_inspector/db.py b/ironic_inspector/db.py index 042ff06e1..99bc41302 100644 --- a/ironic_inspector/db.py +++ b/ironic_inspector/db.py @@ -17,10 +17,11 @@ import contextlib +from oslo_concurrency import lockutils from oslo_config import cfg from oslo_db import options as db_opts +from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import models -from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import types as db_types from sqlalchemy import (Boolean, Column, DateTime, Enum, ForeignKey, Integer, String, Text) @@ -39,9 +40,11 @@ class ModelBase(models.ModelBase): Base = declarative_base(cls=ModelBase) CONF = cfg.CONF _DEFAULT_SQL_CONNECTION = 'sqlite:///ironic_inspector.sqlite' -_FACADE = None +_CTX_MANAGER = None -db_opts.set_defaults(cfg.CONF, connection=_DEFAULT_SQL_CONNECTION) +db_opts.set_defaults(CONF, connection=_DEFAULT_SQL_CONNECTION) + +_synchronized = lockutils.synchronized_with_prefix("ironic-inspector-") class Node(Base): @@ -131,18 +134,12 @@ class RuleAction(Base): def init(): - """Initialize the database.""" - return get_session() + """Initialize the database. - -def get_session(**kwargs): - facade = create_facade_lazily() - return facade.get_session(**kwargs) - - -def get_engine(): - facade = create_facade_lazily() - return facade.get_engine() + Method called on service start up, initialize transaction + context manager and try to create db session. + """ + get_writer_session() def model_query(model, *args, **kwargs): @@ -150,21 +147,51 @@ def model_query(model, *args, **kwargs): :param session: if present, the session to use """ - - session = kwargs.get('session') or get_session() + session = kwargs.get('session') or get_reader_session() query = session.query(model, *args) return query -def create_facade_lazily(): - global _FACADE - if _FACADE is None: - _FACADE = db_session.EngineFacade.from_config(cfg.CONF) - return _FACADE - - @contextlib.contextmanager def ensure_transaction(session=None): - session = session or get_session() + session = session or get_writer_session() with session.begin(subtransactions=True): yield session + + +@_synchronized("transaction-context-manager") +def _create_context_manager(): + _ctx_mgr = enginefacade.transaction_context() + # TODO(aarefiev): enable foreign keys for SQLite once all unit + # tests with failed constraint will be fixed. + _ctx_mgr.configure(sqlite_fk=False) + + return _ctx_mgr + + +def get_context_manager(): + """Create transaction context manager lazily. + + :returns: The transaction context manager. + """ + global _CTX_MANAGER + if _CTX_MANAGER is None: + _CTX_MANAGER = _create_context_manager() + + return _CTX_MANAGER + + +def get_reader_session(): + """Help method to get reader session. + + :returns: The reader session. + """ + return get_context_manager().reader.get_sessionmaker()() + + +def get_writer_session(): + """Help method to get writer session. + + :returns: The writer session. + """ + return get_context_manager().writer.get_sessionmaker()() diff --git a/ironic_inspector/migrations/env.py b/ironic_inspector/migrations/env.py index 775d1f979..71b186f07 100644 --- a/ironic_inspector/migrations/env.py +++ b/ironic_inspector/migrations/env.py @@ -66,8 +66,8 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = db.create_facade_lazily().get_engine() - with connectable.connect() as connection: + session = db.get_writer_session() + with session.connection() as connection: context.configure( connection=connection, target_metadata=target_metadata diff --git a/ironic_inspector/test/base.py b/ironic_inspector/test/base.py index 69ce18a53..9142ef533 100644 --- a/ironic_inspector/test/base.py +++ b/ironic_inspector/test/base.py @@ -45,11 +45,11 @@ class BaseTest(test_base.BaseTestCase): super(BaseTest, self).setUp() if not self.IS_FUNCTIONAL: self.init_test_conf() - self.session = db.get_session() - engine = db.get_engine() + self.session = db.get_writer_session() + engine = self.session.get_bind() db.Base.metadata.create_all(engine) engine.connect() - self.addCleanup(db.get_engine().dispose) + self.addCleanup(engine.dispose) plugins_base._HOOKS_MGR = None node_cache._SEMAPHORES = lockutils.Semaphores() patch = mock.patch.object(i18n, '_', lambda s: s) diff --git a/ironic_inspector/test/unit/test_db.py b/ironic_inspector/test/unit/test_db.py new file mode 100644 index 000000000..7b2a445c8 --- /dev/null +++ b/ironic_inspector/test/unit/test_db.py @@ -0,0 +1,77 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import mock + +from ironic_inspector import db +from ironic_inspector.test import base as test_base + + +class TestDB(test_base.NodeTest): + @mock.patch.object(db, 'get_reader_session', autospec=True) + def test_model_query(self, mock_reader): + mock_session = mock_reader.return_value + fake_query = mock_session.query.return_value + + query = db.model_query('db.Node') + + mock_reader.assert_called_once_with() + mock_session.query.assert_called_once_with('db.Node') + self.assertEqual(fake_query, query) + + @mock.patch.object(db, 'get_writer_session', autospec=True) + def test_ensure_transaction_new_session(self, mock_writer): + mock_session = mock_writer.return_value + + with db.ensure_transaction() as session: + mock_writer.assert_called_once_with() + mock_session.begin.assert_called_once_with(subtransactions=True) + self.assertEqual(mock_session, session) + + @mock.patch.object(db, 'get_writer_session', autospec=True) + def test_ensure_transaction_session(self, mock_writer): + mock_session = mock.MagicMock() + + with db.ensure_transaction(session=mock_session) as session: + self.assertFalse(mock_writer.called) + mock_session.begin.assert_called_once_with(subtransactions=True) + self.assertEqual(mock_session, session) + + @mock.patch.object(db.enginefacade, 'transaction_context', autospec=True) + def test__create_context_manager(self, mock_cnxt): + mock_ctx_mgr = mock_cnxt.return_value + + ctx_mgr = db._create_context_manager() + + mock_ctx_mgr.configure.assert_called_once_with(sqlite_fk=False) + self.assertEqual(mock_ctx_mgr, ctx_mgr) + + @mock.patch.object(db, 'get_context_manager', autospec=True) + def test_get_reader_session(self, mock_cnxt_mgr): + mock_cnxt = mock_cnxt_mgr.return_value + mock_sess_maker = mock_cnxt.reader.get_sessionmaker.return_value + + session = db.get_reader_session() + + mock_sess_maker.assert_called_once_with() + self.assertEqual(mock_sess_maker.return_value, session) + + @mock.patch.object(db, 'get_context_manager', autospec=True) + def test_get_writer_session(self, mock_cnxt_mgr): + mock_cnxt = mock_cnxt_mgr.return_value + mock_sess_maker = mock_cnxt.writer.get_sessionmaker.return_value + + session = db.get_writer_session() + + mock_sess_maker.assert_called_once_with() + self.assertEqual(mock_sess_maker.return_value, session) diff --git a/ironic_inspector/test/unit/test_migrations.py b/ironic_inspector/test/unit/test_migrations.py index 407fdacb9..dd7fa37a0 100644 --- a/ironic_inspector/test/unit/test_migrations.py +++ b/ironic_inspector/test/unit/test_migrations.py @@ -31,6 +31,7 @@ from alembic import script import mock from oslo_config import cfg from oslo_db.sqlalchemy.migration_cli import ext_alembic +from oslo_db.sqlalchemy import orm from oslo_db.sqlalchemy import test_base from oslo_db.sqlalchemy import test_migrations from oslo_db.sqlalchemy import utils as db_utils @@ -82,18 +83,12 @@ def _is_backend_avail(backend, user, passwd, database): return True -class FakeFacade(object): - def __init__(self, engine): - self.engine = engine - - def get_engine(self): - return self.engine - - @contextlib.contextmanager def patch_with_engine(engine): - with mock.patch.object(db, 'create_facade_lazily') as patch_engine: - patch_engine.return_value = FakeFacade(engine) + with mock.patch.object(db, 'get_writer_session') as patch_w_sess, \ + mock.patch.object(db, 'get_reader_session') as patch_r_sess: + patch_w_sess.return_value = patch_r_sess.return_value = ( + orm.get_maker(engine)()) yield diff --git a/ironic_inspector/test/unit/test_node_cache.py b/ironic_inspector/test/unit/test_node_cache.py index ff0f1f79f..6ce9bab9b 100644 --- a/ironic_inspector/test/unit/test_node_cache.py +++ b/ironic_inspector/test/unit/test_node_cache.py @@ -38,7 +38,7 @@ class TestNodeCache(test_base.NodeTest): def test_add_node(self): # Ensure previous node information is cleared uuid2 = uuidutils.generate_uuid() - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.starting).save(session) @@ -75,7 +75,7 @@ class TestNodeCache(test_base.NodeTest): [(row.name, row.value, row.node_uuid) for row in res]) def test__delete_node(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.finished).save(session) @@ -88,7 +88,7 @@ class TestNodeCache(test_base.NodeTest): session) node_cache._delete_node(self.uuid) - session = db.get_session() + session = db.get_writer_session() row_node = db.model_query(db.Node).filter_by( uuid=self.uuid).first() self.assertIsNone(row_node) @@ -108,7 +108,7 @@ class TestNodeCache(test_base.NodeTest): uuid2 = uuidutils.generate_uuid() uuids = {self.uuid} mock__list_node_uuids.return_value = {self.uuid, uuid2} - session = db.get_session() + session = db.get_writer_session() with session.begin(): node_cache.delete_nodes_not_in_list(uuids) mock__delete_node.assert_called_once_with(uuid2) @@ -116,7 +116,7 @@ class TestNodeCache(test_base.NodeTest): mock__get_lock_ctx.return_value.__enter__.assert_called_once_with() def test_active_macs(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.starting).save(session) @@ -129,7 +129,7 @@ class TestNodeCache(test_base.NodeTest): node_cache.active_macs()) def test__list_node_uuids(self): - session = db.get_session() + session = db.get_writer_session() uuid2 = uuidutils.generate_uuid() with session.begin(): db.Node(uuid=self.node.uuid, @@ -141,7 +141,7 @@ class TestNodeCache(test_base.NodeTest): self.assertEqual({self.uuid, uuid2}, node_uuid_list) def test_add_attribute(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.starting).save(session) @@ -158,7 +158,7 @@ class TestNodeCache(test_base.NodeTest): self.assertEqual({'key': ['value']}, node_info.attributes) def test_add_attribute_same_name(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.starting).save(session) @@ -175,7 +175,7 @@ class TestNodeCache(test_base.NodeTest): [tuple(row) for row in res]) def test_add_attribute_same_value(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.node.uuid, state=istate.States.starting).save(session) @@ -198,7 +198,7 @@ class TestNodeCache(test_base.NodeTest): 'mac': self.macs}, node_info.attributes) # check invalidation - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Attribute(uuid=uuidutils.generate_uuid(), name='foo', value='bar', node_uuid=self.uuid).save(session) @@ -285,7 +285,7 @@ class TestNodeCacheFind(test_base.NodeTest): self.assertTrue(res._locked) def test_inconsistency(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): (db.model_query(db.Node).filter_by(uuid=self.uuid). delete()) @@ -293,7 +293,7 @@ class TestNodeCacheFind(test_base.NodeTest): bmc_address='1.2.3.4') def test_already_finished(self): - session = db.get_session() + session = db.get_writer_session() with session.begin(): (db.model_query(db.Node).filter_by(uuid=self.uuid). update({'finished_at': datetime.datetime.utcnow()})) @@ -305,7 +305,7 @@ class TestNodeCacheCleanUp(test_base.NodeTest): def setUp(self): super(TestNodeCacheCleanUp, self).setUp() self.started_at = datetime.datetime.utcnow() - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.uuid, state=istate.States.waiting, @@ -350,7 +350,7 @@ class TestNodeCacheCleanUp(test_base.NodeTest): def test_timeout(self, time_mock, get_lock_mock): # Add a finished node to confirm we don't try to timeout it time_mock.return_value = self.started_at - session = db.get_session() + session = db.get_writer_session() finished_at = self.started_at + datetime.timedelta(seconds=60) with session.begin(): db.Node(uuid=self.uuid + '1', started_at=self.started_at, @@ -380,7 +380,7 @@ class TestNodeCacheCleanUp(test_base.NodeTest): @mock.patch.object(timeutils, 'utcnow') def test_timeout_active_state(self, time_mock, get_lock_mock): time_mock.return_value = self.started_at - session = db.get_session() + session = db.get_writer_session() CONF.set_override('timeout', 1) for state in [istate.States.starting, istate.States.enrolling, istate.States.processing, istate.States.reapplying]: @@ -400,7 +400,7 @@ class TestNodeCacheCleanUp(test_base.NodeTest): def test_old_status(self): CONF.set_override('node_status_keep_time', 42) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.model_query(db.Node).update( {'finished_at': (datetime.datetime.utcnow() - @@ -412,7 +412,7 @@ class TestNodeCacheCleanUp(test_base.NodeTest): def test_old_status_disabled(self): # Status clean up is disabled by default - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.model_query(db.Node).update( {'finished_at': (datetime.datetime.utcnow() - @@ -427,7 +427,7 @@ class TestNodeCacheGetNode(test_base.NodeTest): def test_ok(self): started_at = (datetime.datetime.utcnow() - datetime.timedelta(seconds=42)) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.uuid, state=istate.States.starting, @@ -443,7 +443,7 @@ class TestNodeCacheGetNode(test_base.NodeTest): def test_locked(self): started_at = (datetime.datetime.utcnow() - datetime.timedelta(seconds=42)) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.uuid, state=istate.States.starting, @@ -463,7 +463,7 @@ class TestNodeCacheGetNode(test_base.NodeTest): def test_with_name(self): started_at = (datetime.datetime.utcnow() - datetime.timedelta(seconds=42)) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.uuid, state=istate.States.starting, @@ -491,7 +491,7 @@ class TestNodeInfoFinished(test_base.NodeTest): mac=self.macs) self.node_info = node_cache.NodeInfo( uuid=self.uuid, started_at=datetime.datetime(3, 1, 4)) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Option(uuid=self.uuid, name='foo', value='bar').save( session) @@ -499,7 +499,7 @@ class TestNodeInfoFinished(test_base.NodeTest): def test_success(self): self.node_info.finished() - session = db.get_session() + session = db.get_writer_session() with session.begin(): self.assertEqual((datetime.datetime(1, 1, 1), None), tuple(db.model_query( @@ -533,7 +533,7 @@ class TestNodeInfoOptions(test_base.NodeTest): bmc_address='1.2.3.4', mac=self.macs) self.node_info = node_cache.NodeInfo(uuid=self.uuid, started_at=3.14) - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Option(uuid=self.uuid, name='foo', value='"bar"').save( session) @@ -927,7 +927,7 @@ class TestNodeCacheListNode(test_base.NodeTest): def setUp(self): super(TestNodeCacheListNode, self).setUp() self.uuid2 = uuidutils.generate_uuid() - session = db.get_session() + session = db.get_writer_session() with session.begin(): db.Node(uuid=self.uuid, started_at=datetime.datetime(1, 1, 2)).save(session)