diff --git a/billistix/storage/__init__.py b/billistix/storage/__init__.py index 6f109c8..0d93dac 100644 --- a/billistix/storage/__init__.py +++ b/billistix/storage/__init__.py @@ -37,8 +37,16 @@ def register_opts(conf): engine.register_opts(conf) +def get_engine_name(string): + """ + Return the engine name from either a non-dialected or dialected string + """ + return string.split("+")[0] + + def get_engine(conf): - engine_name = urlparse(conf.database_connection).scheme + scheme = urlparse(conf.database_connection).scheme + engine_name = get_engine_name(scheme) LOG.debug('looking for %r engine in %r', engine_name, DRIVER_NAMESPACE) mgr = driver.DriverManager(DRIVER_NAMESPACE, diff --git a/billistix/tests/test_storage/test_scheme.py b/billistix/tests/test_storage/test_scheme.py new file mode 100644 index 0000000..f7b75d6 --- /dev/null +++ b/billistix/tests/test_storage/test_scheme.py @@ -0,0 +1,13 @@ +from moniker.tests import TestCase +from moniker.storage import get_engine_name + + +class TestEngineName(TestCase): + def test_engine_non_dialected(self): + name = get_engine_name("mysql") + self.assertEqual(name, "mysql") + + def test_engine_dialacted(self): + name = get_engine_name("mysql+oursql") + self.assertEqual(name, "mysql") +