Skip to content

Commit

Permalink
Implement AES etypes from RFC8009
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Bokovoy <[email protected]>
  • Loading branch information
jrisc and abbra committed Jan 16, 2024
1 parent 3c7b02c commit 33eac39
Show file tree
Hide file tree
Showing 4 changed files with 436 additions and 53 deletions.
4 changes: 4 additions & 0 deletions impacket/krb5/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ class EncryptionTypes(Enum):
des3_cbc_sha1_kd = 16
aes128_cts_hmac_sha1_96 = 17
aes256_cts_hmac_sha1_96 = 18
aes128_cts_hmac_sha256_128 = 19
aes256_cts_hmac_sha384_192 = 20
rc4_hmac = 23
rc4_hmac_exp = 24
subkey_keymaterial = 65
Expand All @@ -473,3 +475,5 @@ class ChecksumTypes(Enum):
hmac_sha1_des3_kd = 12
hmac_sha1_96_aes128 = 15
hmac_sha1_96_aes256 = 16
hmac_sha256_128_aes128 = 19
hmac_sha384_192_aes256 = 20
229 changes: 178 additions & 51 deletions impacket/krb5/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,37 +55,43 @@
from struct import pack, unpack

from Cryptodome.Cipher import AES, DES3, ARC4, DES
from Cryptodome.Hash import HMAC, MD4, MD5, SHA1
from Cryptodome.Protocol.KDF import PBKDF2
from Cryptodome.Hash import HMAC, MD4, MD5, SHA1, SHA256, SHA384
from Cryptodome.Protocol.KDF import PBKDF2, SP800_108_Counter
from Cryptodome.Util.number import GCD as gcd
from six import b, PY3, indexbytes, binary_type

from impacket.krb5 import constants


def get_random_bytes(lenBytes):
# We don't really need super strong randomness here to use PyCrypto.Random
return urandom(lenBytes)

class Enctype(object):
DES_CRC = 1
DES_MD4 = 2
DES_MD5 = 3
DES3 = 16
AES128_SHA1 = 17
AES256_SHA1 = 18
RC4 = 23
DES_CRC = constants.EncryptionTypes.des_cbc_crc.value
DES_MD4 = constants.EncryptionTypes.des_cbc_md4.value
DES_MD5 = constants.EncryptionTypes.des_cbc_md5.value
DES3 = constants.EncryptionTypes.des3_cbc_sha1_kd.value
AES128_SHA1 = constants.EncryptionTypes.aes128_cts_hmac_sha1_96.value
AES256_SHA1 = constants.EncryptionTypes.aes256_cts_hmac_sha1_96.value
AES128_SHA256 = constants.EncryptionTypes.aes128_cts_hmac_sha256_128.value
AES256_SHA384 = constants.EncryptionTypes.aes256_cts_hmac_sha384_192.value
RC4 = constants.EncryptionTypes.rc4_hmac.value


class Cksumtype(object):
CRC32 = 1
MD4 = 2
MD4_DES = 3
MD5 = 7
MD5_DES = 8
MD5_DES = constants.ChecksumTypes.rsa_md5_des.value
SHA1 = 9
SHA1_DES3 = 12
SHA1_AES128 = 15
SHA1_AES256 = 16
HMAC_MD5 = -138
SHA1_DES3 = constants.ChecksumTypes.hmac_sha1_des3_kd.value
SHA1_AES128 = constants.ChecksumTypes.hmac_sha1_96_aes128.value
SHA1_AES256 = constants.ChecksumTypes.hmac_sha1_96_aes256.value
SHA256_AES128 = constants.ChecksumTypes.hmac_sha256_128_aes128.value
SHA384_AES256 = constants.ChecksumTypes.hmac_sha384_192_aes256.value
HMAC_MD5 = constants.ChecksumTypes.hmac_md5.value


class InvalidChecksum(ValueError):
Expand Down Expand Up @@ -434,6 +440,47 @@ def basic_decrypt(cls, key, ciphertext):
return des3.decrypt(bytes(ciphertext))


