diff --git a/simport/__init__.py b/simport/__init__.py index 7028bc9..de3a7a9 100644 --- a/simport/__init__.py +++ b/simport/__init__.py @@ -74,22 +74,28 @@ def _get_module(target): if not class_or_function: raise MissingMethodOrFunction("No Method or Function specified in '%s'" % target) - try: - __import__(module) - except ImportError as e: - raise ImportFailed("Failed to import '%s'. Error: %s" % (module, e)) + if module: + try: + __import__(module) + except ImportError as e: + raise ImportFailed("Failed to import '%s'. Error: %s" % (module, e)) klass, sep, function = class_or_function.rpartition('.') return module, klass, function -def load(target): +def load(target, source_module=None): """Get the actual implementation of the target.""" module, klass, function = _get_module(target) + if not module and source_module: + module = source_module + if not module: + raise MissingModule("No module name supplied or source_module provided.") + actual_module = sys.modules[module] if not klass: - return getattr(sys.modules[module], function) + return getattr(actual_module, function) - class_object = getattr(sys.modules[module], klass) + class_object = getattr(actual_module, klass) if function: return getattr(class_object, function) return class_object diff --git a/tests/test_simport.py b/tests/test_simport.py index 5f45f71..b026735 100644 --- a/tests/test_simport.py +++ b/tests/test_simport.py @@ -14,6 +14,11 @@ class DummyClass(object): pass +class LocalClass(object): + def my_method(self): + pass + + class TestSimport(unittest.TestCase): def test_bad_targets(self): self.assertRaises(simport.BadDirectory, simport._get_module, @@ -78,3 +83,7 @@ class TestSimport(unittest.TestCase): "external.externalmodule:Blah") import external.externalmodule self.assertEqual(klass, external.externalmodule.Blah) + + def test_local_class(self): + klass = simport.load("LocalClass", __name__) + self.assertEqual(klass, LocalClass)