# Copyright 2013 Cloudbase Solutions Srl # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import base64 import ctypes import ctypes.util import struct import sys clib_path = ctypes.util.find_library("c") if sys.platform == "win32": if clib_path: clib = ctypes.CDLL(clib_path) else: clib = ctypes.cdll.ucrtbase openssl = ctypes.cdll.libeay32 else: clib = ctypes.CDLL(clib_path) openssl_lib_path = ctypes.util.find_library("ssl") openssl = ctypes.CDLL(openssl_lib_path) class RSA(ctypes.Structure): _fields_ = [ ("pad", ctypes.c_int), ("version", ctypes.c_long), ("meth", ctypes.c_void_p), ("engine", ctypes.c_void_p), ("n", ctypes.c_void_p), ("e", ctypes.c_void_p), ("d", ctypes.c_void_p), ("p", ctypes.c_void_p), ("q", ctypes.c_void_p), ("dmp1", ctypes.c_void_p), ("dmq1", ctypes.c_void_p), ("iqmp", ctypes.c_void_p), ("sk", ctypes.c_void_p), ("dummy", ctypes.c_int), ("references", ctypes.c_int), ("flags", ctypes.c_int), ("_method_mod_n", ctypes.c_void_p), ("_method_mod_p", ctypes.c_void_p), ("_method_mod_q", ctypes.c_void_p), ("bignum_data", ctypes.c_char_p), ("blinding", ctypes.c_void_p), ("mt_blinding", ctypes.c_void_p) ] openssl.RSA_PKCS1_PADDING = 1 openssl.RSA_new.restype = ctypes.POINTER(RSA) openssl.BN_bin2bn.restype = ctypes.c_void_p openssl.BN_bin2bn.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p] openssl.BN_new.restype = ctypes.c_void_p openssl.RSA_size.restype = ctypes.c_int openssl.RSA_size.argtypes = [ctypes.POINTER(RSA)] openssl.RSA_public_encrypt.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(RSA), ctypes.c_int] openssl.RSA_public_encrypt.restype = ctypes.c_int openssl.RSA_free.argtypes = [ctypes.POINTER(RSA)] openssl.PEM_write_RSAPublicKey.restype = ctypes.c_int openssl.PEM_write_RSAPublicKey.argtypes = [ctypes.c_void_p, ctypes.POINTER(RSA)] openssl.ERR_get_error.restype = ctypes.c_long openssl.ERR_get_error.argtypes = [] openssl.ERR_error_string_n.restype = ctypes.c_void_p openssl.ERR_error_string_n.argtypes = [ctypes.c_long, ctypes.c_char_p, ctypes.c_int] openssl.ERR_load_crypto_strings.restype = ctypes.c_int openssl.ERR_load_crypto_strings.argtypes = [] clib.fopen.restype = ctypes.c_void_p clib.fopen.argtypes = [ctypes.c_char_p, ctypes.c_char_p] clib.fclose.restype = ctypes.c_int clib.fclose.argtypes = [ctypes.c_void_p] class CryptException(Exception): pass class OpenSSLException(CryptException): def __init__(self): message = self._get_openssl_error_msg() super(OpenSSLException, self).__init__(message) def _get_openssl_error_msg(self): openssl.ERR_load_crypto_strings() errno = openssl.ERR_get_error() errbuf = ctypes.create_string_buffer(1024) openssl.ERR_error_string_n(errno, errbuf, 1024) return errbuf.value.decode("ascii") class RSAWrapper(object): def __init__(self, rsa_p): self._rsa_p = rsa_p def __enter__(self): return self def __exit__(self, tp, value, tb): self.free() def free(self): openssl.RSA_free(self._rsa_p) def public_encrypt(self, clear_text): flen = len(clear_text) rsa_size = openssl.RSA_size(self._rsa_p) enc_text = ctypes.create_string_buffer(rsa_size) enc_text_len = openssl.RSA_public_encrypt(flen, clear_text, enc_text, self._rsa_p, openssl.RSA_PKCS1_PADDING) if enc_text_len == -1: raise OpenSSLException() return enc_text[:enc_text_len] class CryptManager(object): def load_ssh_rsa_public_key(self, ssh_pub_key): ssh_rsa_prefix = "ssh-rsa " if not ssh_pub_key.startswith(ssh_rsa_prefix): raise CryptException('Invalid SSH key') s = ssh_pub_key[len(ssh_rsa_prefix):] idx = s.find(' ') if idx >= 0: b64_pub_key = s[:idx] else: b64_pub_key = s pub_key = base64.b64decode(b64_pub_key) offset = 0 key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] offset += 4 key_type = pub_key[offset:offset + key_type_len].decode('utf-8') offset += key_type_len if key_type not in ['ssh-rsa', 'rsa', 'rsa1']: raise CryptException('Unsupported SSH key type "%s". ' 'Only RSA keys are currently supported' % key_type) rsa_p = openssl.RSA_new() try: rsa_p.contents.e = openssl.BN_new() rsa_p.contents.n = openssl.BN_new() e_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] offset += 4 e_key_bin = pub_key[offset:offset + e_len] offset += e_len if not openssl.BN_bin2bn(e_key_bin, e_len, rsa_p.contents.e): raise OpenSSLException() n_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] offset += 4 n_key_bin = pub_key[offset:offset + n_len] offset += n_len if offset != len(pub_key): raise CryptException('Invalid SSH key') if not openssl.BN_bin2bn(n_key_bin, n_len, rsa_p.contents.n): raise OpenSSLException() return RSAWrapper(rsa_p) except Exception: openssl.RSA_free(rsa_p) raise