diff --git a/README.md b/README.md index 90758d0..9f9c729 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A Python implementation of the authenticated encryption mode [Galois/Counter Mode (GCM)](http://en.wikipedia.org/wiki/Galois/Counter_Mode). -Currently it supports only 128-bit AES and 96-bit nonce. +Currently it supports 128, 192 and 256-bit AES and 96-bit nonce. ## Dependencies diff --git a/aes_gcm.py b/aes_gcm.py index 04e496b..f6b7e71 100755 --- a/aes_gcm.py +++ b/aes_gcm.py @@ -50,7 +50,7 @@ def __str__(self): class InvalidTagException(Exception): def __str__(self): - return 'The authenticaiton tag is invalid.' + return 'The authentication tag is invalid.' # Galois/Counter Mode with AES-128 and 96-bit IV @@ -59,10 +59,20 @@ def __init__(self, master_key): self.change_key(master_key) def change_key(self, master_key): - if master_key >= (1 << 128): - raise InvalidInputException('Master key should be 128-bit') + if type(master_key) in [bytes, str]: + # Assume this is a good key in a bytearray/bytes/string + self.__master_key = master_key + elif master_key <= (1 << 128): + self.__master_key = long_to_bytes(master_key, 16) + elif master_key <= (1 << 192): + self.__master_key = long_to_bytes(master_key, 24) + elif master_key <= (1 << 256): + self.__master_key = long_to_bytes(master_key, 32) + + if len(self.__master_key) > 32: + raise InvalidInputException( + 'Master key should be 128, 192 or 256-bit, (got: %s)' % len(self.__master_key * 8)) - self.__master_key = long_to_bytes(master_key, 16) self.__aes_ecb = AES.new(self.__master_key, AES.MODE_ECB) self.__auth_key = bytes_to_long(self.__aes_ecb.encrypt(b'\x00' * 16)) @@ -109,7 +119,10 @@ def __ghash(self, aad, txt): return tag - def encrypt(self, init_value, plaintext, auth_data=b''): + def encrypt(self, init_value, plaintext, auth_data=b'', tag_len = 16): + if type(init_value) in [bytes, str]: + # Assume the IV is provided as bytes + init_value = bytes_to_long(init_value) if init_value >= (1 << 96): raise InvalidInputException('IV should be 96-bit') # a naive checking for IV reuse @@ -145,17 +158,27 @@ def encrypt(self, init_value, plaintext, auth_data=b''): # assert len(ciphertext) == len(plaintext) assert auth_tag < (1 << 128) + auth_tag = auth_tag >> (16 - tag_len) * 8 return ciphertext, auth_tag - def decrypt(self, init_value, ciphertext, auth_tag, auth_data=b''): + def decrypt(self, init_value, ciphertext, auth_tag, auth_data=b'', tag_len = 16): + # Assume the IV and/or auth tag are provided as byte arrays when they look like strings + if type(init_value) in [bytes, str]: + init_value = bytes_to_long(init_value) + if type(auth_tag) in [bytes, str]: + tag_len = int(len(auth_tag) / 8) + auth_tag = bytes_to_long(auth_tag) + if init_value >= (1 << 96): raise InvalidInputException('IV should be 96-bit') if auth_tag >= (1 << 128): raise InvalidInputException('Tag should be 128-bit') - if auth_tag != self.__ghash(auth_data, ciphertext) ^ \ + ghash = self.__ghash(auth_data, ciphertext) ^ \ bytes_to_long(self.__aes_ecb.encrypt( - long_to_bytes((init_value << 32) | 1, 16))): + long_to_bytes((init_value << 32) | 1, 16))) + ghash = ghash >> (16 - tag_len) * 8 + if auth_tag != ghash: raise InvalidTagException len_ciphertext = len(ciphertext) diff --git a/test.py b/test.py index 4231d33..4bbf223 100755 --- a/test.py +++ b/test.py @@ -25,7 +25,8 @@ from aes_gcm import AES_GCM from pprint import pprint from Crypto.Random.random import getrandbits -from Crypto.Util.number import long_to_bytes +from Crypto.Util.number import long_to_bytes, bytes_to_long +from Crypto.Hash import SHA256 test_cases = ({ 'master_key': 0x00000000000000000000000000000000, @@ -34,6 +35,57 @@ 'init_value': 0x000000000000000000000000, 'ciphertext': b'', 'auth_tag': 0x58e2fccefa7e3061367f1d57a4e7455a, +}, { + 'name': '192 bit key', + 'master_key': 0xfffffffffffffffffffffffffffffffff, + 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + + b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'auth_data': b'', + 'init_value': 0x000000000000000000000000, + 'ciphertext': b'\x72\x3e\x3e\x28\x03\x9e\xb7\x8c' + + b'\xd2\x95\x39\x7f\x9c\x27\x08\x32', + 'auth_tag': 0xe2deabd0b0d93a417facac9df9f6cf91, +}, { + 'name': '256 bit key as byte array, IV as byte array', + 'master_key': SHA256.new("").digest(), + 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + + b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'auth_data': b'', + 'init_value': b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'ciphertext': b'\x09\x33\x71\x35\x75\xe5\x1b\x11' + + b'\xca\xab\x53\x99\xb8\x8d\x48\xc6', + 'auth_tag': 0x761b53ebf18f95502fcd10865ba91e17, +}, { + 'name': '256 bit key as byte array, IV and auth tag as byte array', + 'master_key': SHA256.new("").digest(), + 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + + b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'auth_data': b'', + 'init_value': b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'ciphertext': b'\x09\x33\x71\x35\x75\xe5\x1b\x11' + + b'\xca\xab\x53\x99\xb8\x8d\x48\xc6', + 'auth_tag': b'\x76\x1b\x53\xeb\xf1\x8f\x95\x50' + + b'\x2f\xcd\x10\x86\x5b\xa9\x1e\x17', +}, { + 'name': '256 bit key as byte array', + 'master_key': SHA256.new("").digest(), + 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + + b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'auth_data': b'', + 'init_value': 0x000000000000000000000000, + 'ciphertext': b'\x09\x33\x71\x35\x75\xe5\x1b\x11' + + b'\xca\xab\x53\x99\xb8\x8d\x48\xc6', + 'auth_tag': 0x761b53ebf18f95502fcd10865ba91e17, +}, { + 'name': '256 bit key', + 'master_key': bytes_to_long(SHA256.new("").digest()), + 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + + b'\x00\x00\x00\x00\x00\x00\x00\x00', + 'auth_data': b'', + 'init_value': 0x000000000000000000000000, + 'ciphertext': b'\x09\x33\x71\x35\x75\xe5\x1b\x11' + + b'\xca\xab\x53\x99\xb8\x8d\x48\xc6', + 'auth_tag': 0x761b53ebf18f95502fcd10865ba91e17, }, { 'master_key': 0x00000000000000000000000000000000, 'plaintext': b'\x00\x00\x00\x00\x00\x00\x00\x00' + @@ -94,12 +146,21 @@ num_failures = 0 for test_data in test_cases: + test_tag = test_data['auth_tag'] + if type(test_data['auth_tag']) in [bytes, str]: + test_tag = bytes_to_long(test_data['auth_tag']) + test_gcm = AES_GCM(test_data['master_key']) encrypted, tag = test_gcm.encrypt( test_data['init_value'], test_data['plaintext'], test_data['auth_data'] ) + if type(encrypted) == str: + enc_dbg = '\\x' + '\\x'.join('{:02x}'.format(ord(x)) for x in encrypted) + else: + enc_dbg = '\\x' + '\\x'.join('{:02x}'.format(x) for x in encrypted) + tag_dbg = hex(tag) states = [] tags = [] @@ -129,16 +190,19 @@ decrypted = test_gcm.decrypt( test_data['init_value'], encrypted, - tag, + test_data['auth_tag'], test_data['auth_data'] ) if encrypted != test_data['ciphertext'] or \ - tag != test_data['auth_tag'] or \ + tag != test_tag or \ decrypted != test_data['plaintext']: num_failures += 1 print('This test case failed:') pprint(test_data) + print("Encrypted: %s (%s)" % (enc_dbg, encrypted == test_data['ciphertext'])) + print("Tag: %s (%s)" % (tag_dbg, tag == test_tag)) + print("Decrypted: %s (%s)" % (decrypted, decrypted == test_data['plaintext'])) print() if num_failures == 0: diff --git a/test_nist.py b/test_nist.py new file mode 100755 index 0000000..abfd6a5 --- /dev/null +++ b/test_nist.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +import fileinput +import sys +from Crypto.Util.number import long_to_bytes, bytes_to_long + +from aes_gcm import AES_GCM, InvalidTagException + +current_test_parameters = {} +current_test = {} +success_count = 0 +fail_count = 0 + +def process(line): + global current_test + global success_count + global fail_count + sline = line.strip() + if sline.startswith("["): + data = sline[1:-1] + key, value = data.split("=", 1) + current_test_parameters[key.strip()] = int(value) + elif (sline == "" and not current_test) or line.startswith("#"): + return + elif sline == "" and 'count' in current_test.keys(): + errors = [] + if 'PT' not in current_test.keys(): + current_test['PT'] = '' + test_gcm = AES_GCM(int(current_test['Key'],16)) + test_aad = b'' if (len(current_test['AAD']) == 0) else long_to_bytes(int(current_test['AAD'], 16)) + test_tag = b'' if (len(current_test['Tag']) == 0) else int(current_test['Tag'], 16) + test_crypttext = b'' if (len(current_test['CT']) == 0) else long_to_bytes(int(current_test['CT'], 16)) + test_plaintext = b'' if (len(current_test['PT']) == 0) else long_to_bytes(int(current_test['PT'], 16)) + test_iv = int(current_test['IV'], 16) + tag_len = int(int(current_test_parameters['Taglen']) / 8) + try: + computed_crypttext, computed_tag = test_gcm.encrypt( + test_iv, + test_plaintext, + test_aad, + tag_len) + except ValueError as e: + errors.append(e) + if computed_tag != test_tag: + errors.append("Tag mismatch after encryption") + computed_plaintext = b'' + try: + computed_plaintext = test_gcm.decrypt( + test_iv, + test_crypttext, + test_tag, + test_aad, + tag_len) + if computed_plaintext != test_plaintext: + errors.append("Plaintext mismatch") + except InvalidTagException: + errors.append("Tag mismatch while decrypting") + test_passed = current_test['fail'] == (len(errors) > 0) + if not test_passed: + fail_count += 1 + print("\n\nFailed test %s" % current_test['count']) + print("Parameters:") + print(current_test_parameters) + print("Test case:") + print(current_test) + print(errors) + print("Crypttext") + print(" Test: %s" % test_crypttext) + print(" Computed: %s" % computed_crypttext) + print("Plaintext") + print(" Test: %s" % test_plaintext) + print(" Computed: %s" % computed_plaintext) + print("Tags") + print(" Test: %s" % hex(test_tag)) + print(" Computed: %s" % hex(computed_tag)) + print("Failed: %s | Success: %s" % (fail_count, success_count)) + else: + success_count += 1 + current_test = None + elif line.startswith("Count ="): + current_test = { + 'count': int(line.split("=", 1)[1]), + 'fail': False + } + elif " = " in line: + name, value = line.split(" = ", 1) + current_test[name.strip()] = value.strip() + elif sline == "FAIL": + current_test['fail'] = True + else: + print("unknown line: %s" % line) + +print("Parsing") + +total = 0 +last = 0 +for line in fileinput.input(): + process(line) + total = success_count + fail_count + if (total % 20) == 0 and last != total: + print("Failed: %s | Success: %s" % (fail_count, success_count)) + last = total + +print("Success: %s" % success_count) +print("Failed: %s" % fail_count)