cloudbase-init/cloudbaseinit/utils/windows/x509.py

523 lines
20 KiB
Python

# Copyright 2013 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import ctypes
from ctypes import wintypes
import uuid
import six
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import cryptoapi
from cloudbaseinit.utils import x509constants
malloc = ctypes.cdll.msvcrt.malloc
malloc.restype = ctypes.c_void_p
malloc.argtypes = [ctypes.c_size_t]
free = ctypes.cdll.msvcrt.free
free.restype = None
free.argtypes = [ctypes.c_void_p]
STORE_NAME_MY = "My"
STORE_NAME_ROOT = "Root"
STORE_NAME_TRUSTED_PEOPLE = "TrustedPeople"
X509_START_DATE_INTERVAL = -24 * 60 * 60 * 10000000
X509_END_DATE_INTERVAL = 10 * 365 * 24 * 60 * 60 * 10000000
class CryptoAPICertManager(object):
@staticmethod
def _get_thumprint_str(thumbprint, size):
thumbprint_ar = ctypes.cast(
thumbprint,
ctypes.POINTER(ctypes.c_ubyte *
size)).contents
thumbprint_str = ""
for b in thumbprint_ar:
thumbprint_str += "%02x" % b
return thumbprint_str
@staticmethod
def _get_thumbprint_buffer(thumbprint_str):
thumbprint_bytes = encoding.hex_to_bytes(thumbprint_str)
return ctypes.cast(
ctypes.create_string_buffer(thumbprint_bytes),
ctypes.POINTER(wintypes.BYTE *
len(thumbprint_bytes))).contents
def _get_cert_thumprint(self, cert_context_p):
thumbprint = None
try:
thumprint_len = wintypes.DWORD()
if not cryptoapi.CertGetCertificateContextProperty(
cert_context_p,
cryptoapi.CERT_SHA1_HASH_PROP_ID,
None, ctypes.byref(thumprint_len)):
raise cryptoapi.CryptoAPIException()
size = ctypes.c_size_t(thumprint_len.value)
thumbprint = malloc(size)
if not cryptoapi.CertGetCertificateContextProperty(
cert_context_p,
cryptoapi.CERT_SHA1_HASH_PROP_ID,
thumbprint, ctypes.byref(thumprint_len)):
raise cryptoapi.CryptoAPIException()
return self._get_thumprint_str(thumbprint, thumprint_len.value)
finally:
if thumbprint:
free(thumbprint)
def _generate_key(self, container_name, machine_keyset):
crypt_prov_handle = wintypes.HANDLE()
key_handle = wintypes.HANDLE()
try:
flags = 0
if machine_keyset:
flags |= cryptoapi.CRYPT_MACHINE_KEYSET
if not cryptoapi.CryptAcquireContext(
ctypes.byref(crypt_prov_handle),
container_name,
None,
cryptoapi.PROV_RSA_FULL,
flags):
flags |= cryptoapi.CRYPT_NEWKEYSET
if not cryptoapi.CryptAcquireContext(
ctypes.byref(crypt_prov_handle),
container_name,
None,
cryptoapi.PROV_RSA_FULL,
flags):
raise cryptoapi.CryptoAPIException()
# RSA 2048 bits
if not cryptoapi.CryptGenKey(crypt_prov_handle,
cryptoapi.AT_KEYEXCHANGE,
0x08000000,
ctypes.byref(key_handle)):
raise cryptoapi.CryptoAPIException()
return key_handle
finally:
if key_handle:
cryptoapi.CryptDestroyKey(key_handle)
if crypt_prov_handle:
cryptoapi.CryptReleaseContext(crypt_prov_handle, 0)
@staticmethod
def _add_system_time_interval(system_time, increment):
'''increment's unit: 10ns'''
file_time = cryptoapi.FILETIME()
if not cryptoapi.SystemTimeToFileTime(ctypes.byref(system_time),
ctypes.byref(file_time)):
raise cryptoapi.CryptoAPIException()
t = file_time.dwLowDateTime + (file_time.dwHighDateTime << 32)
t += increment
file_time.dwLowDateTime = t & 0xFFFFFFFF
file_time.dwHighDateTime = t >> 32 & 0xFFFFFFFF
new_system_time = cryptoapi.SYSTEMTIME()
if not cryptoapi.FileTimeToSystemTime(ctypes.byref(file_time),
ctypes.byref(new_system_time)):
raise cryptoapi.CryptoAPIException()
return new_system_time
def create_self_signed_cert(self, subject, validity_years=10,
machine_keyset=True, store_name=STORE_NAME_MY):
subject_encoded = None
cert_context_p = None
store_handle = None
container_name = str(uuid.uuid4())
self._generate_key(container_name, machine_keyset)
try:
subject_encoded_len = wintypes.DWORD()
if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
subject,
cryptoapi.CERT_X500_NAME_STR, None,
None,
ctypes.byref(subject_encoded_len),
None):
raise cryptoapi.CryptoAPIException()
size = ctypes.c_size_t(subject_encoded_len.value)
subject_encoded = ctypes.cast(malloc(size),
ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
subject,
cryptoapi.CERT_X500_NAME_STR, None,
subject_encoded,
ctypes.byref(subject_encoded_len),
None):
raise cryptoapi.CryptoAPIException()
subject_blob = cryptoapi.CRYPTOAPI_BLOB()
subject_blob.cbData = subject_encoded_len
subject_blob.pbData = subject_encoded
key_prov_info = cryptoapi.CRYPT_KEY_PROV_INFO()
key_prov_info.pwszContainerName = container_name
key_prov_info.pwszProvName = None
key_prov_info.dwProvType = cryptoapi.PROV_RSA_FULL
key_prov_info.cProvParam = None
key_prov_info.rgProvParam = None
key_prov_info.dwKeySpec = cryptoapi.AT_KEYEXCHANGE
if machine_keyset:
key_prov_info.dwFlags = cryptoapi.CRYPT_MACHINE_KEYSET
else:
key_prov_info.dwFlags = 0
sign_alg = cryptoapi.CRYPT_ALGORITHM_IDENTIFIER()
sign_alg.pszObjId = cryptoapi.szOID_RSA_SHA1RSA
start_time = cryptoapi.SYSTEMTIME()
cryptoapi.GetSystemTime(ctypes.byref(start_time))
end_time = self._add_system_time_interval(
start_time, X509_END_DATE_INTERVAL)
# Needed in case of time sync issues as PowerShell remoting
# enforces a valid time interval even for self signed certificates
start_time = self._add_system_time_interval(
start_time, X509_START_DATE_INTERVAL)
cert_context_p = cryptoapi.CertCreateSelfSignCertificate(
None, ctypes.byref(subject_blob), 0,
ctypes.byref(key_prov_info),
ctypes.byref(sign_alg), ctypes.byref(start_time),
ctypes.byref(end_time), None)
if not cert_context_p:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertAddEnhancedKeyUsageIdentifier(
cert_context_p, cryptoapi.szOID_PKIX_KP_SERVER_AUTH):
raise cryptoapi.CryptoAPIException()
if machine_keyset:
flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
six.text_type(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertAddCertificateContextToStore(
store_handle, cert_context_p,
cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None):
raise cryptoapi.CryptoAPIException()
return self._get_cert_thumprint(cert_context_p)
finally:
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if subject_encoded:
free(subject_encoded)
def _get_cert_base64(self, cert_data):
"""Remove certificate header and footer and also new lines."""
# It's assured that the certificate is already a string.
removal = [
x509constants.PEM_HEADER,
x509constants.PEM_FOOTER,
"\r",
"\n"
]
for remove in removal:
cert_data = cert_data.replace(remove, "")
return cert_data
def _find_certificate_in_store(self, thumbprint_str, machine_keyset=True,
store_name=STORE_NAME_MY):
store_handle = None
thumbprint = self._get_thumbprint_buffer(thumbprint_str)
hash_blob = cryptoapi.CRYPTOAPI_BLOB()
hash_blob.cbData = len(thumbprint)
hash_blob.pbData = thumbprint
try:
flags = cryptoapi.CERT_STORE_OPEN_EXISTING_FLAG
if machine_keyset:
flags |= cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags |= cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
six.text_type(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
cert_context_p = cryptoapi.CertFindCertificateInStore(
store_handle,
cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
0,
cryptoapi.CERT_FIND_SHA1_HASH,
ctypes.pointer(hash_blob),
None)
if not cert_context_p:
raise cryptoapi.CryptoAPIException()
return cert_context_p
finally:
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
def delete_certificate_from_store(self, thumbprint_str,
machine_keyset=True,
store_name=STORE_NAME_MY):
cert_context_p = None
try:
cert_context_p = self._find_certificate_in_store(
thumbprint_str, machine_keyset, store_name)
if not cert_context_p:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertDeleteCertificateFromStore(cert_context_p):
raise cryptoapi.CryptoAPIException()
finally:
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
def import_pfx_certificate(self, pfx_data, pfx_password=None,
machine_keyset=True, store_name=STORE_NAME_MY):
cert_context_p = None
import_store_handle = None
store_handle = None
try:
pfx_blob = cryptoapi.CRYPTOAPI_BLOB()
pfx_blob.cbData = len(pfx_data)
pfx_blob.pbData = ctypes.cast(
pfx_data, ctypes.POINTER(wintypes.BYTE))
import_store_handle = cryptoapi.PFXImportCertStore(
ctypes.pointer(pfx_blob), pfx_password, 0)
if not import_store_handle:
raise cryptoapi.CryptoAPIException()
cert_context_p = cryptoapi.CertFindCertificateInStore(
import_store_handle,
cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
0, cryptoapi.CERT_FIND_ANY, None, None)
if not cert_context_p:
raise cryptoapi.CryptoAPIException()
if machine_keyset:
flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
six.text_type(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertAddCertificateContextToStore(
store_handle, cert_context_p,
cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None):
raise cryptoapi.CryptoAPIException()
finally:
if import_store_handle:
cryptoapi.CertCloseStore(import_store_handle, 0)
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
def decode_pkcs7_base64_blob(self, data, thumbprint_str,
machine_keyset=True,
store_name=STORE_NAME_MY):
base64_data = data.replace('\r', '').replace('\n', '')
store_handle = None
cert_context_p = None
try:
data_encoded_len = wintypes.DWORD()
if not cryptoapi.CryptStringToBinaryW(
base64_data, len(base64_data),
cryptoapi.CRYPT_STRING_BASE64,
None, ctypes.byref(data_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
data_encoded = ctypes.cast(
ctypes.create_string_buffer(data_encoded_len.value),
ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CryptStringToBinaryW(
base64_data, len(base64_data),
cryptoapi.CRYPT_STRING_BASE64,
data_encoded, ctypes.byref(data_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_MEMORY,
cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
None, cryptoapi.CERT_STORE_CREATE_NEW_FLAG, None)
if not store_handle:
raise cryptoapi.CryptoAPIException()
cert_context_p = self._find_certificate_in_store(
thumbprint_str, machine_keyset, store_name)
if not cryptoapi.CertAddCertificateLinkToStore(
store_handle, cert_context_p,
cryptoapi.CERT_STORE_ADD_NEW, None):
raise cryptoapi.CryptoAPIException()
para = cryptoapi.CRYPT_DECRYPT_MESSAGE_PARA()
para.cbSize = ctypes.sizeof(cryptoapi.CRYPT_DECRYPT_MESSAGE_PARA)
para.dwMsgAndCertEncodingType = (cryptoapi.X509_ASN_ENCODING |
cryptoapi.PKCS_7_ASN_ENCODING)
para.cCertStore = 1
para.rghCertStore = ctypes.pointer(wintypes.HANDLE(store_handle))
para.dwFlags = cryptoapi.CRYPT_SILENT
data_decoded_len = wintypes.DWORD()
if not cryptoapi.CryptDecryptMessage(
ctypes.byref(para),
data_encoded,
data_encoded_len,
None,
ctypes.byref(data_decoded_len),
None):
raise cryptoapi.CryptoAPIException()
data_decoded_buf = ctypes.create_string_buffer(
data_decoded_len.value)
data_decoded = ctypes.cast(
data_decoded_buf, ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CryptDecryptMessage(
ctypes.pointer(para),
data_encoded,
data_encoded_len,
data_decoded,
ctypes.byref(data_decoded_len),
None):
raise cryptoapi.CryptoAPIException()
return bytes(data_decoded_buf)
finally:
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
def import_cert(self, cert_data, machine_keyset=True,
store_name=STORE_NAME_MY):
base64_cert_data = self._get_cert_base64(cert_data)
cert_encoded = None
store_handle = None
cert_context_p = None
try:
cert_encoded_len = wintypes.DWORD()
if not cryptoapi.CryptStringToBinaryW(
base64_cert_data, len(base64_cert_data),
cryptoapi.CRYPT_STRING_BASE64,
None, ctypes.byref(cert_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
size = ctypes.c_size_t(cert_encoded_len.value)
cert_encoded = ctypes.cast(malloc(size),
ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CryptStringToBinaryW(
base64_cert_data, len(base64_cert_data),
cryptoapi.CRYPT_STRING_BASE64,
cert_encoded, ctypes.byref(cert_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
if machine_keyset:
flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
six.text_type(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
cert_context_p = ctypes.POINTER(cryptoapi.CERT_CONTEXT)()
if not cryptoapi.CertAddEncodedCertificateToStore(
store_handle,
cryptoapi.X509_ASN_ENCODING |
cryptoapi.PKCS_7_ASN_ENCODING,
cert_encoded, cert_encoded_len,
cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING,
ctypes.byref(cert_context_p)):
raise cryptoapi.CryptoAPIException()
# Get the UPN (1.3.6.1.4.1.311.20.2.3 OID) from the
# certificate subject alt name
upn = None
upn_len = cryptoapi.CertGetNameString(
cert_context_p,
cryptoapi.CERT_NAME_UPN_TYPE, 0,
None, None, 0)
if upn_len > 1:
upn_ar = ctypes.create_unicode_buffer(upn_len)
if cryptoapi.CertGetNameString(
cert_context_p,
cryptoapi.CERT_NAME_UPN_TYPE,
0, None, upn_ar, upn_len) != upn_len:
raise cryptoapi.CryptoAPIException()
upn = upn_ar.value
thumbprint = self._get_cert_thumprint(cert_context_p)
return thumbprint, upn
finally:
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
if cert_encoded:
free(cert_encoded)