crypto: replaced openssl with cryptography module

Use cryptography instead of the flaky openssl libraries loading.
If the libssl DLLs were not present or were present in another order
then the required one, there could be errors when securing the password
before sending it to the metadata service.

Fixes: https://github.com/cloudbase/cloudbase-init/issues/34

Change-Id: I1a2245e199f65f4665071ada9576dcae77a3a432
This commit is contained in:
Adrian Vladu 2020-02-03 19:31:17 +02:00
parent 576db310c2
commit ea5e8da627
5 changed files with 39 additions and 240 deletions

View File

@ -12,8 +12,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import base64
from oslo_log import log as oslo_logging from oslo_log import log as oslo_logging
from cloudbaseinit import conf as cloudbaseinit_conf from cloudbaseinit import conf as cloudbaseinit_conf
@ -23,7 +21,6 @@ from cloudbaseinit.plugins.common import base
from cloudbaseinit.plugins.common import constants as plugin_constant from cloudbaseinit.plugins.common import constants as plugin_constant
from cloudbaseinit.utils import crypt from cloudbaseinit.utils import crypt
CONF = cloudbaseinit_conf.CONF CONF = cloudbaseinit_conf.CONF
LOG = oslo_logging.getLogger(__name__) LOG = oslo_logging.getLogger(__name__)
@ -32,9 +29,7 @@ class SetUserPasswordPlugin(base.BasePlugin):
def _encrypt_password(self, ssh_pub_key, password): def _encrypt_password(self, ssh_pub_key, password):
cm = crypt.CryptManager() cm = crypt.CryptManager()
with cm.load_ssh_rsa_public_key(ssh_pub_key) as rsa: return cm.public_encrypt(ssh_pub_key, password)
enc_password = rsa.public_encrypt(password.encode())
return base64.b64encode(enc_password)
def _get_password(self, service, shared_data): def _get_password(self, service, shared_data):
injected = False injected = False

View File

@ -37,25 +37,17 @@ class SetUserPasswordPluginTests(unittest.TestCase):
self.fake_data = fake_json_response.get_fake_metadata_json( self.fake_data = fake_json_response.get_fake_metadata_json(
'2013-04-04') '2013-04-04')
@mock.patch('base64.b64encode')
@mock.patch('cloudbaseinit.utils.crypt.CryptManager' @mock.patch('cloudbaseinit.utils.crypt.CryptManager'
'.load_ssh_rsa_public_key') '.public_encrypt')
def test_encrypt_password(self, mock_load_ssh_key, mock_b64encode): def test_encrypt_password(self, mock_public_encrypt):
mock_rsa = mock.MagicMock() fake_ssh_pub_key = 'ssh-rsa key'
fake_ssh_pub_key = 'fake key'
fake_password = 'fake password' fake_password = 'fake password'
mock_load_ssh_key.return_value = mock_rsa fake_encrypt_pwd = 'encrypted password'
mock_rsa.__enter__().public_encrypt.return_value = 'public encrypted' mock_public_encrypt.return_value = fake_encrypt_pwd
mock_b64encode.return_value = 'encrypted password'
response = self._setpassword_plugin._encrypt_password( response = self._setpassword_plugin._encrypt_password(
fake_ssh_pub_key, fake_password) fake_ssh_pub_key, fake_password)
mock_load_ssh_key.assert_called_with(fake_ssh_pub_key) self.assertEqual(fake_encrypt_pwd, response)
mock_rsa.__enter__().public_encrypt.assert_called_with(
b'fake password')
mock_b64encode.assert_called_with('public encrypted')
self.assertEqual('encrypted password', response)
def _test_get_password(self, inject_password): def _test_get_password(self, inject_password):
shared_data = {} shared_data = {}

View File