def basic_decrypt_all_aes(cls, key, ciphertext, iv):
bs = cls.blocksize
assert len(ciphertext) >= bs
aes = AES.new(key.contents, AES.MODE_ECB)
if len(ciphertext) == bs:
return aes.decrypt(ciphertext)
# Split the ciphertext into blocks. The last block may be partial.
cblocks = [bytearray(ciphertext[p:p+bs]) for p in range(0, len(ciphertext), bs)]
lastlen = len(cblocks[-1])
# CBC-decrypt all but the last two blocks.
prev_cblock = bytearray(iv)
plaintext = b''
for bb in cblocks[:-2]:
plaintext += _xorbytes(bytearray(aes.decrypt(bytes(bb))), prev_cblock)
prev_cblock = bb
# Decrypt the second-to-last cipher block. The left side of
# the decrypted block will be the final block of plaintext
# xor'd with the final partial cipher block; the right side
# will be the omitted bytes of ciphertext from the final
# block.
bb = bytearray(aes.decrypt(bytes(cblocks[-2])))
lastplaintext =_xorbytes(bb[:lastlen], cblocks[-1])
omitted = bb[lastlen:]
# Decrypt the final cipher block plus the omitted bytes to get
# the second-to-last plaintext block.
plaintext += _xorbytes(bytearray(aes.decrypt(bytes(cblocks[-1]) + bytes(omitted))), prev_cblock)
return plaintext + lastplaintext

def basic_encrypt_all_aes(cls, key, plaintext, iv):
bs = cls.blocksize
assert len(plaintext) >= bs
aes = AES.new(key.contents, AES.MODE_CBC, iv)
ctext = aes.encrypt(_zeropad(bytes(plaintext), bs))
if len(plaintext) > bs:
# Swap the last two ciphertext blocks and truncate the
# final block to match the plaintext length.
lastlen = len(plaintext) % bs or bs
ctext = ctext[:-(bs*2)] + ctext[-bs:] + ctext[-(bs*2):-bs][:lastlen]
return ctext


class _AES_SHA1_Enctype(_SimplifiedEnctype):
# Base class for aes128-cts and aes256-cts.
blocksize = 16
Expand All @@ -456,43 +503,13 @@ def string_to_key(cls, string, salt, params):

@classmethod
def basic_encrypt(cls, key, plaintext):
assert len(plaintext) >= 16
aes = AES.new(key.contents, AES.MODE_CBC, b'\0' * 16)
ctext = aes.encrypt(_zeropad(bytes(plaintext), 16))
if len(plaintext) > 16:
# Swap the last two ciphertext blocks and truncate the
# final block to match the plaintext length.
lastlen = len(plaintext) % 16 or 16
ctext = ctext[:-32] + ctext[-16:] + ctext[-32:-16][:lastlen]
return ctext
iv = bytes(cls.blocksize)
return basic_encrypt_all_aes(cls, key, plaintext, iv)

@classmethod
def basic_decrypt(cls, key, ciphertext):
assert len(ciphertext) >= 16
aes = AES.new(key.contents, AES.MODE_ECB)
if len(ciphertext) == 16:
return aes.decrypt(ciphertext)
# Split the ciphertext into blocks. The last block may be partial.
cblocks = [bytearray(ciphertext[p:p+16]) for p in range(0, len(ciphertext), 16)]
lastlen = len(cblocks[-1])
# CBC-decrypt all but the last two blocks.
prev_cblock = bytearray(16)
plaintext = b''
for bb in cblocks[:-2]:
plaintext += _xorbytes(bytearray(aes.decrypt(bytes(bb))), prev_cblock)
prev_cblock = bb
# Decrypt the second-to-last cipher block. The left side of
# the decrypted block will be the final block of plaintext
# xor'd with the final partial cipher block; the right side
# will be the omitted bytes of ciphertext from the final
# block.
bb = bytearray(aes.decrypt(bytes(cblocks[-2])))
lastplaintext =_xorbytes(bb[:lastlen], cblocks[-1])
omitted = bb[lastlen:]
# Decrypt the final cipher block plus the omitted bytes to get
# the second-to-last plaintext block.
plaintext += _xorbytes(bytearray(aes.decrypt(bytes(cblocks[-1]) + bytes(omitted))), prev_cblock)
return plaintext + lastplaintext
iv = bytes(cls.blocksize)
return basic_decrypt_all_aes(cls, key, ciphertext, iv)


