diff --git a/cloudbaseinit/plugins/common/setuserpassword.py b/cloudbaseinit/plugins/common/setuserpassword.py index 76f436a2..a9af9311 100644 --- a/cloudbaseinit/plugins/common/setuserpassword.py +++ b/cloudbaseinit/plugins/common/setuserpassword.py @@ -12,8 +12,6 @@ # License for the specific language governing permissions and limitations # under the License. -import base64 - from oslo_log import log as oslo_logging from cloudbaseinit import conf as cloudbaseinit_conf @@ -23,7 +21,6 @@ from cloudbaseinit.plugins.common import base from cloudbaseinit.plugins.common import constants as plugin_constant from cloudbaseinit.utils import crypt - CONF = cloudbaseinit_conf.CONF LOG = oslo_logging.getLogger(__name__) @@ -32,9 +29,7 @@ class SetUserPasswordPlugin(base.BasePlugin): def _encrypt_password(self, ssh_pub_key, password): cm = crypt.CryptManager() - with cm.load_ssh_rsa_public_key(ssh_pub_key) as rsa: - enc_password = rsa.public_encrypt(password.encode()) - return base64.b64encode(enc_password) + return cm.public_encrypt(ssh_pub_key, password) def _get_password(self, service, shared_data): injected = False diff --git a/cloudbaseinit/tests/plugins/common/test_setuserpassword.py b/cloudbaseinit/tests/plugins/common/test_setuserpassword.py index 5b174889..ed005c74 100644 --- a/cloudbaseinit/tests/plugins/common/test_setuserpassword.py +++ b/cloudbaseinit/tests/plugins/common/test_setuserpassword.py @@ -37,25 +37,17 @@ class SetUserPasswordPluginTests(unittest.TestCase): self.fake_data = fake_json_response.get_fake_metadata_json( '2013-04-04') - @mock.patch('base64.b64encode') @mock.patch('cloudbaseinit.utils.crypt.CryptManager' - '.load_ssh_rsa_public_key') - def test_encrypt_password(self, mock_load_ssh_key, mock_b64encode): - mock_rsa = mock.MagicMock() - fake_ssh_pub_key = 'fake key' + '.public_encrypt') + def test_encrypt_password(self, mock_public_encrypt): + fake_ssh_pub_key = 'ssh-rsa key' fake_password = 'fake password' - mock_load_ssh_key.return_value = mock_rsa - mock_rsa.__enter__().public_encrypt.return_value = 'public encrypted' - mock_b64encode.return_value = 'encrypted password' - + fake_encrypt_pwd = 'encrypted password' + mock_public_encrypt.return_value = fake_encrypt_pwd response = self._setpassword_plugin._encrypt_password( fake_ssh_pub_key, fake_password) - mock_load_ssh_key.assert_called_with(fake_ssh_pub_key) - mock_rsa.__enter__().public_encrypt.assert_called_with( - b'fake password') - mock_b64encode.assert_called_with('public encrypted') - self.assertEqual('encrypted password', response) + self.assertEqual(fake_encrypt_pwd, response) def _test_get_password(self, inject_password): shared_data = {} diff --git a/cloudbaseinit/tests/utils/test_crypt.py b/cloudbaseinit/tests/utils/test_crypt.py index 4abfaaa9..83f08710 100644 --- a/cloudbaseinit/tests/utils/test_crypt.py +++ b/cloudbaseinit/tests/utils/test_crypt.py @@ -16,17 +16,13 @@ import unittest from cloudbaseinit.utils import crypt - -class TestOpenSSLException(unittest.TestCase): - - def setUp(self): - self._openssl = crypt.OpenSSLException() - - def test_get_openssl_error_msg(self): - expected_err_msg = u'error:00000000:lib(0):func(0):reason(0)' - expected_err_msg_py10 = u'error:00000000:lib(0)::reason(0)' - err_msg = self._openssl._get_openssl_error_msg() - self.assertIn(err_msg, [expected_err_msg, expected_err_msg_py10]) +PUB_KEY = ''' +AAAAB3NzaC1yc2EAAAADAQABAAABAQDP1e9IAYXwwUKuFtoReGXidwnM1RuXWB53IO0Hg +mbZArXvEIOfgm/l6IsOJwF7znOBn0hClW7ZONPweX1Al9Hy/LInX1x96Aamq4yyKQCmHDiuZc7Qwu +xr82Ph8XfWic/wo4es/ODSYeFT5NoFDhsYII8O9EGoubpQdakxt9skX0X+zg8TYPuIOANGhlaN8nn +U7gYbO7Gt9vZDmYeRACthNzCIg+w38oxmcgmQqQHxPEp4tUtuFfpjptyVvHz273QvisbdymD3RO0L +9oGMdKzjGgcdE1VuhXuucnUWlZuKe7BirxF8glF5NHKzWto67lDRzVI/F1snkTAorm5EWkA9 test +''' class TestCryptManager(unittest.TestCase): @@ -36,6 +32,16 @@ class TestCryptManager(unittest.TestCase): def test_load_ssh_rsa_public_key_invalid(self): ssh_pub_key = "ssh" - exc = Exception - self.assertRaises(exc, self._crypt_manager.load_ssh_rsa_public_key, - ssh_pub_key) + exc = crypt.CryptException + self.assertRaises(exc, self._crypt_manager.public_encrypt, + ssh_pub_key, '') + + def test_encrypt_password(self): + ssh_pub_key = "ssh-rsa " + PUB_KEY.replace('\n', "") + password = 'testpassword' + + response = self._crypt_manager.public_encrypt( + ssh_pub_key, password) + + self.assertTrue(len(response) > 0) + self.assertTrue(isinstance(response, bytes)) diff --git a/cloudbaseinit/utils/crypt.py b/cloudbaseinit/utils/crypt.py index 22b66103..eb1d7fcf 100644 --- a/cloudbaseinit/utils/crypt.py +++ b/cloudbaseinit/utils/crypt.py @@ -13,223 +13,28 @@ # under the License. import base64 -import ctypes -import ctypes.util -import struct -import sys -clib_path = ctypes.util.find_library("c") - -if sys.platform == "win32": - if clib_path: - clib = ctypes.CDLL(clib_path) - else: - clib = ctypes.cdll.ucrtbase - - # for backwards compatibility, try the older names - # libcrypto-1_1 comes bundled with PY 3.7 to 3.12 - ssl_lib_names = [ - "libcrypto-1_1", - "libcrypto", - "libeay32" - ] - - for ssl_lib_name in ssl_lib_names: - try: - openssl = ctypes.CDLL(ssl_lib_name) - break - except Exception: - pass -else: - clib = ctypes.CDLL(clib_path) - openssl_lib_path = ctypes.util.find_library("ssl") - openssl = ctypes.CDLL(openssl_lib_path) - - -class RSA(ctypes.Structure): - _fields_ = [ - ("pad", ctypes.c_int), - ("version", ctypes.c_long), - ("meth", ctypes.c_void_p), - ("engine", ctypes.c_void_p), - ("n", ctypes.c_void_p), - ("e", ctypes.c_void_p), - ("d", ctypes.c_void_p), - ("p", ctypes.c_void_p), - ("q", ctypes.c_void_p), - ("dmp1", ctypes.c_void_p), - ("dmq1", ctypes.c_void_p), - ("iqmp", ctypes.c_void_p), - ("sk", ctypes.c_void_p), - ("dummy", ctypes.c_int), - ("references", ctypes.c_int), - ("flags", ctypes.c_int), - ("_method_mod_n", ctypes.c_void_p), - ("_method_mod_p", ctypes.c_void_p), - ("_method_mod_q", ctypes.c_void_p), - ("bignum_data", ctypes.c_char_p), - ("blinding", ctypes.c_void_p), - ("mt_blinding", ctypes.c_void_p) - ] - - -openssl.RSA_PKCS1_PADDING = 1 - -openssl.RSA_new.restype = ctypes.POINTER(RSA) - -openssl.BN_bin2bn.restype = ctypes.c_void_p -openssl.BN_bin2bn.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p] - -openssl.BN_new.restype = ctypes.c_void_p - -openssl.RSA_size.restype = ctypes.c_int -openssl.RSA_size.argtypes = [ctypes.POINTER(RSA)] - -openssl.RSA_public_encrypt.argtypes = [ctypes.c_int, - ctypes.c_char_p, - ctypes.c_char_p, - ctypes.POINTER(RSA), - ctypes.c_int] -openssl.RSA_public_encrypt.restype = ctypes.c_int - -openssl.RSA_free.argtypes = [ctypes.POINTER(RSA)] - -openssl.PEM_write_RSAPublicKey.restype = ctypes.c_int -openssl.PEM_write_RSAPublicKey.argtypes = [ctypes.c_void_p, - ctypes.POINTER(RSA)] - -openssl.ERR_get_error.restype = ctypes.c_long -openssl.ERR_get_error.argtypes = [] - -openssl.ERR_error_string_n.restype = ctypes.c_void_p -openssl.ERR_error_string_n.argtypes = [ctypes.c_long, - ctypes.c_char_p, - ctypes.c_int] - -try: - openssl.ERR_load_crypto_strings.restype = ctypes.c_int - openssl.ERR_load_crypto_strings.argtypes = [] -except AttributeError: - # NOTE(avladu): This function is deprecated and no longer needed - # since OpenSSL 1.1 - pass - -clib.fopen.restype = ctypes.c_void_p -clib.fopen.argtypes = [ctypes.c_char_p, ctypes.c_char_p] - -clib.fclose.restype = ctypes.c_int -clib.fclose.argtypes = [ctypes.c_void_p] +from cryptography.hazmat import backends +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives import serialization class CryptException(Exception): pass -class OpenSSLException(CryptException): - - def __init__(self): - message = self._get_openssl_error_msg() - super(OpenSSLException, self).__init__(message) - - def _get_openssl_error_msg(self): - try: - openssl.ERR_load_crypto_strings() - except AttributeError: - pass - - errno = openssl.ERR_get_error() - errbuf = ctypes.create_string_buffer(1024) - openssl.ERR_error_string_n(errno, errbuf, 1024) - return errbuf.value.decode("ascii") - - -class RSAWrapper(object): - - def __init__(self, rsa_p): - self._rsa_p = rsa_p - - def __enter__(self): - return self - - def __exit__(self, tp, value, tb): - self.free() - - def free(self): - openssl.RSA_free(self._rsa_p) - - def public_encrypt(self, clear_text): - flen = len(clear_text) - rsa_size = openssl.RSA_size(self._rsa_p) - enc_text = ctypes.create_string_buffer(rsa_size) - - enc_text_len = openssl.RSA_public_encrypt(flen, - clear_text, - enc_text, - self._rsa_p, - openssl.RSA_PKCS1_PADDING) - if enc_text_len == -1: - raise OpenSSLException() - - return enc_text[:enc_text_len] - - class CryptManager(object): - def load_ssh_rsa_public_key(self, ssh_pub_key): + def public_encrypt(self, ssh_pub_key, password): ssh_rsa_prefix = "ssh-rsa " if not ssh_pub_key.startswith(ssh_rsa_prefix): raise CryptException('Invalid SSH key') - s = ssh_pub_key[len(ssh_rsa_prefix):] - idx = s.find(' ') - if idx >= 0: - b64_pub_key = s[:idx] - else: - b64_pub_key = s - - pub_key = base64.b64decode(b64_pub_key) - - offset = 0 - - key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] - offset += 4 - - key_type = pub_key[offset:offset + key_type_len].decode('utf-8') - offset += key_type_len - - if key_type not in ['ssh-rsa', 'rsa', 'rsa1']: - raise CryptException('Unsupported SSH key type "%s". ' - 'Only RSA keys are currently supported' - % key_type) - - rsa_p = openssl.RSA_new() - try: - rsa_p.contents.e = openssl.BN_new() - rsa_p.contents.n = openssl.BN_new() - - e_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] - offset += 4 - - e_key_bin = pub_key[offset:offset + e_len] - offset += e_len - - if not openssl.BN_bin2bn(e_key_bin, e_len, rsa_p.contents.e): - raise OpenSSLException() - - n_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] - offset += 4 - - n_key_bin = pub_key[offset:offset + n_len] - offset += n_len - - if offset != len(pub_key): - raise CryptException('Invalid SSH key') - - if not openssl.BN_bin2bn(n_key_bin, n_len, rsa_p.contents.n): - raise OpenSSLException() - - return RSAWrapper(rsa_p) - except Exception: - openssl.RSA_free(rsa_p) - raise + rsa_public_key = serialization.load_ssh_public_key( + ssh_pub_key.encode(), backends.default_backend()) + enc_password = rsa_public_key.encrypt( + password.encode(), + padding.PKCS1v15() + ) + return base64.b64encode(enc_password) diff --git a/requirements.txt b/requirements.txt index b132bc8a..c53d37aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ PyYAML requests untangle==1.2.1 jinja2 +cryptography pywin32;sys_platform=="win32" comtypes;sys_platform=="win32" pymi;sys_platform=="win32"