diff --git a/zun/common/crypt.py b/zun/common/crypt.py new file mode 100644 index 000000000..a9ce538ed --- /dev/null +++ b/zun/common/crypt.py @@ -0,0 +1,53 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 + +from cryptography import fernet +from oslo_config import cfg +from oslo_utils import encodeutils + +from zun.common import exception + + +def encrypt(value, encryption_key=None): + if value is None: + return None + + encryption_key = get_valid_encryption_key(encryption_key) + encoded_key = base64.b64encode(encryption_key.encode('utf-8')) + sym = fernet.Fernet(encoded_key) + res = sym.encrypt(encodeutils.safe_encode(value)) + return encodeutils.safe_decode(res) + + +def decrypt(data, encryption_key=None): + if data is None: + return None + + encryption_key = get_valid_encryption_key(encryption_key) + encoded_key = base64.b64encode(encryption_key.encode('utf-8')) + sym = fernet.Fernet(encoded_key) + try: + value = sym.decrypt(encodeutils.safe_encode(data)) + if value is not None: + return encodeutils.safe_decode(value, 'utf-8') + except fernet.InvalidToken: + raise exception.InvalidEncryptionKey() + + +def get_valid_encryption_key(encryption_key): + if encryption_key is None: + encryption_key = cfg.CONF.auth_encryption_key + + return encryption_key[:32] diff --git a/zun/common/exception.py b/zun/common/exception.py index ef9e66629..591b92218 100644 --- a/zun/common/exception.py +++ b/zun/common/exception.py @@ -754,3 +754,8 @@ class QuotaResourceUnknown(QuotaNotFound): class Base64Exception(Invalid): message = _("Invalid Base 64 file data") + + +class InvalidEncryptionKey(ZunException): + message = _('Can not decrypt data with the auth_encryption_key ' + 'in zun config.') diff --git a/zun/conf/utils.py b/zun/conf/utils.py index a632fdaf4..f7e28a090 100644 --- a/zun/conf/utils.py +++ b/zun/conf/utils.py @@ -18,6 +18,11 @@ utils_opts = [ default="/etc/zun/rootwrap.conf", help='Path to the rootwrap configuration file to use for ' 'running commands as root.'), + cfg.StrOpt('auth_encryption_key', + secret=True, + default='notgood but just long enough i t', + help='Key used to encrypt authentication info in the ' + 'database. Length of this key must be 32 characters.'), ] diff --git a/zun/db/sqlalchemy/api.py b/zun/db/sqlalchemy/api.py index 11438dba6..8df29ef1d 100644 --- a/zun/db/sqlalchemy/api.py +++ b/zun/db/sqlalchemy/api.py @@ -29,6 +29,7 @@ from sqlalchemy.sql.expression import desc from sqlalchemy.sql import func from zun.common import consts +from zun.common import crypt from zun.common import exception from zun.common.i18n import _ import zun.conf @@ -1364,14 +1365,21 @@ class Connection(object): query = model_query(models.Registry) query = self._add_project_filters(context, query) query = self._add_registries_filters(query, filters) - return _paginate_query(models.Registry, limit, marker, - sort_key, sort_dir, query) + result = _paginate_query(models.Registry, limit, marker, + sort_key, sort_dir, query) + for row in result: + row['password'] = crypt.decrypt(row['password']) + return result def create_registry(self, context, values): # ensure defaults are present for new registries if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() + original_password = values.get('password') + if original_password: + values['password'] = crypt.encrypt(values.get('password')) + registry = models.Registry() registry.update(values) try: @@ -1379,6 +1387,12 @@ class Connection(object): except db_exc.DBDuplicateEntry: raise exception.RegistryAlreadyExists( field='UUID', value=values['uuid']) + + if original_password: + # the password is encrypted but we want to return the original + # password + registry['password'] = original_password + return registry def update_registry(self, context, registry_uuid, values): @@ -1386,7 +1400,18 @@ class Connection(object): if 'uuid' in values: msg = _("Cannot overwrite UUID for an existing registry.") raise exception.InvalidParameterValue(err=msg) - return self._do_update_registry(registry_uuid, values) + + original_password = values.get('password') + if original_password: + values['password'] = crypt.encrypt(values.get('password')) + + updated = self._do_update_registry(registry_uuid, values) + if original_password: + # the password is encrypted but we want to return the original + # password + updated['password'] = original_password + + return updated def _do_update_registry(self, registry_uuid, values): session = get_session() @@ -1406,7 +1431,9 @@ class Connection(object): query = self._add_project_filters(context, query) query = query.filter_by(uuid=registry_uuid) try: - return query.one() + result = query.one() + result['password'] = crypt.decrypt(result['password']) + return result except NoResultFound: raise exception.RegistryNotFound(registry=registry_uuid) @@ -1415,7 +1442,9 @@ class Connection(object): query = self._add_project_filters(context, query) query = query.filter_by(name=registry_name) try: - return query.one() + result = query.one() + result['password'] = crypt.decrypt(result['password']) + return result except NoResultFound: raise exception.RegistryNotFound(registry=registry_name) except MultipleResultsFound: diff --git a/zun/tests/unit/db/test_registry.py b/zun/tests/unit/db/test_registry.py index d1961379d..60b80cc49 100644 --- a/zun/tests/unit/db/test_registry.py +++ b/zun/tests/unit/db/test_registry.py @@ -29,7 +29,13 @@ class DbRegistryTestCase(base.DbTestCase): super(DbRegistryTestCase, self).setUp() def test_create_registry(self): - utils.create_test_registry(context=self.context) + username = 'fake-user' + password = 'fake-pass' + registry = utils.create_test_registry(context=self.context, + username=username, + password=password) + self.assertEqual(username, registry.username) + self.assertEqual(password, registry.password) def test_create_registry_already_exists(self): utils.create_test_registry(context=self.context, @@ -40,18 +46,30 @@ class DbRegistryTestCase(base.DbTestCase): uuid='123') def test_get_registry_by_uuid(self): - registry = utils.create_test_registry(context=self.context) + username = 'fake-user' + password = 'fake-pass' + registry = utils.create_test_registry(context=self.context, + username=username, + password=password) res = dbapi.get_registry_by_uuid(self.context, registry.uuid) self.assertEqual(registry.id, res.id) self.assertEqual(registry.uuid, res.uuid) + self.assertEqual(username, res.username) + self.assertEqual(password, res.password) def test_get_registry_by_name(self): - registry = utils.create_test_registry(context=self.context) + username = 'fake-user' + password = 'fake-pass' + registry = utils.create_test_registry(context=self.context, + username=username, + password=password) res = dbapi.get_registry_by_name( self.context, registry.name) self.assertEqual(registry.id, res.id) self.assertEqual(registry.uuid, res.uuid) + self.assertEqual(username, res.username) + self.assertEqual(password, res.password) def test_get_registry_that_does_not_exist(self): self.assertRaises(exception.RegistryNotFound, @@ -61,15 +79,21 @@ class DbRegistryTestCase(base.DbTestCase): def test_list_registries(self): uuids = [] + passwords = [] for i in range(1, 6): + password = 'pass' + str(i) + passwords.append(password) registry = utils.create_test_registry( uuid=uuidutils.generate_uuid(), context=self.context, - name='registry' + str(i)) + name='registry' + str(i), + password=password) uuids.append(six.text_type(registry['uuid'])) res = dbapi.list_registries(self.context) res_uuids = [r.uuid for r in res] self.assertEqual(sorted(uuids), sorted(res_uuids)) + res_passwords = [r.password for r in res] + self.assertEqual(sorted(passwords), sorted(res_passwords)) def test_list_registries_sorted(self): uuids = [] @@ -153,11 +177,14 @@ class DbRegistryTestCase(base.DbTestCase): registry = utils.create_test_registry(context=self.context) old_name = registry.name new_name = 'new-name' + new_password = 'new-pass' self.assertNotEqual(old_name, new_name) res = dbapi.update_registry(self.context, registry.id, - {'name': new_name}) + {'name': new_name, + 'password': new_password}) self.assertEqual(new_name, res.name) + self.assertEqual(new_password, res.password) def test_update_registry_not_found(self): registry_uuid = uuidutils.generate_uuid()