class _AES128_SHA1_CTS(_AES_SHA1_Enctype):
Expand All @@ -507,6 +524,102 @@ class _AES256_SHA1_CTS(_AES_SHA1_Enctype):
seedsize = 32


class _RFC8009_Enctype(_EnctypeProfile):
# Base class for aes128-cts-hmac-sha256-128 and aes256-cts-hmac-sha384-192.
blocksize = 128 // 8 # Cipher block size
seedsize = None # PRF output size
keysize = None # Encryption key size
macsize = None # Integrity key size
hashmod = None # Hash function module
enctype_name = None # Encryption type name as byte string

@classmethod
def random_to_key(cls, seed):
return Key(cls.enctype, seed)

@classmethod
def basic_encrypt(cls, key, plaintext, iv):
return basic_encrypt_all_aes(cls, key, plaintext, iv)

@classmethod
def basic_decrypt(cls, key, ciphertext, iv):
return basic_decrypt_all_aes(cls, key, ciphertext, iv)

@classmethod
def kdf_hmac_sha2(cls, key, label, k, context=b''):
hmac_sha2 = lambda p, s: HMAC.new(p, s, cls.hashmod).digest()
return SP800_108_Counter(master=key, key_len=k, prf=hmac_sha2, label=label, context=context)

@classmethod
def derive(cls, key, constant):
return cls.random_to_key(cls.kdf_hmac_sha2(key=key.contents, label=constant, k=cls.macsize))

@classmethod
def prf(cls, input_key, string):
return cls.kdf_hmac_sha2(key=input_key.contents, label=b'prf',
k=cls.seedsize, context=string)

@classmethod
def string_to_key(cls, string, salt, params):
if not isinstance(string, binary_type):
string = string.encode("utf-8")
if not isinstance(salt, binary_type):
salt = salt.encode("utf-8")

saltp = cls.enctype_name + b'\0' + salt

iter_count = unpack('>L', params)[0] if params else 32768
tkey = PBKDF2(password=string, salt=saltp, count=iter_count, dkLen=cls.keysize, hmac_hash_module=cls.hashmod)
return cls.random_to_key(cls.kdf_hmac_sha2(key=tkey, label=b'kerberos', k=cls.keysize))


@classmethod
def encrypt(cls, key, keyusage, plaintext, confounder):
ke = cls.random_to_key(cls.kdf_hmac_sha2(key.contents, pack('>IB', keyusage, 0xAA), cls.keysize))
ki = cls.random_to_key(cls.kdf_hmac_sha2(key.contents, pack('>IB', keyusage, 0x55), cls.macsize))
n = get_random_bytes(cls.blocksize)
# Initial cipher state is a zeroed buffer
iv = bytes(cls.blocksize)
c = cls.basic_encrypt(ke, n + plaintext, iv)
h = HMAC.new(ki.contents, iv + c, cls.hashmod).digest()
ciphertext = c + h[:cls.macsize]
assert(plaintext == cls.decrypt(key, keyusage, ciphertext))
return ciphertext

@classmethod
def decrypt(cls, key, keyusage, ciphertext):
if not isinstance(ciphertext, binary_type):
ciphertext = bytes(ciphertext)
ke = cls.random_to_key(cls.kdf_hmac_sha2(key.contents, pack('>IB', keyusage, 0xAA), cls.keysize))
ki = cls.random_to_key(cls.kdf_hmac_sha2(key.contents, pack('>IB', keyusage, 0x55), cls.macsize))
c = ciphertext[:-cls.macsize]
h = ciphertext[-cls.macsize:]
# Initial cipher state is a zeroed buffer
iv = bytes(cls.blocksize)
if h != HMAC.new(ki.contents, iv + c, cls.hashmod).digest()[:cls.macsize]:
raise InvalidChecksum('ciphertext integrity failure')
plaintext = cls.basic_decrypt(ke, c, iv)[cls.blocksize:]
return plaintext


class _AES128_SHA256_CTS(_RFC8009_Enctype):
enctype = Enctype.AES128_SHA256
seedsize = 256 // 8
macsize = 128 // 8
keysize = 128 // 8
hashmod = SHA256
enctype_name = b'aes128-cts-hmac-sha256-128'


class _AES256_SHA384_CTS(_RFC8009_Enctype):
enctype = Enctype.AES256_SHA384
seedsize = 384 // 8
macsize = 192 // 8
keysize = 256 // 8
hashmod = SHA384
enctype_name = b'aes256-cts-hmac-sha384-192'


