From beae2db127d3b5017cbcf685da9de7a9ef496541 Mon Sep 17 00:00:00 2001
From: sepro <sepro@sepr0.com>
Date: Sun, 3 Nov 2024 21:03:09 +0100
Subject: [PATCH] [aes] Fix GCM pad length calculation (#11438)

Closes #10169
Authored by: seproDev
---
 test/test_aes.py | 12 ++++++++++++
 yt_dlp/aes.py    |  4 ++--
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/test/test_aes.py b/test/test_aes.py
index 5f975efec..6fe6059a1 100644
--- a/test/test_aes.py
+++ b/test/test_aes.py
@@ -83,6 +83,18 @@ class TestAES(unittest.TestCase):
                 data, intlist_to_bytes(self.key), authentication_tag, intlist_to_bytes(self.iv[:12]))
             self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg)
 
+    def test_gcm_aligned_decrypt(self):
+        data = b'\x159Y\xcf5eud\x90\x9c\x85&]\x14\x1d\x0f'
+        authentication_tag = b'\x08\xb1\x9d!&\x98\xd0\xeaRq\x90\xe6;\xb5]\xd8'
+
+        decrypted = intlist_to_bytes(aes_gcm_decrypt_and_verify(
+            list(data), self.key, list(authentication_tag), self.iv[:12]))
+        self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg[:16])
+        if Cryptodome.AES:
+            decrypted = aes_gcm_decrypt_and_verify_bytes(
+                data, bytes(self.key), authentication_tag, bytes(self.iv[:12]))
+            self.assertEqual(decrypted.rstrip(b'\x08'), self.secret_msg[:16])
+
     def test_decrypt_text(self):
         password = intlist_to_bytes(self.key).decode()
         encrypted = base64.b64encode(
diff --git a/yt_dlp/aes.py b/yt_dlp/aes.py
index abf54a998..be67b40fe 100644
--- a/yt_dlp/aes.py
+++ b/yt_dlp/aes.py
@@ -230,11 +230,11 @@ def aes_gcm_decrypt_and_verify(data, key, tag, nonce):
     iv_ctr = inc(j0)
 
     decrypted_data = aes_ctr_decrypt(data, key, iv_ctr + [0] * (BLOCK_SIZE_BYTES - len(iv_ctr)))
-    pad_len = len(data) // 16 * 16
+    pad_len = (BLOCK_SIZE_BYTES - (len(data) % BLOCK_SIZE_BYTES)) % BLOCK_SIZE_BYTES
     s_tag = ghash(
         hash_subkey,
         data
-        + [0] * (BLOCK_SIZE_BYTES - len(data) + pad_len)        # pad
+        + [0] * pad_len                                         # pad
         + bytes_to_intlist((0 * 8).to_bytes(8, 'big')           # length of associated data
                            + ((len(data) * 8).to_bytes(8, 'big'))),  # length of data
     )