diff --git a/designate/context.py b/designate/context.py index ca47d31dc..104835bd7 100644 --- a/designate/context.py +++ b/designate/context.py @@ -63,9 +63,7 @@ class DesignateContext(context.RequestContext): self.delete_shares = delete_shares def deepcopy(self): - d = self.to_dict() - - return self.from_dict(d) + return self.from_dict(self.to_dict()) def to_dict(self): d = super().to_dict() @@ -215,14 +213,13 @@ class DesignateContext(context.RequestContext): def get_auth_plugin(self): if self.user_auth_plugin: return self.user_auth_plugin - else: - return _ContextAuthPlugin(self.auth_token, self.service_catalog) + return _ContextAuthPlugin(self.auth_token, self.service_catalog) class _ContextAuthPlugin(plugin.BaseAuthPlugin): """A keystoneauth auth plugin that uses the values from the Context. Ideally we would use the plugin provided by auth_token middleware however - this plugin isn't serialized yet so we construct one from the serialized + this plugin isn't serialized yet, so we construct one from the serialized auth data. """ def __init__(self, auth_token, sc): @@ -234,26 +231,13 @@ class _ContextAuthPlugin(plugin.BaseAuthPlugin): def get_token(self, *args, **kwargs): return self.auth_token - def get_endpoint(self, session, **kwargs): - endpoint_data = self.get_endpoint_data(session, **kwargs) - if not endpoint_data: - return None - return endpoint_data.url - - def get_endpoint_data(self, session, - endpoint_override=None, - discover_versions=True, - **kwargs): - urlkw = {} - for k in ('service_type', 'service_name', 'service_id', 'endpoint_id', - 'region_name', 'interface'): - if k in kwargs: - urlkw[k] = kwargs[k] - - endpoint = endpoint_override or self.service_catalog.url_for(**urlkw) - return super().get_endpoint_data( - session, endpoint_override=endpoint, - discover_versions=discover_versions, **kwargs) + def get_endpoint(self, session, service_type=None, interface=None, + region_name=None, service_name=None, **kwargs): + endpoint = self.service_catalog.url_for( + service_type=service_type, service_name=service_name, + interface=interface, region_name=region_name + ) + return self.get_endpoint_data(session, endpoint_override=endpoint).url def get_current(): diff --git a/designate/tests/test_context.py b/designate/tests/test_context.py new file mode 100644 index 000000000..cd85c1570 --- /dev/null +++ b/designate/tests/test_context.py @@ -0,0 +1,114 @@ +# Copyright 2012 Managed I.T. +# +# Author: Kiall Mac Innes +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import testtools + +from designate import context +from designate import exceptions +import designate.tests + + +class TestDesignateContext(designate.tests.TestCase): + def test_deepcopy(self): + orig = context.DesignateContext( + user_id='12345', project_id='54321' + ) + copy = orig.deepcopy() + + self.assertEqual(orig.to_dict(), copy.to_dict()) + + def test_elevated(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated() + + self.assertFalse(ctxt.is_admin) + self.assertTrue(admin_ctxt.is_admin) + self.assertEqual(0, len(ctxt.roles)) + + def test_elevated_hard_delete(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated(hard_delete=True) + + self.assertTrue(admin_ctxt.hard_delete) + + def test_elevated_with_show_deleted(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated(show_deleted=True) + + self.assertTrue(admin_ctxt.show_deleted) + + def test_all_tenants(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated() + admin_ctxt.all_tenants = True + + self.assertFalse(ctxt.is_admin) + self.assertTrue(admin_ctxt.is_admin) + self.assertTrue(admin_ctxt.all_tenants) + + def test_all_tenants_policy_failure(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + + with testtools.ExpectedException(exceptions.Forbidden): + ctxt.all_tenants = True + + def test_edit_managed_records(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated() + + admin_ctxt.edit_managed_records = True + + self.assertFalse(ctxt.is_admin) + self.assertTrue(admin_ctxt.is_admin) + self.assertTrue(admin_ctxt.edit_managed_records) + + def test_edit_managed_records_failure(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + with testtools.ExpectedException(exceptions.Forbidden): + ctxt.edit_managed_records = True + + def test_hard_delete(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + admin_ctxt = ctxt.elevated() + + admin_ctxt.hard_delete = True + + self.assertFalse(ctxt.is_admin) + self.assertTrue(admin_ctxt.is_admin) + self.assertTrue(admin_ctxt.hard_delete) + + def test_hard_delete_failure(self): + ctxt = context.DesignateContext( + user_id='12345', project_id='54321' + ) + with testtools.ExpectedException(exceptions.Forbidden): + ctxt.hard_delete = True diff --git a/designate/tests/unit/test_context.py b/designate/tests/unit/test_context.py index f02a29e86..5ce8ba4e0 100644 --- a/designate/tests/unit/test_context.py +++ b/designate/tests/unit/test_context.py @@ -13,25 +13,16 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + from unittest import mock -import testtools +import oslotest.base from designate import context -from designate import exceptions from designate import policy -import designate.tests -class TestDesignateContext(designate.tests.TestCase): - def test_deepcopy(self): - orig = context.DesignateContext( - user_id='12345', project_id='54321' - ) - copy = orig.deepcopy() - - self.assertEqual(orig.to_dict(), copy.to_dict()) - +class TestDesignateContext(oslotest.base.BaseTestCase): def test_tsigkey_id_override(self): orig = context.DesignateContext( tsigkey_id='12345', project_id='54321' @@ -40,89 +31,6 @@ class TestDesignateContext(designate.tests.TestCase): self.assertEqual('TSIG:12345 54321 - - -', copy['user_identity']) - def test_elevated(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated() - - self.assertFalse(ctxt.is_admin) - self.assertTrue(admin_ctxt.is_admin) - self.assertEqual(0, len(ctxt.roles)) - - def test_elevated_hard_delete(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated(hard_delete=True) - - self.assertTrue(admin_ctxt.hard_delete) - - def test_elevated_with_show_deleted(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated(show_deleted=True) - - self.assertTrue(admin_ctxt.show_deleted) - - def test_all_tenants(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated() - admin_ctxt.all_tenants = True - - self.assertFalse(ctxt.is_admin) - self.assertTrue(admin_ctxt.is_admin) - self.assertTrue(admin_ctxt.all_tenants) - - def test_all_tenants_policy_failure(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - - with testtools.ExpectedException(exceptions.Forbidden): - ctxt.all_tenants = True - - def test_edit_managed_records(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated() - - admin_ctxt.edit_managed_records = True - - self.assertFalse(ctxt.is_admin) - self.assertTrue(admin_ctxt.is_admin) - self.assertTrue(admin_ctxt.edit_managed_records) - - def test_edit_managed_records_failure(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - with testtools.ExpectedException(exceptions.Forbidden): - ctxt.edit_managed_records = True - - def test_hard_delete(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - admin_ctxt = ctxt.elevated() - - admin_ctxt.hard_delete = True - - self.assertFalse(ctxt.is_admin) - self.assertTrue(admin_ctxt.is_admin) - self.assertTrue(admin_ctxt.hard_delete) - - def test_hard_delete_failure(self): - ctxt = context.DesignateContext( - user_id='12345', project_id='54321' - ) - with testtools.ExpectedException(exceptions.Forbidden): - ctxt.hard_delete = True - @mock.patch.object(policy, 'check') def test_sudo(self, mock_policy_check): ctxt = context.DesignateContext( @@ -133,3 +41,36 @@ class TestDesignateContext(designate.tests.TestCase): self.assertTrue(mock_policy_check.called) self.assertEqual('new_project', ctxt.project_id) self.assertEqual('old_project', ctxt.original_project_id) + + def test_get_auth_plugin(self): + ctx = context.DesignateContext() + self.assertIsInstance( + ctx.get_auth_plugin(), context._ContextAuthPlugin + ) + + @mock.patch('keystoneauth1.access.service_catalog.ServiceCatalogV2') + def test_get_auth_plugin_get_endpoint(self, mock_sc): + mock_session = mock.Mock() + mock_service_catalog = mock.Mock() + mock_sc.return_value = mock_service_catalog + + ctx = context.DesignateContext( + auth_token='token', service_catalog='catalog' + ) + + auth_plugin = ctx.get_auth_plugin() + auth_plugin.get_endpoint_data = mock.Mock() + auth_plugin.get_endpoint(mock_session) + + mock_sc.assert_called_with('catalog') + mock_service_catalog.url_for.assert_called_with( + service_type=None, service_name=None, interface=None, + region_name=None + ) + auth_plugin.get_endpoint_data.assert_called() + + def test_get_auth_plugin_user(self): + ctx = context.DesignateContext( + user_auth_plugin='user_auth_plugin' + ) + self.assertEqual('user_auth_plugin', ctx.get_auth_plugin())