class _RC4(_EnctypeProfile):
enctype = Enctype.RC4
keysize = 16
Expand Down Expand Up @@ -591,6 +704,16 @@ def verify(cls, key, keyusage, text, cksum):
super(_SimplifiedChecksum, cls).verify(key, keyusage, text, cksum)


class _SHA256AES128(_SimplifiedChecksum):
macsize = _AES128_SHA256_CTS.macsize
enc = _AES128_SHA256_CTS


class _SHA384AES256(_SimplifiedChecksum):
macsize = _AES256_SHA384_CTS.macsize
enc = _AES256_SHA384_CTS


class _SHA1AES128(_SimplifiedChecksum):
macsize = 12
enc = _AES128_SHA1_CTS
Expand Down Expand Up @@ -625,7 +748,9 @@ def verify(cls, key, keyusage, text, cksum):
Enctype.DES3: _DES3CBC,
Enctype.AES128_SHA1: _AES128_SHA1_CTS,
Enctype.AES256_SHA1: _AES256_SHA1_CTS,
Enctype.RC4: _RC4
Enctype.RC4: _RC4,
Enctype.AES128_SHA256: _AES128_SHA256_CTS,
Enctype.AES256_SHA384: _AES256_SHA384_CTS
}


Expand All @@ -634,7 +759,9 @@ def verify(cls, key, keyusage, text, cksum):
Cksumtype.SHA1_AES128: _SHA1AES128,
Cksumtype.SHA1_AES256: _SHA1AES256,
Cksumtype.HMAC_MD5: _HMACMD5,
0xffffff76: _HMACMD5
0xffffff76: _HMACMD5,
Cksumtype.SHA256_AES128: _SHA256AES128,
Cksumtype.SHA384_AES256: _SHA384AES256
}


Expand All @@ -653,7 +780,7 @@ def _get_checksum_profile(cksumtype):
class Key(object):
def __init__(self, enctype, contents):
e = _get_enctype_profile(enctype)
if len(contents) != e.keysize:
if len(contents) != e.keysize and len(contents) != e.macsize:
raise ValueError('Wrong key length')
self.enctype = enctype
self.contents = contents
Expand Down
20 changes: 18 additions & 2 deletions impacket/krb5/gssapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def GSSAPI(cipher):
return GSSAPI_AES256_SHA1()
if cipher.enctype == constants.EncryptionTypes.aes128_cts_hmac_sha1_96.value:
return GSSAPI_AES128_SHA1()
if cipher.enctype == constants.EncryptionTypes.aes256_cts_hmac_sha384_192.value:
return GSSAPI_AES256_SHA2()
if cipher.enctype == constants.EncryptionTypes.aes128_cts_hmac_sha256_128.value:
return GSSAPI_AES128_SHA2()
elif cipher.enctype == constants.EncryptionTypes.rc4_hmac.value:
return GSSAPI_RC4()
else:
Expand Down Expand Up @@ -289,10 +293,22 @@ def GSS_Unwrap(self, sessionKey, data, sequenceNumber, direction = 'init', encry

return plainText[:-(token['EC']+len(self.WRAP()))], None

class GSSAPI_AES256_SHA1(GSSAPI_AES):
class GSSAPI_AES256_SHA1(GSSAPI_AES_SHA1):
checkSumProfile = crypto._SHA1AES256
cipherType = crypto._AES256_SHA1_CTS

class GSSAPI_AES128_SHA1(GSSAPI_AES):
class GSSAPI_AES128_SHA1(GSSAPI_AES_SHA1):
checkSumProfile = crypto._SHA1AES128
cipherType = crypto._AES128_SHA1_CTS

class GSSAPI_AES_SHA2():
checkSumProfile = None
cipherType = None

class GSSAPI_AES128_SHA256(GSSAPI_AES_SHA2):
checkSumProfile = crypto._SHA256AES128
cipherType = crypto._AES128_SHA256_CTS

class GSSAPI_AES256_SHA384(GSSAPI_AES_SHA2):
checkSumProfile = crypto._SHA384AES256
cipherType = crypto._AES256_SHA384_CTS
Loading

0 comments on commit 33eac39

Please sign in to comment.