From 69bec6730ec9d724bcedeab199d9d684d61423ba Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sun, 21 May 2023 09:56:23 +1200 Subject: [PATCH] [cleanup, utils] Split into submodules (#7090) Closes https://github.com/yt-dlp/yt-dlp/pull/2173 Authored by: pukkandan, coletdjnz Co-authored-by: pukkandan --- Makefile | 2 +- setup.cfg | 1 + yt_dlp/YoutubeDL.py | 2 - yt_dlp/utils/__init__.py | 14 + yt_dlp/utils/_deprecated.py | 30 ++ yt_dlp/utils/_legacy.py | 163 ++++++++++ yt_dlp/{utils.py => utils/_utils.py} | 458 +-------------------------- yt_dlp/utils/traversal.py | 254 +++++++++++++++ 8 files changed, 480 insertions(+), 444 deletions(-) create mode 100644 yt_dlp/utils/__init__.py create mode 100644 yt_dlp/utils/_deprecated.py create mode 100644 yt_dlp/utils/_legacy.py rename yt_dlp/{utils.py => utils/_utils.py} (92%) create mode 100644 yt_dlp/utils/traversal.py diff --git a/Makefile b/Makefile index d5d47629b..f03fe2052 100644 --- a/Makefile +++ b/Makefile @@ -74,7 +74,7 @@ offlinetest: codetest $(PYTHON) -m pytest -k "not download" # XXX: This is hard to maintain -CODE_FOLDERS = yt_dlp yt_dlp/downloader yt_dlp/extractor yt_dlp/postprocessor yt_dlp/compat yt_dlp/dependencies +CODE_FOLDERS = yt_dlp yt_dlp/downloader yt_dlp/extractor yt_dlp/postprocessor yt_dlp/compat yt_dlp/utils yt_dlp/dependencies yt-dlp: yt_dlp/*.py yt_dlp/*/*.py mkdir -p zip for d in $(CODE_FOLDERS) ; do \ diff --git a/setup.cfg b/setup.cfg index 6deaa7971..68d9e516d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ ignore = E402,E501,E731,E741,W503 max_line_length = 120 per_file_ignores = devscripts/lazy_load_template.py: F401 + yt_dlp/utils/__init__.py: F401, F403 [autoflake] diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 91aec1fe6..b8f1a05a0 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -124,7 +124,6 @@ from .utils import ( parse_filesize, preferredencoding, prepend_extension, - register_socks_protocols, remove_terminal_sequences, render_table, replace_extension, @@ -739,7 +738,6 @@ class YoutubeDL: when=when) self._setup_opener() - register_socks_protocols() def preload_download_archive(fn): """Preload the archive, if any is specified""" diff --git a/yt_dlp/utils/__init__.py b/yt_dlp/utils/__init__.py new file mode 100644 index 000000000..74b39e2c7 --- /dev/null +++ b/yt_dlp/utils/__init__.py @@ -0,0 +1,14 @@ +import warnings + +from ..compat.compat_utils import passthrough_module + +# XXX: Implement this the same way as other DeprecationWarnings without circular import +passthrough_module(__name__, '._legacy', callback=lambda attr: warnings.warn( + DeprecationWarning(f'{__name__}.{attr} is deprecated'), stacklevel=5)) +del passthrough_module + +# isort: off +from .traversal import * +from ._utils import * +from ._utils import _configuration_args, _get_exe_version_output +from ._deprecated import * diff --git a/yt_dlp/utils/_deprecated.py b/yt_dlp/utils/_deprecated.py new file mode 100644 index 000000000..4454d84a7 --- /dev/null +++ b/yt_dlp/utils/_deprecated.py @@ -0,0 +1,30 @@ +"""Deprecated - New code should avoid these""" + +from ._utils import preferredencoding + + +def encodeFilename(s, for_subprocess=False): + assert isinstance(s, str) + return s + + +def decodeFilename(b, for_subprocess=False): + return b + + +def decodeArgument(b): + return b + + +def decodeOption(optval): + if optval is None: + return optval + if isinstance(optval, bytes): + optval = optval.decode(preferredencoding()) + + assert isinstance(optval, str) + return optval + + +def error_to_compat_str(err): + return str(err) diff --git a/yt_dlp/utils/_legacy.py b/yt_dlp/utils/_legacy.py new file mode 100644 index 000000000..cd009b504 --- /dev/null +++ b/yt_dlp/utils/_legacy.py @@ -0,0 +1,163 @@ +"""No longer used and new code should not use. Exists only for API compat.""" + +import platform +import struct +import sys +import urllib.parse +import zlib + +from ._utils import decode_base_n, preferredencoding +from .traversal import traverse_obj +from ..dependencies import certifi, websockets + +has_certifi = bool(certifi) +has_websockets = bool(websockets) + + +def load_plugins(name, suffix, namespace): + from ..plugins import load_plugins + ret = load_plugins(name, suffix) + namespace.update(ret) + return ret + + +def traverse_dict(dictn, keys, casesense=True): + return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True) + + +def decode_base(value, digits): + return decode_base_n(value, table=digits) + + +def platform_name(): + """ Returns the platform name as a str """ + return platform.platform() + + +def get_subprocess_encoding(): + if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5: + # For subprocess calls, encode with locale encoding + # Refer to http://stackoverflow.com/a/9951851/35070 + encoding = preferredencoding() + else: + encoding = sys.getfilesystemencoding() + if encoding is None: + encoding = 'utf-8' + return encoding + + +# UNUSED +# Based on png2str() written by @gdkchan and improved by @yokrysty +# Originally posted at https://github.com/ytdl-org/youtube-dl/issues/9706 +def decode_png(png_data): + # Reference: https://www.w3.org/TR/PNG/ + header = png_data[8:] + + if png_data[:8] != b'\x89PNG\x0d\x0a\x1a\x0a' or header[4:8] != b'IHDR': + raise OSError('Not a valid PNG file.') + + int_map = {1: '>B', 2: '>H', 4: '>I'} + unpack_integer = lambda x: struct.unpack(int_map[len(x)], x)[0] + + chunks = [] + + while header: + length = unpack_integer(header[:4]) + header = header[4:] + + chunk_type = header[:4] + header = header[4:] + + chunk_data = header[:length] + header = header[length:] + + header = header[4:] # Skip CRC + + chunks.append({ + 'type': chunk_type, + 'length': length, + 'data': chunk_data + }) + + ihdr = chunks[0]['data'] + + width = unpack_integer(ihdr[:4]) + height = unpack_integer(ihdr[4:8]) + + idat = b'' + + for chunk in chunks: + if chunk['type'] == b'IDAT': + idat += chunk['data'] + + if not idat: + raise OSError('Unable to read PNG data.') + + decompressed_data = bytearray(zlib.decompress(idat)) + + stride = width * 3 + pixels = [] + + def _get_pixel(idx): + x = idx % stride + y = idx // stride + return pixels[y][x] + + for y in range(height): + basePos = y * (1 + stride) + filter_type = decompressed_data[basePos] + + current_row = [] + + pixels.append(current_row) + + for x in range(stride): + color = decompressed_data[1 + basePos + x] + basex = y * stride + x + left = 0 + up = 0 + + if x > 2: + left = _get_pixel(basex - 3) + if y > 0: + up = _get_pixel(basex - stride) + + if filter_type == 1: # Sub + color = (color + left) & 0xff + elif filter_type == 2: # Up + color = (color + up) & 0xff + elif filter_type == 3: # Average + color = (color + ((left + up) >> 1)) & 0xff + elif filter_type == 4: # Paeth + a = left + b = up + c = 0 + + if x > 2 and y > 0: + c = _get_pixel(basex - stride - 3) + + p = a + b - c + + pa = abs(p - a) + pb = abs(p - b) + pc = abs(p - c) + + if pa <= pb and pa <= pc: + color = (color + a) & 0xff + elif pb <= pc: + color = (color + b) & 0xff + else: + color = (color + c) & 0xff + + current_row.append(color) + + return width, height, pixels + + +def register_socks_protocols(): + # "Register" SOCKS protocols + # In Python < 2.6.5, urlsplit() suffers from bug https://bugs.python.org/issue7904 + # URLs with protocols not in urlparse.uses_netloc are not handled correctly + for scheme in ('socks', 'socks4', 'socks4a', 'socks5'): + if scheme not in urllib.parse.uses_netloc: + urllib.parse.uses_netloc.append(scheme) diff --git a/yt_dlp/utils.py b/yt_dlp/utils/_utils.py similarity index 92% rename from yt_dlp/utils.py rename to yt_dlp/utils/_utils.py index 190af1b7d..f032af901 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils/_utils.py @@ -47,26 +47,18 @@ import urllib.request import xml.etree.ElementTree import zlib -from .compat import functools # isort: split -from .compat import ( +from . import traversal + +from ..compat import functools # isort: split +from ..compat import ( compat_etree_fromstring, compat_expanduser, compat_HTMLParseError, compat_os_name, compat_shlex_quote, ) -from .dependencies import brotli, certifi, websockets, xattr -from .socks import ProxyType, sockssocket - - -def register_socks_protocols(): - # "Register" SOCKS protocols - # In Python < 2.6.5, urlsplit() suffers from bug https://bugs.python.org/issue7904 - # URLs with protocols not in urlparse.uses_netloc are not handled correctly - for scheme in ('socks', 'socks4', 'socks4a', 'socks5'): - if scheme not in urllib.parse.uses_netloc: - urllib.parse.uses_netloc.append(scheme) - +from ..dependencies import brotli, certifi, websockets, xattr +from ..socks import ProxyType, sockssocket # This is not clearly defined otherwise compiled_regex_type = type(re.compile('')) @@ -928,27 +920,6 @@ class Popen(subprocess.Popen): return stdout or default, stderr or default, proc.returncode -def get_subprocess_encoding(): - if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5: - # For subprocess calls, encode with locale encoding - # Refer to http://stackoverflow.com/a/9951851/35070 - encoding = preferredencoding() - else: - encoding = sys.getfilesystemencoding() - if encoding is None: - encoding = 'utf-8' - return encoding - - -def encodeFilename(s, for_subprocess=False): - assert isinstance(s, str) - return s - - -def decodeFilename(b, for_subprocess=False): - return b - - def encodeArgument(s): # Legacy code that uses byte strings # Uncomment the following line after fixing all post processors @@ -956,20 +927,6 @@ def encodeArgument(s): return s if isinstance(s, str) else s.decode('ascii') -def decodeArgument(b): - return b - - -def decodeOption(optval): - if optval is None: - return optval - if isinstance(optval, bytes): - optval = optval.decode(preferredencoding()) - - assert isinstance(optval, str) - return optval - - _timetuple = collections.namedtuple('Time', ('hours', 'minutes', 'seconds', 'milliseconds')) @@ -1034,7 +991,7 @@ def make_HTTPS_handler(params, **kwargs): context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE if opts_check_certificate: - if has_certifi and 'no-certifi' not in params.get('compat_opts', []): + if certifi and 'no-certifi' not in params.get('compat_opts', []): context.load_verify_locations(cafile=certifi.where()) else: try: @@ -1068,7 +1025,7 @@ def make_HTTPS_handler(params, **kwargs): def bug_reports_message(before=';'): - from .update import REPOSITORY + from ..update import REPOSITORY msg = (f'please report this issue on https://github.com/{REPOSITORY}/issues?q= , ' 'filling out the appropriate issue template. Confirm you are on the latest version using yt-dlp -U') @@ -2019,12 +1976,6 @@ class DateRange: and self.start == other.start and self.end == other.end) -def platform_name(): - """ Returns the platform name as a str """ - deprecation_warning(f'"{__name__}.platform_name" is deprecated, use "platform.platform" instead') - return platform.platform() - - @functools.cache def system_identifier(): python_implementation = platform.python_implementation() @@ -2076,7 +2027,7 @@ def write_string(s, out=None, encoding=None): def deprecation_warning(msg, *, printer=None, stacklevel=0, **kwargs): - from . import _IN_CLI + from .. import _IN_CLI if _IN_CLI: if msg in deprecation_warning._cache: return @@ -3284,13 +3235,6 @@ def variadic(x, allowed_types=NO_DEFAULT): return x if is_iterable_like(x, blocked_types=allowed_types) else (x, ) -def dict_get(d, key_or_keys, default=None, skip_false_values=True): - for val in map(d.get, variadic(key_or_keys)): - if val is not None and (val or not skip_false_values): - return val - return default - - def try_call(*funcs, expected_type=None, args=[], kwargs={}): for f in funcs: try: @@ -3528,7 +3472,7 @@ def is_outdated_version(version, limit, assume_new=True): def ytdl_is_updateable(): """ Returns if yt-dlp can be updated with -U """ - from .update import is_non_updateable + from ..update import is_non_updateable return not is_non_updateable() @@ -3538,10 +3482,6 @@ def args_to_str(args): return ' '.join(compat_shlex_quote(a) for a in args) -def error_to_compat_str(err): - return str(err) - - def error_to_str(err): return f'{type(err).__name__}: {err}' @@ -3628,7 +3568,7 @@ def mimetype2ext(mt, default=NO_DEFAULT): mimetype = mt.partition(';')[0].strip().lower() _, _, subtype = mimetype.rpartition('/') - ext = traverse_obj(MAP, mimetype, subtype, subtype.rsplit('+')[-1]) + ext = traversal.traverse_obj(MAP, mimetype, subtype, subtype.rsplit('+')[-1]) if ext: return ext elif default is not NO_DEFAULT: @@ -3660,7 +3600,7 @@ def parse_codecs(codecs_str): vcodec = full_codec if parts[0] in ('dvh1', 'dvhe'): hdr = 'DV' - elif parts[0] == 'av1' and traverse_obj(parts, 3) == '10': + elif parts[0] == 'av1' and traversal.traverse_obj(parts, 3) == '10': hdr = 'HDR10' elif parts[:2] == ['vp9', '2']: hdr = 'HDR10' @@ -3706,8 +3646,7 @@ def get_compatible_ext(*, vcodecs, acodecs, vexts, aexts, preferences=None): }, } - sanitize_codec = functools.partial( - try_get, getter=lambda x: x[0].split('.')[0].replace('0', '').lower()) + sanitize_codec = functools.partial(try_get, getter=lambda x: x[0].split('.')[0].replace('0', '')) vcodec, acodec = sanitize_codec(vcodecs), sanitize_codec(acodecs) for ext in preferences or COMPATIBLE_CODECS.keys(): @@ -5088,12 +5027,6 @@ def decode_base_n(string, n=None, table=None): return result -def decode_base(value, digits): - deprecation_warning(f'{__name__}.decode_base is deprecated and may be removed ' - f'in a future version. Use {__name__}.decode_base_n instead') - return decode_base_n(value, table=digits) - - def decode_packed_codes(code): mobj = re.search(PACKED_CODES_RE, code) obfuscated_code, base, count, symbols = mobj.groups() @@ -5138,113 +5071,6 @@ def urshift(val, n): return val >> n if val >= 0 else (val + 0x100000000) >> n -# Based on png2str() written by @gdkchan and improved by @yokrysty -# Originally posted at https://github.com/ytdl-org/youtube-dl/issues/9706 -def decode_png(png_data): - # Reference: https://www.w3.org/TR/PNG/ - header = png_data[8:] - - if png_data[:8] != b'\x89PNG\x0d\x0a\x1a\x0a' or header[4:8] != b'IHDR': - raise OSError('Not a valid PNG file.') - - int_map = {1: '>B', 2: '>H', 4: '>I'} - unpack_integer = lambda x: struct.unpack(int_map[len(x)], x)[0] - - chunks = [] - - while header: - length = unpack_integer(header[:4]) - header = header[4:] - - chunk_type = header[:4] - header = header[4:] - - chunk_data = header[:length] - header = header[length:] - - header = header[4:] # Skip CRC - - chunks.append({ - 'type': chunk_type, - 'length': length, - 'data': chunk_data - }) - - ihdr = chunks[0]['data'] - - width = unpack_integer(ihdr[:4]) - height = unpack_integer(ihdr[4:8]) - - idat = b'' - - for chunk in chunks: - if chunk['type'] == b'IDAT': - idat += chunk['data'] - - if not idat: - raise OSError('Unable to read PNG data.') - - decompressed_data = bytearray(zlib.decompress(idat)) - - stride = width * 3 - pixels = [] - - def _get_pixel(idx): - x = idx % stride - y = idx // stride - return pixels[y][x] - - for y in range(height): - basePos = y * (1 + stride) - filter_type = decompressed_data[basePos] - - current_row = [] - - pixels.append(current_row) - - for x in range(stride): - color = decompressed_data[1 + basePos + x] - basex = y * stride + x - left = 0 - up = 0 - - if x > 2: - left = _get_pixel(basex - 3) - if y > 0: - up = _get_pixel(basex - stride) - - if filter_type == 1: # Sub - color = (color + left) & 0xff - elif filter_type == 2: # Up - color = (color + up) & 0xff - elif filter_type == 3: # Average - color = (color + ((left + up) >> 1)) & 0xff - elif filter_type == 4: # Paeth - a = left - b = up - c = 0 - - if x > 2 and y > 0: - c = _get_pixel(basex - stride - 3) - - p = a + b - c - - pa = abs(p - a) - pb = abs(p - b) - pc = abs(p - c) - - if pa <= pb and pa <= pc: - color = (color + a) & 0xff - elif pb <= pc: - color = (color + b) & 0xff - else: - color = (color + c) & 0xff - - current_row.append(color) - - return width, height, pixels - - def write_xattr(path, key, value): # Windows: Write xattrs to NTFS Alternate Data Streams: # http://en.wikipedia.org/wiki/NTFS#Alternate_data_streams_.28ADS.29 @@ -5403,7 +5229,7 @@ def to_high_limit_path(path): def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY): - val = traverse_obj(obj, *variadic(field)) + val = traversal.traverse_obj(obj, *variadic(field)) if not val if ignore is NO_DEFAULT else val in variadic(ignore): return default return template % func(val) @@ -5441,12 +5267,12 @@ def make_dir(path, to_screen=None): return True except OSError as err: if callable(to_screen) is not None: - to_screen('unable to create directory ' + error_to_compat_str(err)) + to_screen(f'unable to create directory {err}') return False def get_executable_path(): - from .update import _get_variant_and_executable_path + from ..update import _get_variant_and_executable_path return os.path.dirname(os.path.abspath(_get_variant_and_executable_path()[1])) @@ -5470,244 +5296,6 @@ def get_system_config_dirs(package_name): yield os.path.join('/etc', package_name) -def traverse_obj( - obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, - casesense=True, is_user_input=False, traverse_string=False): - """ - Safely traverse nested `dict`s and `Iterable`s - - >>> obj = [{}, {"key": "value"}] - >>> traverse_obj(obj, (1, "key")) - "value" - - Each of the provided `paths` is tested and the first producing a valid result will be returned. - The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. - Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. - - The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. - - The keys in the path can be one of: - - `None`: Return the current object. - - `set`: Requires the only item in the set to be a type or function, - like `{type}`/`{func}`. If a `type`, returns only values - of this type. If a function, returns `func(obj)`. - - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - - `slice`: Branch out and return all values in `obj[key]`. - - `Ellipsis`: Branch out and return a list of all values. - - `tuple`/`list`: Branch out and return a list of all matching values. - Read as: `[traverse_obj(obj, branch) for branch in branches]`. - - `function`: Branch out and return values filtered by the function. - Read as: `[value for key, value in obj if function(key, value)]`. - For `Iterable`s, `key` is the index of the value. - For `re.Match`es, `key` is the group number (0 = full match) - as well as additionally any group names, if given. - - `dict` Transform the current object and return a matching dict. - Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. - - `tuple`, `list`, and `dict` all support nested paths and branches. - - @params paths Paths which to traverse by. - @param default Value to return if the paths do not match. - If the last key in the path is a `dict`, it will apply to each value inside - the dict instead, depth first. Try to avoid if using nested `dict` keys. - @param expected_type If a `type`, only accept final values of this type. - If any other callable, try to call the function on each result. - If the last key in the path is a `dict`, it will apply to each value inside - the dict instead, recursively. This does respect branching paths. - @param get_all If `False`, return the first matching result, otherwise all matching ones. - @param casesense If `False`, consider string dictionary keys as case insensitive. - - The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API - - @param is_user_input Whether the keys are generated from user input. - If `True` strings get converted to `int`/`slice` if needed. - @param traverse_string Whether to traverse into objects as strings. - If `True`, any non-compatible object will first be - converted into a string and then traversed into. - The return value of that path will be a string instead, - not respecting any further branching. - - - @returns The result of the object traversal. - If successful, `get_all=True`, and the path branches at least once, - then a list of results is returned instead. - If no `default` is given and the last path branches, a `list` of results - is always returned. If a path ends on a `dict` that result will always be a `dict`. - """ - casefold = lambda k: k.casefold() if isinstance(k, str) else k - - if isinstance(expected_type, type): - type_test = lambda val: val if isinstance(val, expected_type) else None - else: - type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) - - def apply_key(key, obj, is_last): - branching = False - result = None - - if obj is None and traverse_string: - if key is ... or callable(key) or isinstance(key, slice): - branching = True - result = () - - elif key is None: - result = obj - - elif isinstance(key, set): - assert len(key) == 1, 'Set should only be used to wrap a single item' - item = next(iter(key)) - if isinstance(item, type): - if isinstance(obj, item): - result = obj - else: - result = try_call(item, args=(obj,)) - - elif isinstance(key, (list, tuple)): - branching = True - result = itertools.chain.from_iterable( - apply_path(obj, branch, is_last)[0] for branch in key) - - elif key is ...: - branching = True - if isinstance(obj, collections.abc.Mapping): - result = obj.values() - elif is_iterable_like(obj): - result = obj - elif isinstance(obj, re.Match): - result = obj.groups() - elif traverse_string: - branching = False - result = str(obj) - else: - result = () - - elif callable(key): - branching = True - if isinstance(obj, collections.abc.Mapping): - iter_obj = obj.items() - elif is_iterable_like(obj): - iter_obj = enumerate(obj) - elif isinstance(obj, re.Match): - iter_obj = itertools.chain( - enumerate((obj.group(), *obj.groups())), - obj.groupdict().items()) - elif traverse_string: - branching = False - iter_obj = enumerate(str(obj)) - else: - iter_obj = () - - result = (v for k, v in iter_obj if try_call(key, args=(k, v))) - if not branching: # string traversal - result = ''.join(result) - - elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) - result = { - k: v if v is not None else default for k, v in iter_obj - if v is not None or default is not NO_DEFAULT - } or None - - elif isinstance(obj, collections.abc.Mapping): - result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else - next((v for k, v in obj.items() if casefold(k) == key), None)) - - elif isinstance(obj, re.Match): - if isinstance(key, int) or casesense: - with contextlib.suppress(IndexError): - result = obj.group(key) - - elif isinstance(key, str): - result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) - - elif isinstance(key, (int, slice)): - if is_iterable_like(obj, collections.abc.Sequence): - branching = isinstance(key, slice) - with contextlib.suppress(IndexError): - result = obj[key] - elif traverse_string: - with contextlib.suppress(IndexError): - result = str(obj)[key] - - return branching, result if branching else (result,) - - def lazy_last(iterable): - iterator = iter(iterable) - prev = next(iterator, NO_DEFAULT) - if prev is NO_DEFAULT: - return - - for item in iterator: - yield False, prev - prev = item - - yield True, prev - - def apply_path(start_obj, path, test_type): - objs = (start_obj,) - has_branched = False - - key = None - for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): - if is_user_input and isinstance(key, str): - if key == ':': - key = ... - elif ':' in key: - key = slice(*map(int_or_none, key.split(':'))) - elif int_or_none(key) is not None: - key = int(key) - - if not casesense and isinstance(key, str): - key = key.casefold() - - if __debug__ and callable(key): - # Verify function signature - inspect.signature(key).bind(None, None) - - new_objs = [] - for obj in objs: - branching, results = apply_key(key, obj, last) - has_branched |= branching - new_objs.append(results) - - objs = itertools.chain.from_iterable(new_objs) - - if test_type and not isinstance(key, (dict, list, tuple)): - objs = map(type_test, objs) - - return objs, has_branched, isinstance(key, dict) - - def _traverse_obj(obj, path, allow_empty, test_type): - results, has_branched, is_dict = apply_path(obj, path, test_type) - results = LazyList(item for item in results if item not in (None, {})) - if get_all and has_branched: - if results: - return results.exhaust() - if allow_empty: - return [] if default is NO_DEFAULT else default - return None - - return results[0] if results else {} if allow_empty and is_dict else None - - for index, path in enumerate(paths, 1): - result = _traverse_obj(obj, path, index == len(paths), True) - if result is not None: - return result - - return None if default is NO_DEFAULT else default - - -def traverse_dict(dictn, keys, casesense=True): - deprecation_warning(f'"{__name__}.traverse_dict" is deprecated and may be removed ' - f'in a future version. Use "{__name__}.traverse_obj" instead') - return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True) - - -def get_first(obj, *paths, **kwargs): - return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False) - - def time_seconds(**kwargs): """ Returns TZ-aware time in seconds since the epoch (1970-01-01T00:00:00Z) @@ -5803,7 +5391,7 @@ def number_of_digits(number): def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: - values = (traverse_obj(from_dict, variadic(v)) for v in values) + values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values) return delim.join(map(str, filter(None, values))) @@ -6514,15 +6102,3 @@ class FormatSorter: format['abr'] = format.get('tbr') - format.get('vbr', 0) return tuple(self._calculate_field_preference(format, field) for field in self._order) - - -# Deprecated -has_certifi = bool(certifi) -has_websockets = bool(websockets) - - -def load_plugins(name, suffix, namespace): - from .plugins import load_plugins - ret = load_plugins(name, suffix) - namespace.update(ret) - return ret diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py new file mode 100644 index 000000000..462c3ba5d --- /dev/null +++ b/yt_dlp/utils/traversal.py @@ -0,0 +1,254 @@ +import collections.abc +import contextlib +import inspect +import itertools +import re + +from ._utils import ( + IDENTITY, + NO_DEFAULT, + LazyList, + int_or_none, + is_iterable_like, + try_call, + variadic, +) + + +def traverse_obj( + obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, + casesense=True, is_user_input=False, traverse_string=False): + """ + Safely traverse nested `dict`s and `Iterable`s + + >>> obj = [{}, {"key": "value"}] + >>> traverse_obj(obj, (1, "key")) + "value" + + Each of the provided `paths` is tested and the first producing a valid result will be returned. + The next path will also be tested if the path branched but no results could be found. + Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. + + The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. + + The keys in the path can be one of: + - `None`: Return the current object. + - `set`: Requires the only item in the set to be a type or function, + like `{type}`/`{func}`. If a `type`, returns only values + of this type. If a function, returns `func(obj)`. + - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. + - `slice`: Branch out and return all values in `obj[key]`. + - `Ellipsis`: Branch out and return a list of all values. + - `tuple`/`list`: Branch out and return a list of all matching values. + Read as: `[traverse_obj(obj, branch) for branch in branches]`. + - `function`: Branch out and return values filtered by the function. + Read as: `[value for key, value in obj if function(key, value)]`. + For `Iterable`s, `key` is the index of the value. + For `re.Match`es, `key` is the group number (0 = full match) + as well as additionally any group names, if given. + - `dict` Transform the current object and return a matching dict. + Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + + `tuple`, `list`, and `dict` all support nested paths and branches. + + @params paths Paths which to traverse by. + @param default Value to return if the paths do not match. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, depth first. Try to avoid if using nested `dict` keys. + @param expected_type If a `type`, only accept final values of this type. + If any other callable, try to call the function on each result. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, recursively. This does respect branching paths. + @param get_all If `False`, return the first matching result, otherwise all matching ones. + @param casesense If `False`, consider string dictionary keys as case insensitive. + + The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + + @param is_user_input Whether the keys are generated from user input. + If `True` strings get converted to `int`/`slice` if needed. + @param traverse_string Whether to traverse into objects as strings. + If `True`, any non-compatible object will first be + converted into a string and then traversed into. + The return value of that path will be a string instead, + not respecting any further branching. + + + @returns The result of the object traversal. + If successful, `get_all=True`, and the path branches at least once, + then a list of results is returned instead. + If no `default` is given and the last path branches, a `list` of results + is always returned. If a path ends on a `dict` that result will always be a `dict`. + """ + casefold = lambda k: k.casefold() if isinstance(k, str) else k + + if isinstance(expected_type, type): + type_test = lambda val: val if isinstance(val, expected_type) else None + else: + type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + + def apply_key(key, obj, is_last): + branching = False + result = None + + if obj is None and traverse_string: + if key is ... or callable(key) or isinstance(key, slice): + branching = True + result = () + + elif key is None: + result = obj + + elif isinstance(key, set): + assert len(key) == 1, 'Set should only be used to wrap a single item' + item = next(iter(key)) + if isinstance(item, type): + if isinstance(obj, item): + result = obj + else: + result = try_call(item, args=(obj,)) + + elif isinstance(key, (list, tuple)): + branching = True + result = itertools.chain.from_iterable( + apply_path(obj, branch, is_last)[0] for branch in key) + + elif key is ...: + branching = True + if isinstance(obj, collections.abc.Mapping): + result = obj.values() + elif is_iterable_like(obj): + result = obj + elif isinstance(obj, re.Match): + result = obj.groups() + elif traverse_string: + branching = False + result = str(obj) + else: + result = () + + elif callable(key): + branching = True + if isinstance(obj, collections.abc.Mapping): + iter_obj = obj.items() + elif is_iterable_like(obj): + iter_obj = enumerate(obj) + elif isinstance(obj, re.Match): + iter_obj = itertools.chain( + enumerate((obj.group(), *obj.groups())), + obj.groupdict().items()) + elif traverse_string: + branching = False + iter_obj = enumerate(str(obj)) + else: + iter_obj = () + + result = (v for k, v in iter_obj if try_call(key, args=(k, v))) + if not branching: # string traversal + result = ''.join(result) + + elif isinstance(key, dict): + iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) + result = { + k: v if v is not None else default for k, v in iter_obj + if v is not None or default is not NO_DEFAULT + } or None + + elif isinstance(obj, collections.abc.Mapping): + result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else + next((v for k, v in obj.items() if casefold(k) == key), None)) + + elif isinstance(obj, re.Match): + if isinstance(key, int) or casesense: + with contextlib.suppress(IndexError): + result = obj.group(key) + + elif isinstance(key, str): + result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + + elif isinstance(key, (int, slice)): + if is_iterable_like(obj, collections.abc.Sequence): + branching = isinstance(key, slice) + with contextlib.suppress(IndexError): + result = obj[key] + elif traverse_string: + with contextlib.suppress(IndexError): + result = str(obj)[key] + + return branching, result if branching else (result,) + + def lazy_last(iterable): + iterator = iter(iterable) + prev = next(iterator, NO_DEFAULT) + if prev is NO_DEFAULT: + return + + for item in iterator: + yield False, prev + prev = item + + yield True, prev + + def apply_path(start_obj, path, test_type): + objs = (start_obj,) + has_branched = False + + key = None + for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): + if is_user_input and isinstance(key, str): + if key == ':': + key = ... + elif ':' in key: + key = slice(*map(int_or_none, key.split(':'))) + elif int_or_none(key) is not None: + key = int(key) + + if not casesense and isinstance(key, str): + key = key.casefold() + + if __debug__ and callable(key): + # Verify function signature + inspect.signature(key).bind(None, None) + + new_objs = [] + for obj in objs: + branching, results = apply_key(key, obj, last) + has_branched |= branching + new_objs.append(results) + + objs = itertools.chain.from_iterable(new_objs) + + if test_type and not isinstance(key, (dict, list, tuple)): + objs = map(type_test, objs) + + return objs, has_branched, isinstance(key, dict) + + def _traverse_obj(obj, path, allow_empty, test_type): + results, has_branched, is_dict = apply_path(obj, path, test_type) + results = LazyList(item for item in results if item not in (None, {})) + if get_all and has_branched: + if results: + return results.exhaust() + if allow_empty: + return [] if default is NO_DEFAULT else default + return None + + return results[0] if results else {} if allow_empty and is_dict else None + + for index, path in enumerate(paths, 1): + result = _traverse_obj(obj, path, index == len(paths), True) + if result is not None: + return result + + return None if default is NO_DEFAULT else default + + +def get_first(obj, *paths, **kwargs): + return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False) + + +def dict_get(d, key_or_keys, default=None, skip_false_values=True): + for val in map(d.get, variadic(key_or_keys)): + if val is not None and (val or not skip_false_values): + return val + return default