From 422195ec70a00b0d2002b238cacbae7790c57fdf Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Sat, 2 Nov 2024 21:42:00 +0100 Subject: [PATCH] [utils] Allow partial application for even more functions (#11437) Fixes b6dc2c49e8793c6dfa21275e61caf49ec1148b81 Authored by: Grub4K --- test/test_traversal.py | 11 +++++++++++ yt_dlp/utils/_utils.py | 1 + yt_dlp/utils/traversal.py | 8 ++++++++ 3 files changed, 20 insertions(+) diff --git a/test/test_traversal.py b/test/test_traversal.py index f1d123bd6..1c0cc5362 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -9,6 +9,7 @@ from yt_dlp.utils import ( determine_ext, dict_get, int_or_none, + join_nonempty, str_or_none, ) from yt_dlp.utils.traversal import ( @@ -16,6 +17,7 @@ from yt_dlp.utils.traversal import ( subs_list_to_dict, traverse_obj, trim_str, + unpack, ) _TEST_DATA = { @@ -510,6 +512,15 @@ class TestTraversalHelpers: assert trim_str(start='abc', end='abc')('abc') == '' assert trim_str(start='', end='')('abc') == 'abc' + def test_unpack(self): + assert unpack(lambda *x: ''.join(map(str, x)))([1, 2, 3]) == '123' + assert unpack(join_nonempty)([1, 2, 3]) == '1-2-3' + assert unpack(join_nonempty(delim=' '))([1, 2, 3]) == '1 2 3' + with pytest.raises(TypeError): + unpack(join_nonempty)() + with pytest.raises(TypeError): + unpack() + class TestDictGet: def test_dict_get(self): diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index e30008e93..2f4c0a00f 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -5294,6 +5294,7 @@ def make_archive_id(ie, video_id): return f'{ie_key.lower()} {video_id}' +@partial_application def truncate_string(s, left, right=0): assert left > 3 and right >= 0 if s is None or len(s) <= left + right: diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index dd9b4690b..bc313d5c4 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -449,6 +449,14 @@ def trim_str(*, start=None, end=None): return trim +def unpack(func): + @functools.wraps(func) + def inner(items, **kwargs): + return func(*items, **kwargs) + + return inner + + def get_first(obj, *paths, **kwargs): return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)