From 39d79c9b9cf23411d935910685c40aa1a2fdb409 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Fri, 15 Nov 2024 22:06:15 +0100 Subject: [PATCH] [utils] Fix `join_nonempty`, add `**kwargs` to `unpack` (#11559) Authored by: Grub4K --- test/test_traversal.py | 2 +- test/test_utils.py | 5 ----- yt_dlp/utils/_utils.py | 3 +-- yt_dlp/utils/traversal.py | 4 ++-- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/test_traversal.py b/test/test_traversal.py index d48606e99..52ea19fab 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -525,7 +525,7 @@ class TestTraversalHelpers: def test_unpack(self): assert unpack(lambda *x: ''.join(map(str, x)))([1, 2, 3]) == '123' assert unpack(join_nonempty)([1, 2, 3]) == '1-2-3' - assert unpack(join_nonempty(delim=' '))([1, 2, 3]) == '1 2 3' + assert unpack(join_nonempty, delim=' ')([1, 2, 3]) == '1 2 3' with pytest.raises(TypeError): unpack(join_nonempty)() with pytest.raises(TypeError): diff --git a/test/test_utils.py b/test/test_utils.py index b5f35736b..835774a91 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -72,7 +72,6 @@ from yt_dlp.utils import ( intlist_to_bytes, iri_to_uri, is_html, - join_nonempty, js_to_json, limit_length, locked_file, @@ -2158,10 +2157,6 @@ Line 1 assert int_or_none(v=10) == 10, 'keyword passed positional should call function' assert int_or_none(scale=0.1)(10) == 100, 'call after partial application should call the function' - assert callable(join_nonempty(delim=', ')), 'varargs positional should apply partially' - assert callable(join_nonempty()), 'varargs positional should apply partially' - assert join_nonempty(None, delim=', ') == '', 'passed varargs should call the function' - if __name__ == '__main__': unittest.main() diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index b28bb555e..89c53c39e 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -216,7 +216,7 @@ def partial_application(func): sig = inspect.signature(func) required_args = [ param.name for param in sig.parameters.values() - if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) if param.default is inspect.Parameter.empty ] @@ -4837,7 +4837,6 @@ def number_of_digits(number): return len('%d' % number) -@partial_application def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values) diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 361f239ba..6bb52050f 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -452,9 +452,9 @@ def trim_str(*, start=None, end=None): return trim -def unpack(func): +def unpack(func, **kwargs): @functools.wraps(func) - def inner(items, **kwargs): + def inner(items): return func(*items, **kwargs) return inner