diff --git a/eventlet/patcher.py b/eventlet/patcher.py index c0ef377..a672aa6 100644 --- a/eventlet/patcher.py +++ b/eventlet/patcher.py @@ -1,5 +1,7 @@ +import collections import imp import sys +import types from eventlet.support import six @@ -7,6 +9,7 @@ from eventlet.support import six __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched'] __exclude = set(('__builtins__', '__file__', '__name__')) +_MISSING = object() class SysModulesSaver(object): @@ -33,10 +36,7 @@ class SysModulesSaver(object): if mod is not None: sys.modules[modname] = mod else: - try: - del sys.modules[modname] - except KeyError: - pass + sys.modules.pop(modname, None) finally: imp.release_lock() @@ -78,7 +78,7 @@ def inject(module_name, new_globals, *additional_modules): # after this we are gonna screw with sys.modules, so capture the # state of all the modules we're going to mess with, and lock - saver = SysModulesSaver([name for name, m in additional_modules]) + saver = SysModulesSaver([name for name, _ in additional_modules]) saver.save(module_name) # Cover the target modules so that when you import the module it @@ -89,8 +89,14 @@ def inject(module_name, new_globals, *additional_modules): # Remove the old module from sys.modules and reimport it while # the specified modules are in place sys.modules.pop(module_name, None) + + import_control = _patch_import(dict(additional_modules)) try: - module = __import__(module_name, {}, {}, module_name.split('.')[:-1]) + import_control.begin() + try: + module = import_control.original(module_name, fromlist=module_name.split('.')[:-1]) + finally: + import_control.end() if new_globals is not None: # Update the given globals dictionary with everything from this new module @@ -106,6 +112,32 @@ def inject(module_name, new_globals, *additional_modules): return module +ImportControl = collections.namedtuple('ImportControl', 'begin end original patched') + + +def _patch_import(patch_map, _missing=object()): + original = __builtins__['__import__'] + + def fun(name, *args, **kwargs): + module = original(name, *args, **kwargs) + for k in dir(module): + v = getattr(module, k, None) + replacement = patch_map.get(k) + if isinstance(v, types.ModuleType) and replacement is not None: + # print(' _patch_import {0}.{1} = {2}'.format(name, k, v)) + setattr(module, k, replacement) + return module + + def begin(): + __builtins__['__import__'] = fun + + def end(): + __builtins__['__import__'] = original + + control = ImportControl(begin, end, original, fun) + return control + + def import_patched(module_name, *additional_modules, **kw_additional_modules): """Imports a module in a way that ensures that the module uses "green" versions of the standard library modules, so that everything works diff --git a/tests/isolated/patcher_import_patched_defaults.py b/tests/isolated/patcher_import_patched_defaults.py index 62b0807..12673da 100644 --- a/tests/isolated/patcher_import_patched_defaults.py +++ b/tests/isolated/patcher_import_patched_defaults.py @@ -1,4 +1,5 @@ import os +import sys __test__ = False @@ -9,7 +10,8 @@ if os.environ.get('eventlet_test_import_patched_defaults') == '1': import urllib as target t = target.socket.socket import eventlet.green.socket - if issubclass(t, eventlet.green.socket.socket): - print('pass') - else: + if not issubclass(t, eventlet.green.socket.socket): print('Fail. Target socket not green: {0} bases {1}'.format(t, t.__bases__)) + sys.exit(1) + + print('pass')