@ -16,17 +16,13 @@ import unittest
from cloudbaseinit.utils import crypt from cloudbaseinit.utils import crypt
PUB_KEY = '''
class TestOpenSSLException(unittest.TestCase): AAAAB3NzaC1yc2EAAAADAQABAAABAQDP1e9IAYXwwUKuFtoReGXidwnM1RuXWB53IO0Hg
mbZArXvEIOfgm/l6IsOJwF7znOBn0hClW7ZONPweX1Al9Hy/LInX1x96Aamq4yyKQCmHDiuZc7Qwu
def setUp(self): xr82Ph8XfWic/wo4es/ODSYeFT5NoFDhsYII8O9EGoubpQdakxt9skX0X+zg8TYPuIOANGhlaN8nn
self._openssl = crypt.OpenSSLException() U7gYbO7Gt9vZDmYeRACthNzCIg+w38oxmcgmQqQHxPEp4tUtuFfpjptyVvHz273QvisbdymD3RO0L
9oGMdKzjGgcdE1VuhXuucnUWlZuKe7BirxF8glF5NHKzWto67lDRzVI/F1snkTAorm5EWkA9 test
def test_get_openssl_error_msg(self): '''
expected_err_msg = u'error:00000000:lib(0):func(0):reason(0)'
expected_err_msg_py10 = u'error:00000000:lib(0)::reason(0)'
err_msg = self._openssl._get_openssl_error_msg()
self.assertIn(err_msg, [expected_err_msg, expected_err_msg_py10])
class TestCryptManager(unittest.TestCase): class TestCryptManager(unittest.TestCase):
@ -36,6 +32,16 @@ class TestCryptManager(unittest.TestCase):
def test_load_ssh_rsa_public_key_invalid(self): def test_load_ssh_rsa_public_key_invalid(self):
ssh_pub_key = "ssh" ssh_pub_key = "ssh"
exc = Exception exc = crypt.CryptException
self.assertRaises(exc, self._crypt_manager.load_ssh_rsa_public_key, self.assertRaises(exc, self._crypt_manager.public_encrypt,
ssh_pub_key) ssh_pub_key, '')
def test_encrypt_password(self):
ssh_pub_key = "ssh-rsa " + PUB_KEY.replace('\n', "")
password = 'testpassword'
response = self._crypt_manager.public_encrypt(
ssh_pub_key, password)
self.assertTrue(len(response) > 0)
self.assertTrue(isinstance(response, bytes))

View File

