[aes] Fix GCM pad length calculation (#11438)

Closes #10169
Authored by: seproDev
This commit is contained in:
sepro 2024-11-03 21:03:09 +01:00 committed by GitHub
parent 3945677a75
commit beae2db127
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -83,6 +83,18 @@ class TestAES(unittest.TestCase):
data, intlist_to_bytes(self.key), authentication_tag, intlist_to_bytes(self.iv[:12])) data, intlist_to_bytes(self.key), authentication_tag, intlist_to_bytes(self.iv[:12]))
self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg) self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg)
def test_gcm_aligned_decrypt(self):
data = b'\x159Y\xcf5eud\x90\x9c\x85&]\x14\x1d\x0f'
authentication_tag = b'\x08\xb1\x9d!&\x98\xd0\xeaRq\x90\xe6;\xb5]\xd8'
decrypted = intlist_to_bytes(aes_gcm_decrypt_and_verify(
list(data), self.key, list(authentication_tag), self.iv[:12]))
self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg[:16])
if Cryptodome.AES:
decrypted = aes_gcm_decrypt_and_verify_bytes(
data, bytes(self.key), authentication_tag, bytes(self.iv[:12]))
self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg[:16])
def test_decrypt_text(self): def test_decrypt_text(self):
password = intlist_to_bytes(self.key).decode() password = intlist_to_bytes(self.key).decode()
encrypted = base64.b64encode( encrypted = base64.b64encode(

View File

@ -230,11 +230,11 @@ def aes_gcm_decrypt_and_verify(data, key, tag, nonce):
iv_ctr = inc(j0) iv_ctr = inc(j0)
decrypted_data = aes_ctr_decrypt(data, key, iv_ctr + [0] * (BLOCK_SIZE_BYTES - len(iv_ctr))) decrypted_data = aes_ctr_decrypt(data, key, iv_ctr + [0] * (BLOCK_SIZE_BYTES - len(iv_ctr)))
pad_len = len(data) // 16 * 16 pad_len = (BLOCK_SIZE_BYTES - (len(data) % BLOCK_SIZE_BYTES)) % BLOCK_SIZE_BYTES
s_tag = ghash( s_tag = ghash(
hash_subkey, hash_subkey,
data data
+ [0] * (BLOCK_SIZE_BYTES - len(data) + pad_len) # pad + [0] * pad_len # pad
+ bytes_to_intlist((0 * 8).to_bytes(8, 'big') # length of associated data + bytes_to_intlist((0 * 8).to_bytes(8, 'big') # length of associated data
+ ((len(data) * 8).to_bytes(8, 'big'))), # length of data + ((len(data) * 8).to_bytes(8, 'big'))), # length of data
) )