Merge "Cleanup context code"

This commit is contained in:
Zuul 2023-12-09 16:15:39 +00:00 committed by Gerrit Code Review
commit ca4b0ef515
3 changed files with 160 additions and 121 deletions

View File

@ -63,9 +63,7 @@ class DesignateContext(context.RequestContext):
self.delete_shares = delete_shares self.delete_shares = delete_shares
def deepcopy(self): def deepcopy(self):
d = self.to_dict() return self.from_dict(self.to_dict())
return self.from_dict(d)
def to_dict(self): def to_dict(self):
d = super().to_dict() d = super().to_dict()
@ -215,14 +213,13 @@ class DesignateContext(context.RequestContext):
def get_auth_plugin(self): def get_auth_plugin(self):
if self.user_auth_plugin: if self.user_auth_plugin:
return self.user_auth_plugin return self.user_auth_plugin
else: return _ContextAuthPlugin(self.auth_token, self.service_catalog)
return _ContextAuthPlugin(self.auth_token, self.service_catalog)
class _ContextAuthPlugin(plugin.BaseAuthPlugin): class _ContextAuthPlugin(plugin.BaseAuthPlugin):
"""A keystoneauth auth plugin that uses the values from the Context. """A keystoneauth auth plugin that uses the values from the Context.
Ideally we would use the plugin provided by auth_token middleware however Ideally we would use the plugin provided by auth_token middleware however
this plugin isn't serialized yet so we construct one from the serialized this plugin isn't serialized yet, so we construct one from the serialized
auth data. auth data.
""" """
def __init__(self, auth_token, sc): def __init__(self, auth_token, sc):
@ -234,26 +231,13 @@ class _ContextAuthPlugin(plugin.BaseAuthPlugin):
def get_token(self, *args, **kwargs): def get_token(self, *args, **kwargs):
return self.auth_token return self.auth_token
def get_endpoint(self, session, **kwargs): def get_endpoint(self, session, service_type=None, interface=None,
endpoint_data = self.get_endpoint_data(session, **kwargs) region_name=None, service_name=None, **kwargs):
if not endpoint_data: endpoint = self.service_catalog.url_for(
return None service_type=service_type, service_name=service_name,
return endpoint_data.url interface=interface, region_name=region_name
)
def get_endpoint_data(self, session, return self.get_endpoint_data(session, endpoint_override=endpoint).url
endpoint_override=None,
discover_versions=True,
**kwargs):
urlkw = {}
for k in ('service_type', 'service_name', 'service_id', 'endpoint_id',
'region_name', 'interface'):
if k in kwargs:
urlkw[k] = kwargs[k]
endpoint = endpoint_override or self.service_catalog.url_for(**urlkw)
return super().get_endpoint_data(
session, endpoint_override=endpoint,
discover_versions=discover_versions, **kwargs)
def get_current(): def get_current():

View File

@ -0,0 +1,114 @@
# Copyright 2012 Managed I.T.
#
# Author: Kiall Mac Innes <kiall@managedit.ie>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import testtools
from designate import context
from designate import exceptions
import designate.tests
class TestDesignateContext(designate.tests.TestCase):
def test_deepcopy(self):
orig = context.DesignateContext(
user_id='12345', project_id='54321'
)
copy = orig.deepcopy()
self.assertEqual(orig.to_dict(), copy.to_dict())
def test_elevated(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertEqual(0, len(ctxt.roles))
def test_elevated_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(hard_delete=True)
self.assertTrue(admin_ctxt.hard_delete)
def test_elevated_with_show_deleted(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(show_deleted=True)
self.assertTrue(admin_ctxt.show_deleted)
def test_all_tenants(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.all_tenants = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.all_tenants)
def test_all_tenants_policy_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.all_tenants = True
def test_edit_managed_records(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.edit_managed_records = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.edit_managed_records)
def test_edit_managed_records_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.edit_managed_records = True
def test_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.hard_delete = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.hard_delete)
def test_hard_delete_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.hard_delete = True

View File

@ -13,25 +13,16 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from unittest import mock from unittest import mock
import testtools import oslotest.base
from designate import context from designate import context
from designate import exceptions
from designate import policy from designate import policy
import designate.tests
class TestDesignateContext(designate.tests.TestCase): class TestDesignateContext(oslotest.base.BaseTestCase):
def test_deepcopy(self):
orig = context.DesignateContext(
user_id='12345', project_id='54321'
)
copy = orig.deepcopy()
self.assertEqual(orig.to_dict(), copy.to_dict())
def test_tsigkey_id_override(self): def test_tsigkey_id_override(self):
orig = context.DesignateContext( orig = context.DesignateContext(
tsigkey_id='12345', project_id='54321' tsigkey_id='12345', project_id='54321'
@ -40,89 +31,6 @@ class TestDesignateContext(designate.tests.TestCase):
self.assertEqual('TSIG:12345 54321 - - -', copy['user_identity']) self.assertEqual('TSIG:12345 54321 - - -', copy['user_identity'])
def test_elevated(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertEqual(0, len(ctxt.roles))
def test_elevated_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(hard_delete=True)
self.assertTrue(admin_ctxt.hard_delete)
def test_elevated_with_show_deleted(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(show_deleted=True)
self.assertTrue(admin_ctxt.show_deleted)
def test_all_tenants(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.all_tenants = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.all_tenants)
def test_all_tenants_policy_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.all_tenants = True
def test_edit_managed_records(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.edit_managed_records = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.edit_managed_records)
def test_edit_managed_records_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.edit_managed_records = True
def test_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.hard_delete = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.hard_delete)
def test_hard_delete_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.hard_delete = True
@mock.patch.object(policy, 'check') @mock.patch.object(policy, 'check')
def test_sudo(self, mock_policy_check): def test_sudo(self, mock_policy_check):
ctxt = context.DesignateContext( ctxt = context.DesignateContext(
@ -133,3 +41,36 @@ class TestDesignateContext(designate.tests.TestCase):
self.assertTrue(mock_policy_check.called) self.assertTrue(mock_policy_check.called)
self.assertEqual('new_project', ctxt.project_id) self.assertEqual('new_project', ctxt.project_id)
self.assertEqual('old_project', ctxt.original_project_id) self.assertEqual('old_project', ctxt.original_project_id)
def test_get_auth_plugin(self):
ctx = context.DesignateContext()
self.assertIsInstance(
ctx.get_auth_plugin(), context._ContextAuthPlugin
)
@mock.patch('keystoneauth1.access.service_catalog.ServiceCatalogV2')
def test_get_auth_plugin_get_endpoint(self, mock_sc):
mock_session = mock.Mock()
mock_service_catalog = mock.Mock()
mock_sc.return_value = mock_service_catalog
ctx = context.DesignateContext(
auth_token='token', service_catalog='catalog'
)
auth_plugin = ctx.get_auth_plugin()
auth_plugin.get_endpoint_data = mock.Mock()
auth_plugin.get_endpoint(mock_session)
mock_sc.assert_called_with('catalog')
mock_service_catalog.url_for.assert_called_with(
service_type=None, service_name=None, interface=None,
region_name=None
)
auth_plugin.get_endpoint_data.assert_called()
def test_get_auth_plugin_user(self):
ctx = context.DesignateContext(
user_auth_plugin='user_auth_plugin'
)
self.assertEqual('user_auth_plugin', ctx.get_auth_plugin())