Merge "Cleanup context code"

This commit is contained in:
Zuul 2023-12-09 16:15:39 +00:00 committed by Gerrit Code Review
commit ca4b0ef515
3 changed files with 160 additions and 121 deletions

View File

@ -63,9 +63,7 @@ class DesignateContext(context.RequestContext):
self.delete_shares = delete_shares
def deepcopy(self):
d = self.to_dict()
return self.from_dict(d)
return self.from_dict(self.to_dict())
def to_dict(self):
d = super().to_dict()
@ -215,14 +213,13 @@ class DesignateContext(context.RequestContext):
def get_auth_plugin(self):
if self.user_auth_plugin:
return self.user_auth_plugin
else:
return _ContextAuthPlugin(self.auth_token, self.service_catalog)
return _ContextAuthPlugin(self.auth_token, self.service_catalog)
class _ContextAuthPlugin(plugin.BaseAuthPlugin):
"""A keystoneauth auth plugin that uses the values from the Context.
Ideally we would use the plugin provided by auth_token middleware however
this plugin isn't serialized yet so we construct one from the serialized
this plugin isn't serialized yet, so we construct one from the serialized
auth data.
"""
def __init__(self, auth_token, sc):
@ -234,26 +231,13 @@ class _ContextAuthPlugin(plugin.BaseAuthPlugin):
def get_token(self, *args, **kwargs):
return self.auth_token
def get_endpoint(self, session, **kwargs):
endpoint_data = self.get_endpoint_data(session, **kwargs)
if not endpoint_data:
return None
return endpoint_data.url
def get_endpoint_data(self, session,
endpoint_override=None,
discover_versions=True,
**kwargs):
urlkw = {}
for k in ('service_type', 'service_name', 'service_id', 'endpoint_id',
'region_name', 'interface'):
if k in kwargs:
urlkw[k] = kwargs[k]
endpoint = endpoint_override or self.service_catalog.url_for(**urlkw)
return super().get_endpoint_data(
session, endpoint_override=endpoint,
discover_versions=discover_versions, **kwargs)
def get_endpoint(self, session, service_type=None, interface=None,
region_name=None, service_name=None, **kwargs):
endpoint = self.service_catalog.url_for(
service_type=service_type, service_name=service_name,
interface=interface, region_name=region_name
)
return self.get_endpoint_data(session, endpoint_override=endpoint).url
def get_current():

View File

@ -0,0 +1,114 @@
# Copyright 2012 Managed I.T.
#
# Author: Kiall Mac Innes <kiall@managedit.ie>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import testtools
from designate import context
from designate import exceptions
import designate.tests
class TestDesignateContext(designate.tests.TestCase):
def test_deepcopy(self):
orig = context.DesignateContext(
user_id='12345', project_id='54321'
)
copy = orig.deepcopy()
self.assertEqual(orig.to_dict(), copy.to_dict())
def test_elevated(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertEqual(0, len(ctxt.roles))
def test_elevated_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(hard_delete=True)
self.assertTrue(admin_ctxt.hard_delete)
def test_elevated_with_show_deleted(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(show_deleted=True)
self.assertTrue(admin_ctxt.show_deleted)
def test_all_tenants(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.all_tenants = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.all_tenants)
def test_all_tenants_policy_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.all_tenants = True
def test_edit_managed_records(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.edit_managed_records = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.edit_managed_records)
def test_edit_managed_records_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.edit_managed_records = True
def test_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.hard_delete = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.hard_delete)
def test_hard_delete_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.hard_delete = True

View File

@ -13,25 +13,16 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
import testtools
import oslotest.base
from designate import context
from designate import exceptions
from designate import policy
import designate.tests
class TestDesignateContext(designate.tests.TestCase):
def test_deepcopy(self):
orig = context.DesignateContext(
user_id='12345', project_id='54321'
)
copy = orig.deepcopy()
self.assertEqual(orig.to_dict(), copy.to_dict())
class TestDesignateContext(oslotest.base.BaseTestCase):
def test_tsigkey_id_override(self):
orig = context.DesignateContext(
tsigkey_id='12345', project_id='54321'
@ -40,89 +31,6 @@ class TestDesignateContext(designate.tests.TestCase):
self.assertEqual('TSIG:12345 54321 - - -', copy['user_identity'])
def test_elevated(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertEqual(0, len(ctxt.roles))
def test_elevated_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(hard_delete=True)
self.assertTrue(admin_ctxt.hard_delete)
def test_elevated_with_show_deleted(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated(show_deleted=True)
self.assertTrue(admin_ctxt.show_deleted)
def test_all_tenants(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.all_tenants = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.all_tenants)
def test_all_tenants_policy_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.all_tenants = True
def test_edit_managed_records(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.edit_managed_records = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.edit_managed_records)
def test_edit_managed_records_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.edit_managed_records = True
def test_hard_delete(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
admin_ctxt = ctxt.elevated()
admin_ctxt.hard_delete = True
self.assertFalse(ctxt.is_admin)
self.assertTrue(admin_ctxt.is_admin)
self.assertTrue(admin_ctxt.hard_delete)
def test_hard_delete_failure(self):
ctxt = context.DesignateContext(
user_id='12345', project_id='54321'
)
with testtools.ExpectedException(exceptions.Forbidden):
ctxt.hard_delete = True
@mock.patch.object(policy, 'check')
def test_sudo(self, mock_policy_check):
ctxt = context.DesignateContext(
@ -133,3 +41,36 @@ class TestDesignateContext(designate.tests.TestCase):
self.assertTrue(mock_policy_check.called)
self.assertEqual('new_project', ctxt.project_id)
self.assertEqual('old_project', ctxt.original_project_id)
def test_get_auth_plugin(self):
ctx = context.DesignateContext()
self.assertIsInstance(
ctx.get_auth_plugin(), context._ContextAuthPlugin
)
@mock.patch('keystoneauth1.access.service_catalog.ServiceCatalogV2')
def test_get_auth_plugin_get_endpoint(self, mock_sc):
mock_session = mock.Mock()
mock_service_catalog = mock.Mock()
mock_sc.return_value = mock_service_catalog
ctx = context.DesignateContext(
auth_token='token', service_catalog='catalog'
)
auth_plugin = ctx.get_auth_plugin()
auth_plugin.get_endpoint_data = mock.Mock()
auth_plugin.get_endpoint(mock_session)
mock_sc.assert_called_with('catalog')
mock_service_catalog.url_for.assert_called_with(
service_type=None, service_name=None, interface=None,
region_name=None
)
auth_plugin.get_endpoint_data.assert_called()
def test_get_auth_plugin_user(self):
ctx = context.DesignateContext(
user_auth_plugin='user_auth_plugin'
)
self.assertEqual('user_auth_plugin', ctx.get_auth_plugin())