Deny API requests where context doesn't match path

We shouldn't overwrite the context tenant_id (which comes from the
scope of the auth_token) with that from the path, instead raise a
HTTPForbidden exception if the path-provided tenant_id doesn't match
the context.

Change-Id: Ib6fb9881103312f7492081a20178f12309f35d81
Closes-Bug: #1256983
This commit is contained in:
Steven Hardy 2013-12-02 23:59:19 +00:00
parent dc21dc16ed
commit 759ee38c53
3 changed files with 39 additions and 24 deletions

View File

@ -21,12 +21,13 @@ from heat.common import identifier
def tenant_local(handler):
'''
Decorator for a handler method that sets the correct tenant_id in the
Decorator for a handler method that checks the path matches the
request context.
'''
@wraps(handler)
def handle_stack_method(controller, req, tenant_id, **kwargs):
req.context.tenant_id = tenant_id
if req.context.tenant_id != tenant_id:
raise exc.HTTPForbidden()
return handler(controller, req, **kwargs)
return handle_stack_method

View File

@ -366,7 +366,7 @@ class StackControllerTest(ControllerTest, HeatTestCase):
req = self._get('/stacks', params=params)
mock_call.return_value = []
result = self.controller.index(req, tenant_id='fake_tenant_id')
result = self.controller.index(req, tenant_id=self.tenant)
rpc_call_args, _ = mock_call.call_args
engine_args = rpc_call_args[2]['args']
@ -388,7 +388,7 @@ class StackControllerTest(ControllerTest, HeatTestCase):
req = self._get('/stacks', params=params)
mock_call.return_value = []
result = self.controller.index(req, tenant_id='fake_tenant_id')
result = self.controller.index(req, tenant_id=self.tenant)
rpc_call_args, _ = mock_call.call_args
engine_args = rpc_call_args[2]['args']
@ -408,7 +408,7 @@ class StackControllerTest(ControllerTest, HeatTestCase):
engine.list_stacks = mock.Mock(return_value=[])
engine.count_stacks = mock.Mock(return_value=0)
result = self.controller.index(req, tenant_id='fake_tenant_id')
result = self.controller.index(req, tenant_id=self.tenant)
self.assertEqual(0, result['count'])
def test_index_doesnt_return_stack_count_if_with_count_is_falsy(self):
@ -419,7 +419,7 @@ class StackControllerTest(ControllerTest, HeatTestCase):
engine.list_stacks = mock.Mock(return_value=[])
engine.count_stacks = mock.Mock()
result = self.controller.index(req, tenant_id='fake_tenant_id')
result = self.controller.index(req, tenant_id=self.tenant)
self.assertNotIn('count', result)
assert not engine.count_stacks.called
@ -432,7 +432,7 @@ class StackControllerTest(ControllerTest, HeatTestCase):
engine.list_stacks = mock.Mock(return_value=[])
mock_count_stacks.side_effect = AttributeError("Should not exist")
result = self.controller.index(req, tenant_id='fake_tenant_id')
result = self.controller.index(req, tenant_id=self.tenant)
self.assertNotIn('count', result)
@mock.patch.object(rpc, 'call')
@ -988,14 +988,6 @@ class StackControllerTest(ControllerTest, HeatTestCase):
req = self._get('/stacks/%(stack_name)s/%(stack_id)s' % identity)
error = heat_exc.InvalidTenant(target='a', actual='b')
self.m.StubOutWithMock(rpc, 'call')
rpc.call(req.context, self.topic,
{'namespace': None,
'method': 'show_stack',
'args': {'stack_identity': dict(identity)},
'version': self.api_version},
None).AndRaise(to_remote_error(error))
self.m.ReplayAll()
resp = request_with_middleware(fault.FaultWrapper,
@ -1004,8 +996,8 @@ class StackControllerTest(ControllerTest, HeatTestCase):
stack_name=identity.stack_name,
stack_id=identity.stack_id)
self.assertEqual(resp.json['code'], 403)
self.assertEqual(resp.json['error']['type'], 'InvalidTenant')
self.assertEqual(resp.status_int, 403)
self.assertIn('403 Forbidden', str(resp))
self.m.VerifyAll()
def test_get_template(self):
@ -2694,37 +2686,37 @@ class ActionControllerTest(ControllerTest, HeatTestCase):
self.m.VerifyAll()
class BuildInfoControllerTest(HeatTestCase):
class BuildInfoControllerTest(ControllerTest, HeatTestCase):
def test_theres_a_default_api_build_revision(self):
req = mock.Mock()
req = self._get('/build_info')
controller = build_info.BuildInfoController({})
controller.engine = mock.Mock()
response = controller.build_info(req, tenant_id='tenant_id')
response = controller.build_info(req, tenant_id=self.tenant)
self.assertIn('api', response)
self.assertIn('revision', response['api'])
self.assertEqual('unknown', response['api']['revision'])
@mock.patch.object(build_info.cfg, 'CONF')
def test_response_api_build_revision_from_config_file(self, mock_conf):
req = mock.Mock()
req = self._get('/build_info')
controller = build_info.BuildInfoController({})
mock_engine = mock.Mock()
mock_engine.get_revision.return_value = 'engine_revision'
controller.engine = mock_engine
mock_conf.revision = {'heat_revision': 'test'}
response = controller.build_info(req, tenant_id='tenant_id')
response = controller.build_info(req, tenant_id=self.tenant)
self.assertEqual('test', response['api']['revision'])
def test_retrieves_build_revision_from_the_engine(self):
req = mock.Mock()
req = self._get('/build_info')
controller = build_info.BuildInfoController({})
mock_engine = mock.Mock()
mock_engine.get_revision.return_value = 'engine_revision'
controller.engine = mock_engine
response = controller.build_info(req, tenant_id='tenant_id')
response = controller.build_info(req, tenant_id=self.tenant)
self.assertIn('engine', response)
self.assertIn('revision', response['engine'])
self.assertEqual('engine_revision', response['engine']['revision'])

View File

@ -12,7 +12,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from webob import exc
from heat.api.openstack.v1 import util
from heat.common import context
from heat.common.wsgi import Request
from heat.tests.common import HeatTestCase
@ -77,3 +80,22 @@ class TestGetAllowedParams(HeatTestCase):
self.whitelist = {'foo': 'blah'}
result = util.get_allowed_params(self.params, self.whitelist)
self.assertNotIn('foo', result)
class TestTenantLocal(HeatTestCase):
def setUp(self):
super(TestTenantLocal, self).setUp()
self.req = Request({})
self.req.context = context.RequestContext(tenant_id='foo',
is_admin=False)
def test_tenant_local(self):
@util.tenant_local
def an_action(controller, req):
return 'woot'
self.assertEqual('woot',
an_action(None, self.req, tenant_id='foo'))
self.assertRaises(exc.HTTPForbidden,
an_action, None, self.req, tenant_id='bar')