@ -13,223 +13,28 @@
# under the License. # under the License.
import base64 import base64
import ctypes
import ctypes.util
import struct
import sys
clib_path = ctypes.util.find_library("c") from cryptography.hazmat import backends
from cryptography.hazmat.primitives.asymmetric import padding
if sys.platform == "win32": from cryptography.hazmat.primitives import serialization
if clib_path:
clib = ctypes.CDLL(clib_path)
else:
clib = ctypes.cdll.ucrtbase
# for backwards compatibility, try the older names
# libcrypto-1_1 comes bundled with PY 3.7 to 3.12
ssl_lib_names = [
"libcrypto-1_1",
"libcrypto",
"libeay32"
]
for ssl_lib_name in ssl_lib_names:
try:
openssl = ctypes.CDLL(ssl_lib_name)
break
except Exception:
pass
else:
clib = ctypes.CDLL(clib_path)
openssl_lib_path = ctypes.util.find_library("ssl")
openssl = ctypes.CDLL(openssl_lib_path)
class RSA(ctypes.Structure):
_fields_ = [
("pad", ctypes.c_int),
("version", ctypes.c_long),
("meth", ctypes.c_void_p),
("engine", ctypes.c_void_p),
("n", ctypes.c_void_p),
("e", ctypes.c_void_p),
("d", ctypes.c_void_p),
("p", ctypes.c_void_p),
("q", ctypes.c_void_p),
("dmp1", ctypes.c_void_p),
("dmq1", ctypes.c_void_p),
("iqmp", ctypes.c_void_p),
("sk", ctypes.c_void_p),
("dummy", ctypes.c_int),
("references", ctypes.c_int),
("flags", ctypes.c_int),
("_method_mod_n", ctypes.c_void_p),
("_method_mod_p", ctypes.c_void_p),
("_method_mod_q", ctypes.c_void_p),
("bignum_data", ctypes.c_char_p),
("blinding", ctypes.c_void_p),
("mt_blinding", ctypes.c_void_p)
]
openssl.RSA_PKCS1_PADDING = 1
openssl.RSA_new.restype = ctypes.POINTER(RSA)
openssl.BN_bin2bn.restype = ctypes.c_void_p
openssl.BN_bin2bn.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p]
openssl.BN_new.restype = ctypes.c_void_p
openssl.RSA_size.restype = ctypes.c_int
openssl.RSA_size.argtypes = [ctypes.POINTER(RSA)]
openssl.RSA_public_encrypt.argtypes = [ctypes.c_int,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.POINTER(RSA),
ctypes.c_int]
openssl.RSA_public_encrypt.restype = ctypes.c_int
openssl.RSA_free.argtypes = [ctypes.POINTER(RSA)]
openssl.PEM_write_RSAPublicKey.restype = ctypes.c_int
openssl.PEM_write_RSAPublicKey.argtypes = [ctypes.c_void_p,
ctypes.POINTER(RSA)]
openssl.ERR_get_error.restype = ctypes.c_long
openssl.ERR_get_error.argtypes = []
openssl.ERR_error_string_n.restype = ctypes.c_void_p
openssl.ERR_error_string_n.argtypes = [ctypes.c_long,
ctypes.c_char_p,
ctypes.c_int]
try:
openssl.ERR_load_crypto_strings.restype = ctypes.c_int
openssl.ERR_load_crypto_strings.argtypes = []
except AttributeError:
# NOTE(avladu): This function is deprecated and no longer needed
# since OpenSSL 1.1
pass
clib.fopen.restype = ctypes.c_void_p
clib.fopen.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
clib.fclose.restype = ctypes.c_int
clib.fclose.argtypes = [ctypes.c_void_p]
class CryptException(Exception): class CryptException(Exception):
pass pass
class OpenSSLException(CryptException):
def __init__(self):
message = self._get_openssl_error_msg()
super(OpenSSLException, self).__init__(message)
def _get_openssl_error_msg(self):
try:
openssl.ERR_load_crypto_strings()
except AttributeError:
pass
errno = openssl.ERR_get_error()
errbuf = ctypes.create_string_buffer(1024)
openssl.ERR_error_string_n(errno, errbuf, 1024)
return errbuf.value.decode("ascii")
class RSAWrapper(object):
def __init__(self, rsa_p):
self._rsa_p = rsa_p
def __enter__(self):
return self
def __exit__(self, tp, value, tb):
self.free()
def free(self):
openssl.RSA_free(self._rsa_p)
def public_encrypt(self, clear_text):
flen = len(clear_text)
rsa_size = openssl.RSA_size(self._rsa_p)
enc_text = ctypes.create_string_buffer(rsa_size)
enc_text_len = openssl.RSA_public_encrypt(flen,
clear_text,
enc_text,
self._rsa_p,
openssl.RSA_PKCS1_PADDING)
if enc_text_len == -1:
raise OpenSSLException()
return enc_text[:enc_text_len]
class CryptManager(object): class CryptManager(object):
def load_ssh_rsa_public_key(self, ssh_pub_key): def public_encrypt(self, ssh_pub_key, password):
ssh_rsa_prefix = "ssh-rsa " ssh_rsa_prefix = "ssh-rsa "
if not ssh_pub_key.startswith(ssh_rsa_prefix): if not ssh_pub_key.startswith(ssh_rsa_prefix):
raise CryptException('Invalid SSH key') raise CryptException('Invalid SSH key')
s = ssh_pub_key[len(ssh_rsa_prefix):] rsa_public_key = serialization.load_ssh_public_key(
idx = s.find(' ') ssh_pub_key.encode(), backends.default_backend())
if idx >= 0: enc_password = rsa_public_key.encrypt(
b64_pub_key = s[:idx] password.encode(),
else: padding.PKCS1v15()
b64_pub_key = s )
return base64.b64encode(enc_password)
pub_key = base64.b64decode(b64_pub_key)
offset = 0
key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0]
offset += 4
key_type = pub_key[offset:offset + key_type_len].decode('utf-8')
offset += key_type_len
if key_type not in ['ssh-rsa', 'rsa', 'rsa1']:
raise CryptException('Unsupported SSH key type "%s". '
'Only RSA keys are currently supported'
% key_type)
rsa_p = openssl.RSA_new()
try:
rsa_p.contents.e = openssl.BN_new()
rsa_p.contents.n = openssl.BN_new()
e_len = struct.unpack('>I', pub_key[offset:offset + 4])[0]
offset += 4
e_key_bin = pub_key[offset:offset + e_len]
offset += e_len
if not openssl.BN_bin2bn(e_key_bin, e_len, rsa_p.contents.e):
raise OpenSSLException()
n_len = struct.unpack('>I', pub_key[offset:offset + 4])[0]
offset += 4
n_key_bin = pub_key[offset:offset + n_len]
offset += n_len
if offset != len(pub_key):
raise CryptException('Invalid SSH key')
if not openssl.BN_bin2bn(n_key_bin, n_len, rsa_p.contents.n):
raise OpenSSLException()
return RSAWrapper(rsa_p)
except Exception:
openssl.RSA_free(rsa_p)
raise

View File

@ -12,6 +12,7 @@ PyYAML
requests requests
untangle==1.2.1 untangle==1.2.1
jinja2 jinja2
cryptography
pywin32;sys_platform=="win32" pywin32;sys_platform=="win32"
comtypes;sys_platform=="win32" comtypes;sys_platform=="win32"
pymi;sys_platform=="win32" pymi;sys_platform=="win32"