299 lines
9.0 KiB
Python
Executable File
299 lines
9.0 KiB
Python
Executable File
#! /usr/bin/env python3
|
|
"""Mock AWS KMS Endpoint."""
|
|
|
|
import argparse
|
|
import collections
|
|
import base64
|
|
import http.server
|
|
import json
|
|
import logging
|
|
import socketserver
|
|
import sys
|
|
import urllib.parse
|
|
import ssl
|
|
|
|
import kms_http_common
|
|
|
|
SECRET_PREFIX = "00SECRET"
|
|
|
|
# Pass this data out of band instead of storing it in AwsKmsHandler since the
|
|
# BaseHTTPRequestHandler does not call the methods as object methods but as class methods. This
|
|
# means there is not self.
|
|
stats = kms_http_common.Stats()
|
|
disable_faults = False
|
|
fault_type = None
|
|
|
|
"""Fault which causes encrypt to return 500."""
|
|
FAULT_ENCRYPT = "fault_encrypt"
|
|
|
|
"""Fault which causes encrypt to return an error that contains a type and message"""
|
|
FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format"
|
|
|
|
"""Fault which causes encrypt to return wrong fields in JSON."""
|
|
FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields"
|
|
|
|
"""Fault which causes encrypt to return bad BASE64."""
|
|
FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64"
|
|
|
|
"""Fault which causes decrypt to return 500."""
|
|
FAULT_DECRYPT = "fault_decrypt"
|
|
|
|
"""Fault which causes decrypt to return an error that contains a type and message"""
|
|
FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format"
|
|
|
|
"""Fault which causes decrypt to return wrong key."""
|
|
FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key"
|
|
|
|
|
|
# List of supported fault types
|
|
SUPPORTED_FAULT_TYPES = [
|
|
FAULT_ENCRYPT,
|
|
FAULT_ENCRYPT_CORRECT_FORMAT,
|
|
FAULT_ENCRYPT_WRONG_FIELDS,
|
|
FAULT_ENCRYPT_BAD_BASE64,
|
|
FAULT_DECRYPT,
|
|
FAULT_DECRYPT_CORRECT_FORMAT,
|
|
FAULT_DECRYPT_WRONG_KEY,
|
|
]
|
|
|
|
class AwsKmsHandler(http.server.BaseHTTPRequestHandler):
|
|
"""
|
|
Handle requests from AWS KMS Monitoring and test commands
|
|
"""
|
|
protocol_version = "HTTP/1.1"
|
|
|
|
def do_GET(self):
|
|
"""Serve a Test GET request."""
|
|
parts = urllib.parse.urlsplit(self.path)
|
|
path = parts[2]
|
|
|
|
if path == kms_http_common.URL_PATH_STATS:
|
|
self._do_stats()
|
|
elif path == kms_http_common.URL_DISABLE_FAULTS:
|
|
self._do_disable_faults()
|
|
elif path == kms_http_common.URL_ENABLE_FAULTS:
|
|
self._do_enable_faults()
|
|
else:
|
|
self.send_response(http.HTTPStatus.NOT_FOUND)
|
|
self.end_headers()
|
|
self.wfile.write("Unknown URL".encode())
|
|
|
|
def do_POST(self):
|
|
"""Serve a POST request."""
|
|
parts = urllib.parse.urlsplit(self.path)
|
|
path = parts[2]
|
|
|
|
if path == "/":
|
|
self._do_post()
|
|
else:
|
|
self.send_response(http.HTTPStatus.NOT_FOUND)
|
|
self.end_headers()
|
|
self.wfile.write("Unknown URL".encode())
|
|
|
|
def _send_reply(self, data, status=http.HTTPStatus.OK):
|
|
print("Sending Response: " + data.decode())
|
|
|
|
self.send_response(status)
|
|
self.send_header("content-type", "application/octet-stream")
|
|
self.send_header("Content-Length", str(len(data)))
|
|
self.end_headers()
|
|
|
|
self.wfile.write(data)
|
|
|
|
def _do_post(self):
|
|
global stats
|
|
clen = int(self.headers.get('content-length'))
|
|
|
|
raw_input = self.rfile.read(clen)
|
|
|
|
print("RAW INPUT: " + str(raw_input))
|
|
|
|
# X-Amz-Target: TrentService.Encrypt
|
|
aws_operation = self.headers['X-Amz-Target']
|
|
|
|
if aws_operation == "TrentService.Encrypt":
|
|
stats.encrypt_calls += 1
|
|
self._do_encrypt(raw_input)
|
|
elif aws_operation == "TrentService.Decrypt":
|
|
stats.decrypt_calls += 1
|
|
self._do_decrypt(raw_input)
|
|
else:
|
|
data = "Unknown AWS Operation"
|
|
self._send_reply(data.encode("utf-8"))
|
|
|
|
def _do_encrypt(self, raw_input):
|
|
request = json.loads(raw_input)
|
|
|
|
print(request)
|
|
|
|
plaintext = request["Plaintext"]
|
|
keyid = request["KeyId"]
|
|
|
|
ciphertext = SECRET_PREFIX.encode() + plaintext.encode()
|
|
ciphertext = base64.b64encode(ciphertext).decode()
|
|
|
|
if fault_type and fault_type.startswith(FAULT_ENCRYPT) and not disable_faults:
|
|
return self._do_encrypt_faults(ciphertext)
|
|
|
|
response = {
|
|
"CiphertextBlob" : ciphertext,
|
|
"KeyId" : keyid,
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
|
|
def _do_encrypt_faults(self, raw_ciphertext):
|
|
stats.fault_calls += 1
|
|
|
|
if fault_type == FAULT_ENCRYPT:
|
|
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
return
|
|
elif fault_type == FAULT_ENCRYPT_WRONG_FIELDS:
|
|
response = {
|
|
"SomeBlob" : raw_ciphertext,
|
|
"KeyId" : "foo",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
return
|
|
elif fault_type == FAULT_ENCRYPT_BAD_BASE64:
|
|
response = {
|
|
"CiphertextBlob" : "foo",
|
|
"KeyId" : "foo",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
return
|
|
elif fault_type == FAULT_ENCRYPT_CORRECT_FORMAT:
|
|
response = {
|
|
"__type" : "NotFoundException",
|
|
"message" : "Error encrypting message",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
return
|
|
|
|
raise ValueError("Unknown Fault Type: " + fault_type)
|
|
|
|
def _do_decrypt(self, raw_input):
|
|
request = json.loads(raw_input)
|
|
blob = base64.b64decode(request["CiphertextBlob"]).decode()
|
|
|
|
print("FOUND SECRET: " + blob)
|
|
|
|
# our "encrypted" values start with the word SECRET_PREFIX otherwise they did not come from us
|
|
if not blob.startswith(SECRET_PREFIX):
|
|
raise ValueError()
|
|
|
|
blob = blob[len(SECRET_PREFIX):]
|
|
|
|
if fault_type and fault_type.startswith(FAULT_DECRYPT) and not disable_faults:
|
|
return self._do_decrypt_faults(blob)
|
|
|
|
response = {
|
|
"Plaintext" : blob,
|
|
"KeyId" : "Not a clue",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
|
|
def _do_decrypt_faults(self, blob):
|
|
stats.fault_calls += 1
|
|
|
|
if fault_type == FAULT_DECRYPT:
|
|
self._send_reply("Internal Error of some sort.".encode(), http.HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
return
|
|
elif fault_type == FAULT_DECRYPT_WRONG_KEY:
|
|
response = {
|
|
"Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==",
|
|
"KeyId" : "Not a clue",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
return
|
|
elif fault_type == FAULT_DECRYPT_CORRECT_FORMAT:
|
|
response = {
|
|
"__type" : "NotFoundException",
|
|
"message" : "Error decrypting message",
|
|
}
|
|
|
|
self._send_reply(json.dumps(response).encode('utf-8'))
|
|
return
|
|
|
|
raise ValueError("Unknown Fault Type: " + fault_type)
|
|
|
|
def _send_header(self):
|
|
self.send_response(http.HTTPStatus.OK)
|
|
self.send_header("content-type", "application/octet-stream")
|
|
self.end_headers()
|
|
|
|
def _do_stats(self):
|
|
self._send_header()
|
|
|
|
self.wfile.write(str(stats).encode('utf-8'))
|
|
|
|
def _do_disable_faults(self):
|
|
global disable_faults
|
|
disable_faults = True
|
|
self._send_header()
|
|
|
|
def _do_enable_faults(self):
|
|
global disable_faults
|
|
disable_faults = False
|
|
self._send_header()
|
|
|
|
def run(port, cert_file, ca_file, server_class=http.server.HTTPServer, handler_class=AwsKmsHandler):
|
|
"""Run web server."""
|
|
server_address = ('', port)
|
|
|
|
httpd = server_class(server_address, handler_class)
|
|
|
|
httpd.socket = ssl.wrap_socket (httpd.socket,
|
|
certfile=cert_file,
|
|
ca_certs=ca_file, server_side=True)
|
|
|
|
print("Mock KMS Web Server Listening on %s" % (str(server_address)))
|
|
|
|
httpd.serve_forever()
|
|
|
|
|
|
def main():
|
|
"""Main Method."""
|
|
global fault_type
|
|
global disable_faults
|
|
|
|
parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.')
|
|
|
|
parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on")
|
|
|
|
parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
|
|
|
|
parser.add_argument('--fault', type=str, help="Type of fault to inject")
|
|
|
|
parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup")
|
|
|
|
parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file")
|
|
|
|
parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file")
|
|
|
|
args = parser.parse_args()
|
|
if args.verbose:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
if args.fault:
|
|
if args.fault not in SUPPORTED_FAULT_TYPES:
|
|
print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES))
|
|
sys.exit(1)
|
|
|
|
fault_type = args.fault
|
|
|
|
if args.disable_faults:
|
|
disable_faults = True
|
|
|
|
run(args.port, args.cert_file, args.ca_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|