Files
mongo/jstests/client_encrypt/lib/kms_http_server.py
2021-01-08 01:27:48 +00:00

245 lines
8.4 KiB
Python
Executable File

#! /usr/bin/env python3
"""Mock AWS KMS Endpoint."""
import argparse
import base64
import http.server
import json
import logging
import sys
import urllib.parse
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
import kms_http_common
SECRET_PREFIX = "00SECRET"
# List of supported fault types
SUPPORTED_FAULT_TYPES = [
kms_http_common.FAULT_ENCRYPT,
kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT,
kms_http_common.FAULT_ENCRYPT_WRONG_FIELDS,
kms_http_common.FAULT_ENCRYPT_BAD_BASE64,
kms_http_common.FAULT_DECRYPT,
kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT,
kms_http_common.FAULT_DECRYPT_WRONG_KEY,
]
def get_dict_subset(headers, subset):
ret = {}
for header in headers.keys():
if header.lower() in subset.lower():
ret[header] = headers[header]
return ret
class AwsKmsHandler(kms_http_common.KmsHandlerBase):
"""
Handle requests from AWS KMS Monitoring and test commands
"""
def do_POST(self):
print("Received POST: " + self.path)
parts = urllib.parse.urlsplit(self.path)
path = parts[2]
if path == "/":
self._do_post()
else:
self.send_response(http.HTTPStatus.NOT_FOUND)
self.end_headers()
self.wfile.write("Unknown URL".encode())
def _do_post(self):
clen = int(self.headers.get('content-length'))
raw_input = self.rfile.read(clen)
print("RAW INPUT: " + str(raw_input))
if not self.headers["Host"] == "localhost":
data = "Unexpected host"
self._send_reply(data.encode("utf-8"))
if not self._validate_signature(self.headers, raw_input):
data = "Bad Signature"
self._send_reply(data.encode("utf-8"))
# X-Amz-Target: TrentService.Encrypt
aws_operation = self.headers['X-Amz-Target']
if aws_operation == "TrentService.Encrypt":
kms_http_common.stats.encrypt_calls += 1
self._do_encrypt(raw_input)
elif aws_operation == "TrentService.Decrypt":
kms_http_common.stats.decrypt_calls += 1
self._do_decrypt(raw_input)
else:
data = "Unknown AWS Operation"
self._send_reply(data.encode("utf-8"))
def _validate_signature(self, headers, raw_input):
auth_header = headers["Authorization"]
signed_headers_start = auth_header.find("SignedHeaders")
signed_headers = auth_header[signed_headers_start:auth_header.find(",", signed_headers_start)]
signed_headers_dict = get_dict_subset(headers, signed_headers)
request = AWSRequest(method="POST", url="/", data=raw_input, headers=signed_headers_dict)
# SigV4Auth assumes this header exists even though it is not required by the algorithm
request.context['timestamp'] = headers['X-Amz-Date']
region_start = auth_header.find("Credential=access/") + len("Credential=access/YYYYMMDD/")
region = auth_header[region_start:auth_header.find("/", region_start)]
credentials = Credentials("access", "secret")
auth = SigV4Auth(credentials, "kms", region)
string_to_sign = auth.string_to_sign(request, auth.canonical_request(request))
expected_signature = auth.signature(string_to_sign, request)
signature_headers_start = auth_header.find("Signature=") + len("Signature=")
actual_signature = auth_header[signature_headers_start:]
return expected_signature == actual_signature
def _do_encrypt(self, raw_input):
request = json.loads(raw_input)
print(request)
plaintext = request["Plaintext"]
keyid = request["KeyId"]
ciphertext = SECRET_PREFIX.encode() + plaintext.encode()
ciphertext = base64.b64encode(ciphertext).decode()
if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \
and not kms_http_common.disable_faults:
return self._do_encrypt_faults(ciphertext)
response = {
"CiphertextBlob" : ciphertext,
"KeyId" : keyid,
}
self._send_reply(json.dumps(response).encode('utf-8'))
def _do_encrypt_faults(self, raw_ciphertext):
kms_http_common.stats.fault_calls += 1
if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT:
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
return
elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_WRONG_FIELDS:
response = {
"SomeBlob" : raw_ciphertext,
"KeyId" : "foo",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_BAD_BASE64:
response = {
"CiphertextBlob" : "foo",
"KeyId" : "foo",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT:
response = {
"__type" : "NotFoundException",
"Message" : "Error encrypting message",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type)
def _do_decrypt(self, raw_input):
request = json.loads(raw_input)
blob = base64.b64decode(request["CiphertextBlob"]).decode()
print("FOUND SECRET: " + blob)
# our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us
if not blob.startswith(SECRET_PREFIX):
raise ValueError()
blob = blob[len(SECRET_PREFIX):]
if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \
and not kms_http_common.disable_faults:
return self._do_decrypt_faults(blob)
response = {
"Plaintext" : blob,
"KeyId" : "Not a clue",
}
self._send_reply(json.dumps(response).encode('utf-8'))
def _do_decrypt_faults(self, blob):
kms_http_common.stats.fault_calls += 1
if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT:
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
return
elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY:
response = {
"Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==",
"KeyId" : "Not a clue",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT:
response = {
"__type" : "NotFoundException",
"Message" : "Error decrypting message",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type)
def main():
"""Main Method."""
parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.')
parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
parser.add_argument('--fault', type=str, help="Type of fault to inject")
parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup")
parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
if args.fault:
if args.fault not in SUPPORTED_FAULT_TYPES:
print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES))
sys.exit(1)
kms_http_common.fault_type = args.fault
if args.disable_faults:
kms_http_common.disable_faults = True
kms_http_common.run(args.port, args.cert_file, args.ca_file, AwsKmsHandler)
if __name__ == '__main__':
main()