From 4a2f19abbd61274358211c2e3b1d9658cfbdcdde Mon Sep 17 00:00:00 2001 From: Felix S Date: Wed, 28 Apr 2021 16:17:30 +0530 Subject: [PATCH] [downloader/hls] Assemble single-file WebVTT subtitles from HLS segments --- yt_dlp/compat.py | 14 ++ yt_dlp/downloader/hls.py | 44 +++++ yt_dlp/extractor/common.py | 6 + yt_dlp/webvtt.py | 368 +++++++++++++++++++++++++++++++++++++ 4 files changed, 432 insertions(+) create mode 100644 yt_dlp/webvtt.py diff --git a/yt_dlp/compat.py b/yt_dlp/compat.py index 3ebf1ee7a..863bd2287 100644 --- a/yt_dlp/compat.py +++ b/yt_dlp/compat.py @@ -3018,10 +3018,24 @@ else: return ctypes.WINFUNCTYPE(*args, **kwargs) +try: + compat_Pattern = re.Pattern +except AttributeError: + compat_Pattern = type(re.compile('')) + + +try: + compat_Match = re.Match +except AttributeError: + compat_Match = type(re.compile('').match('')) + + __all__ = [ 'compat_HTMLParseError', 'compat_HTMLParser', 'compat_HTTPError', + 'compat_Match', + 'compat_Pattern', 'compat_Struct', 'compat_b64decode', 'compat_basestring', diff --git a/yt_dlp/downloader/hls.py b/yt_dlp/downloader/hls.py index f4e41a6c7..cee3807ce 100644 --- a/yt_dlp/downloader/hls.py +++ b/yt_dlp/downloader/hls.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import errno import re +import io import binascii try: from Crypto.Cipher import AES @@ -27,7 +28,9 @@ from ..utils import ( parse_m3u8_attributes, sanitize_open, update_url_query, + bug_reports_message, ) +from .. import webvtt class HlsFD(FragmentFD): @@ -78,6 +81,8 @@ class HlsFD(FragmentFD): man_url = info_dict['url'] self.to_screen('[%s] Downloading m3u8 manifest' % self.FD_NAME) + is_webvtt = info_dict['ext'] == 'vtt' + urlh = self.ydl.urlopen(self._prepare_url(info_dict, man_url)) man_url = urlh.geturl() s = urlh.read().decode('utf-8', 'ignore') @@ -142,6 +147,8 @@ class HlsFD(FragmentFD): else: self._prepare_and_start_frag_download(ctx) + extra_state = ctx.setdefault('extra_state', {}) + fragment_retries = self.params.get('fragment_retries', 0) skip_unavailable_fragments = self.params.get('skip_unavailable_fragments', True) test = self.params.get('test', False) @@ -308,6 +315,42 @@ class HlsFD(FragmentFD): return frag_content, frag_index + pack_fragment = lambda frag_content, _: frag_content + + if is_webvtt: + def pack_fragment(frag_content, frag_index): + output = io.StringIO() + adjust = 0 + for block in webvtt.parse_fragment(frag_content): + if isinstance(block, webvtt.CueBlock): + block.start += adjust + block.end += adjust + elif isinstance(block, webvtt.Magic): + # XXX: we do not handle MPEGTS overflow + if frag_index == 1: + extra_state['webvtt_mpegts'] = block.mpegts or 0 + extra_state['webvtt_local'] = block.local or 0 + # XXX: block.local = block.mpegts = None ? + else: + if block.mpegts is not None and block.local is not None: + adjust = ( + (block.mpegts - extra_state.get('webvtt_mpegts', 0)) + - (block.local - extra_state.get('webvtt_local', 0)) + ) + continue + elif isinstance(block, webvtt.HeaderBlock): + if frag_index != 1: + # XXX: this should probably be silent as well + # or verify that all segments contain the same data + self.report_warning(bug_reports_message( + 'Discarding a %s block found in the middle of the stream; ' + 'if the subtitles display incorrectly,' + % (type(block).__name__))) + continue + block.write_into(output) + + return output.getvalue().encode('utf-8') + def append_fragment(frag_content, frag_index): if frag_content: fragment_filename = '%s-Frag%d' % (ctx['tmpfilename'], frag_index) @@ -315,6 +358,7 @@ class HlsFD(FragmentFD): file, frag_sanitized = sanitize_open(fragment_filename, 'rb') ctx['fragment_filename_sanitized'] = frag_sanitized file.close() + frag_content = pack_fragment(frag_content, frag_index) self._append_fragment(ctx, frag_content) return True except EnvironmentError as ose: diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index 6257c17cd..803c7fa06 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -2035,6 +2035,12 @@ class InfoExtractor(object): 'url': url, 'ext': determine_ext(url), } + if sub_info['ext'] == 'm3u8': + # Per RFC 8216 §3.1, the only possible subtitle format m3u8 + # files may contain is WebVTT: + # + sub_info['ext'] = 'vtt' + sub_info['protocol'] = 'm3u8_native' subtitles.setdefault(lang, []).append(sub_info) if media_type not in ('VIDEO', 'AUDIO'): return diff --git a/yt_dlp/webvtt.py b/yt_dlp/webvtt.py new file mode 100644 index 000000000..4d026834a --- /dev/null +++ b/yt_dlp/webvtt.py @@ -0,0 +1,368 @@ +# coding: utf-8 +from __future__ import unicode_literals, print_function, division + +""" +A partial parser for WebVTT segments. Interprets enough of the WebVTT stream +to be able to assemble a single stand-alone subtitle file, suitably adjusting +timestamps on the way, while everything else is passed through unmodified. + +Regular expressions based on the W3C WebVTT specification +. The X-TIMESTAMP-MAP extension is described +in RFC 8216 §3.5 . +""" + +import re +import io +from .utils import int_or_none +from .compat import ( + compat_str as str, + compat_Pattern, + compat_Match, +) + + +class _MatchParser(object): + """ + An object that maintains the current parsing position and allows + conveniently advancing it as syntax elements are successfully parsed. + """ + + def __init__(self, string): + self._data = string + self._pos = 0 + + def match(self, r): + if isinstance(r, compat_Pattern): + return r.match(self._data, self._pos) + if isinstance(r, str): + if self._data.startswith(r, self._pos): + return len(r) + return None + raise ValueError(r) + + def advance(self, by): + if by is None: + amt = 0 + elif isinstance(by, compat_Match): + amt = len(by.group(0)) + elif isinstance(by, str): + amt = len(by) + elif isinstance(by, int): + amt = by + else: + raise ValueError(by) + self._pos += amt + return by + + def consume(self, r): + return self.advance(self.match(r)) + + def child(self): + return _MatchChildParser(self) + + +class _MatchChildParser(_MatchParser): + """ + A child parser state, which advances through the same data as + its parent, but has an independent position. This is useful when + advancing through syntax elements we might later want to backtrack + from. + """ + + def __init__(self, parent): + super(_MatchChildParser, self).__init__(parent._data) + self.__parent = parent + self._pos = parent._pos + + def commit(self): + """ + Advance the parent state to the current position of this child state. + """ + self.__parent._pos = self._pos + return self.__parent + + +class ParseError(Exception): + def __init__(self, parser): + super(ParseError, self).__init__("Parse error at position %u (near %r)" % ( + parser._pos, parser._data[parser._pos:parser._pos + 20] + )) + + +_REGEX_TS = re.compile(r'''(?x) + (?:([0-9]{2,}):)? + ([0-9]{2}): + ([0-9]{2})\. + ([0-9]{3})? +''') +_REGEX_EOF = re.compile(r'\Z') +_REGEX_NL = re.compile(r'(?:\r\n|[\r\n])') +_REGEX_BLANK = re.compile(r'(?:\r\n|[\r\n])+') + + +def _parse_ts(ts): + """ + Convert a parsed WebVTT timestamp (a re.Match obtained from _REGEX_TS) + into an MPEG PES timestamp: a tick counter at 90 kHz resolution. + """ + + h, min, s, ms = ts.groups() + return 90 * ( + int(h or 0) * 3600000 + # noqa: W504,E221,E222 + int(min) * 60000 + # noqa: W504,E221,E222 + int(s) * 1000 + # noqa: W504,E221,E222 + int(ms) # noqa: W504,E221,E222 + ) + + +def _format_ts(ts): + """ + Convert an MPEG PES timestamp into a WebVTT timestamp. + This will lose sub-millisecond precision. + """ + + ts = int((ts + 45) // 90) + ms , ts = divmod(ts, 1000) # noqa: W504,E221,E222,E203 + s , ts = divmod(ts, 60) # noqa: W504,E221,E222,E203 + min, h = divmod(ts, 60) # noqa: W504,E221,E222 + return '%02u:%02u:%02u.%03u' % (h, min, s, ms) + + +class Block(object): + """ + An abstract WebVTT block. + """ + + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + @classmethod + def parse(cls, parser): + m = parser.match(cls._REGEX) + if not m: + return None + parser.advance(m) + return cls(raw=m.group(0)) + + def write_into(self, stream): + stream.write(self.raw) + + +class HeaderBlock(Block): + """ + A WebVTT block that may only appear in the header part of the file, + i.e. before any cue blocks. + """ + + pass + + +class Magic(HeaderBlock): + _REGEX = re.compile(r'\ufeff?WEBVTT([ \t][^\r\n]*)?(?:\r\n|[\r\n])') + + # XXX: The X-TIMESTAMP-MAP extension is described in RFC 8216 §3.5 + # , but the RFC + # doesn’t specify the exact grammar nor where in the WebVTT + # syntax it should be placed; the below has been devised based + # on usage in the wild + # + # And strictly speaking, the presence of this extension violates + # the W3C WebVTT spec. Oh well. + + _REGEX_TSMAP = re.compile(r'X-TIMESTAMP-MAP=') + _REGEX_TSMAP_LOCAL = re.compile(r'LOCAL:') + _REGEX_TSMAP_MPEGTS = re.compile(r'MPEGTS:([0-9]+)') + + @classmethod + def __parse_tsmap(cls, parser): + parser = parser.child() + + while True: + m = parser.consume(cls._REGEX_TSMAP_LOCAL) + if m: + m = parser.consume(_REGEX_TS) + if m is None: + raise ParseError(parser) + local = _parse_ts(m) + if local is None: + raise ParseError(parser) + else: + m = parser.consume(cls._REGEX_TSMAP_MPEGTS) + if m: + mpegts = int_or_none(m.group(1)) + if mpegts is None: + raise ParseError(parser) + else: + raise ParseError(parser) + if parser.consume(','): + continue + if parser.consume(_REGEX_NL): + break + raise ParseError(parser) + + parser.commit() + return local, mpegts + + @classmethod + def parse(cls, parser): + parser = parser.child() + + m = parser.consume(cls._REGEX) + if not m: + raise ParseError(parser) + + extra = m.group(1) + local, mpegts = None, None + if parser.consume(cls._REGEX_TSMAP): + local, mpegts = cls.__parse_tsmap(parser) + if not parser.consume(_REGEX_NL): + raise ParseError(parser) + parser.commit() + return cls(extra=extra, mpegts=mpegts, local=local) + + def write_into(self, stream): + stream.write('WEBVTT') + if self.extra is not None: + stream.write(self.extra) + stream.write('\n') + if self.local or self.mpegts: + stream.write('X-TIMESTAMP-MAP=LOCAL:') + stream.write(_format_ts(self.local if self.local is not None else 0)) + stream.write(',MPEGTS:') + stream.write(str(self.mpegts if self.mpegts is not None else 0)) + stream.write('\n') + stream.write('\n') + + +class StyleBlock(HeaderBlock): + _REGEX = re.compile(r'''(?x) + STYLE[\ \t]*(?:\r\n|[\r\n]) + ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))* + (?:\r\n|[\r\n]) + ''') + + +class RegionBlock(HeaderBlock): + _REGEX = re.compile(r'''(?x) + REGION[\ \t]* + ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))* + (?:\r\n|[\r\n]) + ''') + + +class CommentBlock(Block): + _REGEX = re.compile(r'''(?x) + NOTE(?:\r\n|[\ \t\r\n]) + ((?:(?!-->)[^\r\n])+(?:\r\n|[\r\n]))* + (?:\r\n|[\r\n]) + ''') + + +class CueBlock(Block): + """ + A cue block. The payload is not interpreted. + """ + + _REGEX_ID = re.compile(r'((?:(?!-->)[^\r\n])+)(?:\r\n|[\r\n])') + _REGEX_ARROW = re.compile(r'[ \t]+-->[ \t]+') + _REGEX_SETTINGS = re.compile(r'[ \t]+((?:(?!-->)[^\r\n])+)') + _REGEX_PAYLOAD = re.compile(r'[^\r\n]+(?:\r\n|[\r\n])?') + + @classmethod + def parse(cls, parser): + parser = parser.child() + + id = None + m = parser.consume(cls._REGEX_ID) + if m: + id = m.group(1) + + m0 = parser.consume(_REGEX_TS) + if not m0: + return None + if not parser.consume(cls._REGEX_ARROW): + return None + m1 = parser.consume(_REGEX_TS) + if not m1: + return None + m2 = parser.consume(cls._REGEX_SETTINGS) + if not parser.consume(_REGEX_NL): + return None + + start = _parse_ts(m0) + end = _parse_ts(m1) + settings = m2.group(1) if m2 is not None else None + + text = io.StringIO() + while True: + m = parser.consume(cls._REGEX_PAYLOAD) + if not m: + break + text.write(m.group(0)) + + parser.commit() + return cls( + id=id, + start=start, end=end, settings=settings, + text=text.getvalue() + ) + + def write_into(self, stream): + if self.id is not None: + stream.write(self.id) + stream.write('\n') + stream.write(_format_ts(self.start)) + stream.write(' --> ') + stream.write(_format_ts(self.end)) + if self.settings is not None: + stream.write(' ') + stream.write(self.settings) + stream.write('\n') + stream.write(self.text) + stream.write('\n') + + +def parse_fragment(frag_content): + """ + A generator that yields (partially) parsed WebVTT blocks when given + a bytes object containing the raw contents of a WebVTT file. + """ + + parser = _MatchParser(frag_content.decode('utf-8')) + + yield Magic.parse(parser) + + while not parser.match(_REGEX_EOF): + if parser.consume(_REGEX_BLANK): + continue + + block = RegionBlock.parse(parser) + if block: + yield block + continue + block = StyleBlock.parse(parser) + if block: + yield block + continue + block = CommentBlock.parse(parser) + if block: + yield block # XXX: or skip + continue + + break + + while not parser.match(_REGEX_EOF): + if parser.consume(_REGEX_BLANK): + continue + + block = CommentBlock.parse(parser) + if block: + yield block # XXX: or skip + continue + block = CueBlock.parse(parser) + if block: + yield block + continue + + raise ParseError(parser)