874 lines
31 KiB
Python
874 lines
31 KiB
Python
# Script which deterministically generates certificates given a definitions file.
|
|
import argparse
|
|
import datetime
|
|
import hashlib
|
|
import ipaddress
|
|
import json
|
|
from pathlib import PurePath
|
|
from typing import Any
|
|
|
|
import asn1crypto.core as asn1
|
|
import cryptography.hazmat.primitives.serialization.pkcs12 as pkcs12
|
|
from cryptography import x509
|
|
from cryptography.hazmat._oid import _OID_NAMES
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.x509.extensions import UnrecognizedExtension
|
|
from cryptography.x509.oid import ObjectIdentifier
|
|
from ecdsa import SigningKey
|
|
import os
|
|
import sys
|
|
|
|
# Dictionary from common names to OIDs.
|
|
NAME_TO_OID = {v: k for k, v in _OID_NAMES.items()}
|
|
# Map short names that we explicitly support to their corresponding OIDs.
|
|
NAME_TO_OID["C"] = NAME_TO_OID["countryName"]
|
|
NAME_TO_OID["ST"] = NAME_TO_OID["stateOrProvinceName"]
|
|
NAME_TO_OID["O"] = NAME_TO_OID["organizationName"]
|
|
NAME_TO_OID["OU"] = NAME_TO_OID["organizationalUnitName"]
|
|
NAME_TO_OID["L"] = NAME_TO_OID["localityName"]
|
|
NAME_TO_OID["SN"] = NAME_TO_OID["surname"]
|
|
NAME_TO_OID["CN"] = NAME_TO_OID["commonName"]
|
|
|
|
# The (partial) ordering of OIDs in subject name expected by our jstests.
|
|
OID_ORDER = [NAME_TO_OID[n].dotted_string for n in ["C", "ST", "L", "O", "OU", "CN"]]
|
|
|
|
# Path to the file specifying the config.
|
|
CONFIGFILE = None
|
|
|
|
# Config parsed as YAML.
|
|
CONFIG = dict[str, Any]
|
|
|
|
# <= 825 in order to abide by https://support.apple.com/en-us/HT210176.
|
|
MAX_VALIDITY_PERIOD_DAYS = 824
|
|
|
|
# Datetime to specify as the start time for all certs.
|
|
DEFAULT_START_TIME = datetime.datetime(datetime.datetime.now().year, 1, 1)
|
|
# Allocate serial numbers sequentially; this is the last-used serial.
|
|
LAST_SERIAL_NUMBER = 999
|
|
# Cache the private key objects for static/*key.pem.
|
|
LOADED_KEYS = {}
|
|
# Base directory where outputs go.
|
|
OUTPUT_PATH = None
|
|
# Base path to keys that are used during generation.
|
|
STATIC_PATH = None
|
|
|
|
LOADED_CERT_AND_KEYS = {}
|
|
|
|
# True if this is a dry run.
|
|
DRY_RUN = False
|
|
|
|
|
|
class CertificateGenerationError(Exception):
|
|
pass
|
|
|
|
|
|
def get_next_serial():
|
|
"""Get the next sequential serial number to use."""
|
|
global LAST_SERIAL_NUMBER
|
|
# Serial numbers 0..999 are reserved for fixed serial numbers.
|
|
# Start at 1000 and increment every time we generate a cert.
|
|
LAST_SERIAL_NUMBER += 1
|
|
return LAST_SERIAL_NUMBER
|
|
|
|
|
|
def get_key(cert):
|
|
"""Get the private key object loaded from keyfile."""
|
|
keyfile = idx(cert, "keyfile")
|
|
if keyfile is None:
|
|
raise CertificateGenerationError("All certificates require a keyfile")
|
|
|
|
if keyfile not in LOADED_KEYS:
|
|
if DRY_RUN:
|
|
LOADED_KEYS[keyfile] = "dummy"
|
|
else:
|
|
passphrase = cert.get("passphrase")
|
|
if passphrase is not None:
|
|
passphrase = bytes(passphrase, "ascii")
|
|
with open(str(STATIC_PATH / keyfile), "rb") as f:
|
|
LOADED_KEYS[keyfile] = serialization.load_pem_private_key(
|
|
f.read(),
|
|
password=passphrase,
|
|
)
|
|
return LOADED_KEYS[keyfile]
|
|
|
|
|
|
def glbl(key, default=None):
|
|
"""Fetch a key from the global dict."""
|
|
return CONFIG.get("global", {}).get(key, default)
|
|
|
|
|
|
def idx(cert, key, default=None):
|
|
"""Fetch a key from the cert dict, falling back through global dict."""
|
|
return cert.get(key, None) or glbl(key, default)
|
|
|
|
|
|
def make_filename(cert):
|
|
"""Form a pathname from a certificate definition."""
|
|
return str(OUTPUT_PATH / cert["name"])
|
|
|
|
|
|
def find_certificate_definition(name):
|
|
"""Locate a definition by name."""
|
|
for ca_cert in CONFIG["certs"]:
|
|
if ca_cert["name"] == name:
|
|
return ca_cert
|
|
|
|
return None
|
|
|
|
|
|
def get_cert_and_key(cert_name):
|
|
"""Locate the cert and key file for a given cert name, load them, and return them."""
|
|
if DRY_RUN:
|
|
return "dummy", "dummy"
|
|
if cert_name in LOADED_CERT_AND_KEYS: # Cache hit, don't need to load again
|
|
return LOADED_CERT_AND_KEYS[cert_name]
|
|
ca_cert = find_certificate_definition(cert_name)
|
|
if ca_cert:
|
|
with open(make_filename(ca_cert), "rb") as f:
|
|
pem = f.read()
|
|
certificate = x509.load_pem_x509_certificate(pem)
|
|
passphrase = ca_cert.get("passphrase", None)
|
|
if passphrase:
|
|
passphrase = passphrase.encode("utf-8")
|
|
|
|
key = serialization.load_pem_private_key(
|
|
pem,
|
|
password=passphrase,
|
|
)
|
|
LOADED_CERT_AND_KEYS[cert_name] = (certificate, key)
|
|
return (certificate, key)
|
|
# Externally sourced certifiate, try by path. Hopefully unencrypted.
|
|
with open(cert_name, "rb") as f:
|
|
pem = f.read()
|
|
certificate = x509.load_pem_x509_certificate(pem)
|
|
key = serialization.load_pem_private_key(pem, password=None)
|
|
LOADED_CERT_AND_KEYS[cert_name] = (certificate, key)
|
|
return (certificate, key)
|
|
|
|
|
|
def get_validity_period(cert):
|
|
"""Get the validity range for the certificate."""
|
|
start_shift_secs = int(idx(cert, "not_before", 0))
|
|
end_shift_secs = int(
|
|
idx(cert, "not_after", start_shift_secs + MAX_VALIDITY_PERIOD_DAYS * 24 * 60 * 60)
|
|
)
|
|
|
|
start_time = DEFAULT_START_TIME + datetime.timedelta(seconds=start_shift_secs)
|
|
end_time = DEFAULT_START_TIME + datetime.timedelta(seconds=end_shift_secs)
|
|
return start_time, end_time
|
|
|
|
|
|
def get_oid(cn_or_oid):
|
|
"""Given a string containing an OID or a common name, return the corresponding OID object."""
|
|
if cn_or_oid in NAME_TO_OID:
|
|
return NAME_TO_OID[cn_or_oid]
|
|
try:
|
|
return ObjectIdentifier(cn_or_oid)
|
|
except:
|
|
raise CertificateGenerationError(f"Name attribute {cn_or_oid} not recognized")
|
|
|
|
|
|
def set_subject(builder, cert, set_issuer=False):
|
|
"""Set the subject on the certificate builder according to the certificate definition. Also set the issuer to the same thing if set_issuer is true."""
|
|
if not cert.get("Subject"):
|
|
if cert.get("explicit_subject", False):
|
|
# do nothing if an empty subject is explicitly provided
|
|
if set_issuer:
|
|
builder = builder.issuer_name(x509.Name([]))
|
|
return builder.subject_name(x509.Name([]))
|
|
raise CertificateGenerationError(cert["name"] + " requires a Subject")
|
|
|
|
attr_dict = {}
|
|
if not cert.get("explicit_subject", False):
|
|
# Load the globally defined subject RDNs
|
|
for key, val in glbl("Subject", {}).items():
|
|
oid = get_oid(key)
|
|
attr_dict[oid] = val
|
|
|
|
if isinstance(cert["Subject"], dict):
|
|
# Normal case: Load the subject RDNs defined by the certificate over the globally defined ones
|
|
for key, val in cert["Subject"].items():
|
|
oid = get_oid(key)
|
|
attr_dict[oid] = val
|
|
|
|
ordered_attrs = sorted(
|
|
attr_dict.items(),
|
|
key=lambda item: "." + str(OID_ORDER.index(item[0].dotted_string))
|
|
if item[0].dotted_string in OID_ORDER
|
|
else item[0].dotted_string,
|
|
)
|
|
name = x509.Name([x509.NameAttribute(oid, val) for oid, val in ordered_attrs])
|
|
else:
|
|
# Multivalued RDN case
|
|
assert isinstance(cert["Subject"], list)
|
|
assert cert[
|
|
"explicit_subject"
|
|
], "explicit_subject must be set to true when using multivalued RDNs"
|
|
rdns = []
|
|
for rdn_def in cert["Subject"]:
|
|
attrs = []
|
|
for key, val in rdn_def.items():
|
|
oid = get_oid(key)
|
|
attrs.append((oid, val))
|
|
|
|
ordered_attrs = sorted(
|
|
attrs,
|
|
key=lambda item: "." + str(OID_ORDER.index(item[0].dotted_string) + 1)
|
|
if item[0].dotted_string in OID_ORDER
|
|
else item[0].dotted_string,
|
|
)
|
|
rdns.append(
|
|
x509.RelativeDistinguishedName(
|
|
[x509.NameAttribute(oid, val) for oid, val in ordered_attrs]
|
|
)
|
|
)
|
|
name = x509.Name(rdns)
|
|
|
|
if set_issuer: # When issuer = self, set the issuer as well
|
|
builder = builder.issuer_name(name)
|
|
return builder.subject_name(name)
|
|
|
|
|
|
def set_validity(builder, cert):
|
|
"""Set the not_valid_before/after fields on the certificate builder according to the certificate definition."""
|
|
start, end = get_validity_period(cert)
|
|
builder = builder.not_valid_before(start)
|
|
return builder.not_valid_after(end)
|
|
|
|
|
|
def to_der_varint(val):
|
|
"""Translate a native int to a variable length ASN.1 encoded integer."""
|
|
if val < 0:
|
|
raise CertificateGenerationError("Negative values nor permitted in DER payload")
|
|
|
|
if val < 0x80:
|
|
return chr(val).encode("ascii")
|
|
|
|
ret = bytearray(b"")
|
|
while (val > 0) and (len(ret) < 8):
|
|
ret.insert(0, val & 0xFF)
|
|
val = val >> 8
|
|
|
|
if val > 0:
|
|
raise CertificateGenerationError("Length is too large to represent in 64bits")
|
|
|
|
ret.insert(0, 0x80 + len(ret))
|
|
return ret
|
|
|
|
|
|
def to_der_utf8_string(val):
|
|
"""Encode a unicode string as a ASN.1 UTF8 String."""
|
|
utf8_val = str(val).encode("utf-8")
|
|
return b"\x0c" + to_der_varint(len(utf8_val)) + utf8_val
|
|
|
|
|
|
def to_der_sequence_pair(name, value):
|
|
"""Encode a pair of ASN.1 values as a sequence pair."""
|
|
# Simplified sequence which always expects two string, a key and a value.
|
|
bin_name = to_der_utf8_string(name)
|
|
bin_value = to_der_utf8_string(value)
|
|
return b"\x30" + to_der_varint(len(bin_name) + len(bin_value)) + bin_name + bin_value
|
|
|
|
|
|
class ExtensionParser:
|
|
"""Collection of methods to convert extension definitions into cryptography extension objects."""
|
|
|
|
@staticmethod
|
|
def basic_constraints(v, **_):
|
|
return x509.BasicConstraints(ca=v.get("CA", False), path_length=v.get("pathlen"))
|
|
|
|
@staticmethod
|
|
def key_usage(v, **_):
|
|
to_param_name = {
|
|
"digitalSignature": "digital_signature",
|
|
"nonRepudiation": "content_commitment",
|
|
"keyEncipherment": "key_encipherment",
|
|
"dataEncipherment": "data_encipherment",
|
|
"keyAgreement": "key_agreement",
|
|
"keyCertSign": "key_cert_sign",
|
|
"cRLSign": "crl_sign",
|
|
"encipherOnly": "encipher_only",
|
|
"decipherOnly": "decipher_only",
|
|
}
|
|
params = {name: False for name in to_param_name.values()}
|
|
for usage in v:
|
|
if usage in to_param_name:
|
|
params[to_param_name[usage]] = True
|
|
return x509.KeyUsage(**params)
|
|
|
|
@staticmethod
|
|
def ext_usage_name_to_oid(name):
|
|
ext_usage_name_map = {
|
|
"serverAuth": 1,
|
|
"clientAuth": 2,
|
|
"codeSigning": 3,
|
|
"emailProtection": 4,
|
|
"timeStamping": 8,
|
|
"OCSPSigning": 9,
|
|
}
|
|
if name not in ext_usage_name_map:
|
|
raise CertificateGenerationError(f'Unknown extended key usage identifier: "{name}"')
|
|
return ObjectIdentifier("1.3.6.1.5.5.7.3." + str(ext_usage_name_map[name]))
|
|
|
|
@staticmethod
|
|
def extended_key_usage(v, **_):
|
|
return x509.ExtendedKeyUsage([ExtensionParser.ext_usage_name_to_oid(name) for name in v])
|
|
|
|
@staticmethod
|
|
def subject_alt_name(v, **_):
|
|
names = []
|
|
for key, val in v.items():
|
|
if key == "critical":
|
|
continue
|
|
elif key == "DNS":
|
|
if not isinstance(val, list):
|
|
val = [val]
|
|
for name in val:
|
|
names.append(x509.DNSName(name))
|
|
elif key == "IP":
|
|
if not isinstance(val, list):
|
|
val = [val]
|
|
for ip in val:
|
|
names.append(x509.IPAddress(ipaddress.ip_address(ip)))
|
|
else:
|
|
raise CertificateGenerationError(f'Unknown subject alt name type: "{key}"')
|
|
return x509.SubjectAlternativeName(names)
|
|
|
|
@staticmethod
|
|
def subject_key_identifier(v, public_key, **_):
|
|
assert v == "hash"
|
|
return x509.SubjectKeyIdentifier.from_public_key(public_key)
|
|
|
|
@staticmethod
|
|
def mongo_roles(v, **_):
|
|
oid = ObjectIdentifier("1.3.6.1.4.1.34601.2.1.1")
|
|
pair = b""
|
|
for role in v:
|
|
if (len(role) != 2) or ("role" not in role) or ("db" not in role):
|
|
raise CertificateGenerationError(
|
|
"mongoRoles must consist of a series of role/db pairs"
|
|
)
|
|
pair = pair + to_der_sequence_pair(role["role"], role["db"])
|
|
|
|
val = b"\x31" + to_der_varint(len(pair)) + pair
|
|
|
|
return UnrecognizedExtension(oid, val)
|
|
|
|
@staticmethod
|
|
def authority_key_identifier(v, issuer_public_key, issuer_ski, **_):
|
|
if v not in ["keyid", "issuer"]:
|
|
raise CertificateGenerationError(
|
|
"Only the 'keyid' or 'issuer' values are accepted for authorityKeyIdentifier"
|
|
)
|
|
|
|
if v == "issuer":
|
|
return x509.AuthorityKeyIdentifier.from_issuer_public_key(issuer_public_key)
|
|
else:
|
|
return x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(issuer_ski)
|
|
|
|
@staticmethod
|
|
def mongo_cluster_membership(v, **_):
|
|
"""Encode a symbolic name to a mongodbClusterMembership extension."""
|
|
oid = ObjectIdentifier("1.3.6.1.4.1.34601.2.1.2")
|
|
val = to_der_utf8_string(v)
|
|
return UnrecognizedExtension(oid, val)
|
|
|
|
@staticmethod
|
|
def authority_information_access(v, **_):
|
|
if not isinstance(v, list):
|
|
v = [v]
|
|
assert all(entry["method"] == "OCSP" for entry in v)
|
|
return x509.AuthorityInformationAccess(
|
|
[
|
|
x509.AccessDescription(
|
|
x509.oid.AuthorityInformationAccessOID.OCSP,
|
|
x509.UniformResourceIdentifier(entry["location"]),
|
|
)
|
|
for entry in v
|
|
]
|
|
)
|
|
|
|
@staticmethod
|
|
def must_staple(v, **_):
|
|
assert v, "If set, mustStaple must be true"
|
|
oid = ObjectIdentifier("1.3.6.1.5.5.7.1.24")
|
|
val = b"\x30\x03\x02\x01\x05"
|
|
return UnrecognizedExtension(oid, val)
|
|
|
|
@staticmethod
|
|
def ns_comment(v, **_):
|
|
oid = ObjectIdentifier("2.16.840.1.113730.1.13")
|
|
val = b"\x16\x1d" + bytes(v, "ascii")
|
|
return UnrecognizedExtension(oid, val)
|
|
|
|
parsers = {
|
|
"basicConstraints": basic_constraints,
|
|
"keyUsage": key_usage,
|
|
"extendedKeyUsage": extended_key_usage,
|
|
"subjectAltName": subject_alt_name,
|
|
"subjectKeyIdentifier": subject_key_identifier,
|
|
"mongoRoles": mongo_roles,
|
|
"authorityKeyIdentifier": authority_key_identifier,
|
|
"mongoClusterMembership": mongo_cluster_membership,
|
|
"authorityInfoAccess": authority_information_access,
|
|
"mustStaple": must_staple,
|
|
"nsComment": ns_comment,
|
|
}
|
|
|
|
|
|
def set_extensions(builder, cert, **kwargs):
|
|
"""Add all the X.509 extensions specified on the certificate definition to the certificate builder."""
|
|
extensions = cert.get("extensions", {})
|
|
for key, val in extensions.items():
|
|
handler = ExtensionParser.parsers.get(key)
|
|
if handler is None:
|
|
raise CertificateGenerationError(f'Extension "{key}" is not handled yet')
|
|
ext = handler(val, **kwargs)
|
|
if isinstance(val, list):
|
|
critical = "critical" in val
|
|
elif isinstance(val, dict):
|
|
critical = val.get("critical", False)
|
|
elif isinstance(val, str) or isinstance(val, bool):
|
|
critical = False
|
|
else:
|
|
raise CertificateGenerationError(f"Could not parse extension: {key} -> {val}")
|
|
builder = builder.add_extension(ext, critical=critical)
|
|
return builder
|
|
|
|
|
|
def get_issuer_cert_and_key(cert, key):
|
|
"""Get the issuer certificate object (or 'self') and key for the given certificate definition."""
|
|
issuer = cert.get("Issuer")
|
|
if issuer == "self":
|
|
return "self", key
|
|
|
|
# Signed by a CA, find the key...
|
|
return get_cert_and_key(issuer)
|
|
|
|
|
|
class SignedCertificateSequence(asn1.Sequence):
|
|
"""Python representation of the ASN1 structure of a signed certificate."""
|
|
|
|
_fields = [
|
|
("cert_content", asn1.Sequence),
|
|
("algo_type", asn1.Sequence),
|
|
("signature", asn1.BitString),
|
|
]
|
|
|
|
|
|
def to_bits(bytestr):
|
|
"""Convert byte array to bit array."""
|
|
ret = []
|
|
for b in bytestr:
|
|
ret.extend((b >> (7 - i)) % 2 for i in range(8))
|
|
return tuple(ret)
|
|
|
|
|
|
def sign_ecdsa_deterministic(key, cert):
|
|
"""Re-sign a signed certificate with the given ECDSA key in a deterministic fashion. Return the newly signed certificate object."""
|
|
ecdsa_pkey = SigningKey.from_pem(
|
|
key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
)
|
|
# Get bytes of our signed certificate as DER and load them.
|
|
all_bytes = cert.public_bytes(encoding=serialization.Encoding.DER)
|
|
seq = SignedCertificateSequence.load(all_bytes)
|
|
# Get just the certificate content and sign it.
|
|
cert_bytes = seq["cert_content"].dump()
|
|
sig = ecdsa_pkey.sign_deterministic(cert_bytes, hashfunc=hashlib.sha256)
|
|
assert len(sig) == 64
|
|
r = sig[:32]
|
|
s = sig[32:]
|
|
# Prepend a 0x00 byte if the high bit is set, to indicate positive integer in ASN.1.
|
|
if r[0] & 0x80:
|
|
r_der = b"\x02\x21\x00" + r
|
|
else:
|
|
r_der = b"\x02\x20" + r
|
|
if s[0] & 0x80:
|
|
s_der = b"\x02\x21\x00" + s
|
|
else:
|
|
s_der = b"\x02\x20" + s
|
|
# Encode the signature -- Split it in half and make a sequence with the two halves.
|
|
len_byte = len(r_der) + len(s_der)
|
|
ber_sig = b"\x30" + len_byte.to_bytes(1, byteorder="big") + r_der + s_der
|
|
# Set this as the signature, then dump the new certificate.
|
|
seq["signature"] = to_bits(ber_sig)
|
|
signed_bytes = seq.dump()
|
|
# Load the new certificate.
|
|
return x509.load_der_x509_certificate(signed_bytes)
|
|
|
|
|
|
def write_cert_as_pkcs12(cert, key, cert_obj, issuer_obj):
|
|
"""Makes a new copy of the cert/key pair using PKCS#12 encoding."""
|
|
pkcs12_opts = cert.get("pkcs12")
|
|
if not pkcs12_opts.get("passphrase"):
|
|
raise CertificateGenerationError("PKCS#12 requires a passphrase")
|
|
|
|
fname = pkcs12_opts.get("name", cert["name"])
|
|
if DRY_RUN:
|
|
return
|
|
serialized = pkcs12.serialize_key_and_certificates(
|
|
fname.encode("ascii"),
|
|
key,
|
|
cert_obj,
|
|
cas=[issuer_obj],
|
|
encryption_algorithm=serialization.BestAvailableEncryption(
|
|
pkcs12_opts["passphrase"].encode("ascii")
|
|
),
|
|
)
|
|
with open(OUTPUT_PATH / fname, "wb") as f:
|
|
f.write(serialized)
|
|
|
|
|
|
def process_normal_cert(cert):
|
|
"""Given a certificate definition which has a subject, deterministically generate its corresponding certificate file and store it in the output path."""
|
|
key = get_key(cert)
|
|
if not DRY_RUN:
|
|
issuer_cert, issuer_key = get_issuer_cert_and_key(cert, key)
|
|
# Get SKI of issuer if it exists; we need it for the AuthorityKeyIdentifier extension
|
|
if issuer_cert == "self":
|
|
my_ski = cert.get("extensions", {}).get("subjectKeyIdentifier")
|
|
if my_ski is None:
|
|
issuer_ski = None
|
|
else:
|
|
issuer_ski = ExtensionParser.subject_key_identifier(my_ski, key.public_key())
|
|
else:
|
|
try:
|
|
issuer_ski = issuer_cert.extensions.get_extension_for_class(
|
|
x509.SubjectKeyIdentifier
|
|
)
|
|
except:
|
|
issuer_ski = None
|
|
|
|
# Set all fields of the certificate.
|
|
builder = x509.CertificateBuilder()
|
|
builder = builder.public_key(key.public_key())
|
|
serial = cert.get("serial")
|
|
if serial is None:
|
|
serial = get_next_serial()
|
|
else:
|
|
serial = int(serial)
|
|
builder = builder.serial_number(serial)
|
|
builder = set_subject(builder, cert, set_issuer=issuer_cert == "self")
|
|
if issuer_cert != "self":
|
|
builder = builder.issuer_name(issuer_cert.subject)
|
|
builder = set_validity(builder, cert)
|
|
builder = set_extensions(
|
|
builder,
|
|
cert,
|
|
public_key=key.public_key(),
|
|
issuer_public_key=issuer_key.public_key(),
|
|
issuer_ski=issuer_ski,
|
|
)
|
|
|
|
if isinstance(issuer_key, ec.EllipticCurvePrivateKey):
|
|
# For EC, we need to compute a deterministic signature ourselves. While newer versions of OpenSSL support deterministic signing with ECDSA, some of the platforms we run tests on use old versions, so we unfortunately cannot use this feature.
|
|
bad_sig_obj = builder.sign(issuer_key, hashes.SHA256())
|
|
cert_obj = sign_ecdsa_deterministic(issuer_key, bad_sig_obj)
|
|
else:
|
|
cert_obj = builder.sign(issuer_key, hashes.SHA256())
|
|
|
|
cert_path = make_filename(cert)
|
|
# Write certificate PEM + key PEM to the output file.
|
|
with open(cert_path, "wt") as f:
|
|
f.write(cert_obj.public_bytes(serialization.Encoding.PEM).decode("ascii"))
|
|
with open(str(STATIC_PATH / idx(cert, "keyfile")), "r") as keyf:
|
|
f.write(keyf.read())
|
|
LOADED_CERT_AND_KEYS[cert["name"]] = (cert_obj, key)
|
|
if cert.get("pkcs12", None) is not None:
|
|
write_cert_as_pkcs12(cert, key, cert_obj, issuer_cert)
|
|
|
|
if cert.get("split_cert_and_key", False):
|
|
# Write just the certificate to <path>.crt, and just the key to <path>.key
|
|
assert cert["name"].endswith(".pem")
|
|
crt_name = cert["name"][: -len(".pem")] + ".crt"
|
|
key_name = cert["name"][: -len(".pem")] + ".key"
|
|
if not DRY_RUN:
|
|
with open(OUTPUT_PATH / crt_name, "wt") as f:
|
|
f.write(cert_obj.public_bytes(serialization.Encoding.PEM).decode("ascii"))
|
|
with open(OUTPUT_PATH / key_name, "wt") as f:
|
|
with open(str(STATIC_PATH / idx(cert, "keyfile")), "r") as keyf:
|
|
f.write(keyf.read())
|
|
|
|
|
|
def process_cert(cert):
|
|
"""Given a certificate definition, produce all expected output files and write them to the output directory."""
|
|
print("Processing certificate: " + cert["name"] + ", writing to: " + make_filename(cert))
|
|
|
|
append_certs = cert.get("append_cert", [])
|
|
if isinstance(append_certs, str):
|
|
append_certs = [append_certs]
|
|
|
|
subject = cert.get("Subject")
|
|
explicit_empty_subject = cert.get("explicit_subject", False) and not subject
|
|
if subject or explicit_empty_subject:
|
|
process_normal_cert(cert)
|
|
elif not append_certs:
|
|
raise CertificateGenerationError(
|
|
"Certificate definitions must have at least one of 'Subject' and/or 'append_cert'"
|
|
)
|
|
|
|
if DRY_RUN:
|
|
return
|
|
for cert_name in append_certs:
|
|
append_cert = get_cert_and_key(cert_name)[0]
|
|
with open(make_filename(cert), "at") as f:
|
|
f.write(append_cert.public_bytes(serialization.Encoding.PEM).decode("ascii") + "\n")
|
|
|
|
|
|
DIGEST_NAME_TO_HASH = {"sha256": hashes.SHA256(), "sha1": hashes.SHA1()}
|
|
|
|
|
|
def write_digest(filename, item_type, digest_type):
|
|
"""Calculate the given digest of the certificate/CRL passed in and write it out to <filename>.digest.<digest_type>"""
|
|
assert item_type in {"cert", "crl"}
|
|
assert digest_type in DIGEST_NAME_TO_HASH
|
|
digest_path = str(filename) + ".digest." + digest_type
|
|
if DRY_RUN:
|
|
return
|
|
with open(filename, "rb") as f:
|
|
data = f.read()
|
|
|
|
if item_type == "cert":
|
|
obj = x509.load_pem_x509_certificate(data)
|
|
else:
|
|
obj = x509.load_pem_x509_crl(data)
|
|
|
|
rawdigest = obj.fingerprint(DIGEST_NAME_TO_HASH[digest_type])
|
|
towrite = rawdigest.hex().upper()
|
|
with open(digest_path, "w") as f:
|
|
f.write(towrite)
|
|
|
|
|
|
def generate_crl(issuer_cert, issuer_key, dest, cert_to_revoke=None):
|
|
"""Generate a CRL.
|
|
:param issuer_cert: x509.Certificate object which issues this CRL.
|
|
:param issuer_key: Private key object to sign the CRL with.
|
|
:param dest: Path to output CRL to.
|
|
:param cert_to_revoke: x509.Certificate object which this CRL should revoke. Empty for no revocation.
|
|
"""
|
|
print(f"Writing CRL: {dest}")
|
|
if not DRY_RUN:
|
|
builder = (
|
|
x509.CertificateRevocationListBuilder()
|
|
.issuer_name(issuer_cert.subject)
|
|
.last_update(DEFAULT_START_TIME)
|
|
.next_update(DEFAULT_START_TIME + datetime.timedelta(days=MAX_VALIDITY_PERIOD_DAYS))
|
|
)
|
|
|
|
if cert_to_revoke is not None:
|
|
revoked_builder = (
|
|
x509.RevokedCertificateBuilder()
|
|
.serial_number(cert_to_revoke.serial_number)
|
|
.revocation_date(DEFAULT_START_TIME)
|
|
)
|
|
builder = builder.add_revoked_certificate(revoked_builder.build())
|
|
|
|
crl = builder.sign(issuer_key, hashes.SHA256())
|
|
|
|
with open(dest, "wb") as f:
|
|
f.write(crl.public_bytes(serialization.Encoding.PEM))
|
|
|
|
write_digest(dest, "crl", "sha256")
|
|
write_digest(dest, "crl", "sha1")
|
|
|
|
|
|
def generate_all_crls():
|
|
"""Generate all required CRLs. Hardcoded with the expectation that we won't need to add new ones frequently."""
|
|
try:
|
|
ca, ca_key = get_cert_and_key("ca.pem")
|
|
trusted_ca, trusted_ca_key = get_cert_and_key("trusted-ca.pem")
|
|
client_revoked, _ = get_cert_and_key("client_revoked.pem")
|
|
intermediate_ca, intermediate_ca_key = get_cert_and_key("ca.pem")
|
|
except FileNotFoundError as e:
|
|
raise CertificateGenerationError(
|
|
"ca.pem, trusted-ca.pem, client_revoked.pem, and intermediate-ca-B.pem are required in order to generate CRLs"
|
|
) from e
|
|
|
|
generate_crl(ca, ca_key, OUTPUT_PATH / "crl.pem")
|
|
generate_crl(ca, ca_key, OUTPUT_PATH / "crl_client_revoked.pem", client_revoked)
|
|
generate_crl(ca, ca_key, OUTPUT_PATH / "crl_intermediate_ca_B_revoked.pem", intermediate_ca)
|
|
generate_crl(trusted_ca, trusted_ca_key, OUTPUT_PATH / "crl_from_trusted_ca.pem")
|
|
generate_crl(
|
|
intermediate_ca, intermediate_ca_key, OUTPUT_PATH / "crl_from_intermediate_ca_B.pem"
|
|
)
|
|
|
|
|
|
def parse_command_line(argv):
|
|
"""Parse and return the command line arguments."""
|
|
parser = argparse.ArgumentParser(description="X509 Test Certificate Generator")
|
|
parser.add_argument(
|
|
"config",
|
|
help="Certificate definition file",
|
|
type=str,
|
|
)
|
|
parser.add_argument(
|
|
"--mkcrl",
|
|
action=argparse.BooleanOptionalAction,
|
|
help="Set to generate the default list of CRLs as well",
|
|
default=False,
|
|
)
|
|
parser.add_argument(
|
|
"-o", "--output", help="Output path for certs", type=str, default=str(PurePath("."))
|
|
)
|
|
parser.add_argument(
|
|
"--static-dir",
|
|
help="Path to directory containing signing keys for certs",
|
|
type=str,
|
|
default=str(PurePath("x509/static")),
|
|
)
|
|
parser.add_argument(
|
|
"-d",
|
|
"--dry-run",
|
|
help="If set, just parse the config, but don't generate any certs. If the file/input list paths are set, they will be written.",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=False,
|
|
)
|
|
parser.add_argument(
|
|
"--quiet",
|
|
action=argparse.BooleanOptionalAction,
|
|
help="If set, suppresses all output",
|
|
default=False,
|
|
)
|
|
parser.add_argument("cert", nargs="*", help="Certificate to generate (blank for all)")
|
|
|
|
args = parser.parse_args(argv)
|
|
return args
|
|
|
|
|
|
def validate_config():
|
|
"""Perform basic start up time validation of config file."""
|
|
if not CONFIG.get("certs"):
|
|
raise CertificateGenerationError("No certificates defined")
|
|
|
|
permissible = [
|
|
"name",
|
|
"description",
|
|
"Subject",
|
|
"Issuer",
|
|
"append_cert",
|
|
"extensions",
|
|
"passphrase",
|
|
"keyfile",
|
|
"split_cert_and_key",
|
|
"explicit_subject",
|
|
"serial",
|
|
"not_before",
|
|
"not_after",
|
|
"pkcs12",
|
|
]
|
|
for cert in CONFIG.get("certs", []):
|
|
keys = cert.keys()
|
|
if "name" not in keys:
|
|
raise CertificateGenerationError("Name field required for all certificate definitions")
|
|
if "description" not in keys:
|
|
raise CertificateGenerationError(
|
|
"description field required for all certificate definitions"
|
|
)
|
|
for key in keys:
|
|
if key not in permissible:
|
|
raise CertificateGenerationError(
|
|
"Unknown element '" + key + "' in certificate: " + cert["name"]
|
|
)
|
|
|
|
|
|
def select_items(names):
|
|
"""Select all certificates requested and their ancestor nodes."""
|
|
if not names:
|
|
return CONFIG["certs"]
|
|
|
|
# Temporarily treat like dictionary for easy de-duping.
|
|
ret = {}
|
|
# Start with the cert(s) explicitly asked for.
|
|
for name in names:
|
|
cert = find_certificate_definition(name)
|
|
if not cert:
|
|
raise CertificateGenerationError("Unknown certificate: " + name)
|
|
ret[name] = cert
|
|
|
|
last_count = -1
|
|
while last_count != len(ret):
|
|
last_count = len(ret)
|
|
issuers = {cert.get("Issuer") for _, cert in ret.items()}
|
|
appends = {name for name in cert.get("append_cert", []) for _, cert in ret.items()}
|
|
req_names = issuers | appends
|
|
ret.update({cert["name"]: cert for cert in CONFIG["certs"] if cert["name"] in req_names})
|
|
|
|
return ret.values()
|
|
|
|
|
|
def sort_items(items):
|
|
"""Ensure that leaves are produced after roots (as much as possible within one file)."""
|
|
all_names = [cert["name"] for cert in items]
|
|
all_names.sort()
|
|
processed_names = set()
|
|
|
|
ret = []
|
|
while len(ret) != len(items):
|
|
for cert in items:
|
|
if cert["name"] in processed_names:
|
|
continue
|
|
|
|
# only concern ourselves with prependents in this config file.
|
|
unmet_prependents = [
|
|
name
|
|
for name in cert.get("append_cert", [])
|
|
if (name in all_names) and (name not in processed_names)
|
|
]
|
|
|
|
# Self-signed, signed by someone in ret already, or signed externally
|
|
issuer = cert.get("Issuer")
|
|
has_issuer = (
|
|
(issuer == "self") or (issuer in processed_names) or (issuer not in all_names)
|
|
)
|
|
|
|
if has_issuer and not unmet_prependents:
|
|
ret.append(cert)
|
|
processed_names.add(cert["name"])
|
|
|
|
return ret
|
|
|
|
|
|
def setup_global_state(parsed_args):
|
|
"""Set up various global state based on the commandline arguments."""
|
|
global CONFIG, CONFIGFILE, OUTPUT_PATH, STATIC_PATH, DRY_RUN
|
|
CONFIGFILE = parsed_args.config
|
|
OUTPUT_PATH = PurePath(parsed_args.output)
|
|
STATIC_PATH = PurePath(parsed_args.static_dir)
|
|
DRY_RUN = parsed_args.dry_run
|
|
with open(CONFIGFILE, "r", encoding="utf-8") as f:
|
|
CONFIG = json.load(f)
|
|
if parsed_args.quiet:
|
|
sys.stdout = open(os.devnull, "w")
|
|
validate_config()
|
|
|
|
|
|
def main(argv=None):
|
|
"""Go go go."""
|
|
args = parse_command_line(argv)
|
|
setup_global_state(args)
|
|
|
|
items_to_process = args.cert or []
|
|
items = select_items(items_to_process)
|
|
items = sort_items(items)
|
|
for item in items:
|
|
try:
|
|
process_cert(item)
|
|
except Exception as e:
|
|
raise CertificateGenerationError(
|
|
f"Failed to process certificate {item['name']}: {str(e)}"
|
|
) from e
|
|
filename = make_filename(item)
|
|
write_digest(filename, "cert", "sha256")
|
|
write_digest(filename, "cert", "sha1")
|
|
if args.mkcrl:
|
|
generate_all_crls()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|