From b6bccfaf61f08cd05baab0ebd91c03f8e1b80386 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Fri, 19 Apr 2024 12:15:34 +0200 Subject: [PATCH 01/16] Refactor --- setup.cfg | 2 +- tests/test_sbv_parser.py | 16 +- tests/test_srt.py | 2 +- tests/test_srt_parser.py | 14 +- tests/test_webvtt.py | 16 +- tests/test_webvtt_parser.py | 31 +-- webvtt/parsers.py | 365 +++++++++++++----------------------- webvtt/structures.py | 289 ++++++++++++++++++++-------- webvtt/webvtt.py | 25 ++- webvtt/writers.py | 50 ++--- 10 files changed, 404 insertions(+), 406 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5aef279..ddb7da9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [metadata] -description-file = README.rst +description_file = README.rst diff --git a/tests/test_sbv_parser.py b/tests/test_sbv_parser.py index 0a6f899..77cc7bb 100644 --- a/tests/test_sbv_parser.py +++ b/tests/test_sbv_parser.py @@ -32,21 +32,13 @@ def test_sbv_parse_captions(self): ) def test_sbv_missing_timeframe_line(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_sbv, - self._get_file('missing_timeframe.sbv') - ) + self.assertEqual(len(webvtt.from_sbv(self._get_file('missing_timeframe.sbv')).captions), 4) def test_sbv_missing_caption_text(self): self.assertTrue(webvtt.from_sbv(self._get_file('missing_caption_text.sbv')).captions) def test_sbv_invalid_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_sbv, - self._get_file('invalid_timeframe.sbv') - ) + self.assertEqual(len(webvtt.from_sbv(self._get_file('invalid_timeframe.sbv')).captions), 4) def test_sbv_timestamps_format(self): vtt = webvtt.from_sbv(self._get_file('sample.sbv')) @@ -55,8 +47,8 @@ def test_sbv_timestamps_format(self): def test_sbv_timestamps_in_seconds(self): vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].start_in_seconds, 11.378) - self.assertEqual(vtt.captions[1].end_in_seconds, 12.305) + self.assertEqual(vtt.captions[1].start_in_seconds, 11) + self.assertEqual(vtt.captions[1].end_in_seconds, 12) def test_sbv_get_caption_text(self): vtt = webvtt.from_sbv(self._get_file('sample.sbv')) diff --git a/tests/test_srt.py b/tests/test_srt.py index eed186d..0e6a645 100644 --- a/tests/test_srt.py +++ b/tests/test_srt.py @@ -32,4 +32,4 @@ def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): with open(os.path.join(OUTPUT_DIR, 'sample_converted.srt'), 'r', encoding='utf-8') as f: converted = f.read() - self.assertEqual(original.strip(), converted.strip()) + self.assertEqual(original, converted) diff --git a/tests/test_srt_parser.py b/tests/test_srt_parser.py index 7885ec5..e09e703 100644 --- a/tests/test_srt_parser.py +++ b/tests/test_srt_parser.py @@ -30,11 +30,7 @@ def test_srt_parse_captions(self): self.assertTrue(webvtt.from_srt(self._get_file('sample.srt')).captions) def test_srt_missing_timeframe_line(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_srt, - self._get_file('missing_timeframe.srt') - ) + self.assertEqual(len(webvtt.from_srt(self._get_file('missing_timeframe.srt')).captions), 4) def test_srt_empty_caption_text(self): self.assertTrue(webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions) @@ -44,11 +40,7 @@ def test_srt_empty_gets_removed(self): self.assertEqual(len(captions), 4) def test_srt_invalid_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_srt, - self._get_file('invalid_timeframe.srt') - ) + self.assertEqual(len(webvtt.from_srt(self._get_file('invalid_timeframe.srt')).captions), 4) def test_srt_timestamps_format(self): vtt = webvtt.from_srt(self._get_file('sample.srt')) @@ -57,7 +49,7 @@ def test_srt_timestamps_format(self): def test_srt_parse_get_caption_data(self): vtt = webvtt.from_srt(self._get_file('one_caption.srt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0.5) + self.assertEqual(vtt.captions[0].start_in_seconds, 0) self.assertEqual(vtt.captions[0].start, '00:00:00.500') self.assertEqual(vtt.captions[0].end_in_seconds, 7) self.assertEqual(vtt.captions[0].end, '00:00:07.000') diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index 9da3238..c5509df 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -20,10 +20,10 @@ def tearDown(self): rmtree(OUTPUT_DIR) def test_create_caption(self): - caption = Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']) + caption = Caption('00:00:00.500', '00:00:07.900', ['Caption test line 1', 'Caption test line 2']) self.assertEqual(caption.start, '00:00:00.500') - self.assertEqual(caption.start_in_seconds, 0.5) - self.assertEqual(caption.end, '00:00:07.000') + self.assertEqual(caption.start_in_seconds, 0) + self.assertEqual(caption.end, '00:00:07.900') self.assertEqual(caption.end_in_seconds, 7) self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) @@ -297,19 +297,17 @@ def test_malformed_start_timestamp(self): ) def test_set_styles_from_text(self): - style = Style() - style.text = '::cue(b) {\n color: peachpuff;\n}' + style = Style('::cue(b) {\n color: peachpuff;\n}') self.assertListEqual( style.lines, ['::cue(b) {', ' color: peachpuff;', '}'] ) def test_get_styles_as_text(self): - style = Style() - style.lines = ['::cue(b) {', ' color: peachpuff;', '}'] + style = Style(['::cue(b) {', ' color: peachpuff;', '}']) self.assertEqual( style.text, - '::cue(b) {color: peachpuff;}' + '::cue(b) {\n color: peachpuff;\n}' ) def test_save_identifiers(self): @@ -402,7 +400,7 @@ def test_content_formatting(self): Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']), Caption('00:00:08.000', '00:00:15.000', ['Caption test line 3', 'Caption test line 4']), ] - expected_content = textwrap.dedent("""\ + expected_content = textwrap.dedent(""" WEBVTT 00:00:00.500 --> 00:00:07.000 diff --git a/tests/test_webvtt_parser.py b/tests/test_webvtt_parser.py index a26741a..1262b7a 100644 --- a/tests/test_webvtt_parser.py +++ b/tests/test_webvtt_parser.py @@ -54,20 +54,16 @@ def test_webvtt_parse_get_captions(self): ) def test_webvtt_parse_invalid_timeframe_line(self): - self.assertRaises( - MalformedCaptionError, - webvtt.read, - self._get_file('invalid_timeframe.vtt') - ) + self.assertEqual(len(webvtt.read(self._get_file('invalid_timeframe.vtt')).captions), 6) def test_webvtt_parse_invalid_timeframe_in_cue_text(self): vtt = webvtt.read(self._get_file('invalid_timeframe_in_cue_text.vtt')) - self.assertEqual(4, len(vtt.captions)) - self.assertEqual('', vtt.captions[1].text) + self.assertEqual(2, len(vtt.captions)) + self.assertEqual('Caption text #3', vtt.captions[1].text) def test_webvtt_parse_get_caption_data(self): vtt = webvtt.read(self._get_file('one_caption.vtt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0.5) + self.assertEqual(vtt.captions[0].start_in_seconds, 0) self.assertEqual(vtt.captions[0].start, '00:00:00.500') self.assertEqual(vtt.captions[0].end_in_seconds, 7) self.assertEqual(vtt.captions[0].end, '00:00:07.000') @@ -75,15 +71,12 @@ def test_webvtt_parse_get_caption_data(self): self.assertEqual(len(vtt.captions[0].lines), 1) def test_webvtt_caption_without_timeframe(self): - self.assertRaises( - MalformedCaptionError, - webvtt.read, - self._get_file('missing_timeframe.vtt') - ) + vtt = webvtt.read(self._get_file('missing_timeframe.vtt')) + self.assertEqual(len(vtt.captions), 6) def test_webvtt_caption_without_cue_text(self): vtt = webvtt.read(self._get_file('missing_caption_text.vtt')) - self.assertEqual(len(vtt.captions), 5) + self.assertEqual(len(vtt.captions), 4) def test_webvtt_timestamps_format(self): vtt = webvtt.read(self._get_file('sample.vtt')) @@ -94,16 +87,12 @@ def test_parse_timestamp(self): caption = Caption(start='02:03:11.890') self.assertEqual( caption.start_in_seconds, - 7391.89 + 7391 ) def test_captions_attribute(self): self.assertListEqual([], webvtt.WebVTT().captions) - def test_webvtt_timestamp_format(self): - self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00:00.000 --> 00:00:00.000')) - self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00.000 --> 00:00.000')) - def test_metadata_headers(self): vtt = webvtt.read(self._get_file('metadata_headers.vtt')) self.assertEqual(len(vtt.captions), 2) @@ -138,7 +127,7 @@ def test_parse_styles(self): self.assertEqual(len(vtt.captions), 1) self.assertEqual( vtt.styles[0].text, - '::cue {background-image: linear-gradient(to bottom, dimgray, lightgray);color: papayawhip;}' + '::cue {\n background-image: linear-gradient(to bottom, dimgray, lightgray);\n color: papayawhip;\n}' ) def test_clean_cue_tags(self): @@ -167,6 +156,6 @@ def test_empty_lines_are_not_included_in_result(self): def test_can_parse_youtube_dl_files(self): vtt = webvtt.read(self._get_file('youtube_dl.vtt')) self.assertEqual( - "this will happen is I'm telling", + "this will happen is I'm telling\n ", vtt.captions[2].text ) diff --git a/webvtt/parsers.py b/webvtt/parsers.py index 3a978ca..ad3e70b 100644 --- a/webvtt/parsers.py +++ b/webvtt/parsers.py @@ -1,45 +1,43 @@ -import re import os import codecs - -from .errors import MalformedFileError, MalformedCaptionError -from .structures import Block, Style, Caption - - -class TextBasedParser(object): - """ - Parser for plain text caption files. - This is a generic class, do not use directly. - """ - - TIMEFRAME_LINE_PATTERN = '' - PARSER_OPTIONS = {} - - def __init__(self, parse_options=None): - self.captions = [] - self.parse_options = parse_options or {} - - def read(self, file): +import typing +from abc import ABC, abstractmethod + +from .errors import MalformedFileError +from .structures import (Style, + Caption, + WebVTTCueBlock, + WebVTTCommentBlock, + WebVTTStyleBlock, + SRTCueBlock, + SBVCueBlock, + ) + + +class BaseParser(ABC): + @classmethod + def read(cls, file): """Reads the captions file.""" - content = self._get_content_from_file(file_path=file) - self._validate(content) - self._parse(content) - - return self + return cls._parse(cls._get_content_from_file(file_path=file)) - def read_from_buffer(self, buffer): - content = self._read_content_lines(buffer) - self._validate(content) - self._parse(content) + @classmethod + def read_from_buffer(cls, buffer): + return cls._parse(cls._read_content_lines(buffer)) - return self + @classmethod + def _parse(cls, content): + if not cls.validate(content): + raise MalformedFileError('Invalid format') + return cls.parse(content) - def _get_content_from_file(self, file_path): - encoding = self._read_file_encoding(file_path) + @classmethod + def _get_content_from_file(cls, file_path): + encoding = cls._read_file_encoding(file_path) with open(file_path, encoding=encoding) as f: - return self._read_content_lines(f) + return cls._read_content_lines(f) - def _read_file_encoding(self, file_path): + @staticmethod + def _read_file_encoding(file_path): first_bytes = min(32, os.path.getsize(file_path)) with open(file_path, 'rb') as f: raw = f.read(first_bytes) @@ -49,7 +47,8 @@ def _read_file_encoding(self, file_path): else: return 'utf-8' - def _read_content_lines(self, file_obj): + @staticmethod + def _read_content_lines(file_obj: typing.IO[str]): lines = [line.rstrip('\n\r') for line in file_obj.readlines()] @@ -58,236 +57,126 @@ def _read_content_lines(self, file_obj): return lines - def _read_content(self, file): - return self._get_content_from_file(file_path=file) + @staticmethod + def iter_blocks_of_lines(lines) -> typing.Generator[typing.List[str], None, None]: + current_text_block = [] - def _parse_timeframe_line(self, line): - """Parse timeframe line and return start and end timestamps.""" - tf = self._validate_timeframe_line(line) - if not tf: - raise MalformedCaptionError('Invalid time format') + for line in lines: + if line.strip(): + current_text_block.append(line) + elif current_text_block: + yield current_text_block + current_text_block = [] - return tf.group(1), tf.group(2) + if current_text_block: + yield current_text_block - def _validate_timeframe_line(self, line): - return re.match(self.TIMEFRAME_LINE_PATTERN, line) - - def _is_timeframe_line(self, line): - """ - This method returns True if the line contains the timeframes. - To be implemented by child classes. - """ + @classmethod + @abstractmethod + def validate(cls, lines: typing.Sequence[str]) -> bool: raise NotImplementedError - def _validate(self, lines): - """ - Validates the format of the parsed file. - To be implemented by child classes. - """ + @classmethod + @abstractmethod + def parse(cls, lines: typing.Sequence[str]): raise NotImplementedError - def _should_skip_line(self, line, index, caption): - """ - This method returns True for a line that should be skipped. - Implement in child classes if needed. - """ - return False - - def _parse(self, lines): - self.captions = [] - c = None - - for index, line in enumerate(lines): - if self._is_timeframe_line(line): - try: - start, end = self._parse_timeframe_line(line) - except MalformedCaptionError as e: - raise MalformedCaptionError('{} in line {}'.format(e, index + 1)) - c = Caption(start, end) - elif self._should_skip_line(line, index, c): # allow child classes to skip lines based on the content - continue - elif line: - if c is None: - raise MalformedCaptionError( - 'Caption missing timeframe in line {}.'.format(index + 1)) - else: - c.add_line(line) - else: - if c is None: - continue - if not c.lines: - if self.PARSER_OPTIONS.get('ignore_empty_captions', False): - c = None - continue - raise MalformedCaptionError('Caption missing text in line {}.'.format(index + 1)) - - self.captions.append(c) - c = None - - if c is not None and c.lines: - self.captions.append(c) - - -class SRTParser(TextBasedParser): + +class WebVTTParser(BaseParser): """ - SRT parser. + Web Video Text Track parser. """ - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})') + @classmethod + def validate(cls, lines: typing.Sequence[str]) -> bool: + return lines[0].startswith('WEBVTT') + + @classmethod + def parse(cls, lines: typing.Sequence[str]) -> typing.List[typing.Union[Style, Caption]]: + items = [] + comments = [] + + for block_lines in cls.iter_blocks_of_lines(lines): + if WebVTTCueBlock.is_valid(block_lines): + cue_block = WebVTTCueBlock.from_lines(block_lines) + caption = Caption(cue_block.start, + cue_block.end, + cue_block.payload, + cue_block.identifier + ) - PARSER_OPTIONS = { - 'ignore_empty_captions': True - } + if comments: + caption.comments = [comment.text for comment in comments] + comments = [] + items.append(caption) - def _validate(self, lines): - if len(lines) < 2 or lines[0] != '1' or not self._validate_timeframe_line(lines[1]): - raise MalformedFileError('The file does not have a valid format.') + elif WebVTTCommentBlock.is_valid(block_lines): + comments.append(WebVTTCommentBlock.from_lines(block_lines)) - def _is_timeframe_line(self, line): - return '-->' in line + elif WebVTTStyleBlock.is_valid(block_lines): + style = Style(text=WebVTTStyleBlock.from_lines(block_lines).text) + if comments: + style.comments = [comment.text for comment in comments] + comments = [] + items.append(style) - def _should_skip_line(self, line, index, caption): - return caption is None and line.isdigit() + if comments and items: + items[-1].comments = [comment.text for comment in comments] + return items -class WebVTTParser(TextBasedParser): + +class SRTParser(BaseParser): """ - WebVTT parser. + SubRip SRT parser. """ - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})') - COMMENT_PATTERN = re.compile(r'NOTE(?:\s.+|$)') - STYLE_PATTERN = re.compile(r'STYLE[ \t]*$') - - def __init__(self): - super().__init__() - self.styles = [] - - def _compute_blocks(self, lines): - blocks = [] - - for index, line in enumerate(lines, start=1): - if line: - if not blocks: - blocks.append(Block(index)) - if not blocks[-1].lines: - if not line.strip(): - continue - blocks[-1].line_number = index - blocks[-1].lines.append(line) - else: - blocks.append(Block(index)) - - # filter out empty blocks and skip signature - return list(filter(lambda x: x.lines, blocks))[1:] - - def _parse_cue_block(self, block): - caption = Caption() - cue_timings = None - additional_blocks = None - - for line_number, line in enumerate(block.lines): - if self._is_cue_timings_line(line): - if cue_timings is None: - try: - cue_timings = self._parse_timeframe_line(line) - except MalformedCaptionError as e: - raise MalformedCaptionError( - '{} in line {}'.format(e, block.line_number + line_number)) - else: - additional_blocks = self._compute_blocks( - ['WEBVTT', ''] + block.lines[line_number:] - ) - break - elif line_number == 0: - caption.identifier = line - else: - caption.add_line(line) - - caption.start = cue_timings[0] - caption.end = cue_timings[1] - return caption, additional_blocks - - def _parse(self, lines): - self.captions = [] - blocks = self._compute_blocks(lines) - self._parse_blocks(blocks) - - def _is_empty(self, block): - is_empty = True - - for line in block.lines: - if line.strip() != "": - is_empty = False - - return is_empty - - def _parse_blocks(self, blocks): - for block in blocks: - # skip empty blocks - if self._is_empty(block): + @classmethod + def validate(cls, lines: typing.Sequence[str]) -> bool: + return len(lines) >= 3 and lines[0].isdigit() and '-->' in lines[1] and lines[2].strip() + + @classmethod + def parse(cls, lines: typing.Sequence[str]) -> typing.List[Caption]: + captions = [] + + for block_lines in cls.iter_blocks_of_lines(lines): + if not SRTCueBlock.is_valid(block_lines): continue - if self._is_cue_block(block): - caption, additional_blocks = self._parse_cue_block(block) - self.captions.append(caption) + cue_block = SRTCueBlock.from_lines(block_lines) + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) - if additional_blocks: - self._parse_blocks(additional_blocks) + return captions - elif self._is_comment_block(block): - continue - elif self._is_style_block(block): - if self.captions: - raise MalformedFileError( - 'Style block defined after the first cue in line {}.' - .format(block.line_number)) - style = Style() - style.lines = block.lines[1:] - self.styles.append(style) - else: - if len(block.lines) == 1: - raise MalformedCaptionError( - 'Standalone cue identifier in line {}.'.format(block.line_number)) - else: - raise MalformedCaptionError( - 'Missing timing cue in line {}.'.format(block.line_number+1)) - - def _validate(self, lines): - if not re.match('WEBVTT', lines[0]): - raise MalformedFileError('The file does not have a valid format') - - def _is_cue_timings_line(self, line): - return '-->' in line - - def _is_cue_block(self, block): - """Returns True if it is a cue block - (one of the two first lines being a cue timing line)""" - return any(map(self._is_cue_timings_line, block.lines[:2])) - - def _is_comment_block(self, block): - """Returns True if it is a comment block""" - return re.match(self.COMMENT_PATTERN, block.lines[0]) - - def _is_style_block(self, block): - """Returns True if it is a style block""" - return re.match(self.STYLE_PATTERN, block.lines[0]) - - -class SBVParser(TextBasedParser): + +class SBVParser(BaseParser): """ YouTube SBV parser. """ - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})') + @classmethod + def validate(cls, lines: typing.Sequence[str]) -> bool: + if len(lines) < 2: + return False + + first_block = next(cls.iter_blocks_of_lines(lines)) + return first_block and SBVCueBlock.is_valid(first_block) - PARSER_OPTIONS = { - 'ignore_empty_captions': True - } + @classmethod + def parse(cls, lines: typing.Sequence[str]) -> typing.List[Caption]: + captions = [] + + for block_lines in cls.iter_blocks_of_lines(lines): + if not SBVCueBlock.is_valid(block_lines): + continue - def _validate(self, lines): - if not self._validate_timeframe_line(lines[0]): - raise MalformedFileError('The file does not have a valid format') + cue_block = SBVCueBlock.from_lines(block_lines) + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) - def _is_timeframe_line(self, line): - return self._validate_timeframe_line(line) + return captions diff --git a/webvtt/structures.py b/webvtt/structures.py index c4a576d..6fbf9f4 100644 --- a/webvtt/structures.py +++ b/webvtt/structures.py @@ -1,135 +1,268 @@ import re +import typing +from datetime import datetime, time +from abc import ABC, abstractmethod from .errors import MalformedCaptionError -TIMESTAMP_PATTERN = re.compile(r'(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})') +__all__ = ['Caption', 'Style'] -__all__ = ['Caption'] +class BlockItem(ABC): + @classmethod + @abstractmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + raise NotImplementedError -class Caption(object): + @classmethod + @abstractmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'BlockItem': + raise NotImplementedError - CUE_TEXT_TAGS = re.compile('<.*?>') - """ - Represents a caption. - """ - def __init__(self, start='00:00:00.000', end='00:00:00.000', text=None): +class WebVTTCueBlock(BlockItem): + CUE_TIMINGS_PATTERN = re.compile(r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})') + + def __init__(self, identifier, start, end, payload): + self.identifier = identifier self.start = start self.end = end - self.identifier = None + self.payload = payload + + @classmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + return bool( + ( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + "-->" not in lines[1] + ) or + ( + len(lines) >= 3 and + "-->" not in lines[0] and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) and + "-->" not in lines[2] + ) + ) + + @classmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCueBlock': + identifier = None + start = None + end = None + payload = [] + + for line in lines: + timing_match = re.match(cls.CUE_TIMINGS_PATTERN, line) + if timing_match: + start = timing_match.group(1) + end = timing_match.group(2) + elif not start: + identifier = line + else: + payload.append(line) + + return cls(identifier, start, end, payload) + + +class WebVTTCommentBlock(BlockItem): + COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text): + self.text = text + + @classmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + return lines[0].startswith('NOTE') + + @classmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCommentBlock': + match = cls.COMMENT_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + +class WebVTTStyleBlock(BlockItem): + STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text): + self.text = text + + @classmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + return lines[0].startswith('STYLE') + + @classmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTStyleBlock': + match = cls.STYLE_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + +class SRTCueBlock(BlockItem): + CUE_TIMINGS_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})') + + def __init__(self, index, start, end, payload): + self.index = index + self.start = start + self.end = end + self.payload = payload - # If lines is a string convert to a list - if text and isinstance(text, str): - text = text.splitlines() + @classmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + ) - self._lines = text or [] + @classmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'SRTCueBlock': + index = lines[0] - def __repr__(self): - return '<%(cls)s start=%(start)s end=%(end)s text=%(text)s>' % { - 'cls': self.__class__.__name__, - 'start': self.start, - 'end': self.end, - 'text': self.text.replace('\n', '\\n') - } + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S,%f').time(), + (match.group(1), match.group(2)) + ) - def __str__(self): - return '%(start)s %(end)s %(text)s' % { - 'start': self.start, - 'end': self.end, - 'text': self.text.replace('\n', '\\n') - } + payload = lines[2:] - def add_line(self, line): - self.lines.append(line) + return cls(index, start, end, payload) - def _to_seconds(self, hours, minutes, seconds, milliseconds): - return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 - def _parse_timestamp(self, timestamp): - res = re.match(TIMESTAMP_PATTERN, timestamp) - if not res: - raise MalformedCaptionError('Invalid timestamp: {}'.format(timestamp)) +class SBVCueBlock(BlockItem): + CUE_TIMINGS_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})') - values = list(map(lambda x: int(x) if x else 0, res.groups())) - return self._to_seconds(*values) + def __init__(self, start, end, payload): + self.start = start + self.end = end + self.payload = payload - def _to_timestamp(self, total_seconds): - hours = int(total_seconds / 3600) - minutes = int(total_seconds / 60 - hours * 60) - seconds = total_seconds - hours * 3600 - minutes * 60 - return '{:02d}:{:02d}:{:06.3f}'.format(hours, minutes, seconds) + @classmethod + def is_valid(cls, lines: typing.Sequence[str]) -> bool: + return bool( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + lines[1].strip() + ) - def _clean_cue_tags(self, text): - return re.sub(self.CUE_TEXT_TAGS, '', text) + @classmethod + def from_lines(cls, lines: typing.Sequence[str]) -> 'SBVCueBlock': + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) + start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S.%f').time(), + (match.group(1), match.group(2)) + ) - @property - def start_in_seconds(self): - return self._start + payload = lines[1:] - @property - def end_in_seconds(self): - return self._end + return cls(start, end, payload) + + +class Caption: + CUE_TEXT_TAGS = re.compile('<.*?>') + + def __init__(self, + start: typing.Optional[typing.Union[str, time]] = None, + end: typing.Optional[typing.Union[str, time]] = None, + text: typing.Optional[typing.Union[str, typing.List[str]]] = None, + identifier: typing.Optional[str] = None + ): + self.start = start or time() + self.end = end or time() + self.identifier = identifier + self.lines = text.splitlines() if isinstance(text, str) else text or [] + self.comments = [] + + def __repr__(self): + cleaned_text = self.text.replace('\n', '\\n') + return (f'<{self.__class__.__name__} ' + f'start={self.start!r} ' + f'end={self.end!r} ' + f'text={cleaned_text!r} ' + f'identifier={self.identifier!r}>' + ) + + def __str__(self): + cleaned_text = self.text.replace('\n', '\\n') + return f'{self.start} {self.end} {cleaned_text}' + + def add_line(self, line: str): + self.lines.append(line) + + def stream(self): + yield from self.lines @property def start(self): - return self._to_timestamp(self._start) + return self.format_timestamp(self._start) @start.setter def start(self, value): - self._start = self._parse_timestamp(value) + self._start = self.parse_timestamp(value) @property def end(self): - return self._to_timestamp(self._end) + return self.format_timestamp(self._end) @end.setter def end(self, value): - self._end = self._parse_timestamp(value) + self._end = self.parse_timestamp(value) @property - def lines(self): - return self._lines + def start_in_seconds(self): + return self.time_in_seconds(self._start) @property - def text(self): - """Returns the captions lines as a text (without cue tags)""" - return self._clean_cue_tags(self.raw_text) + def end_in_seconds(self): + return self.time_in_seconds(self._end) @property - def raw_text(self): + def raw_text(self) -> str: """Returns the captions lines as a text (may include cue tags)""" return '\n'.join(self.lines) + @property + def text(self) -> str: + """Returns the captions lines as a text (without cue tags)""" + return re.sub(self.CUE_TEXT_TAGS, '', self.raw_text) + @text.setter - def text(self, value): + def text(self, value: str): if not isinstance(value, str): raise AttributeError('String value expected but received {}.'.format(type(value))) - self._lines = value.splitlines() + self.lines = value.splitlines() + @staticmethod + def parse_timestamp(value: typing.Union[str, time]): + if isinstance(value, str): + time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' + try: + return datetime.strptime(value, time_format).time() + except ValueError: + raise MalformedCaptionError(f'Invalid timestamp: {value}') + elif isinstance(value, time): + return value -class GenericBlock(object): - """Generic class that defines a data structure holding an array of lines""" - def __init__(self): - self.lines = [] + raise AttributeError(f'The type {type(value)} is not supported') + @staticmethod + def format_timestamp(time_obj: time) -> str: + microseconds = int(time_obj.microsecond / 1000) + return f'{time_obj.strftime("%H:%M:%S")}.{microseconds:03d}' -class Block(GenericBlock): - def __init__(self, line_number): - super().__init__() - self.line_number = line_number + @staticmethod + def time_in_seconds(time_obj: time) -> int: + return (time_obj.hour * 3600 + + time_obj.minute * 60 + + time_obj.second + + time_obj.microsecond // 1_000_000 + ) -class Style(GenericBlock): +class Style: + def __init__(self, text: typing.Union[str, typing.List[str]]): + self.lines = text.splitlines() if isinstance(text, str) else text + self.comments = [] @property def text(self): - """Returns the style lines as a text""" - return ''.join(map(lambda x: x.strip(), self.lines)) - - @text.setter - def text(self, value): - if type(value) != str: - raise TypeError('The text value must be a string.') - self.lines = value.split('\n') + return '\n'.join(self.lines) diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index adec7c9..a1d88a2 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -1,13 +1,15 @@ import os +import typing from .parsers import WebVTTParser, SRTParser, SBVParser from .writers import WebVTTWriter, SRTWriter +from .structures import Caption, Style from .errors import MissingFilenameError __all__ = ['WebVTT'] -class WebVTT(object): +class WebVTT: """ Parse captions in WebVTT format and also from other formats like SRT. @@ -45,28 +47,31 @@ def __str__(self): @classmethod def from_srt(cls, file): """Reads captions from a file in SubRip format.""" - parser = SRTParser().read(file) - return cls(file=file, captions=parser.captions) + return cls(file=file, captions=SRTParser.read(file)) @classmethod def from_sbv(cls, file): """Reads captions from a file in YouTube SBV format.""" - parser = SBVParser().read(file) - return cls(file=file, captions=parser.captions) + return cls(file=file, captions=SBVParser.read(file)) @classmethod def read(cls, file): """Reads a WebVTT captions file.""" - parser = WebVTTParser().read(file) - return cls(file=file, captions=parser.captions, styles=parser.styles) + items = WebVTTParser.read(file) + return cls(file=file, + captions=[it for it in items if isinstance(it, Caption)], + styles=[it for it in items if isinstance(it, Style)] + ) @classmethod - def read_buffer(cls, buffer): + def read_buffer(cls, buffer: typing.IO[str]): """Reads a WebVTT captions from a file-like object. Such file-like object may be the return of an io.open call, io.StringIO object, tempfile.TemporaryFile object, etc.""" - parser = WebVTTParser().read_from_buffer(buffer) - return cls(captions=parser.captions, styles=parser.styles) + items = WebVTTParser.read_from_buffer(buffer) + return cls(captions=[it for it in items if isinstance(it, Caption)], + styles=[it for it in items if isinstance(it, Style)] + ) def _get_output_file(self, output, extension='vtt'): if not output: diff --git a/webvtt/writers.py b/webvtt/writers.py index 5ec551b..71b911e 100644 --- a/webvtt/writers.py +++ b/webvtt/writers.py @@ -1,40 +1,40 @@ +import typing -class WebVTTWriter(object): +from .structures import Caption - def write(self, captions, f): + +class WebVTTWriter: + + def write(self, captions: typing.Iterable[Caption], f: typing.IO[str]): f.write(self.webvtt_content(captions)) - def webvtt_content(self, captions): + def webvtt_content(self, captions: typing.Iterable[Caption]) -> str: """ Return captions content with webvtt formatting. """ - output = ["WEBVTT"] + output = ['WEBVTT'] for caption in captions: - output.append("") - if caption.identifier: - output.append(caption.identifier) - output.append('{} --> {}'.format(caption.start, caption.end)) - output.extend(caption.lines) + output.extend([ + '', + *(identifier for identifier in {caption.identifier} if identifier), + f'{caption.start} --> {caption.end}', + *caption.lines + ]) return '\n'.join(output) -class SRTWriter(object): - - def write(self, captions, f): - for line_number, caption in enumerate(captions, start=1): - f.write('{}\n'.format(line_number)) - f.write('{} --> {}\n'.format(self._to_srt_timestamp(caption.start_in_seconds), - self._to_srt_timestamp(caption.end_in_seconds))) - f.writelines(['{}\n'.format(l) for l in caption.lines]) - f.write('\n') - - def _to_srt_timestamp(self, total_seconds): - hours = int(total_seconds / 3600) - minutes = int(total_seconds / 60 - hours * 60) - seconds = int(total_seconds - hours * 3600 - minutes * 60) - milliseconds = round((total_seconds - seconds - hours * 3600 - minutes * 60)*1000) +class SRTWriter: - return '{:02d}:{:02d}:{:02d},{:03d}'.format(hours, minutes, seconds, milliseconds) + def write(self, captions, f: typing.IO[str]): + output = [] + for index, caption in enumerate(captions, start=1): + output.extend([ + f'{index}', + '{} --> {}'.format(*map(lambda x: x.replace('.', ','), (caption.start, caption.end))), + *caption.lines, + '' + ]) + f.write('\n'.join(output).rstrip()) class SBVWriter(object): From e495532f2b1f8b729ca38853128dd6d0618c80f6 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Sat, 27 Apr 2024 20:32:06 +0200 Subject: [PATCH 02/16] Improvements --- setup.cfg | 31 +++ setup.py | 8 +- tests/subtitles/styles_with_comments.vtt | 23 ++ tests/test_cli.py | 184 ++++++++++++++++ tests/test_segmenter.py | 139 ++++++------ tests/test_srt.py | 1 - tests/test_webvtt.py | 175 ++++++++------- tests/test_webvtt_parser.py | 37 +++- tox.ini | 18 -- webvtt/__init__.py | 13 +- webvtt/cli.py | 88 ++++---- webvtt/errors.py | 3 - webvtt/parsers.py | 91 +++----- webvtt/segmenter.py | 52 ++--- webvtt/structures.py | 83 ++++--- webvtt/webvtt.py | 265 +++++++++++++++-------- webvtt/writers.py | 63 +++--- 17 files changed, 795 insertions(+), 479 deletions(-) create mode 100644 tests/subtitles/styles_with_comments.vtt create mode 100644 tests/test_cli.py delete mode 100644 tox.ini diff --git a/setup.cfg b/setup.cfg index ddb7da9..2233487 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,33 @@ [metadata] description_file = README.rst + +[tox:tox] +envlist = codestyle,py,coverage + +[coverage:run] +source = webvtt +branch = true + +[coverage:report] +fail_under = 97 + +[testenv] +deps = + coverage +description = run the tests and provide coverage metrics +commands = + coverage run -m unittest discover + +[testenv:codestyle] +deps = + flake8 + mypy +commands = + flake8 webvtt setup.py + mypy webvtt + +[testenv:coverage] +deps = coverage +commands = + coverage html --fail-under=0 + coverage report diff --git a/setup.py b/setup.py index 92384b3..3219c24 100644 --- a/setup.py +++ b/setup.py @@ -19,25 +19,19 @@ url='https://github.com/glut23/webvtt-py', packages=find_packages('.', exclude=['tests']), include_package_data=True, - install_requires=[ - 'docopt' - ], entry_points={ 'console_scripts': [ 'webvtt=webvtt.cli:main' ] }, license='MIT', - python_requires='>=3.4', + python_requires='>=3.7', classifiers=[ 'Development Status :: 3 - Alpha', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/tests/subtitles/styles_with_comments.vtt b/tests/subtitles/styles_with_comments.vtt new file mode 100644 index 0000000..2a69de0 --- /dev/null +++ b/tests/subtitles/styles_with_comments.vtt @@ -0,0 +1,23 @@ +WEBVTT + +NOTE This is the first style block + +STYLE +::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; +} + +NOTE This is the second block of styles + +NOTE +Multiline comment for the same +second block of styles + +STYLE +::cue(b) { + color: peachpuff; +} + +00:00:00.000 --> 00:00:10.000 +- Hello world. \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..bb31353 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,184 @@ +import unittest +import tempfile +import os +import pathlib +import textwrap + +from webvtt.cli import main +import subprocess + +class CLITestCase(unittest.TestCase): + + def test_cli(self): + vtt_file = pathlib.Path(__file__).resolve().parent / 'subtitles' / 'sample.vtt' + + with tempfile.TemporaryDirectory() as temp_dir: + + + # result = subprocess.run(['python', 'webvtt/cli.py', *['segment', str(vtt_file.resolve()), '-o', temp_dir]], capture_output=True, + # text=True) + + main(['segment', str(vtt_file.resolve()), '-o', temp_dir]) + _, dirs, files = next(os.walk(temp_dir)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 8) + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + 'fileSequence3.webvtt', + 'fileSequence4.webvtt', + 'fileSequence5.webvtt', + 'fileSequence6.webvtt', + ): + self.assertIn(expected_file, files) + + self.assertEqual( + (pathlib.Path(temp_dir) / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXTINF:30.00000 + fileSequence3.webvtt + #EXTINF:30.00000 + fileSequence4.webvtt + #EXTINF:30.00000 + fileSequence5.webvtt + #EXTINF:30.00000 + fileSequence6.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence3.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence4.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence5.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence6.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip()) + diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index d82a520..b5ce88e 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -2,8 +2,7 @@ import unittest from shutil import rmtree -from webvtt import WebVTTSegmenter, Caption -from webvtt.errors import InvalidCaptionsError +from webvtt import WebVTTSegmenter from webvtt import WebVTT BASE_DIR = os.path.dirname(__file__) @@ -11,7 +10,7 @@ OUTPUT_DIR = os.path.join(BASE_DIR, 'output') -class WebVTTSegmenterTestCase(unittest.TestCase): +class TestWebVTTSegmenter(unittest.TestCase): def setUp(self): self.segmenter = WebVTTSegmenter() @@ -20,100 +19,87 @@ def tearDown(self): if os.path.exists(OUTPUT_DIR): rmtree(OUTPUT_DIR) - def _parse_captions(self, filename): - self.webvtt = WebVTT().read(os.path.join(SUBTITLES_DIR, filename)) - - def test_invalid_captions(self): - self.assertRaises( - FileNotFoundError, - self.segmenter.segment, - 'text' - ) - - self.assertRaises( - InvalidCaptionsError, - self.segmenter.segment, - 10 - ) - - def test_single_invalid_caption(self): - self.assertRaises( - InvalidCaptionsError, - self.segmenter.segment, - [Caption(), Caption(), 'text', Caption()] - ) + @staticmethod + def _get_file(filename: str) -> str: + return os.path.join(SUBTITLES_DIR, filename) def test_total_segments(self): # segment with default 10 seconds - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) self.assertEqual(self.segmenter.total_segments, 7) # segment with custom 30 seconds - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 30) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 30) self.assertEqual(self.segmenter.total_segments, 3) def test_output_folder_is_created(self): self.assertFalse(os.path.exists(OUTPUT_DIR)) - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) self.assertTrue(os.path.exists(OUTPUT_DIR)) def test_segmentation_files_exist(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) for i in range(7): self.assertTrue( - os.path.exists(os.path.join(OUTPUT_DIR, 'fileSequence{}.webvtt'.format(i))) + os.path.exists(os.path.join(OUTPUT_DIR, + f'fileSequence{i}.webvtt' + ) + ) ) - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'prog_index.m3u8'))) + self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, + 'prog_index.m3u8' + ))) def test_segmentation(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) + filepath = self._get_file('sample.vtt') + self.segmenter.segment(filepath, OUTPUT_DIR) + + webvtt = WebVTT.read(filepath) # segment 1 should have caption 1 and 2 self.assertEqual(len(self.segmenter.segments[0]), 2) - self.assertIn(self.webvtt.captions[0], self.segmenter.segments[0]) - self.assertIn(self.webvtt.captions[1], self.segmenter.segments[0]) + self.assertIn(webvtt.captions[0], self.segmenter.segments[0]) + self.assertIn(webvtt.captions[1], self.segmenter.segments[0]) # segment 2 should have caption 2 again (overlap), 3 and 4 self.assertEqual(len(self.segmenter.segments[1]), 3) - self.assertIn(self.webvtt.captions[2], self.segmenter.segments[1]) - self.assertIn(self.webvtt.captions[3], self.segmenter.segments[1]) + self.assertIn(webvtt.captions[2], self.segmenter.segments[1]) + self.assertIn(webvtt.captions[3], self.segmenter.segments[1]) # segment 3 should have caption 4 again (overlap), 5, 6 and 7 self.assertEqual(len(self.segmenter.segments[2]), 4) - self.assertIn(self.webvtt.captions[3], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[4], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[5], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[6], self.segmenter.segments[2]) + self.assertIn(webvtt.captions[3], self.segmenter.segments[2]) + self.assertIn(webvtt.captions[4], self.segmenter.segments[2]) + self.assertIn(webvtt.captions[5], self.segmenter.segments[2]) + self.assertIn(webvtt.captions[6], self.segmenter.segments[2]) # segment 4 should have caption 7 again (overlap), 8, 9 and 10 self.assertEqual(len(self.segmenter.segments[3]), 4) - self.assertIn(self.webvtt.captions[6], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[7], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[8], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[9], self.segmenter.segments[3]) + self.assertIn(webvtt.captions[6], self.segmenter.segments[3]) + self.assertIn(webvtt.captions[7], self.segmenter.segments[3]) + self.assertIn(webvtt.captions[8], self.segmenter.segments[3]) + self.assertIn(webvtt.captions[9], self.segmenter.segments[3]) # segment 5 should have caption 10 again (overlap), 11 and 12 self.assertEqual(len(self.segmenter.segments[4]), 3) - self.assertIn(self.webvtt.captions[9], self.segmenter.segments[4]) - self.assertIn(self.webvtt.captions[10], self.segmenter.segments[4]) - self.assertIn(self.webvtt.captions[11], self.segmenter.segments[4]) + self.assertIn(webvtt.captions[9], self.segmenter.segments[4]) + self.assertIn(webvtt.captions[10], self.segmenter.segments[4]) + self.assertIn(webvtt.captions[11], self.segmenter.segments[4]) # segment 6 should have caption 12 again (overlap), 13, 14 and 15 self.assertEqual(len(self.segmenter.segments[5]), 4) - self.assertIn(self.webvtt.captions[11], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[12], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[13], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[14], self.segmenter.segments[5]) + self.assertIn(webvtt.captions[11], self.segmenter.segments[5]) + self.assertIn(webvtt.captions[12], self.segmenter.segments[5]) + self.assertIn(webvtt.captions[13], self.segmenter.segments[5]) + self.assertIn(webvtt.captions[14], self.segmenter.segments[5]) # segment 7 should have caption 15 again (overlap) and 16 self.assertEqual(len(self.segmenter.segments[6]), 2) - self.assertIn(self.webvtt.captions[14], self.segmenter.segments[6]) - self.assertIn(self.webvtt.captions[15], self.segmenter.segments[6]) + self.assertIn(webvtt.captions[14], self.segmenter.segments[6]) + self.assertIn(webvtt.captions[15], self.segmenter.segments[6]) def test_segment_content(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 10) - with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f: + with open( + os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), + 'r', + encoding='utf-8' + ) as f: lines = [line.rstrip() for line in f.readlines()] expected_lines = [ @@ -130,10 +116,13 @@ def test_segment_content(self): self.assertListEqual(lines, expected_lines) def test_manifest_content(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10) + self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 10) - with open(os.path.join(OUTPUT_DIR, 'prog_index.m3u8'), 'r', encoding='utf-8') as f: + with open( + os.path.join(OUTPUT_DIR, 'prog_index.m3u8'), + 'r', + encoding='utf-8' + ) as f: lines = [line.rstrip() for line in f.readlines()] expected_lines = [ @@ -155,19 +144,29 @@ def test_manifest_content(self): self.assertEqual(lines[index], line) def test_customize_mpegts(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, mpegts=800000) - - with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f: + self.segmenter.segment(self._get_file('sample.vtt'), + OUTPUT_DIR, + mpegts=800000 + ) + + with open( + os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), + 'r', + encoding='utf-8' + ) as f: lines = f.readlines() self.assertIn('MPEGTS:800000', lines[1]) def test_segment_from_file(self): - self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'sample.vtt'), OUTPUT_DIR), + self.segmenter.segment( + os.path.join(SUBTITLES_DIR, 'sample.vtt'), OUTPUT_DIR + ), self.assertEqual(self.segmenter.total_segments, 7) def test_segment_with_no_captions(self): - self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'no_captions.vtt'), OUTPUT_DIR), + self.segmenter.segment( + os.path.join(SUBTITLES_DIR, 'no_captions.vtt'), OUTPUT_DIR + ), self.assertEqual(self.segmenter.total_segments, 0) def test_total_segments_readonly(self): diff --git a/tests/test_srt.py b/tests/test_srt.py index 0e6a645..70476af 100644 --- a/tests/test_srt.py +++ b/tests/test_srt.py @@ -1,5 +1,4 @@ import os -import unittest from shutil import rmtree, copy import webvtt diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index c5509df..9bf50d4 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -15,12 +15,28 @@ class WebVTTTestCase(GenericParserTestCase): + def setUp(self): + os.makedirs(OUTPUT_DIR) + def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) + rmtree(OUTPUT_DIR) def test_create_caption(self): - caption = Caption('00:00:00.500', '00:00:07.900', ['Caption test line 1', 'Caption test line 2']) + caption = Caption(start='00:00:00.500', + end='00:00:07.900', + text=['Caption test line 1', 'Caption test line 2'] + ) + self.assertEqual(caption.start, '00:00:00.500') + self.assertEqual(caption.start_in_seconds, 0) + self.assertEqual(caption.end, '00:00:07.900') + self.assertEqual(caption.end_in_seconds, 7) + self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) + + def test_create_caption_with_text(self): + caption = Caption(start='00:00:00.500', + end='00:00:07.900', + text='Caption test line 1\nCaption test line 2' + ) self.assertEqual(caption.start, '00:00:00.500') self.assertEqual(caption.start_in_seconds, 0) self.assertEqual(caption.end, '00:00:07.900') @@ -28,58 +44,58 @@ def test_create_caption(self): self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) def test_write_captions(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) out = io.StringIO() vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) - new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2']) + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', 'New caption text line2'] + ) vtt.captions.append(new_caption) vtt.write(out) out.seek(0) - lines = [line.rstrip() for line in out.readlines()] - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'New caption text line1', - 'New caption text line2' - ] - - self.assertListEqual(lines, expected_lines) + self.assertListEqual([line for line in out], + [ + 'WEBVTT\n', + '\n', + '00:00:00.500 --> 00:00:07.000\n', + 'Caption text #1\n', + '\n', + '00:00:07.000 --> 00:00:11.890\n', + 'New caption text line1\n', + 'New caption text line2' + ] + ) def test_save_captions(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) - new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2']) + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', 'New caption text line2'] + ) vtt.captions.append(new_caption) vtt.save() with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'New caption text line1', - 'New caption text line2' - ] - - self.assertListEqual(lines, expected_lines) + self.assertListEqual([line for line in f], + [ + 'WEBVTT\n', + '\n', + '00:00:00.500 --> 00:00:07.000\n', + 'Caption text #1\n', + '\n', + '00:00:07.000 --> 00:00:11.890\n', + 'New caption text line1\n', + 'New caption text line2' + ] + ) def test_srt_conversion(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('one_caption.srt'), OUTPUT_DIR) vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'one_caption.srt')) @@ -88,19 +104,16 @@ def test_srt_conversion(self): self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'one_caption.vtt'))) with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - ] - - self.assertListEqual(lines, expected_lines) + self.assertListEqual([line for line in f], + [ + 'WEBVTT\n', + '\n', + '00:00:00.500 --> 00:00:07.000\n', + 'Caption text #1', + ] + ) def test_sbv_conversion(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('two_captions.sbv'), OUTPUT_DIR) vtt = webvtt.from_sbv(os.path.join(OUTPUT_DIR, 'two_captions.sbv')) @@ -109,20 +122,18 @@ def test_sbv_conversion(self): self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'two_captions.vtt'))) with open(os.path.join(OUTPUT_DIR, 'two_captions.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.378 --> 00:00:11.378', - 'Caption text #1', - '', - '00:00:11.378 --> 00:00:12.305', - 'Caption text #2 (line 1)', - 'Caption text #2 (line 2)', - ] - - self.assertListEqual(lines, expected_lines) + self.assertListEqual([line for line in f], + [ + 'WEBVTT\n', + '\n', + '00:00:00.378 --> 00:00:11.378\n', + 'Caption text #1\n', + '\n', + '00:00:11.378 --> 00:00:12.305\n', + 'Caption text #2 (line 1)\n', + 'Caption text #2 (line 2)', + ] + ) def test_save_to_other_location(self): target_path = os.path.join(OUTPUT_DIR, 'test_folder') @@ -226,16 +237,15 @@ def test_manipulate_lines(self): def test_read_file_buffer(self): with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: - vtt = webvtt.read_buffer(f) + vtt = webvtt.from_buffer(f) self.assertIsInstance(vtt.captions, list) def test_read_memory_buffer(self): - payload = '' with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: payload = f.read() buffer = io.StringIO(payload) - vtt = webvtt.read_buffer(buffer) + vtt = webvtt.from_buffer(buffer) self.assertIsInstance(vtt.captions, list) def test_read_memory_buffer_carriage_return(self): @@ -252,31 +262,20 @@ def test_read_memory_buffer_carriage_return(self): 00:00:11.890 --> 00:00:16.320\r Caption text #3\r ''')) - vtt = webvtt.read_buffer(buffer) + vtt = webvtt.from_buffer(buffer) self.assertEqual(len(vtt.captions), 3) def test_read_malformed_buffer(self): - malformed_payloads = ['', 'MOCK MELFORMED CONTENT'] + malformed_payloads = ['', 'MOCK MALFORMED CONTENT'] for payload in malformed_payloads: buffer = io.StringIO(payload) with self.assertRaises(MalformedFileError): - webvtt.read_buffer(buffer) - + webvtt.from_buffer(buffer) def test_captions(self): vtt = webvtt.read(self._get_file('sample.vtt')) self.assertIsInstance(vtt.captions, list) - def test_captions_prevent_write(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertRaises( - AttributeError, - setattr, - vtt, - 'captions', - [] - ) - def test_sequence_iteration(self): vtt = webvtt.read(self._get_file('sample.vtt')) self.assertIsInstance(vtt[0], Caption) @@ -311,7 +310,6 @@ def test_get_styles_as_text(self): ) def test_save_identifiers(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) @@ -347,7 +345,6 @@ def test_save_identifiers(self): self.assertListEqual(lines, expected_lines) def test_save_updated_identifiers(self): - os.makedirs(OUTPUT_DIR) copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) @@ -401,15 +398,15 @@ def test_content_formatting(self): Caption('00:00:08.000', '00:00:15.000', ['Caption test line 3', 'Caption test line 4']), ] expected_content = textwrap.dedent(""" - WEBVTT - - 00:00:00.500 --> 00:00:07.000 - Caption test line 1 - Caption test line 2 - - 00:00:08.000 --> 00:00:15.000 - Caption test line 3 - Caption test line 4 - """).strip() + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption test line 1 + Caption test line 2 + + 00:00:08.000 --> 00:00:15.000 + Caption test line 3 + Caption test line 4 + """).strip() vtt = webvtt.WebVTT(captions=captions) self.assertEqual(expected_content, vtt.content) diff --git a/tests/test_webvtt_parser.py b/tests/test_webvtt_parser.py index 1262b7a..644dc0f 100644 --- a/tests/test_webvtt_parser.py +++ b/tests/test_webvtt_parser.py @@ -1,12 +1,30 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import unittest from .generic import GenericParserTestCase import webvtt -from webvtt.parsers import WebVTTParser +from webvtt.parsers import Parser from webvtt.structures import Caption -from webvtt.errors import MalformedFileError, MalformedCaptionError +from webvtt.errors import MalformedFileError + + +class ParserTestCase(unittest.TestCase): + + def test_validate_not_callable(self): + self.assertRaises( + NotImplementedError, + Parser.validate, + [] + ) + + def test_parse_content_not_callable(self): + self.assertRaises( + NotImplementedError, + Parser.parse_content, + [] + ) class WebVTTParserTestCase(GenericParserTestCase): @@ -130,6 +148,21 @@ def test_parse_styles(self): '::cue {\n background-image: linear-gradient(to bottom, dimgray, lightgray);\n color: papayawhip;\n}' ) + def test_parse_styles_with_comments(self): + vtt = webvtt.read(self._get_file('styles_with_comments.vtt')) + self.assertEqual(len(vtt.captions), 1) + self.assertEqual(len(vtt.styles), 2) + self.assertEqual( + vtt.styles[0].comments, + ['This is the first style block'] + ) + self.assertEqual( + vtt.styles[1].comments, + ['This is the second block of styles', + 'Multiline comment for the same\nsecond block of styles' + ] + ) + def test_clean_cue_tags(self): vtt = webvtt.read(self._get_file('cue_tags.vtt')) self.assertEqual( diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 06fa4ff..0000000 --- a/tox.ini +++ /dev/null @@ -1,18 +0,0 @@ -[tox] -envlist = py34, py35, py36, py37, py38, py39 - -[travis] -python = - 3.9: py39 - 3.8: py38 - 3.7: py37 - 3.6: py36 - 3.5: py35 - 3.4: py34 - -[testenv] -setenv = - PYTHONPATH = {toxinidir} -deps = pytest -commands = - pytest diff --git a/webvtt/__init__.py b/webvtt/__init__.py index b6239f4..45719af 100644 --- a/webvtt/__init__.py +++ b/webvtt/__init__.py @@ -1,15 +1,14 @@ -__version__ = '0.4.6' +__version__ = '0.5.0' -from .webvtt import * -from .segmenter import * -from .structures import * -from .errors import * +from .webvtt import WebVTT +from .segmenter import WebVTTSegmenter +from .structures import Caption, Style # noqa -__all__ = webvtt.__all__ + segmenter.__all__ + structures.__all__ + errors.__all__ +__all__ = ['WebVTT', 'WebVTTSegmenter', 'Caption', 'Style'] read = WebVTT.read read_buffer = WebVTT.read_buffer +from_buffer = WebVTT.from_buffer from_srt = WebVTT.from_srt from_sbv = WebVTT.from_sbv -list_formats = WebVTT.list_formats segment = WebVTTSegmenter().segment diff --git a/webvtt/cli.py b/webvtt/cli.py index ad8ebbf..9c722b2 100644 --- a/webvtt/cli.py +++ b/webvtt/cli.py @@ -1,47 +1,53 @@ -""" -Usage: - webvtt segment [--target-duration=SECONDS] [--mpegts=OFFSET] [--output=] - webvtt -h | --help - webvtt --version +import argparse +import typing -Options: - -h --help Show this screen. - --version Show version. - --target-duration=SECONDS Target duration of each segment in seconds [default: 10]. - --mpegts=OFFSET Presentation timestamp value [default: 900000]. - --output= Output to directory [default: ./]. +from . import WebVTTSegmenter -Examples: - webvtt segment captions.vtt --output destination/directory -""" -from docopt import docopt - -from . import WebVTTSegmenter, __version__ - - -def main(): - """Main entry point for CLI commands.""" - options = docopt(__doc__, version=__version__) - if options['segment']: - segment( - options[''], - options['--output'], - options['--target-duration'], - options['--mpegts'], +def main(argv: typing.Optional[typing.Sequence] = None): + """ + Segment WebVTT from command line. + """ + arguments = argparse.ArgumentParser( + description='Segment WebVTT files.' + ) + arguments.add_argument( + 'command', + choices=['segment'], + help='command to perform' + ) + arguments.add_argument( + 'file', + metavar='PATH', + help='WebVTT file' + ) + arguments.add_argument( + '-o', '--output', + metavar='PATH', + help='output directory' + ) + arguments.add_argument( + '-d', '--target-duration', + metavar='NUMBER', + type=int, + default=10, + help='target duration of each segment in seconds, default: 10' + ) + arguments.add_argument( + '-m', '--mpegts', + metavar='NUMBER', + type=int, + default=900000, + help='presentation timestamp value, default: 900000' + ) + args = arguments.parse_args(argv) + WebVTTSegmenter().segment( + args.file, + args.output, + args.target_duration, + args.mpegts ) -def segment(f, output, target_duration, mpegts): - """Segment command.""" - try: - target_duration = int(target_duration) - except ValueError: - exit('Error: Invalid target duration.') - - try: - mpegts = int(mpegts) - except ValueError: - exit('Error: Invalid MPEGTS value.') - - WebVTTSegmenter().segment(f, output, target_duration, mpegts) \ No newline at end of file +if __name__ == '__main__': + main() diff --git a/webvtt/errors.py b/webvtt/errors.py index 9ee549b..4a09dbb 100644 --- a/webvtt/errors.py +++ b/webvtt/errors.py @@ -1,7 +1,4 @@ -__all__ = ['MalformedFileError', 'MalformedCaptionError', 'InvalidCaptionsError', 'MissingFilenameError'] - - class MalformedFileError(Exception): """Error raised when the file is not well formatted""" diff --git a/webvtt/parsers.py b/webvtt/parsers.py index ad3e70b..7450565 100644 --- a/webvtt/parsers.py +++ b/webvtt/parsers.py @@ -1,5 +1,3 @@ -import os -import codecs import typing from abc import ABC, abstractmethod @@ -14,51 +12,19 @@ ) -class BaseParser(ABC): - @classmethod - def read(cls, file): - """Reads the captions file.""" - return cls._parse(cls._get_content_from_file(file_path=file)) - - @classmethod - def read_from_buffer(cls, buffer): - return cls._parse(cls._read_content_lines(buffer)) +class Parser(ABC): @classmethod - def _parse(cls, content): - if not cls.validate(content): + def parse(cls, lines: typing.Sequence[str]): + if not cls.validate(lines): raise MalformedFileError('Invalid format') - return cls.parse(content) - - @classmethod - def _get_content_from_file(cls, file_path): - encoding = cls._read_file_encoding(file_path) - with open(file_path, encoding=encoding) as f: - return cls._read_content_lines(f) - - @staticmethod - def _read_file_encoding(file_path): - first_bytes = min(32, os.path.getsize(file_path)) - with open(file_path, 'rb') as f: - raw = f.read(first_bytes) - - if raw.startswith(codecs.BOM_UTF8): - return 'utf-8-sig' - else: - return 'utf-8' - - @staticmethod - def _read_content_lines(file_obj: typing.IO[str]): - - lines = [line.rstrip('\n\r') for line in file_obj.readlines()] - - if not lines: - raise MalformedFileError('The file is empty.') - - return lines + return cls.parse_content(lines) @staticmethod - def iter_blocks_of_lines(lines) -> typing.Generator[typing.List[str], None, None]: + def iter_blocks_of_lines(lines) -> typing.Generator[typing.List[str], + None, + None + ]: current_text_block = [] for line in lines: @@ -78,23 +44,25 @@ def validate(cls, lines: typing.Sequence[str]) -> bool: @classmethod @abstractmethod - def parse(cls, lines: typing.Sequence[str]): + def parse_content(cls, lines: typing.Sequence[str]): raise NotImplementedError -class WebVTTParser(BaseParser): +class WebVTTParser(Parser): """ Web Video Text Track parser. """ @classmethod def validate(cls, lines: typing.Sequence[str]) -> bool: - return lines[0].startswith('WEBVTT') + return bool(lines and lines[0].startswith('WEBVTT')) @classmethod - def parse(cls, lines: typing.Sequence[str]) -> typing.List[typing.Union[Style, Caption]]: - items = [] - comments = [] + def parse_content(cls, + lines: typing.Sequence[str] + ) -> typing.List[typing.Union[Style, Caption]]: + items: typing.List[typing.Union[Caption, Style]] = [] + comments: typing.List[WebVTTCommentBlock] = [] for block_lines in cls.iter_blocks_of_lines(lines): if WebVTTCueBlock.is_valid(block_lines): @@ -114,7 +82,7 @@ def parse(cls, lines: typing.Sequence[str]) -> typing.List[typing.Union[Style, C comments.append(WebVTTCommentBlock.from_lines(block_lines)) elif WebVTTStyleBlock.is_valid(block_lines): - style = Style(text=WebVTTStyleBlock.from_lines(block_lines).text) + style = Style(WebVTTStyleBlock.from_lines(block_lines).text) if comments: style.comments = [comment.text for comment in comments] comments = [] @@ -126,18 +94,26 @@ def parse(cls, lines: typing.Sequence[str]) -> typing.List[typing.Union[Style, C return items -class SRTParser(BaseParser): +class SRTParser(Parser): """ SubRip SRT parser. """ @classmethod def validate(cls, lines: typing.Sequence[str]) -> bool: - return len(lines) >= 3 and lines[0].isdigit() and '-->' in lines[1] and lines[2].strip() + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + '-->' in lines[1] and + lines[2].strip() + ) @classmethod - def parse(cls, lines: typing.Sequence[str]) -> typing.List[Caption]: - captions = [] + def parse_content( + cls, + lines: typing.Sequence[str] + ) -> typing.List[Caption]: + captions: typing.List[Caption] = [] for block_lines in cls.iter_blocks_of_lines(lines): if not SRTCueBlock.is_valid(block_lines): @@ -152,7 +128,7 @@ def parse(cls, lines: typing.Sequence[str]) -> typing.List[Caption]: return captions -class SBVParser(BaseParser): +class SBVParser(Parser): """ YouTube SBV parser. """ @@ -163,10 +139,13 @@ def validate(cls, lines: typing.Sequence[str]) -> bool: return False first_block = next(cls.iter_blocks_of_lines(lines)) - return first_block and SBVCueBlock.is_valid(first_block) + return bool(first_block and SBVCueBlock.is_valid(first_block)) @classmethod - def parse(cls, lines: typing.Sequence[str]) -> typing.List[Caption]: + def parse_content( + cls, + lines: typing.Sequence[str] + ) -> typing.List[Caption]: captions = [] for block_lines in cls.iter_blocks_of_lines(lines): diff --git a/webvtt/segmenter.py b/webvtt/segmenter.py index 9378ad6..4822468 100644 --- a/webvtt/segmenter.py +++ b/webvtt/segmenter.py @@ -1,17 +1,13 @@ import os from math import ceil, floor -from .errors import InvalidCaptionsError from .webvtt import WebVTT -from .structures import Caption MPEGTS = 900000 SECONDS = 10 # default number of seconds per segment -__all__ = ['WebVTTSegmenter'] - -class WebVTTSegmenter(object): +class WebVTTSegmenter: """ Provides segmentation of WebVTT captions for HTTP Live Streaming (HLS). """ @@ -22,15 +18,6 @@ def __init__(self): self._mpegts = 0 self._segments = [] - def _validate_webvtt(self, webvtt): - # Validates that the captions is a list and all the captions are instances of Caption. - if not isinstance(webvtt, WebVTT): - return False - for c in webvtt.captions: - if not isinstance(c, Caption): - return False - return True - def _slice_segments(self, captions): self._segments = [[] for _ in range(self.total_segments)] @@ -46,15 +33,19 @@ def _slice_segments(self, captions): def _write_segments(self): for index in range(self.total_segments): - segment_file = os.path.join(self._output_folder, 'fileSequence{}.webvtt'.format(index)) + segment_file = os.path.join(self._output_folder, + f'fileSequence{index}.webvtt' + ) with open(segment_file, 'w', encoding='utf-8') as f: f.write('WEBVTT\n') - f.write('X-TIMESTAMP-MAP=MPEGTS:{},LOCAL:00:00:00.000\n'.format(self._mpegts)) + f.write(f'X-TIMESTAMP-MAP=MPEGTS:{self._mpegts},' + 'LOCAL:00:00:00.000\n' + ) for caption in self.segments[index]: f.write('\n{} --> {}\n'.format(caption.start, caption.end)) - f.writelines(['{}\n'.format(l) for l in caption.lines]) + f.writelines(f'{line}\n' for line in caption.lines) def _write_manifest(self): manifest_file = os.path.join(self._output_folder, 'prog_index.m3u8') @@ -70,18 +61,21 @@ def _write_manifest(self): f.write('#EXT-X-ENDLIST\n') - def segment(self, webvtt, output='', seconds=SECONDS, mpegts=MPEGTS): - """Segments the captions based on a number of seconds.""" - if isinstance(webvtt, str): - # if a string is supplied we parse the file - captions = WebVTT().read(webvtt).captions - elif not self._validate_webvtt(webvtt): - raise InvalidCaptionsError('The captions provided are invalid') - else: - # we expect to have a webvtt object - captions = webvtt.captions - - self._total_segments = 0 if not captions else int(ceil(captions[-1].end_in_seconds / seconds)) + def segment( + self, + webvtt_path: str, + output: str, + seconds: int = SECONDS, + mpegts: int = MPEGTS + ): + """Segment the WebVTT based on a number of seconds.""" + captions = WebVTT.read(webvtt_path).captions + + self._total_segments = ( + 0 + if not captions else + int(ceil(captions[-1].end_in_seconds / seconds)) + ) self._output_folder = output self._seconds = seconds self._mpegts = mpegts diff --git a/webvtt/structures.py b/webvtt/structures.py index 6fbf9f4..7a92da2 100644 --- a/webvtt/structures.py +++ b/webvtt/structures.py @@ -1,27 +1,14 @@ import re import typing from datetime import datetime, time -from abc import ABC, abstractmethod from .errors import MalformedCaptionError -__all__ = ['Caption', 'Style'] - -class BlockItem(ABC): - @classmethod - @abstractmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - raise NotImplementedError - - @classmethod - @abstractmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'BlockItem': - raise NotImplementedError - - -class WebVTTCueBlock(BlockItem): - CUE_TIMINGS_PATTERN = re.compile(r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})') +class WebVTTCueBlock: + CUE_TIMINGS_PATTERN = re.compile( + r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})' + ) def __init__(self, identifier, start, end, payload): self.identifier = identifier @@ -47,6 +34,7 @@ def is_valid(cls, lines: typing.Sequence[str]) -> bool: @classmethod def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCueBlock': + identifier = None start = None end = None @@ -65,10 +53,10 @@ def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCueBlock': return cls(identifier, start, end, payload) -class WebVTTCommentBlock(BlockItem): +class WebVTTCommentBlock: COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL) - def __init__(self, text): + def __init__(self, text: str): self.text = text @classmethod @@ -81,7 +69,7 @@ def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCommentBlock': return cls(text=match.group(1).strip() if match else '') -class WebVTTStyleBlock(BlockItem): +class WebVTTStyleBlock: STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL) def __init__(self, text): @@ -97,10 +85,18 @@ def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTStyleBlock': return cls(text=match.group(1).strip() if match else '') -class SRTCueBlock(BlockItem): - CUE_TIMINGS_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})') +class SRTCueBlock: + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})' + ) - def __init__(self, index, start, end, payload): + def __init__( + self, + index: str, + start: time, + end: time, + payload: typing.Sequence[str] + ): self.index = index self.start = start self.end = end @@ -116,9 +112,11 @@ def is_valid(cls, lines: typing.Sequence[str]) -> bool: @classmethod def from_lines(cls, lines: typing.Sequence[str]) -> 'SRTCueBlock': + index = lines[0] match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + assert match is not None start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S,%f').time(), (match.group(1), match.group(2)) ) @@ -128,10 +126,17 @@ def from_lines(cls, lines: typing.Sequence[str]) -> 'SRTCueBlock': return cls(index, start, end, payload) -class SBVCueBlock(BlockItem): - CUE_TIMINGS_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})') +class SBVCueBlock: + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})' + ) - def __init__(self, start, end, payload): + def __init__( + self, + start: time, + end: time, + payload: typing.Sequence[str] + ): self.start = start self.end = end self.payload = payload @@ -147,6 +152,7 @@ def is_valid(cls, lines: typing.Sequence[str]) -> bool: @classmethod def from_lines(cls, lines: typing.Sequence[str]) -> 'SBVCueBlock': match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) + assert match is not None start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S.%f').time(), (match.group(1), match.group(2)) ) @@ -162,14 +168,21 @@ class Caption: def __init__(self, start: typing.Optional[typing.Union[str, time]] = None, end: typing.Optional[typing.Union[str, time]] = None, - text: typing.Optional[typing.Union[str, typing.List[str]]] = None, + text: typing.Optional[typing.Union[str, + typing.Sequence[str] + ]] = None, identifier: typing.Optional[str] = None ): + text = text or [] self.start = start or time() self.end = end or time() self.identifier = identifier - self.lines = text.splitlines() if isinstance(text, str) else text or [] - self.comments = [] + self.lines = (text.splitlines() + if isinstance(text, str) + else + list(text) + ) + self.comments: typing.List[str] = [] def __repr__(self): cleaned_text = self.text.replace('\n', '\\n') @@ -184,6 +197,12 @@ def __str__(self): cleaned_text = self.text.replace('\n', '\\n') return f'{self.start} {self.end} {cleaned_text}' + def __eq__(self, other): + return (self.start == other.start and + self.end == other.end and + self.raw_text == other.raw_text + ) + def add_line(self, line: str): self.lines.append(line) @@ -227,7 +246,9 @@ def text(self) -> str: @text.setter def text(self, value: str): if not isinstance(value, str): - raise AttributeError('String value expected but received {}.'.format(type(value))) + raise AttributeError( + f'String value expected but received {value}.' + ) self.lines = value.splitlines() @@ -261,7 +282,7 @@ def time_in_seconds(time_obj: time) -> int: class Style: def __init__(self, text: typing.Union[str, typing.List[str]]): self.lines = text.splitlines() if isinstance(text, str) else text - self.comments = [] + self.comments: typing.List[str] = [] @property def text(self): diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index a1d88a2..6776b85 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -1,12 +1,20 @@ import os import typing +import codecs +import warnings -from .parsers import WebVTTParser, SRTParser, SBVParser -from .writers import WebVTTWriter, SRTWriter +from . import writers +from .parsers import WebVTTParser, SRTParser, SBVParser, Parser from .structures import Caption, Style from .errors import MissingFilenameError -__all__ = ['WebVTT'] +CODEC_BOMS = { + 'utf-8-sig': codecs.BOM_UTF8, + 'utf-32-le': codecs.BOM_UTF32_LE, + 'utf-32-be': codecs.BOM_UTF32_BE, + 'utf-16-le': codecs.BOM_UTF16_LE, + 'utf-16-be': codecs.BOM_UTF16_BE +} class WebVTT: @@ -15,134 +23,211 @@ class WebVTT: To read WebVTT: - WebVTT().read('captions.vtt') + WebVTT.read('captions.vtt') - For other formats like SRT, use from_[format in lower case]: + For other formats: - WebVTT().from_srt('captions.srt') - - A list of all supported formats is available calling list_formats(). + WebVTT.from_srt('captions.srt') + WebVTT.from_sbv('captions.sbv') """ - def __init__(self, file='', captions=None, styles=None): + def __init__(self, + file: typing.Optional[str] = None, + captions: typing.Optional[typing.List[Caption]] = None, + styles: typing.Optional[typing.List[Style]] = None + ): self.file = file - self._captions = captions or [] - self._styles = styles + self.captions = captions or [] + self.styles = styles or [] def __len__(self): - return len(self._captions) + return len(self.captions) def __getitem__(self, index): - return self._captions[index] + return self.captions[index] def __repr__(self): - return '<%(cls)s file=%(file)s>' % { - 'cls': self.__class__.__name__, - 'file': self.file - } + return f'<{self.__class__.__name__} file={self.file}>' def __str__(self): - return '\n'.join([str(c) for c in self._captions]) + return '\n'.join(str(c) for c in self.captions) @classmethod - def from_srt(cls, file): - """Reads captions from a file in SubRip format.""" - return cls(file=file, captions=SRTParser.read(file)) + def read( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """Read a WebVTT captions file.""" + with cls._open_file(file, encoding=encoding) as f: + return cls.from_buffer(f) @classmethod - def from_sbv(cls, file): - """Reads captions from a file in YouTube SBV format.""" - return cls(file=file, captions=SBVParser.read(file)) + def read_buffer(cls, buffer: typing.Iterator[str]) -> 'WebVTT': + """ + [DEPRECATED] Read WebVTT captions from a file-like object. - @classmethod - def read(cls, file): - """Reads a WebVTT captions file.""" - items = WebVTTParser.read(file) - return cls(file=file, - captions=[it for it in items if isinstance(it, Caption)], - styles=[it for it in items if isinstance(it, Style)] - ) + Such file-like object may be the return of an io.open call, + io.StringIO object, tempfile.TemporaryFile object, etc. + """ + warnings.warn( + 'Deprecated: use from_buffer instead.', + DeprecationWarning + ) + return cls.from_buffer(buffer) @classmethod - def read_buffer(cls, buffer: typing.IO[str]): - """Reads a WebVTT captions from a file-like object. + def from_buffer(cls, buffer: typing.Iterator[str]) -> 'WebVTT': + """ + Read WebVTT captions from a file-like object. + Such file-like object may be the return of an io.open call, - io.StringIO object, tempfile.TemporaryFile object, etc.""" - items = WebVTTParser.read_from_buffer(buffer) - return cls(captions=[it for it in items if isinstance(it, Caption)], - styles=[it for it in items if isinstance(it, Style)] + io.StringIO object, tempfile.TemporaryFile object, etc. + """ + items = cls._parse_content(buffer, parser=WebVTTParser) + + return cls(file=getattr(buffer, 'name', None), + captions=items[0], + styles=items[1] ) - def _get_output_file(self, output, extension='vtt'): - if not output: + @classmethod + def from_srt( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """Read captions from a file in SubRip format.""" + with cls._open_file(file, encoding=encoding) as f: + return cls(file=f.name, + captions=cls._parse_content(f, parser=SRTParser)[0], + ) + + @classmethod + def from_sbv( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """Read captions from a file in YouTube SBV format.""" + with cls._open_file(file, encoding=encoding) as f: + return cls(file=f.name, + captions=cls._parse_content(f, parser=SBVParser)[0], + ) + + @classmethod + def from_string(cls, string: str) -> 'WebVTT': + return cls( + captions=cls._parse_content(string.splitlines(), + parser=WebVTTParser + )[0] + ) + + @classmethod + def _open_file( + cls, + file_path: str, + encoding: typing.Optional[str] = None + ) -> typing.IO: + return open( + file_path, + encoding=encoding or cls._detect_encoding(file_path) or 'utf-8' + ) + + @classmethod + def _parse_content( + cls, + content: typing.Iterable[str], + parser: typing.Type[Parser] + ) -> typing.Tuple[typing.List[Caption], typing.List[Style]]: + lines = [line.rstrip('\n\r') for line in content] + items = parser.parse(lines) + return (list(filter(lambda c: isinstance(c, Caption), items)), + list(filter(lambda s: isinstance(s, Style), items)) + ) + + @staticmethod + def _detect_encoding(file_path: str) -> typing.Optional[str]: + with open(file_path, mode='rb') as f: + first_bytes = f.read(4) + for encoding, bom in CODEC_BOMS.items(): + if first_bytes.startswith(bom): + return encoding + return None + + def _get_destination_file( + self, + destination_path: typing.Optional[str] = None, + extension: str = 'vtt' + ): + if not destination_path and not self.file: + raise MissingFilenameError + + assert self.file is not None + + destination_path = ( + destination_path or + f'{os.path.splitext(self.file)[0]}.{extension}' + ) + + target = os.path.join(os.getcwd(), destination_path) + if os.path.isdir(target): if not self.file: raise MissingFilenameError - # saving an original vtt file will overwrite the file - # and for files read from other formats will save as vtt - # with the same name and location - return os.path.splitext(self.file)[0] + '.' + extension - else: - target = os.path.join(os.getcwd(), output) - if os.path.isdir(target): - # if an output is provided and it is a directory - # the file will be saved in that location with the same name - filename = os.path.splitext(os.path.basename(self.file))[0] - return os.path.join(target, '{}.{}'.format(filename, extension)) - else: - if target[-3:].lower() != extension: - target += '.' + extension - # otherwise the file will be written in the specified location - return target - - def save(self, output=''): - """Save the document. - If no output is provided the file will be saved in the same location. Otherwise output - can determine a target directory or file. + + # store the file in specified directory + base_name = os.path.splitext(os.path.basename(self.file))[0] + new_filename = f'{base_name}.{extension}' + return os.path.join(target, new_filename) + + if target[-4:].lower() != f'.{extension}': + target = f'{target}.{extension}' + + # store the file in the specified full path + return target + + def save(self, output: typing.Optional[str] = None): """ - self.file = self._get_output_file(output) - with open(self.file, 'w', encoding='utf-8') as f: + Save the WebVTT captions to a file. + + If no output is provided the file will be saved in the same location. + Otherwise output can determine a target directory or file. + """ + destination_file = self._get_destination_file(output) + with open(destination_file, 'w', encoding='utf-8') as f: self.write(f) + self.file = destination_file - def save_as_srt(self, output=''): - self.file = self._get_output_file(output, extension='srt') - with open(self.file, 'w', encoding='utf-8') as f: + def save_as_srt(self, output: typing.Optional[str] = None): + dest_file = self._get_destination_file(output, extension='srt') + with open(dest_file, 'w', encoding='utf-8') as f: self.write(f, format='srt') + self.file = dest_file - def write(self, f, format='vtt'): + def write(self, f: typing.IO[str], format: str = 'vtt'): if format == 'vtt': - WebVTTWriter().write(self._captions, f) + writers.write_vtt(f, self.captions) elif format == 'srt': - SRTWriter().write(self._captions, f) -# elif output_format == OutputFormat.SBV: -# SBVWriter().write(self._captions, f) - - @staticmethod - def list_formats(): - """Provides a list of supported formats that this class can read from.""" - return ('WebVTT (.vtt)', 'SubRip (.srt)', 'YouTube SBV (.sbv)') - - @property - def captions(self): - """Returns the list of captions.""" - return self._captions + writers.write_srt(f, self.captions) + else: + raise ValueError(f'Format {format!r} is not supported') @property def total_length(self): """Returns the total length of the captions.""" - if not self._captions: + if not self.captions: return 0 - return int(self._captions[-1].end_in_seconds) - int(self._captions[0].start_in_seconds) - - @property - def styles(self): - return self._styles + return ( + self.captions[-1].end_in_seconds - + self.captions[0].start_in_seconds + ) @property - def content(self): + def content(self) -> str: """ - Return webvtt content with webvtt formatting. + Return the webvtt capions as string. This property is useful in cases where the webvtt content is needed but no file saving on the system is required. """ - return WebVTTWriter().webvtt_content(self._captions) + return writers.webvtt_content(self.captions) diff --git a/webvtt/writers.py b/webvtt/writers.py index 71b911e..e15ec50 100644 --- a/webvtt/writers.py +++ b/webvtt/writers.py @@ -3,39 +3,32 @@ from .structures import Caption -class WebVTTWriter: - - def write(self, captions: typing.Iterable[Caption], f: typing.IO[str]): - f.write(self.webvtt_content(captions)) - - def webvtt_content(self, captions: typing.Iterable[Caption]) -> str: - """ - Return captions content with webvtt formatting. - """ - output = ['WEBVTT'] - for caption in captions: - output.extend([ - '', - *(identifier for identifier in {caption.identifier} if identifier), - f'{caption.start} --> {caption.end}', - *caption.lines +def write_vtt(f: typing.IO[str], captions: typing.Iterable[Caption]): + f.write(webvtt_content(captions)) + + +def webvtt_content(captions: typing.Iterable[Caption]) -> str: + """Return captions content with webvtt formatting.""" + output = ['WEBVTT'] + for caption in captions: + output.extend([ + '', + *(identifier for identifier in {caption.identifier} if identifier), + f'{caption.start} --> {caption.end}', + *caption.lines + ]) + return '\n'.join(output) + + +def write_srt(f: typing.IO[str], captions: typing.Iterable[Caption]): + output = [] + for index, caption in enumerate(captions, start=1): + output.extend([ + f'{index}', + '{} --> {}'.format(*map(lambda x: x.replace('.', ','), + (caption.start, caption.end)) + ), + *caption.lines, + '' ]) - return '\n'.join(output) - - -class SRTWriter: - - def write(self, captions, f: typing.IO[str]): - output = [] - for index, caption in enumerate(captions, start=1): - output.extend([ - f'{index}', - '{} --> {}'.format(*map(lambda x: x.replace('.', ','), (caption.start, caption.end))), - *caption.lines, - '' - ]) - f.write('\n'.join(output).rstrip()) - - -class SBVWriter(object): - pass + f.write('\n'.join(output).rstrip()) From fdbf129c514328ab1344565174ffb46162c8535d Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Mon, 6 May 2024 18:01:51 +0200 Subject: [PATCH 03/16] Overall improvement --- LICENSE | 2 + setup.cfg | 6 +- setup.py | 19 +- tests/generic.py | 10 - .../captions_with_bom.vtt | 0 tests/{subtitles => samples}/comments.vtt | 0 tests/{subtitles => samples}/cue_tags.vtt | 0 tests/{subtitles => samples}/empty.vtt | 0 tests/{subtitles => samples}/invalid.vtt | 0 .../{subtitles => samples}/invalid_format.sbv | 0 .../invalid_format1.srt | 0 .../invalid_format2.srt | 0 .../invalid_format3.srt | 0 .../invalid_format4.srt | 0 .../invalid_timeframe.sbv | 0 .../invalid_timeframe.srt | 0 .../invalid_timeframe.vtt | 0 .../invalid_timeframe_in_cue_text.vtt | 0 .../metadata_headers.vtt | 0 .../metadata_headers_multiline.vtt | 0 .../missing_caption_text.sbv | 0 .../missing_caption_text.srt | 0 .../missing_caption_text.vtt | 0 .../missing_timeframe.sbv | 0 .../missing_timeframe.srt | 0 .../missing_timeframe.vtt | 0 .../netflix_chicas_del_cable.vtt | 0 tests/{subtitles => samples}/no_captions.vtt | 0 tests/{subtitles => samples}/one_caption.srt | 0 tests/{subtitles => samples}/one_caption.vtt | 0 tests/{subtitles => samples}/sample.sbv | 0 tests/{subtitles => samples}/sample.srt | 0 tests/{subtitles => samples}/sample.vtt | 0 tests/{subtitles => samples}/styles.vtt | 0 .../styles_with_comments.vtt | 0 tests/{subtitles => samples}/two_captions.sbv | 0 .../using_identifiers.vtt | 0 tests/{subtitles => samples}/youtube_dl.vtt | 0 tests/test_cli.py | 12 +- tests/test_models.py | 312 +++++ tests/test_sbv.py | 120 ++ tests/test_sbv_parser.py | 66 -- tests/test_segmenter.py | 483 +++++--- tests/test_srt.py | 169 ++- tests/test_srt_parser.py | 57 - tests/test_utils.py | 26 + tests/test_vtt.py | 433 +++++++ tests/test_webvtt.py | 1009 ++++++++++++----- tests/test_webvtt_parser.py | 194 ---- webvtt/__init__.py | 10 +- webvtt/cli.py | 18 +- webvtt/errors.py | 12 +- webvtt/models.py | 162 +++ webvtt/parsers.py | 161 --- webvtt/sbv.py | 121 ++ webvtt/segmenter.py | 209 ++-- webvtt/srt.py | 149 +++ webvtt/structures.py | 289 ----- webvtt/utils.py | 104 ++ webvtt/vtt.py | 276 +++++ webvtt/webvtt.py | 243 ++-- webvtt/writers.py | 34 - 62 files changed, 3184 insertions(+), 1522 deletions(-) delete mode 100644 tests/generic.py rename tests/{subtitles => samples}/captions_with_bom.vtt (100%) rename tests/{subtitles => samples}/comments.vtt (100%) rename tests/{subtitles => samples}/cue_tags.vtt (100%) rename tests/{subtitles => samples}/empty.vtt (100%) rename tests/{subtitles => samples}/invalid.vtt (100%) rename tests/{subtitles => samples}/invalid_format.sbv (100%) rename tests/{subtitles => samples}/invalid_format1.srt (100%) rename tests/{subtitles => samples}/invalid_format2.srt (100%) rename tests/{subtitles => samples}/invalid_format3.srt (100%) rename tests/{subtitles => samples}/invalid_format4.srt (100%) rename tests/{subtitles => samples}/invalid_timeframe.sbv (100%) rename tests/{subtitles => samples}/invalid_timeframe.srt (100%) rename tests/{subtitles => samples}/invalid_timeframe.vtt (100%) rename tests/{subtitles => samples}/invalid_timeframe_in_cue_text.vtt (100%) rename tests/{subtitles => samples}/metadata_headers.vtt (100%) rename tests/{subtitles => samples}/metadata_headers_multiline.vtt (100%) rename tests/{subtitles => samples}/missing_caption_text.sbv (100%) rename tests/{subtitles => samples}/missing_caption_text.srt (100%) rename tests/{subtitles => samples}/missing_caption_text.vtt (100%) rename tests/{subtitles => samples}/missing_timeframe.sbv (100%) rename tests/{subtitles => samples}/missing_timeframe.srt (100%) rename tests/{subtitles => samples}/missing_timeframe.vtt (100%) rename tests/{subtitles => samples}/netflix_chicas_del_cable.vtt (100%) rename tests/{subtitles => samples}/no_captions.vtt (100%) rename tests/{subtitles => samples}/one_caption.srt (100%) rename tests/{subtitles => samples}/one_caption.vtt (100%) rename tests/{subtitles => samples}/sample.sbv (100%) rename tests/{subtitles => samples}/sample.srt (100%) rename tests/{subtitles => samples}/sample.vtt (100%) rename tests/{subtitles => samples}/styles.vtt (100%) rename tests/{subtitles => samples}/styles_with_comments.vtt (100%) rename tests/{subtitles => samples}/two_captions.sbv (100%) rename tests/{subtitles => samples}/using_identifiers.vtt (100%) rename tests/{subtitles => samples}/youtube_dl.vtt (100%) create mode 100644 tests/test_models.py create mode 100644 tests/test_sbv.py delete mode 100644 tests/test_sbv_parser.py delete mode 100644 tests/test_srt_parser.py create mode 100644 tests/test_utils.py create mode 100644 tests/test_vtt.py delete mode 100644 tests/test_webvtt_parser.py create mode 100644 webvtt/models.py delete mode 100644 webvtt/parsers.py create mode 100644 webvtt/sbv.py create mode 100644 webvtt/srt.py delete mode 100644 webvtt/structures.py create mode 100644 webvtt/utils.py create mode 100644 webvtt/vtt.py delete mode 100644 webvtt/writers.py diff --git a/LICENSE b/LICENSE index 3a702d9..16ad61a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,5 @@ +MIT License + Copyright (c) 2016 Alejandro Mendez Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/setup.cfg b/setup.cfg index 2233487..0fdf338 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,9 @@ [metadata] description_file = README.rst +[flake8] +doctests = true + [tox:tox] envlist = codestyle,py,coverage @@ -9,7 +12,7 @@ source = webvtt branch = true [coverage:report] -fail_under = 97 +fail_under = 100 [testenv] deps = @@ -21,6 +24,7 @@ commands = [testenv:codestyle] deps = flake8 + flake8-docstrings mypy commands = flake8 webvtt setup.py diff --git a/setup.py b/setup.py index 3219c24..6bae26d 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,22 @@ -import io -import re -from setuptools import setup, find_packages +"""webvtt-py setuptools configuration.""" -with io.open('README.rst', 'r', encoding='utf-8') as f: - readme = f.read() +import re +import pathlib -with io.open('webvtt/__init__.py', 'rt', encoding='utf-8') as f: - version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) +from setuptools import setup, find_packages +version = ( + re.search( + r'__version__ = \'(.*?)\'', + pathlib.Path('webvtt/__init__.py').read_text() + ).group(1) + ) setup( name='webvtt-py', version=version, description='WebVTT reader, writer and segmenter', - long_description=readme, + long_description=pathlib.Path('README.rst').read_text(), author='Alejandro Mendez', author_email='amendez23@gmail.com', url='https://github.com/glut23/webvtt-py', diff --git a/tests/generic.py b/tests/generic.py deleted file mode 100644 index 0f84ae3..0000000 --- a/tests/generic.py +++ /dev/null @@ -1,10 +0,0 @@ -import os -import unittest - - -class GenericParserTestCase(unittest.TestCase): - - SUBTITLES_DIR = os.path.join(os.path.dirname(__file__), 'subtitles') - - def _get_file(self, filename): - return os.path.join(self.SUBTITLES_DIR, filename) diff --git a/tests/subtitles/captions_with_bom.vtt b/tests/samples/captions_with_bom.vtt similarity index 100% rename from tests/subtitles/captions_with_bom.vtt rename to tests/samples/captions_with_bom.vtt diff --git a/tests/subtitles/comments.vtt b/tests/samples/comments.vtt similarity index 100% rename from tests/subtitles/comments.vtt rename to tests/samples/comments.vtt diff --git a/tests/subtitles/cue_tags.vtt b/tests/samples/cue_tags.vtt similarity index 100% rename from tests/subtitles/cue_tags.vtt rename to tests/samples/cue_tags.vtt diff --git a/tests/subtitles/empty.vtt b/tests/samples/empty.vtt similarity index 100% rename from tests/subtitles/empty.vtt rename to tests/samples/empty.vtt diff --git a/tests/subtitles/invalid.vtt b/tests/samples/invalid.vtt similarity index 100% rename from tests/subtitles/invalid.vtt rename to tests/samples/invalid.vtt diff --git a/tests/subtitles/invalid_format.sbv b/tests/samples/invalid_format.sbv similarity index 100% rename from tests/subtitles/invalid_format.sbv rename to tests/samples/invalid_format.sbv diff --git a/tests/subtitles/invalid_format1.srt b/tests/samples/invalid_format1.srt similarity index 100% rename from tests/subtitles/invalid_format1.srt rename to tests/samples/invalid_format1.srt diff --git a/tests/subtitles/invalid_format2.srt b/tests/samples/invalid_format2.srt similarity index 100% rename from tests/subtitles/invalid_format2.srt rename to tests/samples/invalid_format2.srt diff --git a/tests/subtitles/invalid_format3.srt b/tests/samples/invalid_format3.srt similarity index 100% rename from tests/subtitles/invalid_format3.srt rename to tests/samples/invalid_format3.srt diff --git a/tests/subtitles/invalid_format4.srt b/tests/samples/invalid_format4.srt similarity index 100% rename from tests/subtitles/invalid_format4.srt rename to tests/samples/invalid_format4.srt diff --git a/tests/subtitles/invalid_timeframe.sbv b/tests/samples/invalid_timeframe.sbv similarity index 100% rename from tests/subtitles/invalid_timeframe.sbv rename to tests/samples/invalid_timeframe.sbv diff --git a/tests/subtitles/invalid_timeframe.srt b/tests/samples/invalid_timeframe.srt similarity index 100% rename from tests/subtitles/invalid_timeframe.srt rename to tests/samples/invalid_timeframe.srt diff --git a/tests/subtitles/invalid_timeframe.vtt b/tests/samples/invalid_timeframe.vtt similarity index 100% rename from tests/subtitles/invalid_timeframe.vtt rename to tests/samples/invalid_timeframe.vtt diff --git a/tests/subtitles/invalid_timeframe_in_cue_text.vtt b/tests/samples/invalid_timeframe_in_cue_text.vtt similarity index 100% rename from tests/subtitles/invalid_timeframe_in_cue_text.vtt rename to tests/samples/invalid_timeframe_in_cue_text.vtt diff --git a/tests/subtitles/metadata_headers.vtt b/tests/samples/metadata_headers.vtt similarity index 100% rename from tests/subtitles/metadata_headers.vtt rename to tests/samples/metadata_headers.vtt diff --git a/tests/subtitles/metadata_headers_multiline.vtt b/tests/samples/metadata_headers_multiline.vtt similarity index 100% rename from tests/subtitles/metadata_headers_multiline.vtt rename to tests/samples/metadata_headers_multiline.vtt diff --git a/tests/subtitles/missing_caption_text.sbv b/tests/samples/missing_caption_text.sbv similarity index 100% rename from tests/subtitles/missing_caption_text.sbv rename to tests/samples/missing_caption_text.sbv diff --git a/tests/subtitles/missing_caption_text.srt b/tests/samples/missing_caption_text.srt similarity index 100% rename from tests/subtitles/missing_caption_text.srt rename to tests/samples/missing_caption_text.srt diff --git a/tests/subtitles/missing_caption_text.vtt b/tests/samples/missing_caption_text.vtt similarity index 100% rename from tests/subtitles/missing_caption_text.vtt rename to tests/samples/missing_caption_text.vtt diff --git a/tests/subtitles/missing_timeframe.sbv b/tests/samples/missing_timeframe.sbv similarity index 100% rename from tests/subtitles/missing_timeframe.sbv rename to tests/samples/missing_timeframe.sbv diff --git a/tests/subtitles/missing_timeframe.srt b/tests/samples/missing_timeframe.srt similarity index 100% rename from tests/subtitles/missing_timeframe.srt rename to tests/samples/missing_timeframe.srt diff --git a/tests/subtitles/missing_timeframe.vtt b/tests/samples/missing_timeframe.vtt similarity index 100% rename from tests/subtitles/missing_timeframe.vtt rename to tests/samples/missing_timeframe.vtt diff --git a/tests/subtitles/netflix_chicas_del_cable.vtt b/tests/samples/netflix_chicas_del_cable.vtt similarity index 100% rename from tests/subtitles/netflix_chicas_del_cable.vtt rename to tests/samples/netflix_chicas_del_cable.vtt diff --git a/tests/subtitles/no_captions.vtt b/tests/samples/no_captions.vtt similarity index 100% rename from tests/subtitles/no_captions.vtt rename to tests/samples/no_captions.vtt diff --git a/tests/subtitles/one_caption.srt b/tests/samples/one_caption.srt similarity index 100% rename from tests/subtitles/one_caption.srt rename to tests/samples/one_caption.srt diff --git a/tests/subtitles/one_caption.vtt b/tests/samples/one_caption.vtt similarity index 100% rename from tests/subtitles/one_caption.vtt rename to tests/samples/one_caption.vtt diff --git a/tests/subtitles/sample.sbv b/tests/samples/sample.sbv similarity index 100% rename from tests/subtitles/sample.sbv rename to tests/samples/sample.sbv diff --git a/tests/subtitles/sample.srt b/tests/samples/sample.srt similarity index 100% rename from tests/subtitles/sample.srt rename to tests/samples/sample.srt diff --git a/tests/subtitles/sample.vtt b/tests/samples/sample.vtt similarity index 100% rename from tests/subtitles/sample.vtt rename to tests/samples/sample.vtt diff --git a/tests/subtitles/styles.vtt b/tests/samples/styles.vtt similarity index 100% rename from tests/subtitles/styles.vtt rename to tests/samples/styles.vtt diff --git a/tests/subtitles/styles_with_comments.vtt b/tests/samples/styles_with_comments.vtt similarity index 100% rename from tests/subtitles/styles_with_comments.vtt rename to tests/samples/styles_with_comments.vtt diff --git a/tests/subtitles/two_captions.sbv b/tests/samples/two_captions.sbv similarity index 100% rename from tests/subtitles/two_captions.sbv rename to tests/samples/two_captions.sbv diff --git a/tests/subtitles/using_identifiers.vtt b/tests/samples/using_identifiers.vtt similarity index 100% rename from tests/subtitles/using_identifiers.vtt rename to tests/samples/using_identifiers.vtt diff --git a/tests/subtitles/youtube_dl.vtt b/tests/samples/youtube_dl.vtt similarity index 100% rename from tests/subtitles/youtube_dl.vtt rename to tests/samples/youtube_dl.vtt diff --git a/tests/test_cli.py b/tests/test_cli.py index bb31353..ca42cf7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,19 +5,18 @@ import textwrap from webvtt.cli import main -import subprocess + class CLITestCase(unittest.TestCase): def test_cli(self): - vtt_file = pathlib.Path(__file__).resolve().parent / 'subtitles' / 'sample.vtt' + vtt_file = ( + pathlib.Path(__file__).resolve().parent + / 'samples' / 'sample.vtt' + ) with tempfile.TemporaryDirectory() as temp_dir: - - # result = subprocess.run(['python', 'webvtt/cli.py', *['segment', str(vtt_file.resolve()), '-o', temp_dir]], capture_output=True, - # text=True) - main(['segment', str(vtt_file.resolve()), '-o', temp_dir]) _, dirs, files = next(os.walk(temp_dir)) @@ -181,4 +180,3 @@ def test_cli(self): Caption text #16 ''' ).lstrip()) - diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d36ebfd --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,312 @@ +import unittest +from datetime import time + +from webvtt import Caption, Style +from webvtt.errors import MalformedCaptionError + + +class TestCaption(unittest.TestCase): + + def test_instantiation(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + self.assertEqual(caption.start, '00:00:07.000') + self.assertEqual(caption.end, '00:00:11.890') + self.assertEqual(caption.text, 'Hello test!') + self.assertEqual(caption.identifier, 'A test caption') + + def test_timestamp_accept_time(self): + caption = Caption( + start=time(hour=0, minute=0, second=7, microsecond=0), + end=time(hour=0, minute=0, second=11, microsecond=890000), + text='Hello test!', + identifier='A test caption' + ) + self.assertEqual(caption.start, '00:00:07.000') + self.assertEqual(caption.end, '00:00:11.890') + + def test_identifier_is_optional(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + ) + self.assertIsNone(caption.identifier) + + def test_multi_lines(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!\nThis is the second line', + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertListEqual( + caption.lines, + ['Hello test!', 'This is the second line'] + ) + + def test_multi_lines_accepts_list(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=['Hello test!', 'This is the second line'], + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertListEqual( + caption.lines, + ['Hello test!', 'This is the second line'] + ) + + def test_cuetags(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=[ + 'Hello test!', + 'This is the second line' + ], + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertEqual( + caption.raw_text, + 'Hello test!\n' + 'This is the second line' + ) + + def test_in_seconds(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=['Hello test!', 'This is the second line'], + identifier='A test caption' + ) + self.assertEqual(caption.start_in_seconds, 7) + self.assertEqual(caption.end_in_seconds, 11) + + def test_wrong_start_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start='1234', + end='00:00:11.890', + text='Hello Test!' + ) + + def test_wrong_type_start_timestamp(self): + self.assertRaises( + TypeError, + Caption, + start=1234, + end='00:00:11.890', + text='Hello Test!' + ) + + def test_wrong_end_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start='00:00:07.000', + end='1234', + text='Hello Test!' + ) + + def test_wrong_type_end_timestamp(self): + self.assertRaises( + TypeError, + Caption, + start='00:00:07.000', + end=1234, + text='Hello Test!' + ) + + def test_equality(self): + caption1 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + caption2 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertTrue(caption1 == caption2) + + caption1 = Caption( + start='00:00:02.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + caption2 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertFalse(caption1 == caption2) + + self.assertFalse( + Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) == 1234 + ) + + def test_repr(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertEqual( + repr(caption), + "" + ) + + def test_str(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertEqual( + str(caption), + '00:00:07.000 00:00:11.890 Hello test!' + ) + + def test_accept_comments(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + caption.comments.append('One comment') + caption.comments.append('Another comment') + + self.assertListEqual( + caption.comments, + ['One comment', 'Another comment'] + ) + + def test_timestamp_update(self): + c = Caption('00:00:00.500', '00:00:07.000') + c.start = '00:00:01.750' + c.end = '00:00:08.250' + + self.assertEqual(c.start, '00:00:01.750') + self.assertEqual(c.end, '00:00:08.250') + + def test_timestamp_format(self): + c = Caption('01:02:03.400', '02:03:04.500') + self.assertEqual(c.start, '01:02:03.400') + self.assertEqual(c.end, '02:03:04.500') + + c = Caption('02:03.400', '03:04.500') + self.assertEqual(c.start, '00:02:03.400') + self.assertEqual(c.end, '00:03:04.500') + + def test_update_text(self): + c = Caption(text='Caption line #1') + c.text = 'Caption line #1 updated' + self.assertEqual( + c.text, + 'Caption line #1 updated' + ) + + def test_update_text_multiline(self): + c = Caption(text='Caption line #1') + c.text = 'Caption line #1\nCaption line #2' + + self.assertEqual( + len(c.lines), + 2 + ) + + self.assertEqual( + c.text, + 'Caption line #1\nCaption line #2' + ) + + def test_update_text_wrong_type(self): + c = Caption(text='Caption line #1') + + self.assertRaises( + AttributeError, + setattr, + c, + 'text', + 123 + ) + + def test_manipulate_lines(self): + c = Caption(text=['Caption line #1', 'Caption line #2']) + c.lines[0] = 'Caption line #1 updated' + self.assertEqual( + c.lines[0], + 'Caption line #1 updated' + ) + + def test_malformed_start_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + '01:00' + ) + + +class TestStyle(unittest.TestCase): + + def test_instantiation(self): + style = Style(text='::cue(b) {\ncolor: peachpuff;\n}') + self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}') + self.assertListEqual( + style.lines, + ['::cue(b) {', 'color: peachpuff;', '}'] + ) + + def test_text_accept_list_of_strings(self): + style = Style(text=['::cue(b) {', 'color: peachpuff;', '}']) + self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}') + self.assertListEqual( + style.lines, + ['::cue(b) {', 'color: peachpuff;', '}'] + ) + + def test_accept_comments(self): + style = Style(text='::cue(b) {\ncolor: peachpuff;\n}') + style.comments.append('One comment') + style.comments.append('Another comment') + + self.assertListEqual( + style.comments, + ['One comment', 'Another comment'] + ) + + def test_get_text(self): + style = Style(['::cue(b) {', ' color: peachpuff;', '}']) + self.assertEqual( + style.text, + '::cue(b) {\n color: peachpuff;\n}' + ) diff --git a/tests/test_sbv.py b/tests/test_sbv.py new file mode 100644 index 0000000..79d9dff --- /dev/null +++ b/tests/test_sbv.py @@ -0,0 +1,120 @@ +import unittest +import textwrap +from datetime import time + +from webvtt import sbv +from webvtt.errors import MalformedFileError +from webvtt.models import Caption + + +class TestSBVCueBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + ''').strip().split('\n')) + ) + + def test_from_lines(self): + cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual( + cue_block.start, + time(hour=0, minute=0, second=0, microsecond=500000) + ) + self.assertEqual( + cue_block.end, + time(hour=0, minute=0, second=7, microsecond=0) + ) + self.assertEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) + + +class TestSBVModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + sbv.parse, + textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + Caption text #1 + + 00:00:07.000,00:00:11.890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + captions = sbv.parse( + textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 + + 00:00:07.000,00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption #2 line 1\n' + 'Caption #2 line 2' + ) diff --git a/tests/test_sbv_parser.py b/tests/test_sbv_parser.py deleted file mode 100644 index 77cc7bb..0000000 --- a/tests/test_sbv_parser.py +++ /dev/null @@ -1,66 +0,0 @@ -import webvtt - -from .generic import GenericParserTestCase - - -class SBVParserTestCase(GenericParserTestCase): - - def test_sbv_parse_empty_file(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_sbv, - self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose. - ) - - def test_sbv_invalid_format(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_sbv, - self._get_file('invalid_format.sbv') - ) - - def test_sbv_total_length(self): - self.assertEqual( - webvtt.from_sbv(self._get_file('sample.sbv')).total_length, - 16 - ) - - def test_sbv_parse_captions(self): - self.assertEqual( - len(webvtt.from_srt(self._get_file('sample.srt')).captions), - 5 - ) - - def test_sbv_missing_timeframe_line(self): - self.assertEqual(len(webvtt.from_sbv(self._get_file('missing_timeframe.sbv')).captions), 4) - - def test_sbv_missing_caption_text(self): - self.assertTrue(webvtt.from_sbv(self._get_file('missing_caption_text.sbv')).captions) - - def test_sbv_invalid_timestamp(self): - self.assertEqual(len(webvtt.from_sbv(self._get_file('invalid_timeframe.sbv')).captions), 4) - - def test_sbv_timestamps_format(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].start, '00:00:11.378') - self.assertEqual(vtt.captions[1].end, '00:00:12.305') - - def test_sbv_timestamps_in_seconds(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].start_in_seconds, 11) - self.assertEqual(vtt.captions[1].end_in_seconds, 12) - - def test_sbv_get_caption_text(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].text, 'Caption text #2') - - def test_sbv_get_caption_text_multiline(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual( - vtt.captions[2].text, - 'Caption text #3 (line 1)\nCaption text #3 (line 2)' - ) - self.assertListEqual( - vtt.captions[2].lines, - ['Caption text #3 (line 1)', 'Caption text #3 (line 2)'] - ) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index b5ce88e..3d290a9 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,179 +1,330 @@ import os import unittest -from shutil import rmtree +import tempfile +import pathlib +import textwrap -from webvtt import WebVTTSegmenter -from webvtt import WebVTT +from webvtt import segmenter -BASE_DIR = os.path.dirname(__file__) -SUBTITLES_DIR = os.path.join(BASE_DIR, 'subtitles') -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') +PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples' -class TestWebVTTSegmenter(unittest.TestCase): +class TestSegmenter(unittest.TestCase): def setUp(self): - self.segmenter = WebVTTSegmenter() + self.temp_dir = tempfile.TemporaryDirectory() def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) - - @staticmethod - def _get_file(filename: str) -> str: - return os.path.join(SUBTITLES_DIR, filename) - - def test_total_segments(self): - # segment with default 10 seconds - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) - self.assertEqual(self.segmenter.total_segments, 7) - - # segment with custom 30 seconds - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 30) - self.assertEqual(self.segmenter.total_segments, 3) - - def test_output_folder_is_created(self): - self.assertFalse(os.path.exists(OUTPUT_DIR)) - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) - self.assertTrue(os.path.exists(OUTPUT_DIR)) - - def test_segmentation_files_exist(self): - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR) - for i in range(7): - self.assertTrue( - os.path.exists(os.path.join(OUTPUT_DIR, - f'fileSequence{i}.webvtt' - ) - ) + self.temp_dir.cleanup() + + def test_segmentation_with_defaults(self): + segmenter.segment(PATH_TO_SAMPLES / 'sample.vtt', self.temp_dir.name) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 8) + + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + 'fileSequence3.webvtt', + 'fileSequence4.webvtt', + 'fileSequence5.webvtt', + 'fileSequence6.webvtt', + ): + self.assertIn(expected_file, files) + + output_path = pathlib.Path(self.temp_dir.name) + + self.assertEqual( + (output_path / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXTINF:30.00000 + fileSequence3.webvtt + #EXTINF:30.00000 + fileSequence4.webvtt + #EXTINF:30.00000 + fileSequence5.webvtt + #EXTINF:30.00000 + fileSequence6.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence3.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence4.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence5.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence6.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip()) + + def test_segmentation_with_custom_values(self): + segmenter.segment( + webvtt_path=PATH_TO_SAMPLES / 'sample.vtt', + output=self.temp_dir.name, + seconds=30, + mpegts=800000 ) - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, - 'prog_index.m3u8' - ))) - - def test_segmentation(self): - filepath = self._get_file('sample.vtt') - self.segmenter.segment(filepath, OUTPUT_DIR) - - webvtt = WebVTT.read(filepath) - - # segment 1 should have caption 1 and 2 - self.assertEqual(len(self.segmenter.segments[0]), 2) - self.assertIn(webvtt.captions[0], self.segmenter.segments[0]) - self.assertIn(webvtt.captions[1], self.segmenter.segments[0]) - # segment 2 should have caption 2 again (overlap), 3 and 4 - self.assertEqual(len(self.segmenter.segments[1]), 3) - self.assertIn(webvtt.captions[2], self.segmenter.segments[1]) - self.assertIn(webvtt.captions[3], self.segmenter.segments[1]) - # segment 3 should have caption 4 again (overlap), 5, 6 and 7 - self.assertEqual(len(self.segmenter.segments[2]), 4) - self.assertIn(webvtt.captions[3], self.segmenter.segments[2]) - self.assertIn(webvtt.captions[4], self.segmenter.segments[2]) - self.assertIn(webvtt.captions[5], self.segmenter.segments[2]) - self.assertIn(webvtt.captions[6], self.segmenter.segments[2]) - # segment 4 should have caption 7 again (overlap), 8, 9 and 10 - self.assertEqual(len(self.segmenter.segments[3]), 4) - self.assertIn(webvtt.captions[6], self.segmenter.segments[3]) - self.assertIn(webvtt.captions[7], self.segmenter.segments[3]) - self.assertIn(webvtt.captions[8], self.segmenter.segments[3]) - self.assertIn(webvtt.captions[9], self.segmenter.segments[3]) - # segment 5 should have caption 10 again (overlap), 11 and 12 - self.assertEqual(len(self.segmenter.segments[4]), 3) - self.assertIn(webvtt.captions[9], self.segmenter.segments[4]) - self.assertIn(webvtt.captions[10], self.segmenter.segments[4]) - self.assertIn(webvtt.captions[11], self.segmenter.segments[4]) - # segment 6 should have caption 12 again (overlap), 13, 14 and 15 - self.assertEqual(len(self.segmenter.segments[5]), 4) - self.assertIn(webvtt.captions[11], self.segmenter.segments[5]) - self.assertIn(webvtt.captions[12], self.segmenter.segments[5]) - self.assertIn(webvtt.captions[13], self.segmenter.segments[5]) - self.assertIn(webvtt.captions[14], self.segmenter.segments[5]) - # segment 7 should have caption 15 again (overlap) and 16 - self.assertEqual(len(self.segmenter.segments[6]), 2) - self.assertIn(webvtt.captions[14], self.segmenter.segments[6]) - self.assertIn(webvtt.captions[15], self.segmenter.segments[6]) - - def test_segment_content(self): - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 10) - - with open( - os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), - 'r', - encoding='utf-8' - ) as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - 'X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2' - ] - - self.assertListEqual(lines, expected_lines) - - def test_manifest_content(self): - self.segmenter.segment(self._get_file('sample.vtt'), OUTPUT_DIR, 10) - - with open( - os.path.join(OUTPUT_DIR, 'prog_index.m3u8'), - 'r', - encoding='utf-8' - ) as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - '#EXTM3U', - '#EXT-X-TARGETDURATION:{}'.format(self.segmenter.seconds), - '#EXT-X-VERSION:3', - '#EXT-X-PLAYLIST-TYPE:VOD', - ] - - for i in range(7): - expected_lines.extend([ - '#EXTINF:30.00000', - 'fileSequence{}.webvtt'.format(i) - ]) - - expected_lines.append('#EXT-X-ENDLIST') - - for index, line in enumerate(expected_lines): - self.assertEqual(lines[index], line) - - def test_customize_mpegts(self): - self.segmenter.segment(self._get_file('sample.vtt'), - OUTPUT_DIR, - mpegts=800000 - ) - - with open( - os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), - 'r', - encoding='utf-8' - ) as f: - lines = f.readlines() - self.assertIn('MPEGTS:800000', lines[1]) - - def test_segment_from_file(self): - self.segmenter.segment( - os.path.join(SUBTITLES_DIR, 'sample.vtt'), OUTPUT_DIR - ), - self.assertEqual(self.segmenter.total_segments, 7) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 4) + + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + ): + self.assertIn(expected_file, files) + + output_path = pathlib.Path(self.temp_dir.name) + + self.assertEqual( + (output_path / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:30 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip()) + self.assertEqual( + (output_path / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip()) def test_segment_with_no_captions(self): - self.segmenter.segment( - os.path.join(SUBTITLES_DIR, 'no_captions.vtt'), OUTPUT_DIR - ), - self.assertEqual(self.segmenter.total_segments, 0) - - def test_total_segments_readonly(self): - self.assertRaises( - AttributeError, - setattr, - WebVTTSegmenter(), - 'total_segments', - 5 - ) + segmenter.segment( + webvtt_path=PATH_TO_SAMPLES / 'no_captions.vtt', + output=self.temp_dir.name + ) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 1) + self.assertIn('prog_index.m3u8', files) + + self.assertEqual( + (pathlib.Path(self.temp_dir.name) / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXT-X-ENDLIST + ''' + ).lstrip()) + diff --git a/tests/test_srt.py b/tests/test_srt.py index 70476af..f7f7577 100644 --- a/tests/test_srt.py +++ b/tests/test_srt.py @@ -1,34 +1,161 @@ -import os -from shutil import rmtree, copy +import unittest +import io +import textwrap +from datetime import time -import webvtt +from webvtt import srt +from webvtt.errors import MalformedFileError +from webvtt.models import Caption -from .generic import GenericParserTestCase +class TestSRTCueBlock(unittest.TestCase): -BASE_DIR = os.path.dirname(__file__) -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') + def test_is_valid(self): + self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + ''').strip().split('\n')) + ) + self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) -class SRTCaptionsTestCase(GenericParserTestCase): + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + Caption #1 + ''').strip().split('\n')) + ) - def setUp(self): - os.makedirs(OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) - def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + ''').strip().split('\n')) + ) - def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): - copy(self._get_file('sample.srt'), OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + Caption #1 + ''').strip().split('\n')) + ) - vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'sample.srt')) - vtt.save_as_srt(os.path.join(OUTPUT_DIR, 'sample_converted.srt')) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + Caption #1 + ''').strip().split('\n')) + ) - with open(os.path.join(OUTPUT_DIR, 'sample.srt'), 'r', encoding='utf-8') as f: - original = f.read() + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + ''').strip().split('\n')) + ) - with open(os.path.join(OUTPUT_DIR, 'sample_converted.srt'), 'r', encoding='utf-8') as f: - converted = f.read() + def test_from_lines(self): + cue_block = srt.SRTCueBlock.from_lines(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual(cue_block.index, '1') + self.assertEqual( + cue_block.start, + time(hour=0, minute=0, second=0, microsecond=500000) + ) + self.assertEqual( + cue_block.end, + time(hour=0, minute=0, second=7, microsecond=0) + ) + self.assertEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) - self.assertEqual(original, converted) + +class TestSRTModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + srt.parse, + textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + Caption text #1 + + 00:00:07,000 --> 00:00:11,890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + captions = srt.parse( + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption #2 line 1\n' + 'Caption #2 line 2' + ) + + def test_write(self): + out = io.StringIO() + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + + srt.write(out, captions) + + out.seek(0) + + self.assertEqual( + out.read(), + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip() + ) diff --git a/tests/test_srt_parser.py b/tests/test_srt_parser.py deleted file mode 100644 index e09e703..0000000 --- a/tests/test_srt_parser.py +++ /dev/null @@ -1,57 +0,0 @@ -import webvtt - -from .generic import GenericParserTestCase - - -class SRTParserTestCase(GenericParserTestCase): - - def test_srt_parse_empty_file(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_srt, - self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose. - ) - - def test_srt_invalid_format(self): - for i in range(1, 5): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_srt, - self._get_file('invalid_format{}.srt'.format(i)) - ) - - def test_srt_total_length(self): - self.assertEqual( - webvtt.from_srt(self._get_file('sample.srt')).total_length, - 23 - ) - - def test_srt_parse_captions(self): - self.assertTrue(webvtt.from_srt(self._get_file('sample.srt')).captions) - - def test_srt_missing_timeframe_line(self): - self.assertEqual(len(webvtt.from_srt(self._get_file('missing_timeframe.srt')).captions), 4) - - def test_srt_empty_caption_text(self): - self.assertTrue(webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions) - - def test_srt_empty_gets_removed(self): - captions = webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions - self.assertEqual(len(captions), 4) - - def test_srt_invalid_timestamp(self): - self.assertEqual(len(webvtt.from_srt(self._get_file('invalid_timeframe.srt')).captions), 4) - - def test_srt_timestamps_format(self): - vtt = webvtt.from_srt(self._get_file('sample.srt')) - self.assertEqual(vtt.captions[2].start, '00:00:11.890') - self.assertEqual(vtt.captions[2].end, '00:00:16.320') - - def test_srt_parse_get_caption_data(self): - vtt = webvtt.from_srt(self._get_file('one_caption.srt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0) - self.assertEqual(vtt.captions[0].start, '00:00:00.500') - self.assertEqual(vtt.captions[0].end_in_seconds, 7) - self.assertEqual(vtt.captions[0].end, '00:00:07.000') - self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') - self.assertEqual(len(vtt.captions[0].lines), 1) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..23fc1a2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +import unittest +import tempfile + +from webvtt.utils import FileWrapper, CODEC_BOMS + + +class TestUtils(unittest.TestCase): + + def test_open_file(self): + with tempfile.NamedTemporaryFile('w') as f: + f.write('Hello test!') + f.flush() + with FileWrapper.open(f.name) as fw: + self.assertEqual(fw.file.read(), 'Hello test!') + self.assertEqual(fw.file.encoding, 'utf-8') + + def test_open_file_with_bom(self): + for encoding, bom in CODEC_BOMS.items(): + with tempfile.NamedTemporaryFile('wb') as f: + f.write(bom) + f.write('Hello test'.encode(encoding)) + f.flush() + + with FileWrapper.open(f.name) as fw: + self.assertEqual(fw.file.read(), 'Hello test') + self.assertEqual(fw.file.encoding, encoding) diff --git a/tests/test_vtt.py b/tests/test_vtt.py new file mode 100644 index 0000000..3f7840d --- /dev/null +++ b/tests/test_vtt.py @@ -0,0 +1,433 @@ +import unittest +import textwrap +import io + +from webvtt import vtt +from webvtt.errors import MalformedFileError +from webvtt.models import Caption, Style + + +class TestWebVTTCueBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 00:00:07.000 + Caption #1 line 1 + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + ''').strip().split('\n')) + ) + + def test_from_lines(self): + cue_block = vtt.WebVTTCueBlock.from_lines(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual(cue_block.identifier, 'identifier') + self.assertEqual(cue_block.start, '00:00:00.500') + self.assertEqual(cue_block.end, '00:00:07.000') + self.assertListEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) + + +class TestWebVTTCommentBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE This is a one line comment + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE + This is a another one line comment + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE + This is a multi-line comment + taking two lines + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + This is not a comment + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + # This is not a comment + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + // This is not a comment + ''').strip().split('\n')) + ) + + def test_from_lines(self): + comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent(''' + NOTE This is a one line comment + ''').strip().split('\n') + ) + self.assertEqual(comment.text, 'This is a one line comment') + + comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent(''' + NOTE + This is a multi-line comment + taking two lines + ''').strip().split('\n') + ) + self.assertEqual( + comment.text, + 'This is a multi-line comment\ntaking two lines' + ) + + +class TestWebVTTStyleBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue--> { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + def test_from_lines(self): + style = vtt.WebVTTStyleBlock.from_lines(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n') + ) + self.assertEqual( + style.text, + textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip() + ) + + +class TestVTTModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + vtt.parse, + textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + captions, styles = vtt.parse( + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertEqual(len(styles), 0) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + + def test_parse_styles(self): + captions, styles = vtt.parse( + textwrap.dedent(''' + WEBVTT + + STYLE + ::cue { + color: white; + } + + STYLE + ::cue(.important) { + color: red; + font-weight: bold; + } + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 1) + self.assertEqual(len(styles), 2) + self.assertIsInstance(styles[0], Style) + self.assertIsInstance(styles[1], Style) + self.assertEqual( + str(styles[0].text), + textwrap.dedent(''' + ::cue { + color: white; + } + ''').strip() + ) + self.assertEqual( + str(styles[1].text), + textwrap.dedent(''' + ::cue(.important) { + color: red; + font-weight: bold; + } + ''').strip(), + ) + + def test_parse_content(self): + captions, styles = vtt.parse( + textwrap.dedent(''' + WEBVTT + + NOTE Comment for the style + + STYLE + ::cue { + color: white; + } + + NOTE Comment for the first caption + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + NOTE + Comment for the second caption + that is very long + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + + NOTE end of file + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertEqual(len(styles), 1) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertIsInstance(styles[0], Style) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + self.assertEqual( + str(styles[0].text), + textwrap.dedent(''' + ::cue { + color: white; + } + ''').strip() + ) + self.assertEqual( + styles[0].comments, + ['Comment for the style'] + ) + self.assertEqual( + captions[0].comments, + ['Comment for the first caption'] + ) + self.assertEqual( + captions[1].comments, + ['Comment for the second caption\nthat is very long', + 'end of file' + ] + ) + + def test_write(self): + out = io.StringIO() + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + + vtt.write(out, captions) + + out.seek(0) + + self.assertEqual( + out.read(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + 00:00:07.000 --> 00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip() + ) + + def test_to_str(self): + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + + self.assertEqual( + vtt.to_str(captions), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + 00:00:07.000 --> 00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip() + ) diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index 9bf50d4..adbe9df 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -1,252 +1,261 @@ +import unittest import os import io import textwrap -from shutil import rmtree, copy +import warnings +import tempfile +import pathlib import webvtt -from webvtt.structures import Caption, Style -from .generic import GenericParserTestCase +from webvtt.models import Caption, Style from webvtt.errors import MalformedFileError +PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples' -BASE_DIR = os.path.dirname(__file__) -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') +class TestWebVTT(unittest.TestCase): -class WebVTTTestCase(GenericParserTestCase): + def test_from_string(self): + vtt = webvtt.WebVTT.from_string(textwrap.dedent(""" + WEBVTT - def setUp(self): - os.makedirs(OUTPUT_DIR) + 00:00:00.500 --> 00:00:07.000 + Caption text #1 - def tearDown(self): - rmtree(OUTPUT_DIR) + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 - def test_create_caption(self): - caption = Caption(start='00:00:00.500', - end='00:00:07.900', - text=['Caption test line 1', 'Caption test line 2'] - ) - self.assertEqual(caption.start, '00:00:00.500') - self.assertEqual(caption.start_in_seconds, 0) - self.assertEqual(caption.end, '00:00:07.900') - self.assertEqual(caption.end_in_seconds, 7) - self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) + 00:00:11.890 --> 00:00:16.320 + Caption text #3 - def test_create_caption_with_text(self): - caption = Caption(start='00:00:00.500', - end='00:00:07.900', - text='Caption test line 1\nCaption test line 2' - ) - self.assertEqual(caption.start, '00:00:00.500') - self.assertEqual(caption.start_in_seconds, 0) - self.assertEqual(caption.end, '00:00:07.900') - self.assertEqual(caption.end_in_seconds, 7) - self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + """).strip()) + self.assertEqual(len(vtt), 4) + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + self.assertEqual( + str(vtt[3]), + '00:00:16.320 00:00:21.580 Caption text #4' + ) def test_write_captions(self): - copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) - out = io.StringIO() - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') new_caption = Caption(start='00:00:07.000', end='00:00:11.890', - text=['New caption text line1', 'New caption text line2'] + text=['New caption text line1', + 'New caption text line2' + ] ) vtt.captions.append(new_caption) vtt.write(out) out.seek(0) - self.assertListEqual([line for line in out], - [ - 'WEBVTT\n', - '\n', - '00:00:00.500 --> 00:00:07.000\n', - 'Caption text #1\n', - '\n', - '00:00:07.000 --> 00:00:11.890\n', - 'New caption text line1\n', - 'New caption text line2' - ] - ) + self.assertEqual( + out.read(), + textwrap.dedent(''' + WEBVTT - def test_save_captions(self): - copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + New caption text line1 + New caption text line2 + ''').strip() + ) - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) + def test_write_captions_in_srt(self): + out = io.StringIO() + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') new_caption = Caption(start='00:00:07.000', end='00:00:11.890', - text=['New caption text line1', 'New caption text line2'] + text=['New caption text line1', + 'New caption text line2' + ] ) vtt.captions.append(new_caption) - vtt.save() - - with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - self.assertListEqual([line for line in f], - [ - 'WEBVTT\n', - '\n', - '00:00:00.500 --> 00:00:07.000\n', - 'Caption text #1\n', - '\n', - '00:00:07.000 --> 00:00:11.890\n', - 'New caption text line1\n', - 'New caption text line2' - ] - ) - - def test_srt_conversion(self): - copy(self._get_file('one_caption.srt'), OUTPUT_DIR) + vtt.write(out, format='srt') - vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'one_caption.srt')) - vtt.save() + out.seek(0) + self.assertEqual( + out.read(), + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption text #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + New caption text line1 + New caption text line2 + ''').strip() + ) + + def test_write_captions_in_unsupported_format(self): + self.assertRaises( + ValueError, + webvtt.WebVTT().write, + io.StringIO(), + format='ttt' + ) - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'one_caption.vtt'))) + def test_save_captions(self): + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + f.write((PATH_TO_SAMPLES / 'one_caption.vtt').read_text()) + f.flush() + + vtt = webvtt.read(f.name) + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', + 'New caption text line2' + ] + ) + vtt.captions.append(new_caption) + vtt.save() + f.flush() + + self.assertEqual( + pathlib.Path(f.name).read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + New caption text line1 + New caption text line2 + ''').strip() + ) - with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - self.assertListEqual([line for line in f], - [ - 'WEBVTT\n', - '\n', - '00:00:00.500 --> 00:00:07.000\n', - 'Caption text #1', - ] - ) + def test_srt_conversion(self): + with tempfile.TemporaryDirectory() as td: + with open(pathlib.Path(td) / 'one_caption.srt', 'w') as f: + f.write((PATH_TO_SAMPLES / 'one_caption.srt').read_text()) + + webvtt.from_srt( + pathlib.Path(td) / 'one_caption.srt' + ).save() + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption.vtt') + ) + self.assertEqual( + (pathlib.Path(td) / 'one_caption.vtt').read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) def test_sbv_conversion(self): - copy(self._get_file('two_captions.sbv'), OUTPUT_DIR) - - vtt = webvtt.from_sbv(os.path.join(OUTPUT_DIR, 'two_captions.sbv')) - vtt.save() - - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'two_captions.vtt'))) - - with open(os.path.join(OUTPUT_DIR, 'two_captions.vtt'), 'r', encoding='utf-8') as f: - self.assertListEqual([line for line in f], - [ - 'WEBVTT\n', - '\n', - '00:00:00.378 --> 00:00:11.378\n', - 'Caption text #1\n', - '\n', - '00:00:11.378 --> 00:00:12.305\n', - 'Caption text #2 (line 1)\n', - 'Caption text #2 (line 2)', - ] - ) + with tempfile.TemporaryDirectory() as td: + with open(pathlib.Path(td) / 'two_captions.sbv', 'w') as f: + f.write( + (PATH_TO_SAMPLES / 'two_captions.sbv').read_text() + ) + + webvtt.from_sbv( + pathlib.Path(td) / 'two_captions.sbv' + ).save() + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'two_captions.vtt') + ) + self.assertEqual( + (pathlib.Path(td) / 'two_captions.vtt').read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.378 --> 00:00:11.378 + Caption text #1 + + 00:00:11.378 --> 00:00:12.305 + Caption text #2 (line 1) + Caption text #2 (line 2) + ''').strip() + ) def test_save_to_other_location(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(td) - webvtt.read(self._get_file('one_caption.vtt')).save(target_path) - self.assertTrue(os.path.exists(os.path.join(target_path, 'one_caption.vtt'))) + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption.vtt') + ) def test_save_specific_filename(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) - output_file = os.path.join(target_path, 'custom_name.vtt') + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save( + pathlib.Path(td) / 'one_caption_new.vtt' + ) - webvtt.read(self._get_file('one_caption.vtt')).save(output_file) - self.assertTrue(os.path.exists(output_file)) + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt') + ) def test_save_specific_filename_no_extension(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) - output_file = os.path.join(target_path, 'custom_name') - - webvtt.read(self._get_file('one_caption.vtt')).save(output_file) - self.assertTrue(os.path.exists(os.path.join(target_path, 'custom_name.vtt'))) - - def test_caption_timestamp_update(self): - c = Caption('00:00:00.500', '00:00:07.000') - c.start = '00:00:01.750' - c.end = '00:00:08.250' - - self.assertEqual(c.start, '00:00:01.750') - self.assertEqual(c.end, '00:00:08.250') - - def test_caption_timestamp_format(self): - c = Caption('01:02:03.400', '02:03:04.500') - self.assertEqual(c.start, '01:02:03.400') - self.assertEqual(c.end, '02:03:04.500') - - c = Caption('02:03.400', '03:04.500') - self.assertEqual(c.start, '00:02:03.400') - self.assertEqual(c.end, '00:03:04.500') - - def test_caption_text(self): - c = Caption(text=['Caption line #1', 'Caption line #2']) - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_caption_receive_text(self): - c = Caption(text='Caption line #1\nCaption line #2') - - self.assertEqual( - len(c.lines), - 2 - ) - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_update_text(self): - c = Caption(text='Caption line #1') - c.text = 'Caption line #1 updated' - self.assertEqual( - c.text, - 'Caption line #1 updated' - ) - - def test_update_text_multiline(self): - c = Caption(text='Caption line #1') - c.text = 'Caption line #1\nCaption line #2' - - self.assertEqual( - len(c.lines), - 2 - ) - - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_update_text_wrong_type(self): - c = Caption(text='Caption line #1') - - self.assertRaises( - AttributeError, - setattr, - c, - 'text', - 123 - ) - - def test_manipulate_lines(self): - c = Caption(text=['Caption line #1', 'Caption line #2']) - c.lines[0] = 'Caption line #1 updated' - self.assertEqual( - c.lines[0], - 'Caption line #1 updated' - ) - - def test_read_file_buffer(self): - with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: - vtt = webvtt.from_buffer(f) - self.assertIsInstance(vtt.captions, list) + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save( + pathlib.Path(td) / 'one_caption_new' + ) + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt') + ) + + def test_from_buffer(self): + with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: + self.assertIsInstance( + webvtt.from_buffer(f).captions, + list + ) + + def test_deprecated_read_buffer(self): + with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter('always') + vtt = webvtt.read_buffer(f) + + self.assertIsInstance(vtt.captions, list) + self.assertEqual( + 'Deprecated: use from_buffer instead.', + str(warns[-1].message) + ) def test_read_memory_buffer(self): - with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: - payload = f.read() + buffer = io.StringIO( + (PATH_TO_SAMPLES / 'sample.vtt').read_text() + ) - buffer = io.StringIO(payload) - vtt = webvtt.from_buffer(buffer) - self.assertIsInstance(vtt.captions, list) + self.assertIsInstance( + webvtt.from_buffer(buffer).captions, + list + ) def test_read_memory_buffer_carriage_return(self): """https://github.com/glut23/webvtt-py/issues/29""" @@ -262,8 +271,11 @@ def test_read_memory_buffer_carriage_return(self): 00:00:11.890 --> 00:00:16.320\r Caption text #3\r ''')) - vtt = webvtt.from_buffer(buffer) - self.assertEqual(len(vtt.captions), 3) + + self.assertEqual( + len(webvtt.from_buffer(buffer).captions), + 3 + ) def test_read_malformed_buffer(self): malformed_payloads = ['', 'MOCK MALFORMED CONTENT'] @@ -273,27 +285,31 @@ def test_read_malformed_buffer(self): webvtt.from_buffer(buffer) def test_captions(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertIsInstance(vtt.captions, list) + captions = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').captions + self.assertIsInstance( + captions, + list + ) + self.assertEqual(len(captions), 16) def test_sequence_iteration(self): - vtt = webvtt.read(self._get_file('sample.vtt')) + vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt') self.assertIsInstance(vtt[0], Caption) self.assertEqual(len(vtt), len(vtt.captions)) def test_save_no_filename(self): - vtt = webvtt.WebVTT() self.assertRaises( webvtt.errors.MissingFilenameError, - vtt.save + webvtt.WebVTT().save ) - def test_malformed_start_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - Caption, - '01:00' - ) + def test_save_with_path_to_dir_no_filename(self): + with tempfile.TemporaryDirectory() as td: + self.assertRaises( + webvtt.errors.MissingFilenameError, + webvtt.WebVTT().save, + td + ) def test_set_styles_from_text(self): style = Style('::cue(b) {\n color: peachpuff;\n}') @@ -302,111 +318,492 @@ def test_set_styles_from_text(self): ['::cue(b) {', ' color: peachpuff;', '}'] ) - def test_get_styles_as_text(self): - style = Style(['::cue(b) {', ' color: peachpuff;', '}']) - self.assertEqual( - style.text, - '::cue(b) {\n color: peachpuff;\n}' - ) - def test_save_identifiers(self): - copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) - - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) - vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt')) - - with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - 'second caption', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2', - '', - '00:00:11.890 --> 00:00:16.320', - 'Caption text #3', - '', - '4', - '00:00:16.320 --> 00:00:21.580', - 'Caption text #4', - '', - '00:00:21.580 --> 00:00:23.880', - 'Caption text #5', - '', - '00:00:23.880 --> 00:00:27.280', - 'Caption text #6' - ] - - self.assertListEqual(lines, expected_lines) + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'using_identifiers.vtt' + ).save( + f.name + ) + + self.assertListEqual( + pathlib.Path(f.name).read_text().splitlines(), + [ + 'WEBVTT', + '', + '00:00:00.500 --> 00:00:07.000', + 'Caption text #1', + '', + 'second caption', + '00:00:07.000 --> 00:00:11.890', + 'Caption text #2', + '', + '00:00:11.890 --> 00:00:16.320', + 'Caption text #3', + '', + '4', + '00:00:16.320 --> 00:00:21.580', + 'Caption text #4', + '', + '00:00:21.580 --> 00:00:23.880', + 'Caption text #5', + '', + '00:00:23.880 --> 00:00:27.280', + 'Caption text #6' + ] + ) def test_save_updated_identifiers(self): - copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) - - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) + vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt') vtt.captions[0].identifier = 'first caption' vtt.captions[1].identifier = None vtt.captions[3].identifier = '44' - last_caption = Caption('00:00:27.280', '00:00:29.200', 'Caption text #7') + last_caption = Caption(start='00:00:27.280', + end='00:00:29.200', + text='Caption text #7' + ) last_caption.identifier = 'last caption' vtt.captions.append(last_caption) - vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt')) - - with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - 'first caption', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2', - '', - '00:00:11.890 --> 00:00:16.320', - 'Caption text #3', - '', - '44', - '00:00:16.320 --> 00:00:21.580', - 'Caption text #4', - '', - '00:00:21.580 --> 00:00:23.880', - 'Caption text #5', - '', - '00:00:23.880 --> 00:00:27.280', - 'Caption text #6', - '', - 'last caption', - '00:00:27.280 --> 00:00:29.200', - 'Caption text #7' - ] - - self.assertListEqual(lines, expected_lines) + + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + vtt.save(f.name) + + self.assertListEqual( + pathlib.Path(f.name).read_text().splitlines(), + [ + 'WEBVTT', + '', + 'first caption', + '00:00:00.500 --> 00:00:07.000', + 'Caption text #1', + '', + '00:00:07.000 --> 00:00:11.890', + 'Caption text #2', + '', + '00:00:11.890 --> 00:00:16.320', + 'Caption text #3', + '', + '44', + '00:00:16.320 --> 00:00:21.580', + 'Caption text #4', + '', + '00:00:21.580 --> 00:00:23.880', + 'Caption text #5', + '', + '00:00:23.880 --> 00:00:27.280', + 'Caption text #6', + '', + 'last caption', + '00:00:27.280 --> 00:00:29.200', + 'Caption text #7' + ] + ) def test_content_formatting(self): """ Verify that content property returns the correctly formatted webvtt. """ captions = [ - Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']), - Caption('00:00:08.000', '00:00:15.000', ['Caption test line 3', 'Caption test line 4']), - ] - expected_content = textwrap.dedent(""" - WEBVTT + Caption(start='00:00:00.500', + end='00:00:07.000', + text=['Caption test line 1', 'Caption test line 2'] + ), + Caption(start='00:00:08.000', + end='00:00:15.000', + text=['Caption test line 3', 'Caption test line 4'] + ), + ] - 00:00:00.500 --> 00:00:07.000 - Caption test line 1 - Caption test line 2 + self.assertEqual( + webvtt.WebVTT(captions=captions).content, + textwrap.dedent(""" + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption test line 1 + Caption test line 2 + + 00:00:08.000 --> 00:00:15.000 + Caption test line 3 + Caption test line 4 + """).strip() + ) + + def test_repr(self): + test_file = PATH_TO_SAMPLES / 'sample.vtt' + self.assertEqual( + repr(webvtt.read(test_file)), + f'' + ) - 00:00:08.000 --> 00:00:15.000 - Caption test line 3 - Caption test line 4 + def test_str(self): + self.assertEqual( + str(webvtt.read(PATH_TO_SAMPLES / 'sample.vtt')), + textwrap.dedent(""" + 00:00:00.500 00:00:07.000 Caption text #1 + 00:00:07.000 00:00:11.890 Caption text #2 + 00:00:11.890 00:00:16.320 Caption text #3 + 00:00:16.320 00:00:21.580 Caption text #4 + 00:00:21.580 00:00:23.880 Caption text #5 + 00:00:23.880 00:00:27.280 Caption text #6 + 00:00:27.280 00:00:30.280 Caption text #7 + 00:00:30.280 00:00:36.510 Caption text #8 + 00:00:36.510 00:00:38.870 Caption text #9 + 00:00:38.870 00:00:45.000 Caption text #10 + 00:00:45.000 00:00:47.000 Caption text #11 + 00:00:47.000 00:00:50.970 Caption text #12 + 00:00:50.970 00:00:54.440 Caption text #13 + 00:00:54.440 00:00:58.600 Caption text #14 + 00:00:58.600 00:01:01.350 Caption text #15 + 00:01:01.350 00:01:04.300 Caption text #16 """).strip() - vtt = webvtt.WebVTT(captions=captions) - self.assertEqual(expected_content, vtt.content) + ) + + def test_parse_invalid_file(self): + self.assertRaises( + MalformedFileError, + webvtt.read, + PATH_TO_SAMPLES / 'invalid.vtt' + ) + + def test_file_not_found(self): + self.assertRaises( + FileNotFoundError, + webvtt.read, + 'nowhere' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').total_length, + 64 + ) + + def test_total_length_no_captions(self): + self.assertEqual( + webvtt.WebVTT().total_length, + 0 + ) + + def test_parse_empty_file(self): + self.assertRaises( + MalformedFileError, + webvtt.read, + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_parse_invalid_timeframe_line(self): + good_captions = len( + webvtt.read(PATH_TO_SAMPLES / 'invalid_timeframe.vtt').captions + ) + self.assertEqual(good_captions, 6) + + def test_parse_invalid_timeframe_in_cue_text(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'invalid_timeframe_in_cue_text.vtt' + ) + self.assertEqual(2, len(vtt.captions)) + self.assertEqual('Caption text #3', vtt.captions[1].text) + + def test_parse_get_caption_data(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') + self.assertEqual(vtt.captions[0].start_in_seconds, 0) + self.assertEqual(vtt.captions[0].start, '00:00:00.500') + self.assertEqual(vtt.captions[0].end_in_seconds, 7) + self.assertEqual(vtt.captions[0].end, '00:00:07.000') + self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') + self.assertEqual(len(vtt.captions[0].lines), 1) + + def test_caption_without_timeframe(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_timeframe.vtt') + self.assertEqual(len(vtt.captions), 6) + + def test_caption_without_cue_text(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_caption_text.vtt') + self.assertEqual(len(vtt.captions), 4) + + def test_timestamps_format(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt') + self.assertEqual(vtt.captions[2].start, '00:00:11.890') + self.assertEqual(vtt.captions[2].end, '00:00:16.320') + + def test_parse_timestamp(self): + self.assertEqual( + Caption(start='02:03:11.890').start_in_seconds, + 7391 + ) + + def test_captions_attribute(self): + self.assertListEqual(webvtt.WebVTT().captions, []) + + def test_metadata_headers(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers.vtt') + self.assertEqual(len(vtt.captions), 2) + + def test_metadata_headers_multiline(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers_multiline.vtt') + self.assertEqual(len(vtt.captions), 2) + + def test_parse_identifiers(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt') + self.assertEqual(len(vtt.captions), 6) + + self.assertEqual(vtt.captions[1].identifier, 'second caption') + self.assertEqual(vtt.captions[2].identifier, None) + self.assertEqual(vtt.captions[3].identifier, '4') + + def test_parse_comments(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'comments.vtt') + self.assertEqual(len(vtt.captions), 3) + self.assertListEqual( + vtt.captions[0].lines, + ['- Ta en kopp varmt te.', + '- Det Ƥr inte varmt.'] + ) + self.assertListEqual( + vtt.captions[0].comments, + ['This translation was done by Kyle so that\n' + 'some friends can watch it with their parents.' + ] + ) + self.assertListEqual( + vtt.captions[1].comments, + [] + ) + self.assertEqual( + vtt.captions[2].text, + '- Ta en kopp' + ) + self.assertListEqual( + vtt.captions[2].comments, + ['This last line may not translate well.'] + ) + + def test_parse_styles(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'styles.vtt') + self.assertEqual(len(vtt.captions), 1) + self.assertEqual( + vtt.styles[0].text, + '::cue {\n background-image: linear-gradient(to bottom, ' + 'dimgray, lightgray);\n color: papayawhip;\n}' + ) + + def test_parse_styles_with_comments(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'styles_with_comments.vtt') + self.assertEqual(len(vtt.captions), 1) + self.assertEqual(len(vtt.styles), 2) + self.assertEqual( + vtt.styles[0].comments, + ['This is the first style block'] + ) + self.assertEqual( + vtt.styles[0].text, + '::cue {\n' + ' background-image: linear-gradient(to bottom, dimgray, ' + 'lightgray);\n' + ' color: papayawhip;\n' + '}' + ) + self.assertEqual( + vtt.styles[1].comments, + ['This is the second block of styles', + 'Multiline comment for the same\nsecond block of styles' + ] + ) + self.assertEqual( + vtt.styles[1].text, + '::cue(b) {\n' + ' color: peachpuff;\n' + '}' + ) + + def test_clean_cue_tags(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'cue_tags.vtt') + self.assertEqual( + vtt.captions[1].text, + 'Like a big-a pizza pie' + ) + self.assertEqual( + vtt.captions[2].text, + 'That\'s amore' + ) + + def test_parse_captions_with_bom(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'captions_with_bom.vtt') + self.assertEqual(len(vtt.captions), 4) + + def test_empty_lines_are_not_included_in_result(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'netflix_chicas_del_cable.vtt') + self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,") + self.assertEqual( + vtt.captions[-2].text, + "Diez aƱos no son suficientes\npara olvidarte..." + ) + + def test_can_parse_youtube_dl_files(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'youtube_dl.vtt') + self.assertEqual( + "this will happen is I'm telling\n ", + vtt.captions[2].text + ) + + +class TestParseSRT(unittest.TestCase): + + def test_parse_empty_file(self): + self.assertRaises( + webvtt.errors.MalformedFileError, + webvtt.from_srt, + # We reuse this file as it is empty and serves the purpose. + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_invalid_format(self): + for i in range(1, 5): + self.assertRaises( + MalformedFileError, + webvtt.from_srt, + PATH_TO_SAMPLES / f'invalid_format{i}.srt' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').total_length, + 23 + ) + + def test_parse_captions(self): + self.assertTrue( + webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').captions + ) + + def test_missing_timeframe_line(self): + self.assertEqual( + len(webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_timeframe.srt').captions + ), + 4 + ) + + def test_empty_caption_text(self): + self.assertTrue( + webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_caption_text.srt').captions + ) + + def test_empty_gets_removed(self): + captions = webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_caption_text.srt' + ).captions + self.assertEqual(len(captions), 4) + + def test_invalid_timestamp(self): + self.assertEqual( + len(webvtt.from_srt( + PATH_TO_SAMPLES / 'invalid_timeframe.srt' + ).captions), + 4 + ) + + def test_timestamps_format(self): + vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt') + self.assertEqual(vtt.captions[2].start, '00:00:11.890') + self.assertEqual(vtt.captions[2].end, '00:00:16.320') + + def test_parse_get_caption_data(self): + vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'one_caption.srt') + self.assertEqual(vtt.captions[0].start_in_seconds, 0) + self.assertEqual(vtt.captions[0].start, '00:00:00.500') + self.assertEqual(vtt.captions[0].end_in_seconds, 7) + self.assertEqual(vtt.captions[0].end, '00:00:07.000') + self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') + self.assertEqual(len(vtt.captions[0].lines), 1) + + +class TestParseSBV(unittest.TestCase): + + def test_parse_empty_file(self): + self.assertRaises( + MalformedFileError, + webvtt.from_sbv, + # We reuse this file as it is empty and serves the purpose. + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_invalid_format(self): + self.assertRaises( + MalformedFileError, + webvtt.from_sbv, + PATH_TO_SAMPLES / 'invalid_format.sbv' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').total_length, + 16 + ) + + def test_parse_captions(self): + self.assertEqual( + len(webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').captions), + 5 + ) + + def test_missing_timeframe_line(self): + self.assertEqual( + len(webvtt.from_sbv( + PATH_TO_SAMPLES / 'missing_timeframe.sbv' + ).captions), + 4 + ) + + def test_missing_caption_text(self): + self.assertTrue( + webvtt.from_sbv( + PATH_TO_SAMPLES / 'missing_caption_text.sbv' + ).captions + ) + + def test_invalid_timestamp(self): + self.assertEqual( + len(webvtt.from_sbv( + PATH_TO_SAMPLES / 'invalid_timeframe.sbv' + ).captions), + 4 + ) + + def test_timestamps_format(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].start, '00:00:11.378') + self.assertEqual(vtt.captions[1].end, '00:00:12.305') + + def test_timestamps_in_seconds(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].start_in_seconds, 11) + self.assertEqual(vtt.captions[1].end_in_seconds, 12) + + def test_get_caption_text(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].text, 'Caption text #2') + + def test_get_caption_text_multiline(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual( + vtt.captions[2].text, + 'Caption text #3 (line 1)\nCaption text #3 (line 2)' + ) + self.assertListEqual( + vtt.captions[2].lines, + ['Caption text #3 (line 1)', 'Caption text #3 (line 2)'] + ) + + def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): + with tempfile.NamedTemporaryFile('w', suffix='.srt') as f: + webvtt.from_srt( + PATH_TO_SAMPLES / 'sample.srt' + ).save_as_srt(f.name) + + self.assertEqual( + pathlib.Path(PATH_TO_SAMPLES / 'sample.srt').read_text(), + pathlib.Path(f.name).read_text() + ) diff --git a/tests/test_webvtt_parser.py b/tests/test_webvtt_parser.py deleted file mode 100644 index 644dc0f..0000000 --- a/tests/test_webvtt_parser.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import unittest - -from .generic import GenericParserTestCase - -import webvtt -from webvtt.parsers import Parser -from webvtt.structures import Caption -from webvtt.errors import MalformedFileError - - -class ParserTestCase(unittest.TestCase): - - def test_validate_not_callable(self): - self.assertRaises( - NotImplementedError, - Parser.validate, - [] - ) - - def test_parse_content_not_callable(self): - self.assertRaises( - NotImplementedError, - Parser.parse_content, - [] - ) - - -class WebVTTParserTestCase(GenericParserTestCase): - - def test_webvtt_parse_invalid_file(self): - self.assertRaises( - MalformedFileError, - webvtt.read, - self._get_file('invalid.vtt') - ) - - def test_webvtt_captions_not_found(self): - self.assertRaises( - FileNotFoundError, - webvtt.read, - 'some_file' - ) - - def test_webvtt_total_length(self): - self.assertEqual( - webvtt.read(self._get_file('sample.vtt')).total_length, - 64 - ) - - def test_webvtt_total_length_no_parser(self): - self.assertEqual( - webvtt.WebVTT().total_length, - 0 - ) - - def test_webvtt__parse_captions(self): - self.assertTrue(webvtt.read(self._get_file('sample.vtt')).captions) - - def test_webvtt_parse_empty_file(self): - self.assertRaises( - MalformedFileError, - webvtt.read, - self._get_file('empty.vtt') - ) - - def test_webvtt_parse_get_captions(self): - self.assertEqual( - len(webvtt.read(self._get_file('sample.vtt')).captions), - 16 - ) - - def test_webvtt_parse_invalid_timeframe_line(self): - self.assertEqual(len(webvtt.read(self._get_file('invalid_timeframe.vtt')).captions), 6) - - def test_webvtt_parse_invalid_timeframe_in_cue_text(self): - vtt = webvtt.read(self._get_file('invalid_timeframe_in_cue_text.vtt')) - self.assertEqual(2, len(vtt.captions)) - self.assertEqual('Caption text #3', vtt.captions[1].text) - - def test_webvtt_parse_get_caption_data(self): - vtt = webvtt.read(self._get_file('one_caption.vtt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0) - self.assertEqual(vtt.captions[0].start, '00:00:00.500') - self.assertEqual(vtt.captions[0].end_in_seconds, 7) - self.assertEqual(vtt.captions[0].end, '00:00:07.000') - self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') - self.assertEqual(len(vtt.captions[0].lines), 1) - - def test_webvtt_caption_without_timeframe(self): - vtt = webvtt.read(self._get_file('missing_timeframe.vtt')) - self.assertEqual(len(vtt.captions), 6) - - def test_webvtt_caption_without_cue_text(self): - vtt = webvtt.read(self._get_file('missing_caption_text.vtt')) - self.assertEqual(len(vtt.captions), 4) - - def test_webvtt_timestamps_format(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertEqual(vtt.captions[2].start, '00:00:11.890') - self.assertEqual(vtt.captions[2].end, '00:00:16.320') - - def test_parse_timestamp(self): - caption = Caption(start='02:03:11.890') - self.assertEqual( - caption.start_in_seconds, - 7391 - ) - - def test_captions_attribute(self): - self.assertListEqual([], webvtt.WebVTT().captions) - - def test_metadata_headers(self): - vtt = webvtt.read(self._get_file('metadata_headers.vtt')) - self.assertEqual(len(vtt.captions), 2) - - def test_metadata_headers_multiline(self): - vtt = webvtt.read(self._get_file('metadata_headers_multiline.vtt')) - self.assertEqual(len(vtt.captions), 2) - - def test_parse_identifiers(self): - vtt = webvtt.read(self._get_file('using_identifiers.vtt')) - self.assertEqual(len(vtt.captions), 6) - - self.assertEqual(vtt.captions[1].identifier, 'second caption') - self.assertEqual(vtt.captions[2].identifier, None) - self.assertEqual(vtt.captions[3].identifier, '4') - - def test_parse_with_comments(self): - vtt = webvtt.read(self._get_file('comments.vtt')) - self.assertEqual(len(vtt.captions), 3) - self.assertListEqual( - vtt.captions[0].lines, - ['- Ta en kopp varmt te.', - '- Det Ƥr inte varmt.'] - ) - self.assertEqual( - vtt.captions[2].text, - '- Ta en kopp' - ) - - def test_parse_styles(self): - vtt = webvtt.read(self._get_file('styles.vtt')) - self.assertEqual(len(vtt.captions), 1) - self.assertEqual( - vtt.styles[0].text, - '::cue {\n background-image: linear-gradient(to bottom, dimgray, lightgray);\n color: papayawhip;\n}' - ) - - def test_parse_styles_with_comments(self): - vtt = webvtt.read(self._get_file('styles_with_comments.vtt')) - self.assertEqual(len(vtt.captions), 1) - self.assertEqual(len(vtt.styles), 2) - self.assertEqual( - vtt.styles[0].comments, - ['This is the first style block'] - ) - self.assertEqual( - vtt.styles[1].comments, - ['This is the second block of styles', - 'Multiline comment for the same\nsecond block of styles' - ] - ) - - def test_clean_cue_tags(self): - vtt = webvtt.read(self._get_file('cue_tags.vtt')) - self.assertEqual( - vtt.captions[1].text, - 'Like a big-a pizza pie' - ) - self.assertEqual( - vtt.captions[2].text, - 'That\'s amore' - ) - - def test_parse_captions_with_bom(self): - vtt = webvtt.read(self._get_file('captions_with_bom.vtt')) - self.assertEqual(len(vtt.captions), 4) - - def test_empty_lines_are_not_included_in_result(self): - vtt = webvtt.read(self._get_file('netflix_chicas_del_cable.vtt')) - self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,") - self.assertEqual( - vtt.captions[-2].text, - "Diez aƱos no son suficientes\npara olvidarte..." - ) - - def test_can_parse_youtube_dl_files(self): - vtt = webvtt.read(self._get_file('youtube_dl.vtt')) - self.assertEqual( - "this will happen is I'm telling\n ", - vtt.captions[2].text - ) diff --git a/webvtt/__init__.py b/webvtt/__init__.py index 45719af..33fa037 100644 --- a/webvtt/__init__.py +++ b/webvtt/__init__.py @@ -1,14 +1,16 @@ +"""Main webvtt package.""" + __version__ = '0.5.0' from .webvtt import WebVTT -from .segmenter import WebVTTSegmenter -from .structures import Caption, Style # noqa +from . import segmenter +from .models import Caption, Style # noqa -__all__ = ['WebVTT', 'WebVTTSegmenter', 'Caption', 'Style'] +__all__ = ['WebVTT', 'Caption', 'Style'] read = WebVTT.read read_buffer = WebVTT.read_buffer from_buffer = WebVTT.from_buffer from_srt = WebVTT.from_srt from_sbv = WebVTT.from_sbv -segment = WebVTTSegmenter().segment +segment = segmenter.segment diff --git a/webvtt/cli.py b/webvtt/cli.py index 9c722b2..a50e7d0 100644 --- a/webvtt/cli.py +++ b/webvtt/cli.py @@ -1,12 +1,16 @@ +"""CLI module.""" + import argparse import typing -from . import WebVTTSegmenter +from . import segmenter def main(argv: typing.Optional[typing.Sequence] = None): """ - Segment WebVTT from command line. + Segment WebVTT file from command line. + + :param argv: command line arguments """ arguments = argparse.ArgumentParser( description='Segment WebVTT files.' @@ -30,18 +34,20 @@ def main(argv: typing.Optional[typing.Sequence] = None): '-d', '--target-duration', metavar='NUMBER', type=int, - default=10, + default=segmenter.DEFAULT_SECONDS, help='target duration of each segment in seconds, default: 10' ) arguments.add_argument( '-m', '--mpegts', metavar='NUMBER', type=int, - default=900000, + default=segmenter.DEFAULT_MPEGTS, help='presentation timestamp value, default: 900000' ) + args = arguments.parse_args(argv) - WebVTTSegmenter().segment( + + segmenter.segment( args.file, args.output, args.target_duration, @@ -50,4 +56,4 @@ def main(argv: typing.Optional[typing.Sequence] = None): if __name__ == '__main__': - main() + main() # pragma: no cover diff --git a/webvtt/errors.py b/webvtt/errors.py index 4a09dbb..f8628c9 100644 --- a/webvtt/errors.py +++ b/webvtt/errors.py @@ -1,15 +1,13 @@ +"""Errors module.""" + class MalformedFileError(Exception): - """Error raised when the file is not well formatted""" + """File is not in the right format.""" class MalformedCaptionError(Exception): - """Error raised when a caption is not well formatted""" - - -class InvalidCaptionsError(Exception): - """Error raised when passing wrong captions to the segmenter""" + """Caption not in the right format.""" class MissingFilenameError(Exception): - """Error raised when saving a file without filename.""" + """Missing a filename when saving to disk.""" diff --git a/webvtt/models.py b/webvtt/models.py new file mode 100644 index 0000000..4bf1b33 --- /dev/null +++ b/webvtt/models.py @@ -0,0 +1,162 @@ +"""Models module.""" + +import re +import typing +from datetime import datetime, time + +from .errors import MalformedCaptionError + + +class Caption: + """Representation of a caption.""" + + CUE_TEXT_TAGS = re.compile('<.*?>') + + def __init__(self, + start: typing.Optional[typing.Union[str, time]] = None, + end: typing.Optional[typing.Union[str, time]] = None, + text: typing.Optional[typing.Union[str, + typing.Sequence[str] + ]] = None, + identifier: typing.Optional[str] = None + ): + """ + Initialize. + + :param start: start time of the caption + :param end: end time of the caption + :param text: the text of the caption + :param identifier: optional identifier + """ + text = text or [] + self.start = start or time() + self.end = end or time() + self.identifier = identifier + self.lines = (text.splitlines() + if isinstance(text, str) + else + list(text) + ) + self.comments: typing.List[str] = [] + + def __repr__(self): + """Return the string representation of the caption.""" + cleaned_text = self.text.replace('\n', '\\n') + return (f'<{self.__class__.__name__} ' + f'start={self.start!r} ' + f'end={self.end!r} ' + f'text={cleaned_text!r} ' + f'identifier={self.identifier!r}>' + ) + + def __str__(self): + """Return a readable representation of the caption.""" + cleaned_text = self.text.replace('\n', '\\n') + return f'{self.start} {self.end} {cleaned_text}' + + def __eq__(self, other): + """Compare equality with another object.""" + if not isinstance(other, type(self)): + return False + + return (self.start == other.start and + self.end == other.end and + self.raw_text == other.raw_text and + self.identifier == other.identifier + ) + + @property + def start(self): + """Return the start time of the caption.""" + return self.format_timestamp(self._start) + + @start.setter + def start(self, value: typing.Union[str, time]): + """Set the start time of the caption.""" + self._start = self.parse_timestamp(value) + + @property + def end(self): + """Return the end time of the caption.""" + return self.format_timestamp(self._end) + + @end.setter + def end(self, value: typing.Union[str, time]): + """Set the end time of the caption.""" + self._end = self.parse_timestamp(value) + + @property + def start_in_seconds(self) -> int: + """Return the start time of the caption in seconds.""" + return self.time_in_seconds(self._start) + + @property + def end_in_seconds(self): + """Return the end time of the caption in seconds.""" + return self.time_in_seconds(self._end) + + @property + def raw_text(self) -> str: + """Return the text of the caption (including cue tags).""" + return '\n'.join(self.lines) + + @property + def text(self) -> str: + """Return the text of the caption (without cue tags).""" + return re.sub(self.CUE_TEXT_TAGS, '', self.raw_text) + + @text.setter + def text(self, value: str): + """Set the text of the captions.""" + if not isinstance(value, str): + raise AttributeError( + f'String value expected but received {value}.' + ) + + self.lines = value.splitlines() + + @staticmethod + def parse_timestamp(value: typing.Union[str, time]) -> time: + """Return timestamp as time object if in string format.""" + if isinstance(value, str): + time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' + try: + return datetime.strptime(value, time_format).time() + except ValueError: + raise MalformedCaptionError(f'Invalid timestamp: {value}') + elif isinstance(value, time): + return value + + raise TypeError(f'The type {type(value)} is not supported') + + @staticmethod + def format_timestamp(time_obj: time) -> str: + """Format timestamp in string format.""" + microseconds = int(time_obj.microsecond / 1000) + return f'{time_obj.strftime("%H:%M:%S")}.{microseconds:03d}' + + @staticmethod + def time_in_seconds(time_obj: time) -> int: + """Return the time in seconds.""" + return (time_obj.hour * 3600 + + time_obj.minute * 60 + + time_obj.second + + time_obj.microsecond // 1_000_000 + ) + + +class Style: + """Representation of a style.""" + + def __init__(self, text: typing.Union[str, typing.List[str]]): + """Initialize. + + :param: text: the style text + """ + self.lines = text.splitlines() if isinstance(text, str) else text + self.comments: typing.List[str] = [] + + @property + def text(self): + """Return the text of the style.""" + return '\n'.join(self.lines) diff --git a/webvtt/parsers.py b/webvtt/parsers.py deleted file mode 100644 index 7450565..0000000 --- a/webvtt/parsers.py +++ /dev/null @@ -1,161 +0,0 @@ -import typing -from abc import ABC, abstractmethod - -from .errors import MalformedFileError -from .structures import (Style, - Caption, - WebVTTCueBlock, - WebVTTCommentBlock, - WebVTTStyleBlock, - SRTCueBlock, - SBVCueBlock, - ) - - -class Parser(ABC): - - @classmethod - def parse(cls, lines: typing.Sequence[str]): - if not cls.validate(lines): - raise MalformedFileError('Invalid format') - return cls.parse_content(lines) - - @staticmethod - def iter_blocks_of_lines(lines) -> typing.Generator[typing.List[str], - None, - None - ]: - current_text_block = [] - - for line in lines: - if line.strip(): - current_text_block.append(line) - elif current_text_block: - yield current_text_block - current_text_block = [] - - if current_text_block: - yield current_text_block - - @classmethod - @abstractmethod - def validate(cls, lines: typing.Sequence[str]) -> bool: - raise NotImplementedError - - @classmethod - @abstractmethod - def parse_content(cls, lines: typing.Sequence[str]): - raise NotImplementedError - - -class WebVTTParser(Parser): - """ - Web Video Text Track parser. - """ - - @classmethod - def validate(cls, lines: typing.Sequence[str]) -> bool: - return bool(lines and lines[0].startswith('WEBVTT')) - - @classmethod - def parse_content(cls, - lines: typing.Sequence[str] - ) -> typing.List[typing.Union[Style, Caption]]: - items: typing.List[typing.Union[Caption, Style]] = [] - comments: typing.List[WebVTTCommentBlock] = [] - - for block_lines in cls.iter_blocks_of_lines(lines): - if WebVTTCueBlock.is_valid(block_lines): - cue_block = WebVTTCueBlock.from_lines(block_lines) - caption = Caption(cue_block.start, - cue_block.end, - cue_block.payload, - cue_block.identifier - ) - - if comments: - caption.comments = [comment.text for comment in comments] - comments = [] - items.append(caption) - - elif WebVTTCommentBlock.is_valid(block_lines): - comments.append(WebVTTCommentBlock.from_lines(block_lines)) - - elif WebVTTStyleBlock.is_valid(block_lines): - style = Style(WebVTTStyleBlock.from_lines(block_lines).text) - if comments: - style.comments = [comment.text for comment in comments] - comments = [] - items.append(style) - - if comments and items: - items[-1].comments = [comment.text for comment in comments] - - return items - - -class SRTParser(Parser): - """ - SubRip SRT parser. - """ - - @classmethod - def validate(cls, lines: typing.Sequence[str]) -> bool: - return bool( - len(lines) >= 3 and - lines[0].isdigit() and - '-->' in lines[1] and - lines[2].strip() - ) - - @classmethod - def parse_content( - cls, - lines: typing.Sequence[str] - ) -> typing.List[Caption]: - captions: typing.List[Caption] = [] - - for block_lines in cls.iter_blocks_of_lines(lines): - if not SRTCueBlock.is_valid(block_lines): - continue - - cue_block = SRTCueBlock.from_lines(block_lines) - captions.append(Caption(cue_block.start, - cue_block.end, - cue_block.payload - )) - - return captions - - -class SBVParser(Parser): - """ - YouTube SBV parser. - """ - - @classmethod - def validate(cls, lines: typing.Sequence[str]) -> bool: - if len(lines) < 2: - return False - - first_block = next(cls.iter_blocks_of_lines(lines)) - return bool(first_block and SBVCueBlock.is_valid(first_block)) - - @classmethod - def parse_content( - cls, - lines: typing.Sequence[str] - ) -> typing.List[Caption]: - captions = [] - - for block_lines in cls.iter_blocks_of_lines(lines): - if not SBVCueBlock.is_valid(block_lines): - continue - - cue_block = SBVCueBlock.from_lines(block_lines) - captions.append(Caption(cue_block.start, - cue_block.end, - cue_block.payload - )) - - return captions diff --git a/webvtt/sbv.py b/webvtt/sbv.py new file mode 100644 index 0000000..0d2cef2 --- /dev/null +++ b/webvtt/sbv.py @@ -0,0 +1,121 @@ +"""SBV format module.""" + +import typing +import re +from datetime import datetime, time + +from . import utils +from .models import Caption +from .errors import MalformedFileError + + +class SBVCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})' + ) + + def __init__( + self, + start: time, + end: time, + payload: typing.Sequence[str] + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + lines[1].strip() + ) + + @classmethod + def from_lines( + cls, + lines: typing.Sequence[str] + ) -> 'SBVCueBlock': + """ + Create a `SBVCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `SBVCueBlock` instance + """ + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) + assert match is not None + start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S.%f').time(), + (match.group(1), match.group(2)) + ) + + payload = lines[1:] + + return cls(start, end, payload) + + +def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse SBV captions from lines of text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + if not _is_valid_content(lines): + raise MalformedFileError('Invalid format') + + return _parse_captions(lines) + + +def _is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid SBV content. + + :param lines: lines of text + :returns: true for a valid SBV content + """ + if len(lines) < 2: + return False + + first_block = next(utils.iter_blocks_of_lines(lines)) + return bool(first_block and SBVCueBlock.is_valid(first_block)) + + +def _parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse captions from the text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + captions = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + if not SBVCueBlock.is_valid(block_lines): + continue + + cue_block = SBVCueBlock.from_lines(block_lines) + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) + + return captions diff --git a/webvtt/segmenter.py b/webvtt/segmenter.py index 4822468..a159fd8 100644 --- a/webvtt/segmenter.py +++ b/webvtt/segmenter.py @@ -1,104 +1,121 @@ +"""Segmenter module.""" + +import typing import os +import pathlib from math import ceil, floor -from .webvtt import WebVTT +from .webvtt import WebVTT, Caption + +DEFAULT_MPEGTS = 900000 +DEFAULT_SECONDS = 10 # default number of seconds per segment + + +def segment( + webvtt_path: str, + output: str, + seconds: int = DEFAULT_SECONDS, + mpegts: int = DEFAULT_MPEGTS + ): + """ + Segment a WebVTT captions file. + + :param webvtt_path: the path to the file + :param output: the path to the destination folder + :param seconds: the number of seconds for each segment + :param mpegts: value for the MPEG-TS + """ + captions = WebVTT.read(webvtt_path).captions + + output_folder = pathlib.Path(output) + os.makedirs(output_folder, exist_ok=True) -MPEGTS = 900000 -SECONDS = 10 # default number of seconds per segment + segments = slice_segments(captions, seconds) + write_segments(output_folder, segments, mpegts) + write_manifest(output_folder, segments, seconds) -class WebVTTSegmenter: +def slice_segments( + captions: typing.Sequence[Caption], + seconds: int + ) -> typing.List[typing.List[Caption]]: + """ + Slice segments of captions based on seconds per segment. + + :param captions: the captions + :param seconds: seconds per segment + :returns: list of lists of `Caption` objects """ - Provides segmentation of WebVTT captions for HTTP Live Streaming (HLS). + total_segments = ( + 0 + if not captions else + int(ceil(captions[-1].end_in_seconds / seconds)) + ) + + segments: typing.List[typing.List[Caption]] = [ + [] for _ in range(total_segments) + ] + + for c in captions: + segment_index_start = floor(c.start_in_seconds / seconds) + segments[segment_index_start].append(c) + + # Also include a caption in other segments based on the end time. + segment_index_end = floor(c.end_in_seconds / seconds) + if segment_index_end > segment_index_start: + for i in range(segment_index_start + 1, segment_index_end + 1): + segments[i].append(c) + + return segments + + +def write_segments( + output_folder: pathlib.Path, + segments: typing.Iterable[typing.Iterable[Caption]], + mpegts: int + ): + """ + Write the segments to the output folder. + + :param output_folder: folder where the segment files will be stored + :param segments: the segments of `Caption` objects + :param mpegts: value for the MPEG-TS + """ + for index, segment in enumerate(segments): + segment_file = output_folder / f'fileSequence{index}.webvtt' + + with open(segment_file, 'w', encoding='utf-8') as f: + f.write('WEBVTT\n') + f.write(f'X-TIMESTAMP-MAP=MPEGTS:{mpegts},' + 'LOCAL:00:00:00.000\n' + ) + + for caption in segment: + f.write('\n{} --> {}\n'.format(caption.start, caption.end)) + f.writelines(f'{line}\n' for line in caption.lines) + + +def write_manifest( + output_folder: pathlib.Path, + segments: typing.Iterable[typing.Iterable[Caption]], + seconds: int + ): + """ + Write the manifest in the output folder. + + :param output_folder: folder where the manifest will be stored + :param segments: the segments of `Caption` objects + :param seconds: the seconds per segment """ - def __init__(self): - self._total_segments = 0 - self._output_folder = '' - self._seconds = 0 - self._mpegts = 0 - self._segments = [] - - def _slice_segments(self, captions): - self._segments = [[] for _ in range(self.total_segments)] - - for c in captions: - segment_index_start = floor(c.start_in_seconds / self.seconds) - self.segments[segment_index_start].append(c) - - # Also include a caption in other segments based on the end time. - segment_index_end = floor(c.end_in_seconds / self.seconds) - if segment_index_end > segment_index_start: - for i in range(segment_index_start + 1, segment_index_end + 1): - self.segments[i].append(c) - - def _write_segments(self): - for index in range(self.total_segments): - segment_file = os.path.join(self._output_folder, - f'fileSequence{index}.webvtt' - ) - - with open(segment_file, 'w', encoding='utf-8') as f: - f.write('WEBVTT\n') - f.write(f'X-TIMESTAMP-MAP=MPEGTS:{self._mpegts},' - 'LOCAL:00:00:00.000\n' - ) - - for caption in self.segments[index]: - f.write('\n{} --> {}\n'.format(caption.start, caption.end)) - f.writelines(f'{line}\n' for line in caption.lines) - - def _write_manifest(self): - manifest_file = os.path.join(self._output_folder, 'prog_index.m3u8') - with open(manifest_file, 'w', encoding='utf-8') as f: - f.write('#EXTM3U\n') - f.write('#EXT-X-TARGETDURATION:{}\n'.format(self.seconds)) - f.write('#EXT-X-VERSION:3\n') - f.write('#EXT-X-PLAYLIST-TYPE:VOD\n') - - for i in range(self.total_segments): - f.write('#EXTINF:30.00000\n') - f.write('fileSequence{}.webvtt\n'.format(i)) - - f.write('#EXT-X-ENDLIST\n') - - def segment( - self, - webvtt_path: str, - output: str, - seconds: int = SECONDS, - mpegts: int = MPEGTS - ): - """Segment the WebVTT based on a number of seconds.""" - captions = WebVTT.read(webvtt_path).captions - - self._total_segments = ( - 0 - if not captions else - int(ceil(captions[-1].end_in_seconds / seconds)) - ) - self._output_folder = output - self._seconds = seconds - self._mpegts = mpegts - - output_folder = os.path.join(os.getcwd(), output) - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - self._slice_segments(captions) - self._write_segments() - self._write_manifest() - - @property - def seconds(self): - """Returns the number of seconds used for segmenting captions.""" - return self._seconds - - @property - def total_segments(self): - """Returns the total of segments.""" - return self._total_segments - - @property - def segments(self): - """Return the list of segments.""" - return self._segments + manifest_file = output_folder / 'prog_index.m3u8' + with open(manifest_file, 'w', encoding='utf-8') as f: + f.write('#EXTM3U\n') + f.write(f'#EXT-X-TARGETDURATION:{seconds}\n') + f.write('#EXT-X-VERSION:3\n') + f.write('#EXT-X-PLAYLIST-TYPE:VOD\n') + + for index, _ in enumerate(segments): + f.write('#EXTINF:30.00000\n') + f.write(f'fileSequence{index}.webvtt\n') + + f.write('#EXT-X-ENDLIST\n') diff --git a/webvtt/srt.py b/webvtt/srt.py new file mode 100644 index 0000000..3dc70c4 --- /dev/null +++ b/webvtt/srt.py @@ -0,0 +1,149 @@ +"""SRT format module.""" + +import typing +import re +from datetime import datetime, time + +from .models import Caption +from .errors import MalformedFileError +from . import utils + + +class SRTCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})' + ) + + def __init__( + self, + index: str, + start: time, + end: time, + payload: typing.Sequence[str] + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.index = index + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Sequence[str] + ) -> 'SRTCueBlock': + """ + Create a `SRTCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `SRTCueBlock` instance + """ + index = lines[0] + + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + assert match is not None + start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S,%f').time(), + (match.group(1), match.group(2)) + ) + + payload = lines[2:] + + return cls(index, start, end, payload) + + +def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse SRT captions from lines of text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + if not is_valid_content(lines): + raise MalformedFileError('Invalid format') + + return parse_captions(lines) + + +def is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid SBV content. + + :param lines: lines of text + :returns: true for a valid SBV content + """ + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + '-->' in lines[1] and + lines[2].strip() + ) + + +def parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse captions from the text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + captions: typing.List[Caption] = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + if not SRTCueBlock.is_valid(block_lines): + continue + + cue_block = SRTCueBlock.from_lines(block_lines) + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) + + return captions + + +def write( + f: typing.IO[str], + captions: typing.Iterable[Caption] + ): + """ + Write captions to an output. + + :param f: file or file-like object + :param captions: Iterable of `Caption` objects + """ + output = [] + for index, caption in enumerate(captions, start=1): + output.extend([ + f'{index}', + '{} --> {}'.format(*map(lambda x: x.replace('.', ','), + (caption.start, caption.end)) + ), + *caption.lines, + '' + ]) + f.write('\n'.join(output).rstrip()) diff --git a/webvtt/structures.py b/webvtt/structures.py deleted file mode 100644 index 7a92da2..0000000 --- a/webvtt/structures.py +++ /dev/null @@ -1,289 +0,0 @@ -import re -import typing -from datetime import datetime, time - -from .errors import MalformedCaptionError - - -class WebVTTCueBlock: - CUE_TIMINGS_PATTERN = re.compile( - r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})' - ) - - def __init__(self, identifier, start, end, payload): - self.identifier = identifier - self.start = start - self.end = end - self.payload = payload - - @classmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - return bool( - ( - len(lines) >= 2 and - re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and - "-->" not in lines[1] - ) or - ( - len(lines) >= 3 and - "-->" not in lines[0] and - re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) and - "-->" not in lines[2] - ) - ) - - @classmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCueBlock': - - identifier = None - start = None - end = None - payload = [] - - for line in lines: - timing_match = re.match(cls.CUE_TIMINGS_PATTERN, line) - if timing_match: - start = timing_match.group(1) - end = timing_match.group(2) - elif not start: - identifier = line - else: - payload.append(line) - - return cls(identifier, start, end, payload) - - -class WebVTTCommentBlock: - COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL) - - def __init__(self, text: str): - self.text = text - - @classmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - return lines[0].startswith('NOTE') - - @classmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTCommentBlock': - match = cls.COMMENT_PATTERN.match('\n'.join(lines)) - return cls(text=match.group(1).strip() if match else '') - - -class WebVTTStyleBlock: - STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL) - - def __init__(self, text): - self.text = text - - @classmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - return lines[0].startswith('STYLE') - - @classmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'WebVTTStyleBlock': - match = cls.STYLE_PATTERN.match('\n'.join(lines)) - return cls(text=match.group(1).strip() if match else '') - - -class SRTCueBlock: - CUE_TIMINGS_PATTERN = re.compile( - r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})' - ) - - def __init__( - self, - index: str, - start: time, - end: time, - payload: typing.Sequence[str] - ): - self.index = index - self.start = start - self.end = end - self.payload = payload - - @classmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - return bool( - len(lines) >= 3 and - lines[0].isdigit() and - re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) - ) - - @classmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'SRTCueBlock': - - index = lines[0] - - match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) - assert match is not None - start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S,%f').time(), - (match.group(1), match.group(2)) - ) - - payload = lines[2:] - - return cls(index, start, end, payload) - - -class SBVCueBlock: - CUE_TIMINGS_PATTERN = re.compile( - r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})' - ) - - def __init__( - self, - start: time, - end: time, - payload: typing.Sequence[str] - ): - self.start = start - self.end = end - self.payload = payload - - @classmethod - def is_valid(cls, lines: typing.Sequence[str]) -> bool: - return bool( - len(lines) >= 2 and - re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and - lines[1].strip() - ) - - @classmethod - def from_lines(cls, lines: typing.Sequence[str]) -> 'SBVCueBlock': - match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) - assert match is not None - start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S.%f').time(), - (match.group(1), match.group(2)) - ) - - payload = lines[1:] - - return cls(start, end, payload) - - -class Caption: - CUE_TEXT_TAGS = re.compile('<.*?>') - - def __init__(self, - start: typing.Optional[typing.Union[str, time]] = None, - end: typing.Optional[typing.Union[str, time]] = None, - text: typing.Optional[typing.Union[str, - typing.Sequence[str] - ]] = None, - identifier: typing.Optional[str] = None - ): - text = text or [] - self.start = start or time() - self.end = end or time() - self.identifier = identifier - self.lines = (text.splitlines() - if isinstance(text, str) - else - list(text) - ) - self.comments: typing.List[str] = [] - - def __repr__(self): - cleaned_text = self.text.replace('\n', '\\n') - return (f'<{self.__class__.__name__} ' - f'start={self.start!r} ' - f'end={self.end!r} ' - f'text={cleaned_text!r} ' - f'identifier={self.identifier!r}>' - ) - - def __str__(self): - cleaned_text = self.text.replace('\n', '\\n') - return f'{self.start} {self.end} {cleaned_text}' - - def __eq__(self, other): - return (self.start == other.start and - self.end == other.end and - self.raw_text == other.raw_text - ) - - def add_line(self, line: str): - self.lines.append(line) - - def stream(self): - yield from self.lines - - @property - def start(self): - return self.format_timestamp(self._start) - - @start.setter - def start(self, value): - self._start = self.parse_timestamp(value) - - @property - def end(self): - return self.format_timestamp(self._end) - - @end.setter - def end(self, value): - self._end = self.parse_timestamp(value) - - @property - def start_in_seconds(self): - return self.time_in_seconds(self._start) - - @property - def end_in_seconds(self): - return self.time_in_seconds(self._end) - - @property - def raw_text(self) -> str: - """Returns the captions lines as a text (may include cue tags)""" - return '\n'.join(self.lines) - - @property - def text(self) -> str: - """Returns the captions lines as a text (without cue tags)""" - return re.sub(self.CUE_TEXT_TAGS, '', self.raw_text) - - @text.setter - def text(self, value: str): - if not isinstance(value, str): - raise AttributeError( - f'String value expected but received {value}.' - ) - - self.lines = value.splitlines() - - @staticmethod - def parse_timestamp(value: typing.Union[str, time]): - if isinstance(value, str): - time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' - try: - return datetime.strptime(value, time_format).time() - except ValueError: - raise MalformedCaptionError(f'Invalid timestamp: {value}') - elif isinstance(value, time): - return value - - raise AttributeError(f'The type {type(value)} is not supported') - - @staticmethod - def format_timestamp(time_obj: time) -> str: - microseconds = int(time_obj.microsecond / 1000) - return f'{time_obj.strftime("%H:%M:%S")}.{microseconds:03d}' - - @staticmethod - def time_in_seconds(time_obj: time) -> int: - return (time_obj.hour * 3600 + - time_obj.minute * 60 + - time_obj.second + - time_obj.microsecond // 1_000_000 - ) - - -class Style: - def __init__(self, text: typing.Union[str, typing.List[str]]): - self.lines = text.splitlines() if isinstance(text, str) else text - self.comments: typing.List[str] = [] - - @property - def text(self): - return '\n'.join(self.lines) diff --git a/webvtt/utils.py b/webvtt/utils.py new file mode 100644 index 0000000..28da8c8 --- /dev/null +++ b/webvtt/utils.py @@ -0,0 +1,104 @@ +"""Utils module.""" + +import typing +import codecs + +CODEC_BOMS = { + 'utf-8': codecs.BOM_UTF8, + 'utf-32-le': codecs.BOM_UTF32_LE, + 'utf-32-be': codecs.BOM_UTF32_BE, + 'utf-16-le': codecs.BOM_UTF16_LE, + 'utf-16-be': codecs.BOM_UTF16_BE +} + + +class FileWrapper: + """File handling functionality with built-in support for Byte OrderMark.""" + + def __init__( + self, + file_path: str, + mode: typing.Optional[str] = None, + encoding: typing.Optional[str] = None + ): + """ + Initialize. + + :param file_path: path to the file + :param mode: mode in which the file is opened + :param encoding: name of the encoding used to decode the file + """ + self.file_path = file_path + self.mode = mode or 'r' + self.bom_encoding = self.detect_bom_encoding(file_path) + self.encoding = (self.bom_encoding or + encoding or + 'utf-8' + ) + + @classmethod + def open( + cls, + file_path: str, + mode: typing.Optional[str] = None, + encoding: typing.Optional[str] = None + ) -> 'FileWrapper': + """ + Open a file. + + :param file_path: path to the file + :param mode: mode in which the file is opened + :param encoding: name of the encoding used to decode the file + """ + return cls(file_path, mode, encoding) + + def __enter__(self): + """Enter context.""" + self.file = open( + file=self.file_path, + mode=self.mode, + encoding=self.encoding + ) + if self.bom_encoding: + self.file.seek(len(CODEC_BOMS[self.bom_encoding])) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context.""" + self.file.close() + + @staticmethod + def detect_bom_encoding(file_path: str) -> typing.Optional[str]: + """ + Detect the encoding of a file based on the presence of the BOM. + + :param file_path: path to the file + :returns: the encoding if BOM is found or None. + """ + with open(file_path, mode='rb') as f: + first_bytes = f.read(4) + for encoding, bom in CODEC_BOMS.items(): + if first_bytes.startswith(bom): + return encoding + return None + + +def iter_blocks_of_lines( + lines: typing.Iterable[str] + ) -> typing.Generator[typing.List[str], None, None]: + """ + Iterate blocks of text. + + :param lines: lines of text. + """ + current_text_block = [] + + for line in lines: + if line.strip(): + current_text_block.append(line) + elif current_text_block: + yield current_text_block + current_text_block = [] + + if current_text_block: + yield current_text_block diff --git a/webvtt/vtt.py b/webvtt/vtt.py new file mode 100644 index 0000000..945d46f --- /dev/null +++ b/webvtt/vtt.py @@ -0,0 +1,276 @@ +"""VTT format module.""" + +import re +import typing +from .errors import MalformedFileError +from .models import Caption, Style +from . import utils + + +class WebVTTCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})' + ) + + def __init__( + self, + identifier, + start, + end, + payload + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.identifier = identifier + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + ( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + "-->" not in lines[1] + ) or + ( + len(lines) >= 3 and + "-->" not in lines[0] and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) and + "-->" not in lines[2] + ) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTCueBlock': + """ + Create a `WebVTTCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTCueBlock` instance + """ + identifier = None + start = None + end = None + payload = [] + + for line in lines: + timing_match = re.match(cls.CUE_TIMINGS_PATTERN, line) + if timing_match: + start = timing_match.group(1) + end = timing_match.group(2) + elif not start: + identifier = line + else: + payload.append(line) + + return cls(identifier, start, end, payload) + + +class WebVTTCommentBlock: + """Representation of a comment block.""" + + COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text: str): + """ + Initialize. + + :param text: comment text + """ + self.text = text + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a comment block. + + :param lines: the lines to be validated + :returns: true for a matching comment block + """ + return bool(lines and lines[0].startswith('NOTE')) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTCommentBlock': + """ + Create a `WebVTTCommentBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTCommentBlock` instance + """ + match = cls.COMMENT_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + +class WebVTTStyleBlock: + """Representation of a style block.""" + + STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text: str): + """ + Initialize. + + :param text: style text + """ + self.text = text + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a style block. + + :param lines: the lines to be validated + :returns: true for a matching style block + """ + return (len(lines) >= 2 and + lines[0] == 'STYLE' and + not any(line.strip() == '' or '-->' in line for line in lines) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTStyleBlock': + """ + Create a `WebVTTStyleBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTStyleBlock` instance + """ + match = cls.STYLE_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + +def parse( + lines: typing.Sequence[str] + ) -> typing.Tuple[typing.List[Caption], typing.List[Style]]: + """ + Parse VTT captions from lines of text. + + :param lines: lines of text + :returns: tuple of a list of `Caption` objects and a list of `Style` + objects + """ + if not is_valid_content(lines): + raise MalformedFileError('Invalid format') + + items = parse_captions(lines) + return ([item for item in items if isinstance(item, Caption)], + [item for item in items if isinstance(item, Style)] + ) + + +def is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid VTT content. + + :param lines: lines of text + :returns: true for a valid VTT content + """ + return bool(lines and lines[0].startswith('WEBVTT')) + + +def parse_captions( + lines: typing.Sequence[str] + ) -> typing.List[typing.Union[Caption, Style]]: + """ + Parse captions from the text. + + :param lines: lines of text + :returns: tuple of a list of `Caption` objects and a list of `Style` + objects + """ + items: typing.List[typing.Union[Caption, Style]] = [] + comments: typing.List[WebVTTCommentBlock] = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + if WebVTTCueBlock.is_valid(block_lines): + cue_block = WebVTTCueBlock.from_lines(block_lines) + caption = Caption(cue_block.start, + cue_block.end, + cue_block.payload, + cue_block.identifier + ) + + if comments: + caption.comments = [comment.text for comment in comments] + comments = [] + items.append(caption) + + elif WebVTTCommentBlock.is_valid(block_lines): + comments.append(WebVTTCommentBlock.from_lines(block_lines)) + + elif WebVTTStyleBlock.is_valid(block_lines): + style = Style(WebVTTStyleBlock.from_lines(block_lines).text) + if comments: + style.comments = [comment.text for comment in comments] + comments = [] + items.append(style) + + if comments and items: + items[-1].comments.extend( + [comment.text for comment in comments] + ) + + return items + + +def write( + f: typing.IO[str], + captions: typing.Iterable[Caption] + ): + """ + Write captions to an output. + + :param f: file or file-like object + :param captions: Iterable of `Caption` objects + """ + f.write(to_str(captions)) + + +def to_str(captions: typing.Iterable[Caption]) -> str: + """ + Convert captions to a string with webvtt format. + + :returns: String of the captions with WebVTT format. + """ + output = ['WEBVTT'] + for caption in captions: + output.extend([ + '', + *(identifier for identifier in {caption.identifier} if identifier), + f'{caption.start} --> {caption.end}', + *caption.lines + ]) + return '\n'.join(output) diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 6776b85..9675226 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -1,21 +1,15 @@ +"""WebVTT module.""" + import os import typing -import codecs import warnings -from . import writers -from .parsers import WebVTTParser, SRTParser, SBVParser, Parser -from .structures import Caption, Style +from . import vtt, utils +from . import srt +from . import sbv +from .models import Caption, Style from .errors import MissingFilenameError -CODEC_BOMS = { - 'utf-8-sig': codecs.BOM_UTF8, - 'utf-32-le': codecs.BOM_UTF32_LE, - 'utf-32-be': codecs.BOM_UTF32_BE, - 'utf-16-le': codecs.BOM_UTF16_LE, - 'utf-16-be': codecs.BOM_UTF16_BE -} - class WebVTT: """ @@ -31,25 +25,40 @@ class WebVTT: WebVTT.from_sbv('captions.sbv') """ - def __init__(self, - file: typing.Optional[str] = None, - captions: typing.Optional[typing.List[Caption]] = None, - styles: typing.Optional[typing.List[Style]] = None - ): + def __init__( + self, + file: typing.Optional[str] = None, + captions: typing.Optional[typing.List[Caption]] = None, + styles: typing.Optional[typing.List[Style]] = None, + bom: bool = False + ): + """ + Initialize. + + :param file: the path of the WebVTT file + :param captions: the list of captions + :param styles: the list of styles + :param bom: include Byte Order Mark. Default is not to include it. + """ self.file = file self.captions = captions or [] self.styles = styles or [] + self._bom_encoding = None def __len__(self): + """Return the number of captions.""" return len(self.captions) def __getitem__(self, index): + """Return a caption by index.""" return self.captions[index] def __repr__(self): + """Return the string representation of the WebVTT file.""" return f'<{self.__class__.__name__} file={self.file}>' def __str__(self): + """Return a readable representation of the WebVTT content.""" return '\n'.join(str(c) for c in self.captions) @classmethod @@ -58,17 +67,33 @@ def read( file: str, encoding: typing.Optional[str] = None ) -> 'WebVTT': - """Read a WebVTT captions file.""" - with cls._open_file(file, encoding=encoding) as f: - return cls.from_buffer(f) + """ + Read a WebVTT captions file. + + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + instance = cls.from_buffer(fw.file) + instance._bom_encoding = fw.bom_encoding + return instance @classmethod - def read_buffer(cls, buffer: typing.Iterator[str]) -> 'WebVTT': + def read_buffer( + cls, + buffer: typing.Iterator[str] + ) -> 'WebVTT': """ - [DEPRECATED] Read WebVTT captions from a file-like object. + Read WebVTT captions from a file-like object. + + This method is DEPRECATED. Use from_buffer instead. Such file-like object may be the return of an io.open call, io.StringIO object, tempfile.TemporaryFile object, etc. + + :param buffer: the file-like object to read captions from + :returns: a `WebVTT` instance """ warnings.warn( 'Deprecated: use from_buffer instead.', @@ -77,19 +102,26 @@ def read_buffer(cls, buffer: typing.Iterator[str]) -> 'WebVTT': return cls.from_buffer(buffer) @classmethod - def from_buffer(cls, buffer: typing.Iterator[str]) -> 'WebVTT': + def from_buffer( + cls, + buffer: typing.Iterator[str] + ) -> 'WebVTT': """ Read WebVTT captions from a file-like object. Such file-like object may be the return of an io.open call, io.StringIO object, tempfile.TemporaryFile object, etc. + + :param buffer: the file-like object to read captions from + :returns: a `WebVTT` instance """ - items = cls._parse_content(buffer, parser=WebVTTParser) + captions, styles = vtt.parse(cls._get_lines(buffer)) - return cls(file=getattr(buffer, 'name', None), - captions=items[0], - styles=items[1] - ) + return cls( + file=getattr(buffer, 'name', None), + captions=captions, + styles=styles + ) @classmethod def from_srt( @@ -97,11 +129,18 @@ def from_srt( file: str, encoding: typing.Optional[str] = None ) -> 'WebVTT': - """Read captions from a file in SubRip format.""" - with cls._open_file(file, encoding=encoding) as f: - return cls(file=f.name, - captions=cls._parse_content(f, parser=SRTParser)[0], - ) + """ + Read captions from a file in SubRip format. + + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + return cls( + file=fw.file.name, + captions=srt.parse(cls._get_lines(fw.file)) + ) @classmethod def from_sbv( @@ -109,67 +148,67 @@ def from_sbv( file: str, encoding: typing.Optional[str] = None ) -> 'WebVTT': - """Read captions from a file in YouTube SBV format.""" - with cls._open_file(file, encoding=encoding) as f: - return cls(file=f.name, - captions=cls._parse_content(f, parser=SBVParser)[0], - ) + """ + Read captions from a file in YouTube SBV format. - @classmethod - def from_string(cls, string: str) -> 'WebVTT': - return cls( - captions=cls._parse_content(string.splitlines(), - parser=WebVTTParser - )[0] + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + return cls( + file=fw.file.name, + captions=sbv.parse(cls._get_lines(fw.file)), ) @classmethod - def _open_file( - cls, - file_path: str, - encoding: typing.Optional[str] = None - ) -> typing.IO: - return open( - file_path, - encoding=encoding or cls._detect_encoding(file_path) or 'utf-8' - ) + def from_string(cls, string: str) -> 'WebVTT': + """ + Read captions from a string. - @classmethod - def _parse_content( - cls, - content: typing.Iterable[str], - parser: typing.Type[Parser] - ) -> typing.Tuple[typing.List[Caption], typing.List[Style]]: - lines = [line.rstrip('\n\r') for line in content] - items = parser.parse(lines) - return (list(filter(lambda c: isinstance(c, Caption), items)), - list(filter(lambda s: isinstance(s, Style), items)) - ) + :param string: the captions in a string + :returns: a `WebVTT` instance + """ + captions, styles = vtt.parse(cls._get_lines(string.splitlines())) + return cls( + captions=captions, + styles=styles + ) @staticmethod - def _detect_encoding(file_path: str) -> typing.Optional[str]: - with open(file_path, mode='rb') as f: - first_bytes = f.read(4) - for encoding, bom in CODEC_BOMS.items(): - if first_bytes.startswith(bom): - return encoding - return None + def _get_lines(lines: typing.Iterable[str]) -> typing.List[str]: + """ + Return cleaned lines from an iterable of lines. + + :param lines: iterable of lines + :returns: a list of cleaned lines + """ + return [line.rstrip('\n\r') for line in lines] def _get_destination_file( self, destination_path: typing.Optional[str] = None, extension: str = 'vtt' - ): + ) -> str: + """ + Return the destination file based on the provided params. + + :param destination_path: optional destination path + :param extension: the extension of the file + :returns: the destination file + + :raises MissingFilenameError: if destination path cannot be determined + """ if not destination_path and not self.file: raise MissingFilenameError - assert self.file is not None - - destination_path = ( - destination_path or + if not destination_path and self.file: + destination_path = ( f'{os.path.splitext(self.file)[0]}.{extension}' ) + assert destination_path is not None + target = os.path.join(os.getcwd(), destination_path) if os.path.isdir(target): if not self.file: @@ -186,31 +225,57 @@ def _get_destination_file( # store the file in the specified full path return target - def save(self, output: typing.Optional[str] = None): + def save( + self, + output: typing.Optional[str] = None + ): """ Save the WebVTT captions to a file. - If no output is provided the file will be saved in the same location. - Otherwise output can determine a target directory or file. + :param output: destination path of the file + + :raises MissingFilenameError: if output cannot be determined """ destination_file = self._get_destination_file(output) with open(destination_file, 'w', encoding='utf-8') as f: - self.write(f) + vtt.write(f, self.captions) self.file = destination_file - def save_as_srt(self, output: typing.Optional[str] = None): + def save_as_srt( + self, + output: typing.Optional[str] = None + ): + """ + Save the WebVTT captions to a file in SubRip format. + + :param output: destination path of the file + + :raises MissingFilenameError: if output cannot be determined + """ dest_file = self._get_destination_file(output, extension='srt') with open(dest_file, 'w', encoding='utf-8') as f: - self.write(f, format='srt') + srt.write(f, self.captions) self.file = dest_file - def write(self, f: typing.IO[str], format: str = 'vtt'): + def write( + self, + f: typing.IO[str], + format: str = 'vtt' + ): + """ + Save the WebVTT captions to a file-like object. + + :param f: destination file-like object + :param format: the format to use (`vtt` or `srt`) + + :raises MissingFilenameError: if output cannot be determined + """ if format == 'vtt': - writers.write_vtt(f, self.captions) + return vtt.write(f, self.captions) elif format == 'srt': - writers.write_srt(f, self.captions) - else: - raise ValueError(f'Format {format!r} is not supported') + return srt.write(f, self.captions) + + raise ValueError(f'Format {format} is not supported.') @property def total_length(self): @@ -228,6 +293,6 @@ def content(self) -> str: Return the webvtt capions as string. This property is useful in cases where the webvtt content is needed - but no file saving on the system is required. + but no file-like destination is required. Storage in DB for instance. """ - return writers.webvtt_content(self.captions) + return vtt.to_str(self.captions) diff --git a/webvtt/writers.py b/webvtt/writers.py deleted file mode 100644 index e15ec50..0000000 --- a/webvtt/writers.py +++ /dev/null @@ -1,34 +0,0 @@ -import typing - -from .structures import Caption - - -def write_vtt(f: typing.IO[str], captions: typing.Iterable[Caption]): - f.write(webvtt_content(captions)) - - -def webvtt_content(captions: typing.Iterable[Caption]) -> str: - """Return captions content with webvtt formatting.""" - output = ['WEBVTT'] - for caption in captions: - output.extend([ - '', - *(identifier for identifier in {caption.identifier} if identifier), - f'{caption.start} --> {caption.end}', - *caption.lines - ]) - return '\n'.join(output) - - -def write_srt(f: typing.IO[str], captions: typing.Iterable[Caption]): - output = [] - for index, caption in enumerate(captions, start=1): - output.extend([ - f'{index}', - '{} --> {}'.format(*map(lambda x: x.replace('.', ','), - (caption.start, caption.end)) - ), - *caption.lines, - '' - ]) - f.write('\n'.join(output).rstrip()) From b38fe2559ec3b9ba55b1ea6f16acd6306e951fd6 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 7 May 2024 10:50:42 +0200 Subject: [PATCH 04/16] Add BOM handling in save --- tests/test_webvtt.py | 145 ++++++++++++++++++++++++++++++++++++++++++- webvtt/webvtt.py | 30 ++++++--- 2 files changed, 167 insertions(+), 8 deletions(-) diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index adbe9df..dd0b685 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -8,6 +8,7 @@ import webvtt from webvtt.models import Caption, Style +from webvtt.utils import CODEC_BOMS from webvtt.errors import MalformedFileError PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples' @@ -433,7 +434,7 @@ def test_repr(self): test_file = PATH_TO_SAMPLES / 'sample.vtt' self.assertEqual( repr(webvtt.read(test_file)), - f'' + f"" ) def test_str(self): @@ -807,3 +808,145 @@ def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): pathlib.Path(PATH_TO_SAMPLES / 'sample.srt').read_text(), pathlib.Path(f.name).read_text() ) + + def test_save_file_with_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, add_bom=True) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_file_with_bom_keeps_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'captions_with_bom.vtt' + ).save(f.name) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''').strip() + ) + + def test_save_file_with_bom_removes_bom_if_requested(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'captions_with_bom.vtt' + ).save(f.name, add_bom=False) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''').strip() + ) + + def test_save_file_with_encoding(self): + with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, + encoding='utf-32-le' + ) + self.assertEqual( + f.read().decode('utf-32-le'), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_file_with_encoding_and_bom(self): + with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, + encoding='utf-32-le', + add_bom=True + ) + self.assertEqual( + f.read().decode('utf-32-le'), + textwrap.dedent(f''' + {CODEC_BOMS['utf-32-le'].decode('utf-32-le')}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_new_file_utf_8_default_encoding_no_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption text #1' + ) + ) + vtt.save(f.name) + self.assertEqual(vtt.encoding, 'utf-8') + self.assertEqual( + f.read(), + textwrap.dedent(f''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_new_file_utf_8_default_encoding_with_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption text #1' + ) + ) + vtt.save(f.name, + add_bom=True + ) + self.assertEqual(vtt.encoding, 'utf-8') + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 9675226..9b69647 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -10,6 +10,8 @@ from .models import Caption, Style from .errors import MissingFilenameError +DEFAULT_ENCODING = 'utf-8' + class WebVTT: """ @@ -30,7 +32,6 @@ def __init__( file: typing.Optional[str] = None, captions: typing.Optional[typing.List[Caption]] = None, styles: typing.Optional[typing.List[Style]] = None, - bom: bool = False ): """ Initialize. @@ -38,12 +39,12 @@ def __init__( :param file: the path of the WebVTT file :param captions: the list of captions :param styles: the list of styles - :param bom: include Byte Order Mark. Default is not to include it. """ self.file = file self.captions = captions or [] self.styles = styles or [] - self._bom_encoding = None + self._has_bom = False + self.encoding = DEFAULT_ENCODING def __len__(self): """Return the number of captions.""" @@ -55,7 +56,9 @@ def __getitem__(self, index): def __repr__(self): """Return the string representation of the WebVTT file.""" - return f'<{self.__class__.__name__} file={self.file}>' + return (f'<{self.__class__.__name__} file={self.file!r} ' + f'encoding={self.encoding!r}>' + ) def __str__(self): """Return a readable representation of the WebVTT content.""" @@ -76,7 +79,9 @@ def read( """ with utils.FileWrapper.open(file, encoding=encoding) as fw: instance = cls.from_buffer(fw.file) - instance._bom_encoding = fw.bom_encoding + if fw.bom_encoding: + instance.encoding = fw.bom_encoding + instance._has_bom = True return instance @classmethod @@ -227,17 +232,28 @@ def _get_destination_file( def save( self, - output: typing.Optional[str] = None + output: typing.Optional[str] = None, + encoding: typing.Optional[str] = None, + add_bom: typing.Optional[bool] = None ): """ Save the WebVTT captions to a file. :param output: destination path of the file + :param encoding: encoding of the file (defaults to UTF-8) + :param add_bom: save the file with Byte Order Mark :raises MissingFilenameError: if output cannot be determined """ destination_file = self._get_destination_file(output) - with open(destination_file, 'w', encoding='utf-8') as f: + encoding = encoding or self.encoding + if add_bom is None and self._has_bom: + add_bom = True + + with open(destination_file, 'w', encoding=encoding) as f: + if add_bom and encoding in utils.CODEC_BOMS: + f.write(utils.CODEC_BOMS[encoding].decode(encoding)) + vtt.write(f, self.captions) self.file = destination_file From 5c67727618836a1bc673dc0261e6f792936d0602 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 7 May 2024 12:07:39 +0200 Subject: [PATCH 05/16] Add Radon and fix complexity --- setup.cfg | 2 ++ webvtt/vtt.py | 57 +++++++++++++++++++++++++++++---------------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0fdf338..4833f7c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,7 @@ description_file = README.rst [flake8] doctests = true +radon-max-cc=10 [tox:tox] envlist = codestyle,py,coverage @@ -25,6 +26,7 @@ commands = deps = flake8 flake8-docstrings + radon mypy commands = flake8 webvtt setup.py diff --git a/webvtt/vtt.py b/webvtt/vtt.py index 945d46f..75ad52b 100644 --- a/webvtt/vtt.py +++ b/webvtt/vtt.py @@ -185,7 +185,7 @@ def parse( if not is_valid_content(lines): raise MalformedFileError('Invalid format') - items = parse_captions(lines) + items = parse_items(lines) return ([item for item in items if isinstance(item, Caption)], [item for item in items if isinstance(item, Style)] ) @@ -201,43 +201,27 @@ def is_valid_content(lines: typing.Sequence[str]) -> bool: return bool(lines and lines[0].startswith('WEBVTT')) -def parse_captions( +def parse_items( lines: typing.Sequence[str] ) -> typing.List[typing.Union[Caption, Style]]: """ - Parse captions from the text. + Parse items from the text. :param lines: lines of text - :returns: tuple of a list of `Caption` objects and a list of `Style` - objects + :returns: a list of `Caption` objects or `Style` objects """ items: typing.List[typing.Union[Caption, Style]] = [] comments: typing.List[WebVTTCommentBlock] = [] for block_lines in utils.iter_blocks_of_lines(lines): - if WebVTTCueBlock.is_valid(block_lines): - cue_block = WebVTTCueBlock.from_lines(block_lines) - caption = Caption(cue_block.start, - cue_block.end, - cue_block.payload, - cue_block.identifier - ) - - if comments: - caption.comments = [comment.text for comment in comments] - comments = [] - items.append(caption) - + item = parse_item(block_lines) + if item: + item.comments = [comment.text for comment in comments] + comments = [] + items.append(item) elif WebVTTCommentBlock.is_valid(block_lines): comments.append(WebVTTCommentBlock.from_lines(block_lines)) - elif WebVTTStyleBlock.is_valid(block_lines): - style = Style(WebVTTStyleBlock.from_lines(block_lines).text) - if comments: - style.comments = [comment.text for comment in comments] - comments = [] - items.append(style) - if comments and items: items[-1].comments.extend( [comment.text for comment in comments] @@ -246,6 +230,29 @@ def parse_captions( return items +def parse_item( + lines: typing.Sequence[str] + ) -> typing.Union[Caption, Style, None]: + """ + Parse an item from lines of text. + + :param lines: lines of text + :returns: An item (Caption or Style) if found, otherwise None + """ + if WebVTTCueBlock.is_valid(lines): + cue_block = WebVTTCueBlock.from_lines(lines) + return Caption(cue_block.start, + cue_block.end, + cue_block.payload, + cue_block.identifier + ) + + if WebVTTStyleBlock.is_valid(lines): + return Style(WebVTTStyleBlock.from_lines(lines).text) + + return None + + def write( f: typing.IO[str], captions: typing.Iterable[Caption] From 6659a6d7754d871230d6ade2d7fad56eed84a61b Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Wed, 8 May 2024 13:06:16 +0200 Subject: [PATCH 06/16] Add slice captions --- tests/test_webvtt.py | 65 ++++++++++++++++++++++++++++++++++++++++++++ webvtt/models.py | 31 +++++++++------------ webvtt/utils.py | 18 ++++++++++++ webvtt/webvtt.py | 40 +++++++++++++++++++++------ 4 files changed, 128 insertions(+), 26 deletions(-) diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index dd0b685..fed1315 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -5,6 +5,7 @@ import warnings import tempfile import pathlib +from datetime import time import webvtt from webvtt.models import Caption, Style @@ -950,3 +951,67 @@ def test_save_new_file_utf_8_default_encoding_with_bom(self): Caption text #1 ''').strip() ) + + def test_iter_slice(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(start_time=time(second=11), + end_time=time(second=27) + ) + for expected_caption in (vtt.captions[2], + vtt.captions[3], + vtt.captions[4] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) + + def test_iter_slice_timestamp_string(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(start_time='00:00:11.000', + end_time='00:00:27.000' + ) + for expected_caption in (vtt.captions[2], + vtt.captions[3], + vtt.captions[4] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) + + def test_iter_slice_no_start_time(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(end_time=time(second=27)) + for expected_caption in (vtt.captions[0], + vtt.captions[1], + vtt.captions[2], + vtt.captions[3], + vtt.captions[4] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) + + def test_iter_slice_no_end_time(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(start_time=time(second=47)) + for expected_caption in (vtt.captions[11], + vtt.captions[12], + vtt.captions[13], + vtt.captions[14], + vtt.captions[15] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) diff --git a/webvtt/models.py b/webvtt/models.py index 4bf1b33..bc62d56 100644 --- a/webvtt/models.py +++ b/webvtt/models.py @@ -2,8 +2,9 @@ import re import typing -from datetime import datetime, time +from datetime import time +from . import utils from .errors import MalformedCaptionError @@ -68,32 +69,32 @@ def __eq__(self, other): @property def start(self): """Return the start time of the caption.""" - return self.format_timestamp(self._start) + return self.format_timestamp(self.start_time) @start.setter def start(self, value: typing.Union[str, time]): """Set the start time of the caption.""" - self._start = self.parse_timestamp(value) + self.start_time = self.parse_timestamp(value) @property def end(self): """Return the end time of the caption.""" - return self.format_timestamp(self._end) + return self.format_timestamp(self.end_time) @end.setter def end(self, value: typing.Union[str, time]): """Set the end time of the caption.""" - self._end = self.parse_timestamp(value) + self.end_time = self.parse_timestamp(value) @property def start_in_seconds(self) -> int: """Return the start time of the caption in seconds.""" - return self.time_in_seconds(self._start) + return self.time_in_seconds(self.start_time) @property def end_in_seconds(self): """Return the end time of the caption in seconds.""" - return self.time_in_seconds(self._end) + return self.time_in_seconds(self.end_time) @property def raw_text(self) -> str: @@ -117,17 +118,11 @@ def text(self, value: str): @staticmethod def parse_timestamp(value: typing.Union[str, time]) -> time: - """Return timestamp as time object if in string format.""" - if isinstance(value, str): - time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' - try: - return datetime.strptime(value, time_format).time() - except ValueError: - raise MalformedCaptionError(f'Invalid timestamp: {value}') - elif isinstance(value, time): - return value - - raise TypeError(f'The type {type(value)} is not supported') + """Parse the provided value as timestamp.""" + try: + return utils.parse_timestamp(value) + except ValueError: + raise MalformedCaptionError(f'Invalid timestamp: {value}') @staticmethod def format_timestamp(time_obj: time) -> str: diff --git a/webvtt/utils.py b/webvtt/utils.py index 28da8c8..8447628 100644 --- a/webvtt/utils.py +++ b/webvtt/utils.py @@ -2,6 +2,7 @@ import typing import codecs +from datetime import datetime, time CODEC_BOMS = { 'utf-8': codecs.BOM_UTF8, @@ -102,3 +103,20 @@ def iter_blocks_of_lines( if current_text_block: yield current_text_block + + +def parse_timestamp(value: typing.Union[str, time]) -> time: + """ + Return timestamp as time object if in string format. + + :param value: value to be parsed as timestamp + :raises ValueError: if the value cannot be parsed as timestamp + :raises TypeError: when the type of the value provided is not supported + """ + if isinstance(value, str): + time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' + return datetime.strptime(value, time_format).time() + elif isinstance(value, time): + return value + + raise TypeError(f'The type {type(value)} is not supported') diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 9b69647..3c01e47 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -3,6 +3,7 @@ import os import typing import warnings +from datetime import time from . import vtt, utils from . import srt @@ -240,38 +241,39 @@ def save( Save the WebVTT captions to a file. :param output: destination path of the file - :param encoding: encoding of the file (defaults to UTF-8) + :param encoding: encoding of the file :param add_bom: save the file with Byte Order Mark :raises MissingFilenameError: if output cannot be determined """ - destination_file = self._get_destination_file(output) + self.file = self._get_destination_file(output) encoding = encoding or self.encoding + if add_bom is None and self._has_bom: add_bom = True - with open(destination_file, 'w', encoding=encoding) as f: + with open(self.file, 'w', encoding=encoding) as f: if add_bom and encoding in utils.CODEC_BOMS: f.write(utils.CODEC_BOMS[encoding].decode(encoding)) vtt.write(f, self.captions) - self.file = destination_file def save_as_srt( self, - output: typing.Optional[str] = None + output: typing.Optional[str] = None, + encoding: typing.Optional[str] = DEFAULT_ENCODING ): """ Save the WebVTT captions to a file in SubRip format. :param output: destination path of the file + :param encoding: encoding of the file :raises MissingFilenameError: if output cannot be determined """ - dest_file = self._get_destination_file(output, extension='srt') - with open(dest_file, 'w', encoding='utf-8') as f: + self.file = self._get_destination_file(output, extension='srt') + with open(self.file, 'w', encoding=encoding) as f: srt.write(f, self.captions) - self.file = dest_file def write( self, @@ -293,6 +295,28 @@ def write( raise ValueError(f'Format {format} is not supported.') + def iter_slice( + self, + start_time: typing.Optional[typing.Union[str, time]] = None, + end_time: typing.Optional[typing.Union[str, time]] = None + ) -> typing.Generator[Caption, None, None]: + """ + Iterate a slice of the captions based on a time range. + + :param start_time: start time of the range + :param end_time: end time of the range + :returns: generator of Captions + """ + start_time = utils.parse_timestamp(start_time) if start_time else None + end_time = utils.parse_timestamp(end_time) if end_time else None + + for caption in self.captions: + if ( + (not start_time or caption.start_time >= start_time) and + (not end_time or caption.end_time <= end_time) + ): + yield caption + @property def total_length(self): """Returns the total length of the captions.""" From 6a92fd3fb428dd2367c9d06e7a406d0aa9c0655b Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Thu, 9 May 2024 13:17:09 +0200 Subject: [PATCH 07/16] Add styles and comments in output file --- tests/samples/comments.vtt | 4 +- tests/samples/styles_with_comments.vtt | 2 +- tests/test_vtt.py | 147 ++++++++++++++++-- tests/test_webvtt.py | 203 +++++++++++++++++++++++-- webvtt/vtt.py | 159 ++++++++++++++++--- webvtt/webvtt.py | 48 ++++-- 6 files changed, 508 insertions(+), 55 deletions(-) diff --git a/tests/samples/comments.vtt b/tests/samples/comments.vtt index 3847e02..8cd3032 100644 --- a/tests/samples/comments.vtt +++ b/tests/samples/comments.vtt @@ -18,4 +18,6 @@ NOTE This last line may not translate well. 3 00:02:25.000 --> 00:02:30.000 -- Ta en kopp \ No newline at end of file +- Ta en kopp + +NOTE end of file \ No newline at end of file diff --git a/tests/samples/styles_with_comments.vtt b/tests/samples/styles_with_comments.vtt index 2a69de0..4ae93ac 100644 --- a/tests/samples/styles_with_comments.vtt +++ b/tests/samples/styles_with_comments.vtt @@ -1,6 +1,6 @@ WEBVTT -NOTE This is the first style block +NOTE Sample of comments with styles STYLE ::cue { diff --git a/tests/test_vtt.py b/tests/test_vtt.py index 3f7840d..2abc944 100644 --- a/tests/test_vtt.py +++ b/tests/test_vtt.py @@ -234,7 +234,7 @@ def test_parse_invalid_format(self): ) def test_parse_captions(self): - captions, styles = vtt.parse( + output = vtt.parse( textwrap.dedent(''' WEBVTT @@ -246,6 +246,8 @@ def test_parse_captions(self): Caption text #2 line 2 ''').strip().split('\n') ) + captions = output.captions + styles = output.styles self.assertEqual(len(captions), 2) self.assertEqual(len(styles), 0) self.assertIsInstance(captions[0], Caption) @@ -261,7 +263,7 @@ def test_parse_captions(self): ) def test_parse_styles(self): - captions, styles = vtt.parse( + output = vtt.parse( textwrap.dedent(''' WEBVTT @@ -280,6 +282,8 @@ def test_parse_styles(self): Caption text #1 ''').strip().split('\n') ) + captions = output.captions + styles = output.styles self.assertEqual(len(captions), 1) self.assertEqual(len(styles), 2) self.assertIsInstance(styles[0], Style) @@ -303,11 +307,22 @@ def test_parse_styles(self): ) def test_parse_content(self): - captions, styles = vtt.parse( + output = vtt.parse( textwrap.dedent(''' WEBVTT - NOTE Comment for the style + NOTE This is a testing sample + + NOTE We can see two header comments, a style + comment and finally a footer comments + + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray); + color: papayawhip; + } + + NOTE the following style needs review STYLE ::cue { @@ -327,14 +342,19 @@ def test_parse_content(self): Caption text #2 line 1 Caption text #2 line 2 + NOTE Copyright 2024 + NOTE end of file ''').strip().split('\n') ) + captions = output.captions + styles = output.styles self.assertEqual(len(captions), 2) - self.assertEqual(len(styles), 1) + self.assertEqual(len(styles), 2) self.assertIsInstance(captions[0], Caption) self.assertIsInstance(captions[1], Caption) self.assertIsInstance(styles[0], Style) + self.assertIsInstance(styles[1], Style) self.assertEqual( str(captions[0]), '00:00:00.500 00:00:07.000 Caption text #1' @@ -346,6 +366,15 @@ def test_parse_content(self): ) self.assertEqual( str(styles[0].text), + textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray); + color: papayawhip; + } + ''').strip() + ) + self.assertEqual( + str(styles[1].text), textwrap.dedent(''' ::cue { color: white; @@ -354,7 +383,11 @@ def test_parse_content(self): ) self.assertEqual( styles[0].comments, - ['Comment for the style'] + [] + ) + self.assertEqual( + styles[1].comments, + ['the following style needs review'] ) self.assertEqual( captions[0].comments, @@ -362,10 +395,19 @@ def test_parse_content(self): ) self.assertEqual( captions[1].comments, - ['Comment for the second caption\nthat is very long', - 'end of file' - ] - ) + ['Comment for the second caption\nthat is very long'] + ) + self.assertListEqual(output.header_comments, + ['This is a testing sample', + 'We can see two header comments, a style\n' + 'comment and finally a footer comments' + ] + ) + self.assertListEqual(output.footer_comments, + ['Copyright 2024', + 'end of file' + ] + ) def test_write(self): out = io.StringIO() @@ -381,10 +423,25 @@ def test_write(self): ] ) ] + styles = [ + Style('::cue(b) {\n color: peachpuff;\n}'), + Style('::cue {\n color: papayawhip;\n}') + ] captions[0].comments.append('Comment for the first caption') captions[1].comments.append('Comment for the second caption') + styles[1].comments.append( + 'Comment for the second style\nwith two lines' + ) + header_comments = ['header comment', 'begin of the file'] + footer_comments = ['footer comment', 'end of file'] - vtt.write(out, captions) + vtt.write( + out, + captions, + styles, + header_comments, + footer_comments + ) out.seek(0) @@ -393,12 +450,38 @@ def test_write(self): textwrap.dedent(''' WEBVTT + NOTE header comment + + NOTE begin of the file + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Comment for the second style + with two lines + + STYLE + ::cue { + color: papayawhip; + } + + NOTE Comment for the first caption + 00:00:00.500 --> 00:00:07.000 Caption #1 + NOTE Comment for the second caption + 00:00:07.000 --> 00:00:11.890 Caption #2 line 1 Caption #2 line 2 + + NOTE footer comment + + NOTE end of file ''').strip() ) @@ -415,19 +498,59 @@ def test_to_str(self): ] ) ] + styles = [ + Style('::cue(b) {\n color: peachpuff;\n}'), + Style('::cue {\n color: papayawhip;\n}') + ] captions[0].comments.append('Comment for the first caption') captions[1].comments.append('Comment for the second caption') + styles[1].comments.append( + 'Comment for the second style\nwith two lines' + ) + header_comments = ['header comment', 'begin of the file'] + footer_comments = ['footer comment', 'end of file'] self.assertEqual( - vtt.to_str(captions), + vtt.to_str( + captions, + styles, + header_comments, + footer_comments + ), textwrap.dedent(''' WEBVTT + NOTE header comment + + NOTE begin of the file + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Comment for the second style + with two lines + + STYLE + ::cue { + color: papayawhip; + } + + NOTE Comment for the first caption + 00:00:00.500 --> 00:00:07.000 Caption #1 + NOTE Comment for the second caption + 00:00:07.000 --> 00:00:11.890 Caption #2 line 1 Caption #2 line 2 + + NOTE footer comment + + NOTE end of file ''').strip() ) diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index fed1315..4a0dae6 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -561,25 +561,33 @@ def test_parse_comments(self): vtt.captions[0].lines, ['- Ta en kopp varmt te.', '- Det Ƥr inte varmt.'] - ) + ) self.assertListEqual( vtt.captions[0].comments, - ['This translation was done by Kyle so that\n' - 'some friends can watch it with their parents.' - ] - ) + [] + ) self.assertListEqual( vtt.captions[1].comments, [] - ) + ) self.assertEqual( vtt.captions[2].text, '- Ta en kopp' - ) + ) self.assertListEqual( vtt.captions[2].comments, ['This last line may not translate well.'] - ) + ) + self.assertListEqual( + vtt.header_comments, + ['This translation was done by Kyle so that\n' + 'some friends can watch it with their parents.' + ] + ) + self.assertListEqual( + vtt.footer_comments, + ['end of file'] + ) def test_parse_styles(self): vtt = webvtt.read(PATH_TO_SAMPLES / 'styles.vtt') @@ -596,7 +604,7 @@ def test_parse_styles_with_comments(self): self.assertEqual(len(vtt.styles), 2) self.assertEqual( vtt.styles[0].comments, - ['This is the first style block'] + [] ) self.assertEqual( vtt.styles[0].text, @@ -618,6 +626,183 @@ def test_parse_styles_with_comments(self): ' color: peachpuff;\n' '}' ) + self.assertListEqual( + vtt.header_comments, + ['Sample of comments with styles'] + ) + self.assertListEqual( + vtt.footer_comments, + [] + ) + + def test_multiple_comments_everywhere(self): + vtt = webvtt.WebVTT.from_string(textwrap.dedent(""" + WEBVTT + + NOTE Test file + + NOTE this file is testing multiple comments + in different places + + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + + NOTE this style uses nice color + + NOTE check how it looks before proceeding + + STYLE + ::cue(b) { + color: peachpuff; + } + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + NOTE next caption has two lines + + NOTE needs review + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + NOTE Copyright 2024 + + NOTE this is the end of the file + """).strip()) + + self.assertListEqual( + vtt.header_comments, + ['Test file', + 'this file is testing multiple comments\nin different places' + ] + ) + self.assertListEqual( + vtt.footer_comments, + ['Copyright 2024', + 'this is the end of the file' + ] + ) + self.assertListEqual( + vtt.captions[0].comments, + [] + ) + self.assertListEqual( + vtt.captions[1].comments, + ['next caption has two lines', + 'needs review' + ] + ) + self.assertListEqual( + vtt.captions[2].comments, + [] + ) + self.assertListEqual( + vtt.captions[3].comments, + [] + ) + self.assertListEqual( + vtt.styles[0].comments, + [] + ) + self.assertListEqual( + vtt.styles[1].comments, + ['this style uses nice color', + 'check how it looks before proceeding' + ] + ) + + def test_comments_in_new_file(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.header_comments.append('This is a header comment') + vtt.header_comments.append( + 'where we can see a\ntwo line comment' + ) + vtt.styles.append( + Style('::cue(b) {\n color: peachpuff;\n}') + ) + style = Style('::cue {\n color: papayawhip;\n}') + style.comments.append('Another style to test\nthe look and feel') + style.comments.append('Please check') + vtt.styles.append(style) + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1', + ) + ) + caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text='Caption #2' + ) + caption.comments.append( + 'Second caption may be a bit off\nand needs checking' + ) + caption.comments.append('Confirm if it displays correctly') + vtt.captions.append(caption) + vtt.footer_comments.append('This is a footer comment') + vtt.footer_comments.append( + 'where we can also see a\ntwo line comment' + ) + + vtt.save(f.name) + self.assertEqual( + f.read(), + textwrap.dedent(''' + WEBVTT + + NOTE This is a header comment + + NOTE + where we can see a + two line comment + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Another style to test + the look and feel + + NOTE Please check + + STYLE + ::cue { + color: papayawhip; + } + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + NOTE + Second caption may be a bit off + and needs checking + + NOTE Confirm if it displays correctly + + 00:00:07.000 --> 00:00:11.890 + Caption #2 + + NOTE This is a footer comment + + NOTE + where we can also see a + two line comment + ''').strip() + ) def test_clean_cue_tags(self): vtt = webvtt.read(PATH_TO_SAMPLES / 'cue_tags.vtt') diff --git a/webvtt/vtt.py b/webvtt/vtt.py index 75ad52b..459ac3e 100644 --- a/webvtt/vtt.py +++ b/webvtt/vtt.py @@ -2,11 +2,42 @@ import re import typing +from dataclasses import dataclass + from .errors import MalformedFileError from .models import Caption, Style from . import utils +@dataclass +class ParserOutput: + """Output of parser.""" + + styles: typing.List[Style] + captions: typing.List[Caption] + header_comments: typing.List[str] + footer_comments: typing.List[str] + + @classmethod + def from_data( + cls, + data: typing.Mapping[str, typing.Any] + ) -> 'ParserOutput': + """ + Return a `ParserOutput` instance from the provided data. + + :param data: data from the parser + :returns: an instance of `ParserOutput` + """ + items = data.get('items', []) + return cls( + captions=[it for it in items if isinstance(it, Caption)], + styles=[it for it in items if isinstance(it, Style)], + header_comments=data.get('header_comments', []), + footer_comments=data.get('footer_comments', []) + ) + + class WebVTTCueBlock: """Representation of a cue timing block.""" @@ -86,6 +117,21 @@ def from_lines( return cls(identifier, start, end, payload) + @staticmethod + def format_lines(caption: Caption) -> typing.List[str]: + """ + Return the lines for a cue block. + + :param caption: the `Caption` instance + :returns: list of lines for a cue block + """ + return [ + '', + *(identifier for identifier in {caption.identifier} if identifier), + f'{caption.start} --> {caption.end}', + *caption.lines + ] + class WebVTTCommentBlock: """Representation of a comment block.""" @@ -127,6 +173,21 @@ def from_lines( match = cls.COMMENT_PATTERN.match('\n'.join(lines)) return cls(text=match.group(1).strip() if match else '') + @staticmethod + def format_lines(lines: str) -> typing.List[str]: + """ + Return the lines for a comment block. + + :param lines: comment lines + :returns: list of lines for a comment block + """ + list_of_lines = lines.split('\n') + + if len(list_of_lines) == 1: + return [f'NOTE {lines}'] + + return ['NOTE', *list_of_lines] + class WebVTTStyleBlock: """Representation of a style block.""" @@ -171,24 +232,30 @@ def from_lines( match = cls.STYLE_PATTERN.match('\n'.join(lines)) return cls(text=match.group(1).strip() if match else '') + @staticmethod + def format_lines(lines: typing.List[str]) -> typing.List[str]: + """ + Return the lines for a style block. + + :param lines: style lines + :returns: list of lines for a style block + """ + return ['STYLE', *lines] + def parse( lines: typing.Sequence[str] - ) -> typing.Tuple[typing.List[Caption], typing.List[Style]]: + ) -> ParserOutput: """ Parse VTT captions from lines of text. :param lines: lines of text - :returns: tuple of a list of `Caption` objects and a list of `Style` - objects + :returns: object `ParserOutput` with all parsed items """ if not is_valid_content(lines): raise MalformedFileError('Invalid format') - items = parse_items(lines) - return ([item for item in items if isinstance(item, Caption)], - [item for item in items if isinstance(item, Style)] - ) + return parse_items(lines) def is_valid_content(lines: typing.Sequence[str]) -> bool: @@ -203,13 +270,14 @@ def is_valid_content(lines: typing.Sequence[str]) -> bool: def parse_items( lines: typing.Sequence[str] - ) -> typing.List[typing.Union[Caption, Style]]: + ) -> ParserOutput: """ Parse items from the text. :param lines: lines of text - :returns: a list of `Caption` objects or `Style` objects + :returns: an object `ParserOutput` with all parsed items """ + header_comments: typing.List[str] = [] items: typing.List[typing.Union[Caption, Style]] = [] comments: typing.List[WebVTTCommentBlock] = [] @@ -222,12 +290,15 @@ def parse_items( elif WebVTTCommentBlock.is_valid(block_lines): comments.append(WebVTTCommentBlock.from_lines(block_lines)) - if comments and items: - items[-1].comments.extend( - [comment.text for comment in comments] - ) + if items: + header_comments, items[0].comments = items[0].comments, header_comments - return items + return ParserOutput.from_data( + {'items': items, + 'header_comments': header_comments, + 'footer_comments': [comment.text for comment in comments] + } + ) def parse_item( @@ -255,29 +326,75 @@ def parse_item( def write( f: typing.IO[str], - captions: typing.Iterable[Caption] + captions: typing.Iterable[Caption], + styles: typing.Iterable[Style], + header_comments: typing.Iterable[str], + footer_comments: typing.Iterable[str] ): """ Write captions to an output. :param f: file or file-like object :param captions: Iterable of `Caption` objects + :param styles: Iterable of `Style` objects + :param header_comments: the comments for the header + :param footer_comments: the comments for the footer """ - f.write(to_str(captions)) + f.write( + to_str(captions, + styles, + header_comments, + footer_comments + ) + ) -def to_str(captions: typing.Iterable[Caption]) -> str: +def to_str( + captions: typing.Iterable[Caption], + styles: typing.Iterable[Style], + header_comments: typing.Iterable[str], + footer_comments: typing.Iterable[str] + ) -> str: """ Convert captions to a string with webvtt format. - :returns: String of the captions with WebVTT format. + :param captions: the iterable of `Caption` objects + :param styles: the iterable of `Style` objects + :param header_comments: the comments for the header + :param footer_comments: the comments for the footer + :returns: String of the content in WebVTT format. """ output = ['WEBVTT'] + + for comment in header_comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + + for style in styles: + for comment in style.comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + output.extend([ + '', + *WebVTTStyleBlock.format_lines(style.lines) + ]) + for caption in captions: + for comment in caption.comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + output.extend(WebVTTCueBlock.format_lines(caption)) + + for comment in footer_comments: output.extend([ '', - *(identifier for identifier in {caption.identifier} if identifier), - f'{caption.start} --> {caption.end}', - *caption.lines + *WebVTTCommentBlock.format_lines(comment) ]) + return '\n'.join(output) diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 3c01e47..1b5110f 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -33,6 +33,8 @@ def __init__( file: typing.Optional[str] = None, captions: typing.Optional[typing.List[Caption]] = None, styles: typing.Optional[typing.List[Style]] = None, + header_comments: typing.Optional[typing.List[str]] = None, + footer_comments: typing.Optional[typing.List[str]] = None, ): """ Initialize. @@ -40,10 +42,14 @@ def __init__( :param file: the path of the WebVTT file :param captions: the list of captions :param styles: the list of styles + :param header_comments: list of comments for the start of the file + :param footer_comments: list of comments for the bottom of the file """ self.file = file self.captions = captions or [] self.styles = styles or [] + self.header_comments = header_comments or [] + self.footer_comments = footer_comments or [] self._has_bom = False self.encoding = DEFAULT_ENCODING @@ -121,12 +127,14 @@ def from_buffer( :param buffer: the file-like object to read captions from :returns: a `WebVTT` instance """ - captions, styles = vtt.parse(cls._get_lines(buffer)) + output = vtt.parse(cls._get_lines(buffer)) return cls( file=getattr(buffer, 'name', None), - captions=captions, - styles=styles + captions=output.captions, + styles=output.styles, + header_comments=output.header_comments, + footer_comments=output.footer_comments ) @classmethod @@ -175,11 +183,13 @@ def from_string(cls, string: str) -> 'WebVTT': :param string: the captions in a string :returns: a `WebVTT` instance """ - captions, styles = vtt.parse(cls._get_lines(string.splitlines())) + output = vtt.parse(cls._get_lines(string.splitlines())) return cls( - captions=captions, - styles=styles - ) + captions=output.captions, + styles=output.styles, + header_comments=output.header_comments, + footer_comments=output.footer_comments + ) @staticmethod def _get_lines(lines: typing.Iterable[str]) -> typing.List[str]: @@ -256,7 +266,13 @@ def save( if add_bom and encoding in utils.CODEC_BOMS: f.write(utils.CODEC_BOMS[encoding].decode(encoding)) - vtt.write(f, self.captions) + vtt.write( + f, + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) def save_as_srt( self, @@ -289,8 +305,13 @@ def write( :raises MissingFilenameError: if output cannot be determined """ if format == 'vtt': - return vtt.write(f, self.captions) - elif format == 'srt': + return vtt.write(f, + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) + if format == 'srt': return srt.write(f, self.captions) raise ValueError(f'Format {format} is not supported.') @@ -335,4 +356,9 @@ def content(self) -> str: This property is useful in cases where the webvtt content is needed but no file-like destination is required. Storage in DB for instance. """ - return vtt.to_str(self.captions) + return vtt.to_str( + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) From 9dfed13fc858efc3a49a138026d06bce3906a68c Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Fri, 10 May 2024 18:04:25 +0200 Subject: [PATCH 08/16] Add timestamp implementation --- tests/test_models.py | 190 +++++++++++++++++++++++++++++++++------- tests/test_sbv.py | 37 ++++++-- tests/test_segmenter.py | 40 ++++++--- tests/test_srt.py | 5 +- tests/test_vtt.py | 8 +- tests/test_webvtt.py | 104 ++++++++++------------ webvtt/models.py | 139 +++++++++++++++++++++-------- webvtt/sbv.py | 12 +-- webvtt/srt.py | 13 ++- webvtt/utils.py | 18 ---- webvtt/webvtt.py | 15 ++-- 11 files changed, 384 insertions(+), 197 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index d36ebfd..2210e57 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,10 +1,139 @@ import unittest -from datetime import time -from webvtt import Caption, Style +from webvtt.models import Timestamp, Caption, Style from webvtt.errors import MalformedCaptionError +class TestTimestamp(unittest.TestCase): + + def test_instantiation(self): + timestamp = Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 12) + self.assertEqual(timestamp.seconds, 23) + self.assertEqual(timestamp.milliseconds, 500) + + def test_from_string(self): + timestamp = Timestamp.from_string('01:24:11.670') + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 24) + self.assertEqual(timestamp.seconds, 11) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_single_digits(self): + timestamp = Timestamp.from_string('1:2:7.670') + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 2) + self.assertEqual(timestamp.seconds, 7) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_missing_hours(self): + timestamp = Timestamp.from_string('24:11.670') + self.assertEqual(timestamp.hours, 0) + self.assertEqual(timestamp.minutes, 24) + self.assertEqual(timestamp.seconds, 11) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_wrong_minutes(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:76:11.670') + + def test_from_string_wrong_seconds(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:24:87.670') + + def test_from_string_wrong_type(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string(1234) + + def test_from_string_wrong_value(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:24:11:670') + + def test_to_tuple(self): + self.assertEqual( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ).to_tuple(), + (1, 12, 23, 500) + ) + + def test_equality(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) == + Timestamp.from_string('01:12:23.500') + ) + + def test_not_equality(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) != + Timestamp.from_string('01:12:23.600') + ) + + def test_greater_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=600 + ) > + Timestamp.from_string('01:12:23.500') + ) + + def test_less_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) < + Timestamp.from_string('01:12:23.600') + ) + + def test_greater_or_equal_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) >= + Timestamp.from_string('01:12:23.500') + ) + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=600 + ) >= + Timestamp.from_string('01:12:23.500') + ) + + def test_less_or_equal_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) <= + Timestamp.from_string('01:12:23.500') + ) + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) <= + Timestamp.from_string('01:12:23.600') + ) + + def test_repr(self): + timestamp = Timestamp( + hours=1, + minutes=12, + seconds=45, + milliseconds=320 + ) + + self.assertEqual( + repr(timestamp), + '' + ) + + class TestCaption(unittest.TestCase): def test_instantiation(self): @@ -19,15 +148,14 @@ def test_instantiation(self): self.assertEqual(caption.text, 'Hello test!') self.assertEqual(caption.identifier, 'A test caption') - def test_timestamp_accept_time(self): - caption = Caption( - start=time(hour=0, minute=0, second=7, microsecond=0), - end=time(hour=0, minute=0, second=11, microsecond=890000), - text='Hello test!', - identifier='A test caption' - ) - self.assertEqual(caption.start, '00:00:07.000') - self.assertEqual(caption.end, '00:00:11.890') + def test_timestamp_wrong_type(self): + with self.assertRaises(MalformedCaptionError): + Caption( + start=1234, + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) def test_identifier_is_optional(self): caption = Caption( @@ -97,16 +225,16 @@ def test_wrong_start_timestamp(self): start='1234', end='00:00:11.890', text='Hello Test!' - ) + ) def test_wrong_type_start_timestamp(self): self.assertRaises( - TypeError, + MalformedCaptionError, Caption, start=1234, end='00:00:11.890', text='Hello Test!' - ) + ) def test_wrong_end_timestamp(self): self.assertRaises( @@ -115,16 +243,16 @@ def test_wrong_end_timestamp(self): start='00:00:07.000', end='1234', text='Hello Test!' - ) + ) def test_wrong_type_end_timestamp(self): self.assertRaises( - TypeError, + MalformedCaptionError, Caption, start='00:00:07.000', end=1234, text='Hello Test!' - ) + ) def test_equality(self): caption1 = Caption( @@ -165,8 +293,8 @@ def test_equality(self): end='00:00:11.890', text='Hello test!', identifier='A test caption' - ) == 1234 - ) + ) == 1234 + ) def test_repr(self): caption = Caption( @@ -174,7 +302,7 @@ def test_repr(self): end='00:00:11.890', text='Hello test!', identifier='A test caption' - ) + ) self.assertEqual( repr(caption), @@ -188,7 +316,7 @@ def test_str(self): end='00:00:11.890', text='Hello test!', identifier='A test caption' - ) + ) self.assertEqual( str(caption), @@ -201,14 +329,14 @@ def test_accept_comments(self): end='00:00:11.890', text='Hello test!', identifier='A test caption' - ) + ) caption.comments.append('One comment') caption.comments.append('Another comment') self.assertListEqual( caption.comments, ['One comment', 'Another comment'] - ) + ) def test_timestamp_update(self): c = Caption('00:00:00.500', '00:00:07.000') @@ -233,7 +361,7 @@ def test_update_text(self): self.assertEqual( c.text, 'Caption line #1 updated' - ) + ) def test_update_text_multiline(self): c = Caption(text='Caption line #1') @@ -242,12 +370,12 @@ def test_update_text_multiline(self): self.assertEqual( len(c.lines), 2 - ) + ) self.assertEqual( c.text, 'Caption line #1\nCaption line #2' - ) + ) def test_update_text_wrong_type(self): c = Caption(text='Caption line #1') @@ -258,7 +386,7 @@ def test_update_text_wrong_type(self): c, 'text', 123 - ) + ) def test_manipulate_lines(self): c = Caption(text=['Caption line #1', 'Caption line #2']) @@ -266,14 +394,14 @@ def test_manipulate_lines(self): self.assertEqual( c.lines[0], 'Caption line #1 updated' - ) + ) def test_malformed_start_timestamp(self): self.assertRaises( MalformedCaptionError, Caption, '01:00' - ) + ) class TestStyle(unittest.TestCase): @@ -302,11 +430,11 @@ def test_accept_comments(self): self.assertListEqual( style.comments, ['One comment', 'Another comment'] - ) + ) def test_get_text(self): style = Style(['::cue(b) {', ' color: peachpuff;', '}']) self.assertEqual( style.text, '::cue(b) {\n color: peachpuff;\n}' - ) + ) diff --git a/tests/test_sbv.py b/tests/test_sbv.py index 79d9dff..d78a9c3 100644 --- a/tests/test_sbv.py +++ b/tests/test_sbv.py @@ -1,6 +1,5 @@ import unittest import textwrap -from datetime import time from webvtt import sbv from webvtt.errors import MalformedFileError @@ -16,6 +15,12 @@ def test_is_valid(self): ''').strip().split('\n')) ) + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 0:0:0.500,0:0:7.000 + Caption #1 + ''').strip().split('\n')) + ) + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' 00:00:00.500,00:00:07.000 Caption #1 line 1 @@ -64,19 +69,35 @@ def test_from_lines(self): Caption #1 line 1 Caption #1 line 2 ''').strip().split('\n') - ) + ) self.assertEqual( cue_block.start, - time(hour=0, minute=0, second=0, microsecond=500000) - ) + '00:00:00.500' + ) self.assertEqual( cue_block.end, - time(hour=0, minute=0, second=7, microsecond=0) - ) + '00:00:07.000' + ) self.assertEqual( cue_block.payload, ['Caption #1 line 1', 'Caption #1 line 2'] - ) + ) + + def test_from_lines_shorter_timestamps(self): + cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent(''' + 0:1:2.500,0:1:03.800 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual( + cue_block.start, + '0:1:2.500' + ) + self.assertEqual( + cue_block.end, + '0:1:03.800' + ) class TestSBVModule(unittest.TestCase): @@ -93,7 +114,7 @@ def test_parse_invalid_format(self): 00:00:07.000,00:00:11.890 Caption text #2 ''').strip().split('\n') - ) + ) def test_parse_captions(self): captions = sbv.parse( diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index 3d290a9..f6ed5b4 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -62,7 +62,8 @@ def test_segmentation_with_defaults(self): fileSequence6.webvtt #EXT-X-ENDLIST ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence0.webvtt').read_text(), textwrap.dedent( @@ -76,7 +77,8 @@ def test_segmentation_with_defaults(self): 00:00:07.000 --> 00:00:11.890 Caption text #2 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence1.webvtt').read_text(), textwrap.dedent( @@ -93,7 +95,8 @@ def test_segmentation_with_defaults(self): 00:00:16.320 --> 00:00:21.580 Caption text #4 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence2.webvtt').read_text(), textwrap.dedent( @@ -113,7 +116,8 @@ def test_segmentation_with_defaults(self): 00:00:27.280 --> 00:00:30.280 Caption text #7 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence3.webvtt').read_text(), textwrap.dedent( @@ -133,7 +137,8 @@ def test_segmentation_with_defaults(self): 00:00:38.870 --> 00:00:45.000 Caption text #10 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence4.webvtt').read_text(), textwrap.dedent( @@ -150,7 +155,8 @@ def test_segmentation_with_defaults(self): 00:00:47.000 --> 00:00:50.970 Caption text #12 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence5.webvtt').read_text(), textwrap.dedent( @@ -170,7 +176,8 @@ def test_segmentation_with_defaults(self): 00:00:58.600 --> 00:01:01.350 Caption text #15 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence6.webvtt').read_text(), textwrap.dedent( @@ -184,7 +191,8 @@ def test_segmentation_with_defaults(self): 00:01:01.350 --> 00:01:04.300 Caption text #16 ''' - ).lstrip()) + ).lstrip() + ) def test_segmentation_with_custom_values(self): segmenter.segment( @@ -224,7 +232,8 @@ def test_segmentation_with_custom_values(self): fileSequence2.webvtt #EXT-X-ENDLIST ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence0.webvtt').read_text(), textwrap.dedent( @@ -253,7 +262,8 @@ def test_segmentation_with_custom_values(self): 00:00:27.280 --> 00:00:30.280 Caption text #7 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence1.webvtt').read_text(), textwrap.dedent( @@ -288,7 +298,8 @@ def test_segmentation_with_custom_values(self): 00:00:58.600 --> 00:01:01.350 Caption text #15 ''' - ).lstrip()) + ).lstrip() + ) self.assertEqual( (output_path / 'fileSequence2.webvtt').read_text(), textwrap.dedent( @@ -302,7 +313,8 @@ def test_segmentation_with_custom_values(self): 00:01:01.350 --> 00:01:04.300 Caption text #16 ''' - ).lstrip()) + ).lstrip() + ) def test_segment_with_no_captions(self): segmenter.segment( @@ -326,5 +338,5 @@ def test_segment_with_no_captions(self): #EXT-X-PLAYLIST-TYPE:VOD #EXT-X-ENDLIST ''' - ).lstrip()) - + ).lstrip() + ) diff --git a/tests/test_srt.py b/tests/test_srt.py index f7f7577..e6a2dd4 100644 --- a/tests/test_srt.py +++ b/tests/test_srt.py @@ -1,7 +1,6 @@ import unittest import io import textwrap -from datetime import time from webvtt import srt from webvtt.errors import MalformedFileError @@ -72,11 +71,11 @@ def test_from_lines(self): self.assertEqual(cue_block.index, '1') self.assertEqual( cue_block.start, - time(hour=0, minute=0, second=0, microsecond=500000) + '00:00:00,500' ) self.assertEqual( cue_block.end, - time(hour=0, minute=0, second=7, microsecond=0) + '00:00:07,000' ) self.assertEqual( cue_block.payload, diff --git a/tests/test_vtt.py b/tests/test_vtt.py index 2abc944..214b71c 100644 --- a/tests/test_vtt.py +++ b/tests/test_vtt.py @@ -372,7 +372,7 @@ def test_parse_content(self): color: papayawhip; } ''').strip() - ) + ) self.assertEqual( str(styles[1].text), textwrap.dedent(''' @@ -380,7 +380,7 @@ def test_parse_content(self): color: white; } ''').strip() - ) + ) self.assertEqual( styles[0].comments, [] @@ -426,7 +426,7 @@ def test_write(self): styles = [ Style('::cue(b) {\n color: peachpuff;\n}'), Style('::cue {\n color: papayawhip;\n}') - ] + ] captions[0].comments.append('Comment for the first caption') captions[1].comments.append('Comment for the second caption') styles[1].comments.append( @@ -501,7 +501,7 @@ def test_to_str(self): styles = [ Style('::cue(b) {\n color: peachpuff;\n}'), Style('::cue {\n color: papayawhip;\n}') - ] + ] captions[0].comments.append('Comment for the first caption') captions[1].comments.append('Comment for the second caption') styles[1].comments.append( diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index 4a0dae6..3111e2f 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -5,7 +5,6 @@ import warnings import tempfile import pathlib -from datetime import time import webvtt from webvtt.models import Caption, Style @@ -33,7 +32,8 @@ def test_from_string(self): 00:00:16.320 --> 00:00:21.580 Caption text #4 - """).strip()) + """).strip() + ) self.assertEqual(len(vtt), 4) self.assertEqual( str(vtt[0]), @@ -114,7 +114,7 @@ def test_write_captions_in_unsupported_format(self): webvtt.WebVTT().write, io.StringIO(), format='ttt' - ) + ) def test_save_captions(self): with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: @@ -144,7 +144,7 @@ def test_save_captions(self): New caption text line1 New caption text line2 ''').strip() - ) + ) def test_srt_conversion(self): with tempfile.TemporaryDirectory() as td: @@ -272,7 +272,8 @@ def test_read_memory_buffer_carriage_return(self): \r 00:00:11.890 --> 00:00:16.320\r Caption text #3\r - ''')) + ''') + ) self.assertEqual( len(webvtt.from_buffer(buffer).captions), @@ -303,7 +304,7 @@ def test_save_no_filename(self): self.assertRaises( webvtt.errors.MissingFilenameError, webvtt.WebVTT().save - ) + ) def test_save_with_path_to_dir_no_filename(self): with tempfile.TemporaryDirectory() as td: @@ -311,14 +312,14 @@ def test_save_with_path_to_dir_no_filename(self): webvtt.errors.MissingFilenameError, webvtt.WebVTT().save, td - ) + ) def test_set_styles_from_text(self): style = Style('::cue(b) {\n color: peachpuff;\n}') self.assertListEqual( style.lines, ['::cue(b) {', ' color: peachpuff;', '}'] - ) + ) def test_save_identifiers(self): with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: @@ -398,8 +399,8 @@ def test_save_updated_identifiers(self): 'last caption', '00:00:27.280 --> 00:00:29.200', 'Caption text #7' - ] - ) + ] + ) def test_content_formatting(self): """ @@ -436,7 +437,7 @@ def test_repr(self): self.assertEqual( repr(webvtt.read(test_file)), f"" - ) + ) def test_str(self): self.assertEqual( @@ -459,33 +460,33 @@ def test_str(self): 00:00:58.600 00:01:01.350 Caption text #15 00:01:01.350 00:01:04.300 Caption text #16 """).strip() - ) + ) def test_parse_invalid_file(self): self.assertRaises( MalformedFileError, webvtt.read, PATH_TO_SAMPLES / 'invalid.vtt' - ) + ) def test_file_not_found(self): self.assertRaises( FileNotFoundError, webvtt.read, 'nowhere' - ) + ) def test_total_length(self): self.assertEqual( webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').total_length, 64 - ) + ) def test_total_length_no_captions(self): self.assertEqual( webvtt.WebVTT().total_length, 0 - ) + ) def test_parse_empty_file(self): self.assertRaises( @@ -533,7 +534,7 @@ def test_parse_timestamp(self): self.assertEqual( Caption(start='02:03:11.890').start_in_seconds, 7391 - ) + ) def test_captions_attribute(self): self.assertListEqual(webvtt.WebVTT().captions, []) @@ -596,7 +597,7 @@ def test_parse_styles(self): vtt.styles[0].text, '::cue {\n background-image: linear-gradient(to bottom, ' 'dimgray, lightgray);\n color: papayawhip;\n}' - ) + ) def test_parse_styles_with_comments(self): vtt = webvtt.read(PATH_TO_SAMPLES / 'styles_with_comments.vtt') @@ -625,7 +626,7 @@ def test_parse_styles_with_comments(self): '::cue(b) {\n' ' color: peachpuff;\n' '}' - ) + ) self.assertListEqual( vtt.header_comments, ['Sample of comments with styles'] @@ -679,7 +680,8 @@ def test_multiple_comments_everywhere(self): NOTE Copyright 2024 NOTE this is the end of the file - """).strip()) + """).strip() + ) self.assertListEqual( vtt.header_comments, @@ -731,7 +733,7 @@ def test_comments_in_new_file(self): ) vtt.styles.append( Style('::cue(b) {\n color: peachpuff;\n}') - ) + ) style = Style('::cue {\n color: papayawhip;\n}') style.comments.append('Another style to test\nthe look and feel') style.comments.append('Please check') @@ -741,7 +743,7 @@ def test_comments_in_new_file(self): end='00:00:07.000', text='Caption #1', ) - ) + ) caption = Caption(start='00:00:07.000', end='00:00:11.890', text='Caption #2' @@ -809,11 +811,11 @@ def test_clean_cue_tags(self): self.assertEqual( vtt.captions[1].text, 'Like a big-a pizza pie' - ) + ) self.assertEqual( vtt.captions[2].text, 'That\'s amore' - ) + ) def test_parse_captions_with_bom(self): vtt = webvtt.read(PATH_TO_SAMPLES / 'captions_with_bom.vtt') @@ -825,14 +827,14 @@ def test_empty_lines_are_not_included_in_result(self): self.assertEqual( vtt.captions[-2].text, "Diez aƱos no son suficientes\npara olvidarte..." - ) + ) def test_can_parse_youtube_dl_files(self): vtt = webvtt.read(PATH_TO_SAMPLES / 'youtube_dl.vtt') self.assertEqual( "this will happen is I'm telling\n ", vtt.captions[2].text - ) + ) class TestParseSRT(unittest.TestCase): @@ -843,7 +845,7 @@ def test_parse_empty_file(self): webvtt.from_srt, # We reuse this file as it is empty and serves the purpose. PATH_TO_SAMPLES / 'empty.vtt' - ) + ) def test_invalid_format(self): for i in range(1, 5): @@ -851,13 +853,13 @@ def test_invalid_format(self): MalformedFileError, webvtt.from_srt, PATH_TO_SAMPLES / f'invalid_format{i}.srt' - ) + ) def test_total_length(self): self.assertEqual( webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').total_length, 23 - ) + ) def test_parse_captions(self): self.assertTrue( @@ -915,26 +917,26 @@ def test_parse_empty_file(self): webvtt.from_sbv, # We reuse this file as it is empty and serves the purpose. PATH_TO_SAMPLES / 'empty.vtt' - ) + ) def test_invalid_format(self): self.assertRaises( MalformedFileError, webvtt.from_sbv, PATH_TO_SAMPLES / 'invalid_format.sbv' - ) + ) def test_total_length(self): self.assertEqual( webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').total_length, 16 - ) + ) def test_parse_captions(self): self.assertEqual( len(webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').captions), 5 - ) + ) def test_missing_timeframe_line(self): self.assertEqual( @@ -978,11 +980,11 @@ def test_get_caption_text_multiline(self): self.assertEqual( vtt.captions[2].text, 'Caption text #3 (line 1)\nCaption text #3 (line 2)' - ) + ) self.assertListEqual( vtt.captions[2].lines, ['Caption text #3 (line 1)', 'Caption text #3 (line 2)'] - ) + ) def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): with tempfile.NamedTemporaryFile('w', suffix='.srt') as f: @@ -1032,7 +1034,7 @@ def test_save_file_with_bom_keeps_bom(self): 00:00:16.320 --> 00:00:21.580 Caption text #4 ''').strip() - ) + ) def test_save_file_with_bom_removes_bom_if_requested(self): with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: @@ -1056,7 +1058,7 @@ def test_save_file_with_bom_removes_bom_if_requested(self): 00:00:16.320 --> 00:00:21.580 Caption text #4 ''').strip() - ) + ) def test_save_file_with_encoding(self): with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: @@ -1073,7 +1075,7 @@ def test_save_file_with_encoding(self): 00:00:00.500 --> 00:00:07.000 Caption text #1 ''').strip() - ) + ) def test_save_file_with_encoding_and_bom(self): with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: @@ -1091,7 +1093,7 @@ def test_save_file_with_encoding_and_bom(self): 00:00:00.500 --> 00:00:07.000 Caption text #1 ''').strip() - ) + ) def test_save_new_file_utf_8_default_encoding_no_bom(self): with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: @@ -1141,24 +1143,8 @@ def test_iter_slice(self): vtt = webvtt.read( PATH_TO_SAMPLES / 'sample.vtt' ) - slice_of_captions = vtt.iter_slice(start_time=time(second=11), - end_time=time(second=27) - ) - for expected_caption in (vtt.captions[2], - vtt.captions[3], - vtt.captions[4] - ): - self.assertIs(expected_caption, next(slice_of_captions)) - - with self.assertRaises(StopIteration): - next(slice_of_captions) - - def test_iter_slice_timestamp_string(self): - vtt = webvtt.read( - PATH_TO_SAMPLES / 'sample.vtt' - ) - slice_of_captions = vtt.iter_slice(start_time='00:00:11.000', - end_time='00:00:27.000' + slice_of_captions = vtt.iter_slice(start='00:00:11.000', + end='00:00:27.000' ) for expected_caption in (vtt.captions[2], vtt.captions[3], @@ -1173,7 +1159,7 @@ def test_iter_slice_no_start_time(self): vtt = webvtt.read( PATH_TO_SAMPLES / 'sample.vtt' ) - slice_of_captions = vtt.iter_slice(end_time=time(second=27)) + slice_of_captions = vtt.iter_slice(end='00:00:27.000') for expected_caption in (vtt.captions[0], vtt.captions[1], vtt.captions[2], @@ -1189,7 +1175,7 @@ def test_iter_slice_no_end_time(self): vtt = webvtt.read( PATH_TO_SAMPLES / 'sample.vtt' ) - slice_of_captions = vtt.iter_slice(start_time=time(second=47)) + slice_of_captions = vtt.iter_slice(start='00:00:47.000') for expected_caption in (vtt.captions[11], vtt.captions[12], vtt.captions[13], diff --git a/webvtt/models.py b/webvtt/models.py index bc62d56..31be12f 100644 --- a/webvtt/models.py +++ b/webvtt/models.py @@ -2,20 +2,108 @@ import re import typing -from datetime import time -from . import utils from .errors import MalformedCaptionError +class Timestamp: + """Representation of a timestamp.""" + + PATTERN = re.compile(r'(?:(\d{1,2}):)?(\d{1,2}):(\d{1,2})\.(\d{3})') + + def __init__( + self, + hours: int = 0, + minutes: int = 0, + seconds: int = 0, + milliseconds: int = 0 + ): + """Initialize.""" + self.hours = hours + self.minutes = minutes + self.seconds = seconds + self.milliseconds = milliseconds + + def __str__(self): + """Return the string representation of the timestamp.""" + return ( + f'{self.hours:02d}:{self.minutes:02d}:{self.seconds:02d}' + f'.{self.milliseconds:03d}' + ) + + def to_tuple(self) -> typing.Tuple[int, int, int, int]: + """Return the timestamp in tuple form.""" + return self.hours, self.minutes, self.seconds, self.milliseconds + + def __repr__(self): + """Return the string representation of the caption.""" + return (f'<{self.__class__.__name__} ' + f'hours={self.hours} ' + f'minutes={self.minutes} ' + f'seconds={self.seconds} ' + f'milliseconds={self.milliseconds}>' + ) + + def __eq__(self, other): + """Compare equality with other object.""" + return self.to_tuple() == other.to_tuple() + + def __ne__(self, other): + """Compare a not equality with other object.""" + return self.to_tuple() != other.to_tuple() + + def __gt__(self, other): + """Compare greater than with other object.""" + return self.to_tuple() > other.to_tuple() + + def __lt__(self, other): + """Compare less than with other object.""" + return self.to_tuple() < other.to_tuple() + + def __ge__(self, other): + """Compare greater or equal with other object.""" + return self.to_tuple() >= other.to_tuple() + + def __le__(self, other): + """Compare less or equal with other object.""" + return self.to_tuple() <= other.to_tuple() + + @classmethod + def from_string(cls, value: str) -> 'Timestamp': + """Return a `Timestamp` instance from a string value.""" + if type(value) is not str: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + match = re.match(cls.PATTERN, value) + if not match: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + hours = int(match.group(1) or 0) + minutes = int(match.group(2)) + seconds = int(match.group(3)) + milliseconds = int(match.group(4)) + + if minutes > 59 or seconds > 59: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + return cls(hours, minutes, seconds, milliseconds) + + def in_seconds(self) -> int: + """Return the timestamp in seconds.""" + return (self.hours * 3600 + + self.minutes * 60 + + self.seconds + ) + + class Caption: """Representation of a caption.""" CUE_TEXT_TAGS = re.compile('<.*?>') def __init__(self, - start: typing.Optional[typing.Union[str, time]] = None, - end: typing.Optional[typing.Union[str, time]] = None, + start: typing.Optional[str] = None, + end: typing.Optional[str] = None, text: typing.Optional[typing.Union[str, typing.Sequence[str] ]] = None, @@ -30,8 +118,8 @@ def __init__(self, :param identifier: optional identifier """ text = text or [] - self.start = start or time() - self.end = end or time() + self.start = start or '00:00:00.000' + self.end = end or '00:00:00.000' self.identifier = identifier self.lines = (text.splitlines() if isinstance(text, str) @@ -69,32 +157,32 @@ def __eq__(self, other): @property def start(self): """Return the start time of the caption.""" - return self.format_timestamp(self.start_time) + return str(self.start_time) @start.setter - def start(self, value: typing.Union[str, time]): + def start(self, value: str): """Set the start time of the caption.""" - self.start_time = self.parse_timestamp(value) + self.start_time = Timestamp.from_string(value) @property def end(self): """Return the end time of the caption.""" - return self.format_timestamp(self.end_time) + return str(self.end_time) @end.setter - def end(self, value: typing.Union[str, time]): + def end(self, value: str): """Set the end time of the caption.""" - self.end_time = self.parse_timestamp(value) + self.end_time = Timestamp.from_string(value) @property def start_in_seconds(self) -> int: """Return the start time of the caption in seconds.""" - return self.time_in_seconds(self.start_time) + return self.start_time.in_seconds() @property def end_in_seconds(self): """Return the end time of the caption in seconds.""" - return self.time_in_seconds(self.end_time) + return self.end_time.in_seconds() @property def raw_text(self) -> str: @@ -116,29 +204,6 @@ def text(self, value: str): self.lines = value.splitlines() - @staticmethod - def parse_timestamp(value: typing.Union[str, time]) -> time: - """Parse the provided value as timestamp.""" - try: - return utils.parse_timestamp(value) - except ValueError: - raise MalformedCaptionError(f'Invalid timestamp: {value}') - - @staticmethod - def format_timestamp(time_obj: time) -> str: - """Format timestamp in string format.""" - microseconds = int(time_obj.microsecond / 1000) - return f'{time_obj.strftime("%H:%M:%S")}.{microseconds:03d}' - - @staticmethod - def time_in_seconds(time_obj: time) -> int: - """Return the time in seconds.""" - return (time_obj.hour * 3600 + - time_obj.minute * 60 + - time_obj.second + - time_obj.microsecond // 1_000_000 - ) - class Style: """Representation of a style.""" diff --git a/webvtt/sbv.py b/webvtt/sbv.py index 0d2cef2..be9c1bf 100644 --- a/webvtt/sbv.py +++ b/webvtt/sbv.py @@ -2,7 +2,6 @@ import typing import re -from datetime import datetime, time from . import utils from .models import Caption @@ -13,13 +12,13 @@ class SBVCueBlock: """Representation of a cue timing block.""" CUE_TIMINGS_PATTERN = re.compile( - r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})' + r'\s*(\d{1,2}:\d{1,2}:\d{1,2}.\d{3}),(\d{1,2}:\d{1,2}:\d{1,2}.\d{3})' ) def __init__( self, - start: time, - end: time, + start: str, + end: str, payload: typing.Sequence[str] ): """ @@ -63,13 +62,10 @@ def from_lines( """ match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) assert match is not None - start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S.%f').time(), - (match.group(1), match.group(2)) - ) payload = lines[1:] - return cls(start, end, payload) + return cls(match.group(1), match.group(2), payload) def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: diff --git a/webvtt/srt.py b/webvtt/srt.py index 3dc70c4..0e0637d 100644 --- a/webvtt/srt.py +++ b/webvtt/srt.py @@ -2,7 +2,6 @@ import typing import re -from datetime import datetime, time from .models import Caption from .errors import MalformedFileError @@ -19,8 +18,8 @@ class SRTCueBlock: def __init__( self, index: str, - start: time, - end: time, + start: str, + end: str, payload: typing.Sequence[str] ): """ @@ -67,13 +66,10 @@ def from_lines( match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) assert match is not None - start, end = map(lambda x: datetime.strptime(x, '%H:%M:%S,%f').time(), - (match.group(1), match.group(2)) - ) payload = lines[2:] - return cls(index, start, end, payload) + return cls(index, match.group(1), match.group(2), payload) def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: @@ -118,6 +114,9 @@ def parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]: continue cue_block = SRTCueBlock.from_lines(block_lines) + cue_block.start, cue_block.end = map( + lambda x: x.replace(',', '.'), (cue_block.start, cue_block.end)) + captions.append(Caption(cue_block.start, cue_block.end, cue_block.payload diff --git a/webvtt/utils.py b/webvtt/utils.py index 8447628..28da8c8 100644 --- a/webvtt/utils.py +++ b/webvtt/utils.py @@ -2,7 +2,6 @@ import typing import codecs -from datetime import datetime, time CODEC_BOMS = { 'utf-8': codecs.BOM_UTF8, @@ -103,20 +102,3 @@ def iter_blocks_of_lines( if current_text_block: yield current_text_block - - -def parse_timestamp(value: typing.Union[str, time]) -> time: - """ - Return timestamp as time object if in string format. - - :param value: value to be parsed as timestamp - :raises ValueError: if the value cannot be parsed as timestamp - :raises TypeError: when the type of the value provided is not supported - """ - if isinstance(value, str): - time_format = '%H:%M:%S.%f' if len(value) >= 11 else '%M:%S.%f' - return datetime.strptime(value, time_format).time() - elif isinstance(value, time): - return value - - raise TypeError(f'The type {type(value)} is not supported') diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 1b5110f..499b6df 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -3,12 +3,11 @@ import os import typing import warnings -from datetime import time from . import vtt, utils from . import srt from . import sbv -from .models import Caption, Style +from .models import Caption, Style, Timestamp from .errors import MissingFilenameError DEFAULT_ENCODING = 'utf-8' @@ -318,18 +317,18 @@ def write( def iter_slice( self, - start_time: typing.Optional[typing.Union[str, time]] = None, - end_time: typing.Optional[typing.Union[str, time]] = None + start: typing.Optional[str] = None, + end: typing.Optional[str] = None ) -> typing.Generator[Caption, None, None]: """ Iterate a slice of the captions based on a time range. - :param start_time: start time of the range - :param end_time: end time of the range + :param start: start timestamp of the range + :param end: end timestamp of the range :returns: generator of Captions """ - start_time = utils.parse_timestamp(start_time) if start_time else None - end_time = utils.parse_timestamp(end_time) if end_time else None + start_time = Timestamp.from_string(start) if start else None + end_time = Timestamp.from_string(end) if end else None for caption in self.captions: if ( From 6f67240fe016ec8e742ff84c236cb48cbf2714ff Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 10:57:51 +0200 Subject: [PATCH 09/16] Update list of supported Python versions --- .travis.yml | 6 +++--- setup.cfg | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9854868..16002a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,11 @@ language: python python: + - 3.12 + - 3.11 + - 3.10 - 3.9 - 3.8 - 3.7 - - 3.6 - - 3.5 - - 3.4 install: pip install -U tox-travis diff --git a/setup.cfg b/setup.cfg index 4833f7c..4d79365 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,7 +6,11 @@ doctests = true radon-max-cc=10 [tox:tox] -envlist = codestyle,py,coverage +envlist = + codestyle + py{312, 311, 310, 39, 38, 37} + coverage +isolated_build = True [coverage:run] source = webvtt From 076b8cbe390c0d5399cf1fe4c0e00e2f6a8a7eea Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 11:54:05 +0200 Subject: [PATCH 10/16] Add github workflow --- .github/workflows/ci.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4efceb6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,28 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout Repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 + + - name: Install tox + run: pip install tox + + - name: Run tox + run: tox From 19b5f7dcb7f63138ff23b77acb4583b9ec300e35 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 11:59:50 +0200 Subject: [PATCH 11/16] Update list of python versions in workflow --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4efceb6..4c15ceb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.12 + python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] - name: Install tox run: pip install tox From 94c962acb0cfa085e0251da5cae2eba1b622f720 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 12:02:25 +0200 Subject: [PATCH 12/16] Fix --- .github/workflows/ci.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c15ceb..68fb6b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,14 +12,18 @@ jobs: build: runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] + steps: - name: Checkout Repository uses: actions/checkout@v2 - - name: Set up Python + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: - python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] + python-version: ${{ matrix.python-version }} - name: Install tox run: pip install tox From e89615c1beb5ebda90c5e6dbfbe2e0c930e68810 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 12:08:58 +0200 Subject: [PATCH 13/16] Test --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68fb6b2..4a05610 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] + python-version: [3.7, 3.8, 3.9, 3.11, 3.12] steps: - name: Checkout Repository From 030de5cecae496db7311b291cc59e90db52313fe Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 12:12:31 +0200 Subject: [PATCH 14/16] Remove python versions from tox to be handled by github workflow --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 4d79365..12ccc6a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ radon-max-cc=10 [tox:tox] envlist = codestyle - py{312, 311, 310, 39, 38, 37} + py coverage isolated_build = True From 5776eb0d5a377af5eed4aefc481e5729913aa892 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 12:16:02 +0200 Subject: [PATCH 15/16] Add python 3.10 --- .github/workflows/ci.yml | 2 +- .travis.yml | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 .travis.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a05610..68fb6b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.11, 3.12] + python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] steps: - name: Checkout Repository diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 16002a5..0000000 --- a/.travis.yml +++ /dev/null @@ -1,12 +0,0 @@ -language: python -python: - - 3.12 - - 3.11 - - 3.10 - - 3.9 - - 3.8 - - 3.7 - -install: pip install -U tox-travis - -script: tox From 8505d06c4d4ab709e8282db7613514532eedf685 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Tue, 14 May 2024 12:18:46 +0200 Subject: [PATCH 16/16] Fix --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68fb6b2..54e5841 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10, 3.11, 3.12] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] steps: - name: Checkout Repository