Files
mongo/jstests/client_encrypt/lib/kms_http_server.py

299 lines
9.0 KiB
Python
Executable File

#! /usr/bin/env python3
"""Mock AWS KMS Endpoint."""
import argparse
import collections
import base64
import http.server
import json
import logging
import socketserver
import sys
import urllib.parse
import ssl
import kms_http_common
SECRET_PREFIX = "00SECRET"
# Pass this data out of band instead of storing it in AwsKmsHandler since the
# BaseHTTPRequestHandler does not call the methods as object methods but as class methods. This
# means there is not self.
stats = kms_http_common.Stats()
disable_faults = False
fault_type = None
"""Fault which causes encrypt to return 500."""
FAULT_ENCRYPT = "fault_encrypt"
"""Fault which causes encrypt to return an error that contains a type and message"""
FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format"
"""Fault which causes encrypt to return wrong fields in JSON."""
FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields"
"""Fault which causes encrypt to return bad BASE64."""
FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64"
"""Fault which causes decrypt to return 500."""
FAULT_DECRYPT = "fault_decrypt"
"""Fault which causes decrypt to return an error that contains a type and message"""
FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format"
"""Fault which causes decrypt to return wrong key."""
FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key"
# List of supported fault types
SUPPORTED_FAULT_TYPES = [
FAULT_ENCRYPT,
FAULT_ENCRYPT_CORRECT_FORMAT,
FAULT_ENCRYPT_WRONG_FIELDS,
FAULT_ENCRYPT_BAD_BASE64,
FAULT_DECRYPT,
FAULT_DECRYPT_CORRECT_FORMAT,
FAULT_DECRYPT_WRONG_KEY,
]
class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
"""
Handle requests from AWS KMS Monitoring and test commands
"""
protocol_version = "HTTP/1.1"
def do_GET(self):
"""Serve a Test GET request."""
parts = urllib.parse.urlsplit(self.path)
path = parts[2]
if path == kms_http_common.URL_PATH_STATS:
self._do_stats()
elif path == kms_http_common.URL_DISABLE_FAULTS:
self._do_disable_faults()
elif path == kms_http_common.URL_ENABLE_FAULTS:
self._do_enable_faults()
else:
self.send_response(http.HTTPStatus.NOT_FOUND)
self.end_headers()
self.wfile.write("Unknown URL".encode())
def do_POST(self):
"""Serve a POST request."""
parts = urllib.parse.urlsplit(self.path)
path = parts[2]
if path == "/":
self._do_post()
else:
self.send_response(http.HTTPStatus.NOT_FOUND)
self.end_headers()
self.wfile.write("Unknown URL".encode())
def _send_reply(self, data, status=http.HTTPStatus.OK):
print("Sending Response: " + data.decode())
self.send_response(status)
self.send_header("content-type", "application/octet-stream")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _do_post(self):
global stats
clen = int(self.headers.get('content-length'))
raw_input = self.rfile.read(clen)
print("RAW INPUT: " + str(raw_input))
# X-Amz-Target: TrentService.Encrypt
aws_operation = self.headers['X-Amz-Target']
if aws_operation == "TrentService.Encrypt":
stats.encrypt_calls += 1
self._do_encrypt(raw_input)
elif aws_operation == "TrentService.Decrypt":
stats.decrypt_calls += 1
self._do_decrypt(raw_input)
else:
data = "Unknown AWS Operation"
self._send_reply(data.encode("utf-8"))
def _do_encrypt(self, raw_input):
request = json.loads(raw_input)
print(request)
plaintext = request["Plaintext"]
keyid = request["KeyId"]
ciphertext = SECRET_PREFIX.encode() + plaintext.encode()
ciphertext = base64.b64encode(ciphertext).decode()
if fault_type and fault_type.startswith(FAULT_ENCRYPT) and not disable_faults:
return self._do_encrypt_faults(ciphertext)
response = {
"CiphertextBlob" : ciphertext,
"KeyId" : keyid,
}
self._send_reply(json.dumps(response).encode('utf-8'))
def _do_encrypt_faults(self, raw_ciphertext):
stats.fault_calls += 1
if fault_type == FAULT_ENCRYPT:
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
return
elif fault_type == FAULT_ENCRYPT_WRONG_FIELDS:
response = {
"SomeBlob" : raw_ciphertext,
"KeyId" : "foo",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif fault_type == FAULT_ENCRYPT_BAD_BASE64:
response = {
"CiphertextBlob" : "foo",
"KeyId" : "foo",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif fault_type == FAULT_ENCRYPT_CORRECT_FORMAT:
response = {
"__type" : "NotFoundException",
"message" : "Error encrypting message",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
raise ValueError("Unknown Fault Type: " + fault_type)
def _do_decrypt(self, raw_input):
request = json.loads(raw_input)
blob = base64.b64decode(request["CiphertextBlob"]).decode()
print("FOUND SECRET: " + blob)
# our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us
if not blob.startswith(SECRET_PREFIX):
raise ValueError()
blob = blob[len(SECRET_PREFIX):]
if fault_type and fault_type.startswith(FAULT_DECRYPT) and not disable_faults:
return self._do_decrypt_faults(blob)
response = {
"Plaintext" : blob,
"KeyId" : "Not a clue",
}
self._send_reply(json.dumps(response).encode('utf-8'))
def _do_decrypt_faults(self, blob):
stats.fault_calls += 1
if fault_type == FAULT_DECRYPT:
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
return
elif fault_type == FAULT_DECRYPT_WRONG_KEY:
response = {
"Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==",
"KeyId" : "Not a clue",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
elif fault_type == FAULT_DECRYPT_CORRECT_FORMAT:
response = {
"__type" : "NotFoundException",
"message" : "Error decrypting message",
}
self._send_reply(json.dumps(response).encode('utf-8'))
return
raise ValueError("Unknown Fault Type: " + fault_type)
def _send_header(self):
self.send_response(http.HTTPStatus.OK)
self.send_header("content-type", "application/octet-stream")
self.end_headers()
def _do_stats(self):
self._send_header()
self.wfile.write(str(stats).encode('utf-8'))
def _do_disable_faults(self):
global disable_faults
disable_faults = True
self._send_header()
def _do_enable_faults(self):
global disable_faults
disable_faults = False
self._send_header()
def run(port, cert_file, ca_file, server_class=http.server.HTTPServer, handler_class=AwsKmsHandler):
"""Run web server."""
server_address = ('', port)
httpd = server_class(server_address, handler_class)
httpd.socket = ssl.wrap_socket (httpd.socket,
certfile=cert_file,
ca_certs=ca_file, server_side=True)
print("Mock KMS Web Server Listening on %s" % (str(server_address)))
httpd.serve_forever()
def main():
"""Main Method."""
global fault_type
global disable_faults
parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.')
parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
parser.add_argument('--fault', type=str, help="Type of fault to inject")
parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup")
parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
if args.fault:
if args.fault not in SUPPORTED_FAULT_TYPES:
print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES))
sys.exit(1)
fault_type = args.fault
if args.disable_faults:
disable_faults = True
run(args.port, args.cert_file, args.ca_file)
if __name__ == '__main__':
main()