diff --git a/ooi/tests/middleware/test_compute_controller.py b/ooi/tests/middleware/test_compute_controller.py index 50fbfa1..5324bdd 100644 --- a/ooi/tests/middleware/test_compute_controller.py +++ b/ooi/tests/middleware/test_compute_controller.py @@ -102,7 +102,7 @@ class TestComputeController(test_middleware.TestMiddleware): resp = req.get_response(app) expected_result = "" - self.assertContentType(resp) + self.assertDefaults(resp) self.assertExpectedResult(expected_result, resp) self.assertEqual(204, resp.status_code) @@ -121,7 +121,7 @@ class TestComputeController(test_middleware.TestMiddleware): ("X-OCCI-Location", utils.join_url(self.application_url + "/", "compute/%s" % s["id"])) ) - self.assertContentType(resp) + self.assertDefaults(resp) self.assertExpectedResult(expected, resp) def test_show_vm(self): @@ -134,7 +134,7 @@ class TestComputeController(test_middleware.TestMiddleware): resp = req.get_response(app) expected = build_occi_server(server) - self.assertContentType(resp) + self.assertDefaults(resp) self.assertExpectedResult(expected, resp) self.assertEqual(200, resp.status_code) @@ -172,7 +172,7 @@ class TestComputeController(test_middleware.TestMiddleware): "compute/%s" % "foo"))] self.assertEqual(200, resp.status_code) self.assertExpectedResult(expected, resp) - self.assertContentType(resp) + self.assertDefaults(resp) def test_create_vm_incomplete(self): tenant = fakes.tenants["foo"] @@ -193,7 +193,7 @@ class TestComputeController(test_middleware.TestMiddleware): resp = req.get_response(app) self.assertEqual(400, resp.status_code) - self.assertContentType(resp) + self.assertDefaults(resp) def test_create_with_context(self): tenant = fakes.tenants["foo"] @@ -228,7 +228,7 @@ class TestComputeController(test_middleware.TestMiddleware): "compute/%s" % "foo"))] self.assertEqual(200, resp.status_code) self.assertExpectedResult(expected, resp) - self.assertContentType(resp) + self.assertDefaults(resp) def test_vm_links(self): tenant = fakes.tenants["baz"] @@ -244,7 +244,7 @@ class TestComputeController(test_middleware.TestMiddleware): vol_id = server["os-extended-volumes:volumes_attached"][0]["id"] link_id = '_'.join([server["id"], vol_id]) - self.assertContentType(resp) + self.assertDefaults(resp) self.assertResultIncludesLink(link_id, server["id"], vol_id, resp) self.assertEqual(200, resp.status_code) diff --git a/ooi/tests/middleware/test_middleware.py b/ooi/tests/middleware/test_middleware.py index 14067da..fbf5554 100644 --- a/ooi/tests/middleware/test_middleware.py +++ b/ooi/tests/middleware/test_middleware.py @@ -37,9 +37,15 @@ class TestMiddleware(base.TestCase): self.accept = self.content_type = None self.application_url = fakes.application_url + self.occi_string = "OCCI/1.1" + def get_app(self, resp=None): return wsgi.OCCIMiddleware(fakes.FakeApp()) + def assertDefaults(self, result): + self.assertContentType(result) + self.assertServerHeader(result) + def assertContentType(self, result): if self.accept in (None, "*/*"): expected = "text/plain" @@ -47,6 +53,10 @@ class TestMiddleware(base.TestCase): expected = self.accept self.assertEqual(expected, result.content_type) + def assertServerHeader(self, result): + self.assertIn("Server", result.headers) + self.assertIn(self.occi_string, result.headers["server"]) + def assertExpectedResult(self, expected, result): expected = ["%s: %s" % e for e in expected] # NOTE(aloga): the order of the result does not matter @@ -85,6 +95,7 @@ class TestMiddleware(base.TestCase): def test_404(self): result = self._build_req("/", "tenant").get_response(self.get_app()) self.assertEqual(404, result.status_code) + self.assertDefaults(result) def test_400_from_openstack(self): @webob.dec.wsgify() @@ -93,8 +104,10 @@ class TestMiddleware(base.TestCase): resp = fakes.FakeOpenStackFault(exc) return resp - result = self._build_req("/-/", "tenant").get_response(_fake_app) + mdl = wsgi.OCCIMiddleware(_fake_app) + result = self._build_req("/-/", "tenant").get_response(mdl) self.assertEqual(400, result.status_code) + self.assertDefaults(result) class TestMiddlewareTextPlain(TestMiddleware): diff --git a/ooi/tests/middleware/test_query_controller.py b/ooi/tests/middleware/test_query_controller.py index 5fbd87c..a445b42 100644 --- a/ooi/tests/middleware/test_query_controller.py +++ b/ooi/tests/middleware/test_query_controller.py @@ -25,7 +25,7 @@ class TestQueryController(test_middleware.TestMiddleware): def test_query(self): tenant_id = fakes.tenants["bar"]["id"] result = self._build_req("/-/", tenant_id).get_response(self.get_app()) - self.assertContentType(result) + self.assertDefaults(result) self.assertExpectedResult(fakes.fake_query_results(), result) self.assertEqual(200, result.status_code) diff --git a/ooi/wsgi/__init__.py b/ooi/wsgi/__init__.py index cfa57d9..d4c6223 100644 --- a/ooi/wsgi/__init__.py +++ b/ooi/wsgi/__init__.py @@ -19,6 +19,7 @@ import routes import routes.middleware import webob.dec +import ooi import ooi.api.compute from ooi.api import query import ooi.api.storage @@ -69,6 +70,9 @@ class Request(webob.Request): class OCCIMiddleware(object): + + occi_version = "1.1" + @classmethod def factory(cls, global_conf, **local_conf): """Factory method for paste.deploy.""" @@ -147,20 +151,30 @@ class OCCIMiddleware(object): @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): response = self.process_request(req) - if response: - return response + if not response: + response = req.get_response(self.application) - response = req.get_response(self.application) return self.process_response(response) def process_request(self, req): match = self.mapper.match(req.path_info, req.environ) if not match: - return webob.exc.HTTPNotFound() + return Fault(webob.exc.HTTPNotFound()) method = match["controller"] return method(req, match) def process_response(self, response): + """Process a response by adding our headers.""" + server_string = "ooi/%s OCCI/%s" % (ooi.__version__, + self.occi_version) + + headers = (("server", server_string),) + if isinstance(response, Fault): + for key, val in headers: + response.wrapped_exc.headers.add(key, val) + else: + for key, val in headers: + response.headers.add(key, val) return response