From d8f60687dcda5c40fee348c68fd97ac8daf09ba8 Mon Sep 17 00:00:00 2001 From: Clint Byrum Date: Mon, 1 Jul 2013 22:28:01 -0700 Subject: [PATCH] Refactor to avoid monkeypatching requests --- os_collect_config/collect.py | 5 ++- os_collect_config/common.py | 3 ++ os_collect_config/ec2.py | 53 +++++++++++++------------ os_collect_config/tests/test_collect.py | 10 +---- os_collect_config/tests/test_ec2.py | 48 +++++++++++----------- 5 files changed, 60 insertions(+), 59 deletions(-) create mode 100644 os_collect_config/common.py diff --git a/os_collect_config/collect.py b/os_collect_config/collect.py index cd56961..cc0270d 100644 --- a/os_collect_config/collect.py +++ b/os_collect_config/collect.py @@ -19,6 +19,7 @@ import subprocess from openstack.common import log from os_collect_config import cache +from os_collect_config import common from os_collect_config import ec2 from oslo.config import cfg @@ -45,11 +46,11 @@ def setup_conf(): CONF.register_cli_opts(opts) -def __main__(): +def __main__(ec2_requests=common.requests): setup_conf() CONF(prog="os-collect-config") log.setup("os-collect-config") - ec2_content = ec2.collect() + ec2_content = ec2.CollectEc2(requests_impl=ec2_requests).collect() if CONF.command: (changed, ec2_path) = cache.store('ec2', ec2_content) diff --git a/os_collect_config/common.py b/os_collect_config/common.py new file mode 100644 index 0000000..f78b6ec --- /dev/null +++ b/os_collect_config/common.py @@ -0,0 +1,3 @@ +import requests + +__all__ = ['requests'] diff --git a/os_collect_config/ec2.py b/os_collect_config/ec2.py index ab57458..c355fdd 100644 --- a/os_collect_config/ec2.py +++ b/os_collect_config/ec2.py @@ -14,9 +14,9 @@ # limitations under the License. from oslo.config import cfg -import requests from openstack.common import log +from os_collect_config import common from os_collect_config import exc EC2_METADATA_URL = 'http://169.254.169.254/latest/meta-data' @@ -29,30 +29,31 @@ opts = [ ] -def _fetch_metadata(fetch_url, session): - try: - r = session.get(fetch_url) - r.raise_for_status() - except (requests.HTTPError, - requests.ConnectionError, - requests.Timeout) as e: - log.getLogger(__name__).warn(e) - raise exc.Ec2MetadataNotAvailable - content = r.text - if fetch_url[-1] == '/': - new_content = {} - for subkey in content.split("\n"): - if '=' in subkey: - subkey = subkey[:subkey.index('=')] + '/' - sub_fetch_url = fetch_url + subkey - if subkey[-1] == '/': - subkey = subkey[:-1] - new_content[subkey] = _fetch_metadata(sub_fetch_url, session) - content = new_content - return content +class CollectEc2(object): + def __init__(self, requests_impl=common.requests): + self._requests_impl = requests_impl + self.session = requests_impl.Session() + def _fetch_metadata(self, fetch_url): + try: + r = self.session.get(fetch_url) + r.raise_for_status() + except self._requests_impl.exceptions.RequestException as e: + log.getLogger(__name__).warn(e) + raise exc.Ec2MetadataNotAvailable + content = r.text + if fetch_url[-1] == '/': + new_content = {} + for subkey in content.split("\n"): + if '=' in subkey: + subkey = subkey[:subkey.index('=')] + '/' + sub_fetch_url = fetch_url + subkey + if subkey[-1] == '/': + subkey = subkey[:-1] + new_content[subkey] = self._fetch_metadata(sub_fetch_url) + content = new_content + return content -def collect(): - root_url = '%s/' % (CONF.ec2.metadata_url) - session = requests.Session() - return _fetch_metadata(root_url, session) + def collect(self): + root_url = '%s/' % (CONF.ec2.metadata_url) + return self._fetch_metadata(root_url) diff --git a/os_collect_config/tests/test_collect.py b/os_collect_config/tests/test_collect.py index 2b032dc..cef89c1 100644 --- a/os_collect_config/tests/test_collect.py +++ b/os_collect_config/tests/test_collect.py @@ -25,12 +25,6 @@ from os_collect_config.tests import test_ec2 class TestCollect(testtools.TestCase): - def setUp(self): - super(TestCollect, self).setUp() - self.useFixture( - fixtures.MonkeyPatch( - 'requests.Session', test_ec2.FakeSession)) - def tearDown(self): super(TestCollect, self).tearDown() cfg.CONF.reset() @@ -67,7 +61,7 @@ class TestCollect(testtools.TestCase): self.useFixture(fixtures.MonkeyPatch('subprocess.call', fake_call)) - collect.__main__() + collect.__main__(ec2_requests=test_ec2.FakeRequests) self.assertTrue(self.called_fake_call) @@ -82,7 +76,7 @@ class TestCollect(testtools.TestCase): output = self.useFixture(fixtures.ByteStream('stdout')) self.useFixture( fixtures.MonkeyPatch('sys.stdout', output.stream)) - collect.__main__() + collect.__main__(ec2_requests=test_ec2.FakeRequests) out_struct = json.loads(output.stream.getvalue()) self.assertThat(out_struct, matchers.IsInstance(dict)) self.assertIn('ec2', out_struct) diff --git a/os_collect_config/tests/test_ec2.py b/os_collect_config/tests/test_ec2.py index fc1502f..ac8f938 100644 --- a/os_collect_config/tests/test_ec2.py +++ b/os_collect_config/tests/test_ec2.py @@ -58,36 +58,40 @@ class FakeResponse(dict): pass -class FakeSession(object): - def get(self, url): - url = urlparse.urlparse(url) +class FakeRequests(object): + exceptions = requests.exceptions - if url.path == '/latest/meta-data/': - # Remove keys which have anything after / - ks = [x for x in META_DATA.keys() if ('/' not in x - or not len(x.split('/')[1]))] - return FakeResponse("\n".join(ks)) + class Session(object): + def get(self, url): + url = urlparse.urlparse(url) - path = url.path - path = path.replace('/latest/meta-data/', '') - return FakeResponse(META_DATA[path]) + if url.path == '/latest/meta-data/': + # Remove keys which have anything after / + ks = [x for x in META_DATA.keys() if ( + '/' not in x or not len(x.split('/')[1]))] + return FakeResponse("\n".join(ks)) + + path = url.path + path = path.replace('/latest/meta-data/', '') + return FakeResponse(META_DATA[path]) -class FakeFailSession(object): - def get(self, url): - raise requests.exceptions.HTTPError(403, 'Forbidden') +class FakeFailRequests(object): + exceptions = requests.exceptions + + class Session(object): + def get(self, url): + raise requests.exceptions.HTTPError(403, 'Forbidden') -class TestCollect(testtools.TestCase): +class TestEc2(testtools.TestCase): def setUp(self): - super(TestCollect, self).setUp() + super(TestEc2, self).setUp() self.log = self.useFixture(fixtures.FakeLogger()) def test_collect_ec2(self): - self.useFixture( - fixtures.MonkeyPatch('requests.Session', FakeSession)) collect.setup_conf() - ec2_md = ec2.collect() + ec2_md = ec2.CollectEc2(requests_impl=FakeRequests).collect() self.assertThat(ec2_md, matchers.IsInstance(dict)) for k in ('public-ipv4', 'instance-id', 'hostname'): @@ -103,9 +107,7 @@ class TestCollect(testtools.TestCase): self.assertEquals('', self.log.output) def test_collect_ec2_fail(self): - self.useFixture( - fixtures.MonkeyPatch( - 'requests.Session', FakeFailSession)) collect.setup_conf() - self.assertRaises(exc.Ec2MetadataNotAvailable, ec2.collect) + collect_ec2 = ec2.CollectEc2(requests_impl=FakeFailRequests) + self.assertRaises(exc.Ec2MetadataNotAvailable, collect_ec2.collect) self.assertIn('Forbidden', self.log.output)