From 32abfb00bdbd119ca675fdc6d1719331f0a2741a Mon Sep 17 00:00:00 2001
From: Simon Sawicki <contact@grub4k.xyz>
Date: Mon, 1 Apr 2024 02:12:03 +0200
Subject: [PATCH] [utils] `traverse_obj`: Convenience improvements (#9577)

Add support for:
- `http.cookies.Morsel`
- Multi type filters (`{type, type}`)

Authored by: Grub4K
---
 test/test_traversal.py    | 33 ++++++++++++++++++++++++++++++++-
 yt_dlp/utils/traversal.py | 28 +++++++++++++++++++---------
 2 files changed, 51 insertions(+), 10 deletions(-)

diff --git a/test/test_traversal.py b/test/test_traversal.py
index 0b2f3fb5d..ed29d03ad 100644
--- a/test/test_traversal.py
+++ b/test/test_traversal.py
@@ -1,3 +1,4 @@
+import http.cookies
 import re
 import xml.etree.ElementTree
 
@@ -94,6 +95,8 @@ class TestTraversal:
             'Function in set should be a transformation'
         assert traverse_obj(_TEST_DATA, (..., {str})) == ['str'], \
             'Type in set should be a type filter'
+        assert traverse_obj(_TEST_DATA, (..., {str, int})) == [100, 'str'], \
+            'Multiple types in set should be a type filter'
         assert traverse_obj(_TEST_DATA, {dict}) == _TEST_DATA, \
             'A single set should be wrapped into a path'
         assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \
@@ -103,7 +106,7 @@ class TestTraversal:
             'Function in set should be a transformation'
         assert traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})) == 'const', \
             'Function in set should always be called'
-        # Sets with length != 1 should raise in debug
+        # Sets with length < 1 or > 1 not including only types should raise
         with pytest.raises(Exception):
             traverse_obj(_TEST_DATA, set())
         with pytest.raises(Exception):
@@ -409,3 +412,31 @@ class TestTraversal:
             '`all` should allow further branching'
         assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \
             '`any` should allow further branching'
+
+    def test_traversal_morsel(self):
+        values = {
+            'expires': 'a',
+            'path': 'b',
+            'comment': 'c',
+            'domain': 'd',
+            'max-age': 'e',
+            'secure': 'f',
+            'httponly': 'g',
+            'version': 'h',
+            'samesite': 'i',
+        }
+        morsel = http.cookies.Morsel()
+        morsel.set('item_key', 'item_value', 'coded_value')
+        morsel.update(values)
+        values['key'] = 'item_key'
+        values['value'] = 'item_value'
+
+        for key, value in values.items():
+            assert traverse_obj(morsel, key) == value, \
+                'Morsel should provide access to all values'
+        assert traverse_obj(morsel, ...) == list(values.values()), \
+            '`...` should yield all values'
+        assert traverse_obj(morsel, lambda k, v: True) == list(values.values()), \
+            'function key should yield all values'
+        assert traverse_obj(morsel, [(None,), any]) == morsel, \
+            'Morsel should not be implicitly changed to dict on usage'
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py
index 926a3d0a1..96eb2eddf 100644
--- a/yt_dlp/utils/traversal.py
+++ b/yt_dlp/utils/traversal.py
@@ -1,5 +1,6 @@
 import collections.abc
 import contextlib
+import http.cookies
 import inspect
 import itertools
 import re
@@ -28,7 +29,8 @@ def traverse_obj(
 
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     The next path will also be tested if the path branched but no results could be found.
-    Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+    Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
+    `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
     Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -36,8 +38,8 @@ def traverse_obj(
     The keys in the path can be one of:
         - `None`:           Return the current object.
         - `set`:            Requires the only item in the set to be a type or function,
-                            like `{type}`/`{func}`. If a `type`, returns only values
-                            of this type. If a function, returns `func(obj)`.
+                            like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only
+                            values of this type. If a function, returns `func(obj)`.
         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `Ellipsis`:       Branch out and return a list of all values.
@@ -48,8 +50,10 @@ def traverse_obj(
                             For `Iterable`s, `key` is the index of the value.
                             For `re.Match`es, `key` is the group number (0 = full match)
                             as well as additionally any group names, if given.
-        - `dict`            Transform the current object and return a matching dict.
+        - `dict`:           Transform the current object and return a matching dict.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+        - `any`-builtin:    Take the first matching object and return it, resetting branching.
+        - `all`-builtin:    Take all matching objects and return them as a list, resetting branching.
 
         `tuple`, `list`, and `dict` all support nested paths and branches.
 
@@ -102,10 +106,10 @@ def traverse_obj(
             result = obj
 
         elif isinstance(key, set):
-            assert len(key) == 1, 'Set should only be used to wrap a single item'
             item = next(iter(key))
-            if isinstance(item, type):
-                if isinstance(obj, item):
+            if len(key) > 1 or isinstance(item, type):
+                assert all(isinstance(item, type) for item in key)
+                if isinstance(obj, tuple(key)):
                     result = obj
             else:
                 result = try_call(item, args=(obj,))
@@ -117,6 +121,8 @@ def traverse_obj(
 
         elif key is ...:
             branching = True
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, collections.abc.Mapping):
                 result = obj.values()
             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
@@ -131,6 +137,8 @@ def traverse_obj(
 
         elif callable(key):
             branching = True
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, collections.abc.Mapping):
                 iter_obj = obj.items()
             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
@@ -157,6 +165,8 @@ def traverse_obj(
             } or None
 
         elif isinstance(obj, collections.abc.Mapping):
+            if isinstance(obj, http.cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
                       next((v for k, v in obj.items() if casefold(k) == key), None))
 
@@ -179,7 +189,7 @@ def traverse_obj(
 
         elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
             xpath, _, special = key.rpartition('/')
-            if not special.startswith('@') and special != 'text()':
+            if not special.startswith('@') and not special.endswith('()'):
                 xpath = key
                 special = None
 
@@ -198,7 +208,7 @@ def traverse_obj(
                     return try_call(element.attrib.get, args=(special[1:],))
                 if special == 'text()':
                     return element.text
-                assert False, f'apply_specials is missing case for {special!r}'
+                raise SyntaxError(f'apply_specials is missing case for {special!r}')
 
             if xpath:
                 result = list(map(apply_specials, obj.iterfind(xpath)))