diff --git a/ovsdbapp/backend/ovs_idl/connection.py b/ovsdbapp/backend/ovs_idl/connection.py index 2fcce6ef..4097f764 100644 --- a/ovsdbapp/backend/ovs_idl/connection.py +++ b/ovsdbapp/backend/ovs_idl/connection.py @@ -16,6 +16,7 @@ import os import threading import traceback +from ovs.db import idl from ovs import poller import six from six.moves import queue as Queue @@ -52,26 +53,30 @@ class TransactionQueue(Queue.Queue, object): class Connection(object): - def __init__(self, idl_factory, timeout): + def __init__(self, idl, timeout): """Create a connection to an OVSDB server using the OVS IDL :param timeout: The timeout value for OVSDB operations - :param idl_factory: A factory function that produces an Idl instance + :param idl: A newly created ovs.db.Idl instance (run never called) """ - self.idl = None self.timeout = timeout self.txns = TransactionQueue(1) self.lock = threading.Lock() - self.idl_factory = idl_factory + self.idl = idl + self.thread = None def start(self): """Start the connection.""" with self.lock: - if self.idl is not None: - return - - self.idl = self.idl_factory() - idlutils.wait_for_change(self.idl, self.timeout) + if self.thread is not None: + return False + if not self.idl.has_ever_connected(): + idlutils.wait_for_change(self.idl, self.timeout) + try: + self.idl.post_connect() + except AttributeError: + # An ovs.db.Idl class has no post_connect + pass self.poller = poller.Poller() self.thread = threading.Thread(target=self.run) self.thread.setDaemon(True) @@ -98,3 +103,19 @@ class Connection(object): def queue_txn(self, txn): self.txns.put(txn) + + +class OvsdbIdl(idl.Idl): + @classmethod + def from_server(cls, connection_string, schema_name): + """Create the Idl instance by pulling the schema from OVSDB server""" + helper = idlutils.get_schema_helper(connection_string, schema_name) + helper.register_all() + return cls(connection_string, helper) + + def post_connect(self): + """Operations to execute after the Idl has connected to the server + + An example would be to set up Idl notification handling for watching + and unwatching certain OVSDB change events + """ diff --git a/ovsdbapp/tests/functional/schema/open_vswitch/test_impl_idl.py b/ovsdbapp/tests/functional/schema/open_vswitch/test_impl_idl.py index 07a69927..8b3f557c 100644 --- a/ovsdbapp/tests/functional/schema/open_vswitch/test_impl_idl.py +++ b/ovsdbapp/tests/functional/schema/open_vswitch/test_impl_idl.py @@ -13,25 +13,16 @@ # License for the specific language governing permissions and limitations # under the License. -from ovs.db import idl - from ovsdbapp.backend.ovs_idl import connection -from ovsdbapp.backend.ovs_idl import idlutils from ovsdbapp import constants from ovsdbapp.schema.open_vswitch import impl_idl from ovsdbapp.tests import base from ovsdbapp.tests import utils - -def default_idl_factory(): - helper = idlutils.get_schema_helper(constants.DEFAULT_OVSDB_CONNECTION, - 'Open_vSwitch') - helper.register_all() - return idl.Idl(constants.DEFAULT_OVSDB_CONNECTION, helper) - - ovsdb_connection = connection.Connection( - idl_factory=default_idl_factory, timeout=constants.DEFAULT_TIMEOUT) + idl=connection.OvsdbIdl.from_server( + constants.DEFAULT_OVSDB_CONNECTION, 'Open_vSwitch'), + timeout=constants.DEFAULT_TIMEOUT) class TestOvsdbIdl(base.TestCase): diff --git a/ovsdbapp/tests/unit/backend/ovs_idl/test_connection.py b/ovsdbapp/tests/unit/backend/ovs_idl/test_connection.py index 056b52ff..6d85edc1 100644 --- a/ovsdbapp/tests/unit/backend/ovs_idl/test_connection.py +++ b/ovsdbapp/tests/unit/backend/ovs_idl/test_connection.py @@ -27,19 +27,18 @@ class TestOVSNativeConnection(base.TestCase): @mock.patch.object(connection, 'TransactionQueue') def setUp(self, mock_trans_queue): super(TestOVSNativeConnection, self).setUp() - self.idl_factory = mock.Mock() + self.idl = mock.Mock() self.mock_trans_queue = mock_trans_queue - self.conn = connection.Connection(self.idl_factory, - timeout=1) + self.conn = connection.Connection(self.idl, timeout=1) self.mock_trans_queue.assert_called_once_with(1) @mock.patch.object(threading, 'Thread') @mock.patch.object(poller, 'Poller') @mock.patch.object(idlutils, 'wait_for_change') def test_start(self, mock_wait_for_change, mock_poller, mock_thread): + self.idl.has_ever_connected.return_value = False self.conn.start() - - self.idl_factory.assert_called_once_with() + self.idl.has_ever_connected.assert_called_once() mock_wait_for_change.assert_called_once_with(self.conn.idl, self.conn.timeout) mock_poller.assert_called_once_with()