diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..54e5841 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - name: Checkout Repository + uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install tox + run: pip install tox + + - name: Run tox + run: tox diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 9854868..0000000 --- a/.travis.yml +++ /dev/null @@ -1,12 +0,0 @@ -language: python -python: - - 3.9 - - 3.8 - - 3.7 - - 3.6 - - 3.5 - - 3.4 - -install: pip install -U tox-travis - -script: tox diff --git a/LICENSE b/LICENSE index 3a702d9..16ad61a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,5 @@ +MIT License + Copyright (c) 2016 Alejandro Mendez Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/setup.cfg b/setup.cfg index 5aef279..12ccc6a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,43 @@ [metadata] -description-file = README.rst +description_file = README.rst + +[flake8] +doctests = true +radon-max-cc=10 + +[tox:tox] +envlist = + codestyle + py + coverage +isolated_build = True + +[coverage:run] +source = webvtt +branch = true + +[coverage:report] +fail_under = 100 + +[testenv] +deps = + coverage +description = run the tests and provide coverage metrics +commands = + coverage run -m unittest discover + +[testenv:codestyle] +deps = + flake8 + flake8-docstrings + radon + mypy +commands = + flake8 webvtt setup.py + mypy webvtt + +[testenv:coverage] +deps = coverage +commands = + coverage html --fail-under=0 + coverage report diff --git a/setup.py b/setup.py index 92384b3..6bae26d 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,40 @@ -import io -import re -from setuptools import setup, find_packages +"""webvtt-py setuptools configuration.""" -with io.open('README.rst', 'r', encoding='utf-8') as f: - readme = f.read() +import re +import pathlib -with io.open('webvtt/__init__.py', 'rt', encoding='utf-8') as f: - version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) +from setuptools import setup, find_packages +version = ( + re.search( + r'__version__ = \'(.*?)\'', + pathlib.Path('webvtt/__init__.py').read_text() + ).group(1) + ) setup( name='webvtt-py', version=version, description='WebVTT reader, writer and segmenter', - long_description=readme, + long_description=pathlib.Path('README.rst').read_text(), author='Alejandro Mendez', author_email='amendez23@gmail.com', url='https://github.com/glut23/webvtt-py', packages=find_packages('.', exclude=['tests']), include_package_data=True, - install_requires=[ - 'docopt' - ], entry_points={ 'console_scripts': [ 'webvtt=webvtt.cli:main' ] }, license='MIT', - python_requires='>=3.4', + python_requires='>=3.7', classifiers=[ 'Development Status :: 3 - Alpha', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/tests/generic.py b/tests/generic.py deleted file mode 100644 index 0f84ae3..0000000 --- a/tests/generic.py +++ /dev/null @@ -1,10 +0,0 @@ -import os -import unittest - - -class GenericParserTestCase(unittest.TestCase): - - SUBTITLES_DIR = os.path.join(os.path.dirname(__file__), 'subtitles') - - def _get_file(self, filename): - return os.path.join(self.SUBTITLES_DIR, filename) diff --git a/tests/subtitles/captions_with_bom.vtt b/tests/samples/captions_with_bom.vtt similarity index 100% rename from tests/subtitles/captions_with_bom.vtt rename to tests/samples/captions_with_bom.vtt diff --git a/tests/subtitles/comments.vtt b/tests/samples/comments.vtt similarity index 92% rename from tests/subtitles/comments.vtt rename to tests/samples/comments.vtt index 3847e02..8cd3032 100644 --- a/tests/subtitles/comments.vtt +++ b/tests/samples/comments.vtt @@ -18,4 +18,6 @@ NOTE This last line may not translate well. 3 00:02:25.000 --> 00:02:30.000 -- Ta en kopp \ No newline at end of file +- Ta en kopp + +NOTE end of file \ No newline at end of file diff --git a/tests/subtitles/cue_tags.vtt b/tests/samples/cue_tags.vtt similarity index 100% rename from tests/subtitles/cue_tags.vtt rename to tests/samples/cue_tags.vtt diff --git a/tests/subtitles/empty.vtt b/tests/samples/empty.vtt similarity index 100% rename from tests/subtitles/empty.vtt rename to tests/samples/empty.vtt diff --git a/tests/subtitles/invalid.vtt b/tests/samples/invalid.vtt similarity index 100% rename from tests/subtitles/invalid.vtt rename to tests/samples/invalid.vtt diff --git a/tests/subtitles/invalid_format.sbv b/tests/samples/invalid_format.sbv similarity index 100% rename from tests/subtitles/invalid_format.sbv rename to tests/samples/invalid_format.sbv diff --git a/tests/subtitles/invalid_format1.srt b/tests/samples/invalid_format1.srt similarity index 100% rename from tests/subtitles/invalid_format1.srt rename to tests/samples/invalid_format1.srt diff --git a/tests/subtitles/invalid_format2.srt b/tests/samples/invalid_format2.srt similarity index 100% rename from tests/subtitles/invalid_format2.srt rename to tests/samples/invalid_format2.srt diff --git a/tests/subtitles/invalid_format3.srt b/tests/samples/invalid_format3.srt similarity index 100% rename from tests/subtitles/invalid_format3.srt rename to tests/samples/invalid_format3.srt diff --git a/tests/subtitles/invalid_format4.srt b/tests/samples/invalid_format4.srt similarity index 100% rename from tests/subtitles/invalid_format4.srt rename to tests/samples/invalid_format4.srt diff --git a/tests/subtitles/invalid_timeframe.sbv b/tests/samples/invalid_timeframe.sbv similarity index 100% rename from tests/subtitles/invalid_timeframe.sbv rename to tests/samples/invalid_timeframe.sbv diff --git a/tests/subtitles/invalid_timeframe.srt b/tests/samples/invalid_timeframe.srt similarity index 100% rename from tests/subtitles/invalid_timeframe.srt rename to tests/samples/invalid_timeframe.srt diff --git a/tests/subtitles/invalid_timeframe.vtt b/tests/samples/invalid_timeframe.vtt similarity index 100% rename from tests/subtitles/invalid_timeframe.vtt rename to tests/samples/invalid_timeframe.vtt diff --git a/tests/subtitles/invalid_timeframe_in_cue_text.vtt b/tests/samples/invalid_timeframe_in_cue_text.vtt similarity index 100% rename from tests/subtitles/invalid_timeframe_in_cue_text.vtt rename to tests/samples/invalid_timeframe_in_cue_text.vtt diff --git a/tests/subtitles/metadata_headers.vtt b/tests/samples/metadata_headers.vtt similarity index 100% rename from tests/subtitles/metadata_headers.vtt rename to tests/samples/metadata_headers.vtt diff --git a/tests/subtitles/metadata_headers_multiline.vtt b/tests/samples/metadata_headers_multiline.vtt similarity index 100% rename from tests/subtitles/metadata_headers_multiline.vtt rename to tests/samples/metadata_headers_multiline.vtt diff --git a/tests/subtitles/missing_caption_text.sbv b/tests/samples/missing_caption_text.sbv similarity index 100% rename from tests/subtitles/missing_caption_text.sbv rename to tests/samples/missing_caption_text.sbv diff --git a/tests/subtitles/missing_caption_text.srt b/tests/samples/missing_caption_text.srt similarity index 100% rename from tests/subtitles/missing_caption_text.srt rename to tests/samples/missing_caption_text.srt diff --git a/tests/subtitles/missing_caption_text.vtt b/tests/samples/missing_caption_text.vtt similarity index 100% rename from tests/subtitles/missing_caption_text.vtt rename to tests/samples/missing_caption_text.vtt diff --git a/tests/subtitles/missing_timeframe.sbv b/tests/samples/missing_timeframe.sbv similarity index 100% rename from tests/subtitles/missing_timeframe.sbv rename to tests/samples/missing_timeframe.sbv diff --git a/tests/subtitles/missing_timeframe.srt b/tests/samples/missing_timeframe.srt similarity index 100% rename from tests/subtitles/missing_timeframe.srt rename to tests/samples/missing_timeframe.srt diff --git a/tests/subtitles/missing_timeframe.vtt b/tests/samples/missing_timeframe.vtt similarity index 100% rename from tests/subtitles/missing_timeframe.vtt rename to tests/samples/missing_timeframe.vtt diff --git a/tests/subtitles/netflix_chicas_del_cable.vtt b/tests/samples/netflix_chicas_del_cable.vtt similarity index 100% rename from tests/subtitles/netflix_chicas_del_cable.vtt rename to tests/samples/netflix_chicas_del_cable.vtt diff --git a/tests/subtitles/no_captions.vtt b/tests/samples/no_captions.vtt similarity index 100% rename from tests/subtitles/no_captions.vtt rename to tests/samples/no_captions.vtt diff --git a/tests/subtitles/one_caption.srt b/tests/samples/one_caption.srt similarity index 100% rename from tests/subtitles/one_caption.srt rename to tests/samples/one_caption.srt diff --git a/tests/subtitles/one_caption.vtt b/tests/samples/one_caption.vtt similarity index 100% rename from tests/subtitles/one_caption.vtt rename to tests/samples/one_caption.vtt diff --git a/tests/subtitles/sample.sbv b/tests/samples/sample.sbv similarity index 100% rename from tests/subtitles/sample.sbv rename to tests/samples/sample.sbv diff --git a/tests/subtitles/sample.srt b/tests/samples/sample.srt similarity index 100% rename from tests/subtitles/sample.srt rename to tests/samples/sample.srt diff --git a/tests/subtitles/sample.vtt b/tests/samples/sample.vtt similarity index 100% rename from tests/subtitles/sample.vtt rename to tests/samples/sample.vtt diff --git a/tests/subtitles/styles.vtt b/tests/samples/styles.vtt similarity index 100% rename from tests/subtitles/styles.vtt rename to tests/samples/styles.vtt diff --git a/tests/samples/styles_with_comments.vtt b/tests/samples/styles_with_comments.vtt new file mode 100644 index 0000000..4ae93ac --- /dev/null +++ b/tests/samples/styles_with_comments.vtt @@ -0,0 +1,23 @@ +WEBVTT + +NOTE Sample of comments with styles + +STYLE +::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; +} + +NOTE This is the second block of styles + +NOTE +Multiline comment for the same +second block of styles + +STYLE +::cue(b) { + color: peachpuff; +} + +00:00:00.000 --> 00:00:10.000 +- Hello world. \ No newline at end of file diff --git a/tests/subtitles/two_captions.sbv b/tests/samples/two_captions.sbv similarity index 100% rename from tests/subtitles/two_captions.sbv rename to tests/samples/two_captions.sbv diff --git a/tests/subtitles/using_identifiers.vtt b/tests/samples/using_identifiers.vtt similarity index 100% rename from tests/subtitles/using_identifiers.vtt rename to tests/samples/using_identifiers.vtt diff --git a/tests/subtitles/youtube_dl.vtt b/tests/samples/youtube_dl.vtt similarity index 100% rename from tests/subtitles/youtube_dl.vtt rename to tests/samples/youtube_dl.vtt diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..ca42cf7 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,182 @@ +import unittest +import tempfile +import os +import pathlib +import textwrap + +from webvtt.cli import main + + +class CLITestCase(unittest.TestCase): + + def test_cli(self): + vtt_file = ( + pathlib.Path(__file__).resolve().parent + / 'samples' / 'sample.vtt' + ) + + with tempfile.TemporaryDirectory() as temp_dir: + + main(['segment', str(vtt_file.resolve()), '-o', temp_dir]) + _, dirs, files = next(os.walk(temp_dir)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 8) + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + 'fileSequence3.webvtt', + 'fileSequence4.webvtt', + 'fileSequence5.webvtt', + 'fileSequence6.webvtt', + ): + self.assertIn(expected_file, files) + + self.assertEqual( + (pathlib.Path(temp_dir) / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXTINF:30.00000 + fileSequence3.webvtt + #EXTINF:30.00000 + fileSequence4.webvtt + #EXTINF:30.00000 + fileSequence5.webvtt + #EXTINF:30.00000 + fileSequence6.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence3.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence4.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence5.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip()) + self.assertEqual( + (pathlib.Path(temp_dir) / 'fileSequence6.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip()) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..2210e57 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,440 @@ +import unittest + +from webvtt.models import Timestamp, Caption, Style +from webvtt.errors import MalformedCaptionError + + +class TestTimestamp(unittest.TestCase): + + def test_instantiation(self): + timestamp = Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 12) + self.assertEqual(timestamp.seconds, 23) + self.assertEqual(timestamp.milliseconds, 500) + + def test_from_string(self): + timestamp = Timestamp.from_string('01:24:11.670') + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 24) + self.assertEqual(timestamp.seconds, 11) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_single_digits(self): + timestamp = Timestamp.from_string('1:2:7.670') + self.assertEqual(timestamp.hours, 1) + self.assertEqual(timestamp.minutes, 2) + self.assertEqual(timestamp.seconds, 7) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_missing_hours(self): + timestamp = Timestamp.from_string('24:11.670') + self.assertEqual(timestamp.hours, 0) + self.assertEqual(timestamp.minutes, 24) + self.assertEqual(timestamp.seconds, 11) + self.assertEqual(timestamp.milliseconds, 670) + + def test_from_string_wrong_minutes(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:76:11.670') + + def test_from_string_wrong_seconds(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:24:87.670') + + def test_from_string_wrong_type(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string(1234) + + def test_from_string_wrong_value(self): + with self.assertRaises(MalformedCaptionError): + Timestamp.from_string('01:24:11:670') + + def test_to_tuple(self): + self.assertEqual( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ).to_tuple(), + (1, 12, 23, 500) + ) + + def test_equality(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) == + Timestamp.from_string('01:12:23.500') + ) + + def test_not_equality(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) != + Timestamp.from_string('01:12:23.600') + ) + + def test_greater_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=600 + ) > + Timestamp.from_string('01:12:23.500') + ) + + def test_less_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) < + Timestamp.from_string('01:12:23.600') + ) + + def test_greater_or_equal_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) >= + Timestamp.from_string('01:12:23.500') + ) + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=600 + ) >= + Timestamp.from_string('01:12:23.500') + ) + + def test_less_or_equal_than(self): + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) <= + Timestamp.from_string('01:12:23.500') + ) + self.assertTrue( + Timestamp( + hours=1, minutes=12, seconds=23, milliseconds=500 + ) <= + Timestamp.from_string('01:12:23.600') + ) + + def test_repr(self): + timestamp = Timestamp( + hours=1, + minutes=12, + seconds=45, + milliseconds=320 + ) + + self.assertEqual( + repr(timestamp), + '' + ) + + +class TestCaption(unittest.TestCase): + + def test_instantiation(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + self.assertEqual(caption.start, '00:00:07.000') + self.assertEqual(caption.end, '00:00:11.890') + self.assertEqual(caption.text, 'Hello test!') + self.assertEqual(caption.identifier, 'A test caption') + + def test_timestamp_wrong_type(self): + with self.assertRaises(MalformedCaptionError): + Caption( + start=1234, + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + def test_identifier_is_optional(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + ) + self.assertIsNone(caption.identifier) + + def test_multi_lines(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!\nThis is the second line', + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertListEqual( + caption.lines, + ['Hello test!', 'This is the second line'] + ) + + def test_multi_lines_accepts_list(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=['Hello test!', 'This is the second line'], + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertListEqual( + caption.lines, + ['Hello test!', 'This is the second line'] + ) + + def test_cuetags(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=[ + 'Hello test!', + 'This is the second line' + ], + identifier='A test caption' + ) + self.assertEqual(caption.text, 'Hello test!\nThis is the second line') + self.assertEqual( + caption.raw_text, + 'Hello test!\n' + 'This is the second line' + ) + + def test_in_seconds(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text=['Hello test!', 'This is the second line'], + identifier='A test caption' + ) + self.assertEqual(caption.start_in_seconds, 7) + self.assertEqual(caption.end_in_seconds, 11) + + def test_wrong_start_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start='1234', + end='00:00:11.890', + text='Hello Test!' + ) + + def test_wrong_type_start_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start=1234, + end='00:00:11.890', + text='Hello Test!' + ) + + def test_wrong_end_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start='00:00:07.000', + end='1234', + text='Hello Test!' + ) + + def test_wrong_type_end_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + start='00:00:07.000', + end=1234, + text='Hello Test!' + ) + + def test_equality(self): + caption1 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + caption2 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertTrue(caption1 == caption2) + + caption1 = Caption( + start='00:00:02.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + caption2 = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertFalse(caption1 == caption2) + + self.assertFalse( + Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) == 1234 + ) + + def test_repr(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertEqual( + repr(caption), + "" + ) + + def test_str(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + + self.assertEqual( + str(caption), + '00:00:07.000 00:00:11.890 Hello test!' + ) + + def test_accept_comments(self): + caption = Caption( + start='00:00:07.000', + end='00:00:11.890', + text='Hello test!', + identifier='A test caption' + ) + caption.comments.append('One comment') + caption.comments.append('Another comment') + + self.assertListEqual( + caption.comments, + ['One comment', 'Another comment'] + ) + + def test_timestamp_update(self): + c = Caption('00:00:00.500', '00:00:07.000') + c.start = '00:00:01.750' + c.end = '00:00:08.250' + + self.assertEqual(c.start, '00:00:01.750') + self.assertEqual(c.end, '00:00:08.250') + + def test_timestamp_format(self): + c = Caption('01:02:03.400', '02:03:04.500') + self.assertEqual(c.start, '01:02:03.400') + self.assertEqual(c.end, '02:03:04.500') + + c = Caption('02:03.400', '03:04.500') + self.assertEqual(c.start, '00:02:03.400') + self.assertEqual(c.end, '00:03:04.500') + + def test_update_text(self): + c = Caption(text='Caption line #1') + c.text = 'Caption line #1 updated' + self.assertEqual( + c.text, + 'Caption line #1 updated' + ) + + def test_update_text_multiline(self): + c = Caption(text='Caption line #1') + c.text = 'Caption line #1\nCaption line #2' + + self.assertEqual( + len(c.lines), + 2 + ) + + self.assertEqual( + c.text, + 'Caption line #1\nCaption line #2' + ) + + def test_update_text_wrong_type(self): + c = Caption(text='Caption line #1') + + self.assertRaises( + AttributeError, + setattr, + c, + 'text', + 123 + ) + + def test_manipulate_lines(self): + c = Caption(text=['Caption line #1', 'Caption line #2']) + c.lines[0] = 'Caption line #1 updated' + self.assertEqual( + c.lines[0], + 'Caption line #1 updated' + ) + + def test_malformed_start_timestamp(self): + self.assertRaises( + MalformedCaptionError, + Caption, + '01:00' + ) + + +class TestStyle(unittest.TestCase): + + def test_instantiation(self): + style = Style(text='::cue(b) {\ncolor: peachpuff;\n}') + self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}') + self.assertListEqual( + style.lines, + ['::cue(b) {', 'color: peachpuff;', '}'] + ) + + def test_text_accept_list_of_strings(self): + style = Style(text=['::cue(b) {', 'color: peachpuff;', '}']) + self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}') + self.assertListEqual( + style.lines, + ['::cue(b) {', 'color: peachpuff;', '}'] + ) + + def test_accept_comments(self): + style = Style(text='::cue(b) {\ncolor: peachpuff;\n}') + style.comments.append('One comment') + style.comments.append('Another comment') + + self.assertListEqual( + style.comments, + ['One comment', 'Another comment'] + ) + + def test_get_text(self): + style = Style(['::cue(b) {', ' color: peachpuff;', '}']) + self.assertEqual( + style.text, + '::cue(b) {\n color: peachpuff;\n}' + ) diff --git a/tests/test_sbv.py b/tests/test_sbv.py new file mode 100644 index 0000000..d78a9c3 --- /dev/null +++ b/tests/test_sbv.py @@ -0,0 +1,141 @@ +import unittest +import textwrap + +from webvtt import sbv +from webvtt.errors import MalformedFileError +from webvtt.models import Caption + + +class TestSBVCueBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 0:0:0.500,0:0:7.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 1 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + Caption #1 + ''').strip().split('\n')) + ) + + self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + ''').strip().split('\n')) + ) + + def test_from_lines(self): + cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual( + cue_block.start, + '00:00:00.500' + ) + self.assertEqual( + cue_block.end, + '00:00:07.000' + ) + self.assertEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) + + def test_from_lines_shorter_timestamps(self): + cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent(''' + 0:1:2.500,0:1:03.800 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual( + cue_block.start, + '0:1:2.500' + ) + self.assertEqual( + cue_block.end, + '0:1:03.800' + ) + + +class TestSBVModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + sbv.parse, + textwrap.dedent(''' + 1 + 00:00:00.500,00:00:07.000 + Caption text #1 + + 00:00:07.000,00:00:11.890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + captions = sbv.parse( + textwrap.dedent(''' + 00:00:00.500,00:00:07.000 + Caption #1 + + 00:00:07.000,00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption #2 line 1\n' + 'Caption #2 line 2' + ) diff --git a/tests/test_sbv_parser.py b/tests/test_sbv_parser.py deleted file mode 100644 index 0a6f899..0000000 --- a/tests/test_sbv_parser.py +++ /dev/null @@ -1,74 +0,0 @@ -import webvtt - -from .generic import GenericParserTestCase - - -class SBVParserTestCase(GenericParserTestCase): - - def test_sbv_parse_empty_file(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_sbv, - self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose. - ) - - def test_sbv_invalid_format(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_sbv, - self._get_file('invalid_format.sbv') - ) - - def test_sbv_total_length(self): - self.assertEqual( - webvtt.from_sbv(self._get_file('sample.sbv')).total_length, - 16 - ) - - def test_sbv_parse_captions(self): - self.assertEqual( - len(webvtt.from_srt(self._get_file('sample.srt')).captions), - 5 - ) - - def test_sbv_missing_timeframe_line(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_sbv, - self._get_file('missing_timeframe.sbv') - ) - - def test_sbv_missing_caption_text(self): - self.assertTrue(webvtt.from_sbv(self._get_file('missing_caption_text.sbv')).captions) - - def test_sbv_invalid_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_sbv, - self._get_file('invalid_timeframe.sbv') - ) - - def test_sbv_timestamps_format(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].start, '00:00:11.378') - self.assertEqual(vtt.captions[1].end, '00:00:12.305') - - def test_sbv_timestamps_in_seconds(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].start_in_seconds, 11.378) - self.assertEqual(vtt.captions[1].end_in_seconds, 12.305) - - def test_sbv_get_caption_text(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual(vtt.captions[1].text, 'Caption text #2') - - def test_sbv_get_caption_text_multiline(self): - vtt = webvtt.from_sbv(self._get_file('sample.sbv')) - self.assertEqual( - vtt.captions[2].text, - 'Caption text #3 (line 1)\nCaption text #3 (line 2)' - ) - self.assertListEqual( - vtt.captions[2].lines, - ['Caption text #3 (line 1)', 'Caption text #3 (line 2)'] - ) diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py index d82a520..f6ed5b4 100644 --- a/tests/test_segmenter.py +++ b/tests/test_segmenter.py @@ -1,180 +1,342 @@ import os import unittest -from shutil import rmtree +import tempfile +import pathlib +import textwrap -from webvtt import WebVTTSegmenter, Caption -from webvtt.errors import InvalidCaptionsError -from webvtt import WebVTT +from webvtt import segmenter -BASE_DIR = os.path.dirname(__file__) -SUBTITLES_DIR = os.path.join(BASE_DIR, 'subtitles') -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') +PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples' -class WebVTTSegmenterTestCase(unittest.TestCase): +class TestSegmenter(unittest.TestCase): def setUp(self): - self.segmenter = WebVTTSegmenter() + self.temp_dir = tempfile.TemporaryDirectory() def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) - - def _parse_captions(self, filename): - self.webvtt = WebVTT().read(os.path.join(SUBTITLES_DIR, filename)) - - def test_invalid_captions(self): - self.assertRaises( - FileNotFoundError, - self.segmenter.segment, - 'text' - ) - - self.assertRaises( - InvalidCaptionsError, - self.segmenter.segment, - 10 - ) - - def test_single_invalid_caption(self): - self.assertRaises( - InvalidCaptionsError, - self.segmenter.segment, - [Caption(), Caption(), 'text', Caption()] - ) - - def test_total_segments(self): - # segment with default 10 seconds - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) - self.assertEqual(self.segmenter.total_segments, 7) - - # segment with custom 30 seconds - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 30) - self.assertEqual(self.segmenter.total_segments, 3) - - def test_output_folder_is_created(self): - self.assertFalse(os.path.exists(OUTPUT_DIR)) - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) - self.assertTrue(os.path.exists(OUTPUT_DIR)) - - def test_segmentation_files_exist(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) - for i in range(7): - self.assertTrue( - os.path.exists(os.path.join(OUTPUT_DIR, 'fileSequence{}.webvtt'.format(i))) + self.temp_dir.cleanup() + + def test_segmentation_with_defaults(self): + segmenter.segment(PATH_TO_SAMPLES / 'sample.vtt', self.temp_dir.name) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 8) + + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + 'fileSequence3.webvtt', + 'fileSequence4.webvtt', + 'fileSequence5.webvtt', + 'fileSequence6.webvtt', + ): + self.assertIn(expected_file, files) + + output_path = pathlib.Path(self.temp_dir.name) + + self.assertEqual( + (output_path / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXTINF:30.00000 + fileSequence3.webvtt + #EXTINF:30.00000 + fileSequence4.webvtt + #EXTINF:30.00000 + fileSequence5.webvtt + #EXTINF:30.00000 + fileSequence6.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence3.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence4.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence5.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence6.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip() + ) + + def test_segmentation_with_custom_values(self): + segmenter.segment( + webvtt_path=PATH_TO_SAMPLES / 'sample.vtt', + output=self.temp_dir.name, + seconds=30, + mpegts=800000 + ) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 4) + + for expected_file in ('prog_index.m3u8', + 'fileSequence0.webvtt', + 'fileSequence1.webvtt', + 'fileSequence2.webvtt', + ): + self.assertIn(expected_file, files) + + output_path = pathlib.Path(self.temp_dir.name) + + self.assertEqual( + (output_path / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:30 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXTINF:30.00000 + fileSequence0.webvtt + #EXTINF:30.00000 + fileSequence1.webvtt + #EXTINF:30.00000 + fileSequence2.webvtt + #EXT-X-ENDLIST + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence0.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + 00:00:21.580 --> 00:00:23.880 + Caption text #5 + + 00:00:23.880 --> 00:00:27.280 + Caption text #6 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence1.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:27.280 --> 00:00:30.280 + Caption text #7 + + 00:00:30.280 --> 00:00:36.510 + Caption text #8 + + 00:00:36.510 --> 00:00:38.870 + Caption text #9 + + 00:00:38.870 --> 00:00:45.000 + Caption text #10 + + 00:00:45.000 --> 00:00:47.000 + Caption text #11 + + 00:00:47.000 --> 00:00:50.970 + Caption text #12 + + 00:00:50.970 --> 00:00:54.440 + Caption text #13 + + 00:00:54.440 --> 00:00:58.600 + Caption text #14 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + ''' + ).lstrip() + ) + self.assertEqual( + (output_path / 'fileSequence2.webvtt').read_text(), + textwrap.dedent( + ''' + WEBVTT + X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000 + + 00:00:58.600 --> 00:01:01.350 + Caption text #15 + + 00:01:01.350 --> 00:01:04.300 + Caption text #16 + ''' + ).lstrip() ) - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'prog_index.m3u8'))) - - def test_segmentation(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR) - - # segment 1 should have caption 1 and 2 - self.assertEqual(len(self.segmenter.segments[0]), 2) - self.assertIn(self.webvtt.captions[0], self.segmenter.segments[0]) - self.assertIn(self.webvtt.captions[1], self.segmenter.segments[0]) - # segment 2 should have caption 2 again (overlap), 3 and 4 - self.assertEqual(len(self.segmenter.segments[1]), 3) - self.assertIn(self.webvtt.captions[2], self.segmenter.segments[1]) - self.assertIn(self.webvtt.captions[3], self.segmenter.segments[1]) - # segment 3 should have caption 4 again (overlap), 5, 6 and 7 - self.assertEqual(len(self.segmenter.segments[2]), 4) - self.assertIn(self.webvtt.captions[3], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[4], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[5], self.segmenter.segments[2]) - self.assertIn(self.webvtt.captions[6], self.segmenter.segments[2]) - # segment 4 should have caption 7 again (overlap), 8, 9 and 10 - self.assertEqual(len(self.segmenter.segments[3]), 4) - self.assertIn(self.webvtt.captions[6], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[7], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[8], self.segmenter.segments[3]) - self.assertIn(self.webvtt.captions[9], self.segmenter.segments[3]) - # segment 5 should have caption 10 again (overlap), 11 and 12 - self.assertEqual(len(self.segmenter.segments[4]), 3) - self.assertIn(self.webvtt.captions[9], self.segmenter.segments[4]) - self.assertIn(self.webvtt.captions[10], self.segmenter.segments[4]) - self.assertIn(self.webvtt.captions[11], self.segmenter.segments[4]) - # segment 6 should have caption 12 again (overlap), 13, 14 and 15 - self.assertEqual(len(self.segmenter.segments[5]), 4) - self.assertIn(self.webvtt.captions[11], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[12], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[13], self.segmenter.segments[5]) - self.assertIn(self.webvtt.captions[14], self.segmenter.segments[5]) - # segment 7 should have caption 15 again (overlap) and 16 - self.assertEqual(len(self.segmenter.segments[6]), 2) - self.assertIn(self.webvtt.captions[14], self.segmenter.segments[6]) - self.assertIn(self.webvtt.captions[15], self.segmenter.segments[6]) - - def test_segment_content(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10) - - with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - 'X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2' - ] - - self.assertListEqual(lines, expected_lines) - - def test_manifest_content(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10) - - with open(os.path.join(OUTPUT_DIR, 'prog_index.m3u8'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - '#EXTM3U', - '#EXT-X-TARGETDURATION:{}'.format(self.segmenter.seconds), - '#EXT-X-VERSION:3', - '#EXT-X-PLAYLIST-TYPE:VOD', - ] - - for i in range(7): - expected_lines.extend([ - '#EXTINF:30.00000', - 'fileSequence{}.webvtt'.format(i) - ]) - - expected_lines.append('#EXT-X-ENDLIST') - - for index, line in enumerate(expected_lines): - self.assertEqual(lines[index], line) - - def test_customize_mpegts(self): - self._parse_captions('sample.vtt') - self.segmenter.segment(self.webvtt, OUTPUT_DIR, mpegts=800000) - - with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f: - lines = f.readlines() - self.assertIn('MPEGTS:800000', lines[1]) - - def test_segment_from_file(self): - self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'sample.vtt'), OUTPUT_DIR), - self.assertEqual(self.segmenter.total_segments, 7) def test_segment_with_no_captions(self): - self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'no_captions.vtt'), OUTPUT_DIR), - self.assertEqual(self.segmenter.total_segments, 0) - - def test_total_segments_readonly(self): - self.assertRaises( - AttributeError, - setattr, - WebVTTSegmenter(), - 'total_segments', - 5 - ) + segmenter.segment( + webvtt_path=PATH_TO_SAMPLES / 'no_captions.vtt', + output=self.temp_dir.name + ) + + _, dirs, files = next(os.walk(self.temp_dir.name)) + + self.assertEqual(len(dirs), 0) + self.assertEqual(len(files), 1) + self.assertIn('prog_index.m3u8', files) + + self.assertEqual( + (pathlib.Path(self.temp_dir.name) / 'prog_index.m3u8').read_text(), + textwrap.dedent( + ''' + #EXTM3U + #EXT-X-TARGETDURATION:10 + #EXT-X-VERSION:3 + #EXT-X-PLAYLIST-TYPE:VOD + #EXT-X-ENDLIST + ''' + ).lstrip() + ) diff --git a/tests/test_srt.py b/tests/test_srt.py index eed186d..e6a2dd4 100644 --- a/tests/test_srt.py +++ b/tests/test_srt.py @@ -1,35 +1,160 @@ -import os import unittest -from shutil import rmtree, copy +import io +import textwrap -import webvtt +from webvtt import srt +from webvtt.errors import MalformedFileError +from webvtt.models import Caption -from .generic import GenericParserTestCase +class TestSRTCueBlock(unittest.TestCase): -BASE_DIR = os.path.dirname(__file__) -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') + def test_is_valid(self): + self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + ''').strip().split('\n')) + ) + self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) -class SRTCaptionsTestCase(GenericParserTestCase): + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + Caption #1 + ''').strip().split('\n')) + ) - def setUp(self): - os.makedirs(OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) - def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + ''').strip().split('\n')) + ) - def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): - copy(self._get_file('sample.srt'), OUTPUT_DIR) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 1 + Caption #1 + ''').strip().split('\n')) + ) - vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'sample.srt')) - vtt.save_as_srt(os.path.join(OUTPUT_DIR, 'sample_converted.srt')) + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + Caption #1 + ''').strip().split('\n')) + ) - with open(os.path.join(OUTPUT_DIR, 'sample.srt'), 'r', encoding='utf-8') as f: - original = f.read() + self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + ''').strip().split('\n')) + ) - with open(os.path.join(OUTPUT_DIR, 'sample_converted.srt'), 'r', encoding='utf-8') as f: - converted = f.read() + def test_from_lines(self): + cue_block = srt.SRTCueBlock.from_lines(textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual(cue_block.index, '1') + self.assertEqual( + cue_block.start, + '00:00:00,500' + ) + self.assertEqual( + cue_block.end, + '00:00:07,000' + ) + self.assertEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) - self.assertEqual(original.strip(), converted.strip()) + +class TestSRTModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + srt.parse, + textwrap.dedent(''' + 00:00:00,500 --> 00:00:07,000 + Caption text #1 + + 00:00:07,000 --> 00:00:11,890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + captions = srt.parse( + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip().split('\n') + ) + self.assertEqual(len(captions), 2) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption #2 line 1\n' + 'Caption #2 line 2' + ) + + def test_write(self): + out = io.StringIO() + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + + srt.write(out, captions) + + out.seek(0) + + self.assertEqual( + out.read(), + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + Caption #2 line 1 + Caption #2 line 2 + ''').strip() + ) diff --git a/tests/test_srt_parser.py b/tests/test_srt_parser.py deleted file mode 100644 index 7885ec5..0000000 --- a/tests/test_srt_parser.py +++ /dev/null @@ -1,65 +0,0 @@ -import webvtt - -from .generic import GenericParserTestCase - - -class SRTParserTestCase(GenericParserTestCase): - - def test_srt_parse_empty_file(self): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_srt, - self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose. - ) - - def test_srt_invalid_format(self): - for i in range(1, 5): - self.assertRaises( - webvtt.errors.MalformedFileError, - webvtt.from_srt, - self._get_file('invalid_format{}.srt'.format(i)) - ) - - def test_srt_total_length(self): - self.assertEqual( - webvtt.from_srt(self._get_file('sample.srt')).total_length, - 23 - ) - - def test_srt_parse_captions(self): - self.assertTrue(webvtt.from_srt(self._get_file('sample.srt')).captions) - - def test_srt_missing_timeframe_line(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_srt, - self._get_file('missing_timeframe.srt') - ) - - def test_srt_empty_caption_text(self): - self.assertTrue(webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions) - - def test_srt_empty_gets_removed(self): - captions = webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions - self.assertEqual(len(captions), 4) - - def test_srt_invalid_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - webvtt.from_srt, - self._get_file('invalid_timeframe.srt') - ) - - def test_srt_timestamps_format(self): - vtt = webvtt.from_srt(self._get_file('sample.srt')) - self.assertEqual(vtt.captions[2].start, '00:00:11.890') - self.assertEqual(vtt.captions[2].end, '00:00:16.320') - - def test_srt_parse_get_caption_data(self): - vtt = webvtt.from_srt(self._get_file('one_caption.srt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0.5) - self.assertEqual(vtt.captions[0].start, '00:00:00.500') - self.assertEqual(vtt.captions[0].end_in_seconds, 7) - self.assertEqual(vtt.captions[0].end, '00:00:07.000') - self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') - self.assertEqual(len(vtt.captions[0].lines), 1) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..23fc1a2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +import unittest +import tempfile + +from webvtt.utils import FileWrapper, CODEC_BOMS + + +class TestUtils(unittest.TestCase): + + def test_open_file(self): + with tempfile.NamedTemporaryFile('w') as f: + f.write('Hello test!') + f.flush() + with FileWrapper.open(f.name) as fw: + self.assertEqual(fw.file.read(), 'Hello test!') + self.assertEqual(fw.file.encoding, 'utf-8') + + def test_open_file_with_bom(self): + for encoding, bom in CODEC_BOMS.items(): + with tempfile.NamedTemporaryFile('wb') as f: + f.write(bom) + f.write('Hello test'.encode(encoding)) + f.flush() + + with FileWrapper.open(f.name) as fw: + self.assertEqual(fw.file.read(), 'Hello test') + self.assertEqual(fw.file.encoding, encoding) diff --git a/tests/test_vtt.py b/tests/test_vtt.py new file mode 100644 index 0000000..214b71c --- /dev/null +++ b/tests/test_vtt.py @@ -0,0 +1,556 @@ +import unittest +import textwrap +import io + +from webvtt import vtt +from webvtt.errors import MalformedFileError +from webvtt.models import Caption, Style + + +class TestWebVTTCueBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 00:00:07.000 + Caption #1 line 1 + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCueBlock.is_valid(textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + ''').strip().split('\n')) + ) + + def test_from_lines(self): + cue_block = vtt.WebVTTCueBlock.from_lines(textwrap.dedent(''' + identifier + 00:00:00.500 --> 00:00:07.000 + Caption #1 line 1 + Caption #1 line 2 + ''').strip().split('\n') + ) + self.assertEqual(cue_block.identifier, 'identifier') + self.assertEqual(cue_block.start, '00:00:00.500') + self.assertEqual(cue_block.end, '00:00:07.000') + self.assertListEqual( + cue_block.payload, + ['Caption #1 line 1', 'Caption #1 line 2'] + ) + + +class TestWebVTTCommentBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE This is a one line comment + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE + This is a another one line comment + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + NOTE + This is a multi-line comment + taking two lines + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + This is not a comment + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + # This is not a comment + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTCommentBlock.is_valid(textwrap.dedent(''' + // This is not a comment + ''').strip().split('\n')) + ) + + def test_from_lines(self): + comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent(''' + NOTE This is a one line comment + ''').strip().split('\n') + ) + self.assertEqual(comment.text, 'This is a one line comment') + + comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent(''' + NOTE + This is a multi-line comment + taking two lines + ''').strip().split('\n') + ) + self.assertEqual( + comment.text, + 'This is a multi-line comment\ntaking two lines' + ) + + +class TestWebVTTStyleBlock(unittest.TestCase): + + def test_is_valid(self): + self.assertTrue( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + self.assertTrue( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + STYLE + ::cue--> { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + self.assertFalse( + vtt.WebVTTStyleBlock.is_valid(textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ''').strip().split('\n')) + ) + + def test_from_lines(self): + style = vtt.WebVTTStyleBlock.from_lines(textwrap.dedent(''' + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip().split('\n') + ) + self.assertEqual( + style.text, + textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + ::cue(b) { + color: peachpuff; + } + ''').strip() + ) + + +class TestVTTModule(unittest.TestCase): + + def test_parse_invalid_format(self): + self.assertRaises( + MalformedFileError, + vtt.parse, + textwrap.dedent(''' + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + ''').strip().split('\n') + ) + + def test_parse_captions(self): + output = vtt.parse( + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + ''').strip().split('\n') + ) + captions = output.captions + styles = output.styles + self.assertEqual(len(captions), 2) + self.assertEqual(len(styles), 0) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + + def test_parse_styles(self): + output = vtt.parse( + textwrap.dedent(''' + WEBVTT + + STYLE + ::cue { + color: white; + } + + STYLE + ::cue(.important) { + color: red; + font-weight: bold; + } + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip().split('\n') + ) + captions = output.captions + styles = output.styles + self.assertEqual(len(captions), 1) + self.assertEqual(len(styles), 2) + self.assertIsInstance(styles[0], Style) + self.assertIsInstance(styles[1], Style) + self.assertEqual( + str(styles[0].text), + textwrap.dedent(''' + ::cue { + color: white; + } + ''').strip() + ) + self.assertEqual( + str(styles[1].text), + textwrap.dedent(''' + ::cue(.important) { + color: red; + font-weight: bold; + } + ''').strip(), + ) + + def test_parse_content(self): + output = vtt.parse( + textwrap.dedent(''' + WEBVTT + + NOTE This is a testing sample + + NOTE We can see two header comments, a style + comment and finally a footer comments + + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray); + color: papayawhip; + } + + NOTE the following style needs review + + STYLE + ::cue { + color: white; + } + + NOTE Comment for the first caption + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + NOTE + Comment for the second caption + that is very long + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + + NOTE Copyright 2024 + + NOTE end of file + ''').strip().split('\n') + ) + captions = output.captions + styles = output.styles + self.assertEqual(len(captions), 2) + self.assertEqual(len(styles), 2) + self.assertIsInstance(captions[0], Caption) + self.assertIsInstance(captions[1], Caption) + self.assertIsInstance(styles[0], Style) + self.assertIsInstance(styles[1], Style) + self.assertEqual( + str(captions[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(captions[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + self.assertEqual( + str(styles[0].text), + textwrap.dedent(''' + ::cue { + background-image: linear-gradient(to bottom, dimgray); + color: papayawhip; + } + ''').strip() + ) + self.assertEqual( + str(styles[1].text), + textwrap.dedent(''' + ::cue { + color: white; + } + ''').strip() + ) + self.assertEqual( + styles[0].comments, + [] + ) + self.assertEqual( + styles[1].comments, + ['the following style needs review'] + ) + self.assertEqual( + captions[0].comments, + ['Comment for the first caption'] + ) + self.assertEqual( + captions[1].comments, + ['Comment for the second caption\nthat is very long'] + ) + self.assertListEqual(output.header_comments, + ['This is a testing sample', + 'We can see two header comments, a style\n' + 'comment and finally a footer comments' + ] + ) + self.assertListEqual(output.footer_comments, + ['Copyright 2024', + 'end of file' + ] + ) + + def test_write(self): + out = io.StringIO() + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + styles = [ + Style('::cue(b) {\n color: peachpuff;\n}'), + Style('::cue {\n color: papayawhip;\n}') + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + styles[1].comments.append( + 'Comment for the second style\nwith two lines' + ) + header_comments = ['header comment', 'begin of the file'] + footer_comments = ['footer comment', 'end of file'] + + vtt.write( + out, + captions, + styles, + header_comments, + footer_comments + ) + + out.seek(0) + + self.assertEqual( + out.read(), + textwrap.dedent(''' + WEBVTT + + NOTE header comment + + NOTE begin of the file + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Comment for the second style + with two lines + + STYLE + ::cue { + color: papayawhip; + } + + NOTE Comment for the first caption + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + NOTE Comment for the second caption + + 00:00:07.000 --> 00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + + NOTE footer comment + + NOTE end of file + ''').strip() + ) + + def test_to_str(self): + captions = [ + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1' + ), + Caption(start='00:00:07.000', + end='00:00:11.890', + text=['Caption #2 line 1', + 'Caption #2 line 2' + ] + ) + ] + styles = [ + Style('::cue(b) {\n color: peachpuff;\n}'), + Style('::cue {\n color: papayawhip;\n}') + ] + captions[0].comments.append('Comment for the first caption') + captions[1].comments.append('Comment for the second caption') + styles[1].comments.append( + 'Comment for the second style\nwith two lines' + ) + header_comments = ['header comment', 'begin of the file'] + footer_comments = ['footer comment', 'end of file'] + + self.assertEqual( + vtt.to_str( + captions, + styles, + header_comments, + footer_comments + ), + textwrap.dedent(''' + WEBVTT + + NOTE header comment + + NOTE begin of the file + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Comment for the second style + with two lines + + STYLE + ::cue { + color: papayawhip; + } + + NOTE Comment for the first caption + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + NOTE Comment for the second caption + + 00:00:07.000 --> 00:00:11.890 + Caption #2 line 1 + Caption #2 line 2 + + NOTE footer comment + + NOTE end of file + ''').strip() + ) diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index 9da3238..3111e2f 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -1,242 +1,263 @@ +import unittest import os import io import textwrap -from shutil import rmtree, copy +import warnings +import tempfile +import pathlib import webvtt -from webvtt.structures import Caption, Style -from .generic import GenericParserTestCase +from webvtt.models import Caption, Style +from webvtt.utils import CODEC_BOMS from webvtt.errors import MalformedFileError +PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples' -BASE_DIR = os.path.dirname(__file__) -OUTPUT_DIR = os.path.join(BASE_DIR, 'output') +class TestWebVTT(unittest.TestCase): -class WebVTTTestCase(GenericParserTestCase): + def test_from_string(self): + vtt = webvtt.WebVTT.from_string(textwrap.dedent(""" + WEBVTT - def tearDown(self): - if os.path.exists(OUTPUT_DIR): - rmtree(OUTPUT_DIR) + 00:00:00.500 --> 00:00:07.000 + Caption text #1 - def test_create_caption(self): - caption = Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']) - self.assertEqual(caption.start, '00:00:00.500') - self.assertEqual(caption.start_in_seconds, 0.5) - self.assertEqual(caption.end, '00:00:07.000') - self.assertEqual(caption.end_in_seconds, 7) - self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2']) + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 - def test_write_captions(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + """).strip() + ) + self.assertEqual(len(vtt), 4) + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n' + 'Caption text #2 line 2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + self.assertEqual( + str(vtt[3]), + '00:00:16.320 00:00:21.580 Caption text #4' + ) + def test_write_captions(self): out = io.StringIO() - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) - new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2']) + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', + 'New caption text line2' + ] + ) vtt.captions.append(new_caption) vtt.write(out) out.seek(0) - lines = [line.rstrip() for line in out.readlines()] - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'New caption text line1', - 'New caption text line2' - ] + self.assertEqual( + out.read(), + textwrap.dedent(''' + WEBVTT - self.assertListEqual(lines, expected_lines) + 00:00:00.500 --> 00:00:07.000 + Caption text #1 - def test_save_captions(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('one_caption.vtt'), OUTPUT_DIR) + 00:00:07.000 --> 00:00:11.890 + New caption text line1 + New caption text line2 + ''').strip() + ) - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt')) - new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2']) + def test_write_captions_in_srt(self): + out = io.StringIO() + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', + 'New caption text line2' + ] + ) vtt.captions.append(new_caption) - vtt.save() - - with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] + vtt.write(out, format='srt') - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'New caption text line1', - 'New caption text line2' - ] + out.seek(0) + self.assertEqual( + out.read(), + textwrap.dedent(''' + 1 + 00:00:00,500 --> 00:00:07,000 + Caption text #1 + + 2 + 00:00:07,000 --> 00:00:11,890 + New caption text line1 + New caption text line2 + ''').strip() + ) + + def test_write_captions_in_unsupported_format(self): + self.assertRaises( + ValueError, + webvtt.WebVTT().write, + io.StringIO(), + format='ttt' + ) - self.assertListEqual(lines, expected_lines) + def test_save_captions(self): + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + f.write((PATH_TO_SAMPLES / 'one_caption.vtt').read_text()) + f.flush() + + vtt = webvtt.read(f.name) + new_caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text=['New caption text line1', + 'New caption text line2' + ] + ) + vtt.captions.append(new_caption) + vtt.save() + f.flush() + + self.assertEqual( + pathlib.Path(f.name).read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + New caption text line1 + New caption text line2 + ''').strip() + ) def test_srt_conversion(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('one_caption.srt'), OUTPUT_DIR) - - vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'one_caption.srt')) - vtt.save() - - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'one_caption.vtt'))) - - with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - ] - - self.assertListEqual(lines, expected_lines) + with tempfile.TemporaryDirectory() as td: + with open(pathlib.Path(td) / 'one_caption.srt', 'w') as f: + f.write((PATH_TO_SAMPLES / 'one_caption.srt').read_text()) + + webvtt.from_srt( + pathlib.Path(td) / 'one_caption.srt' + ).save() + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption.vtt') + ) + self.assertEqual( + (pathlib.Path(td) / 'one_caption.vtt').read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) def test_sbv_conversion(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('two_captions.sbv'), OUTPUT_DIR) - - vtt = webvtt.from_sbv(os.path.join(OUTPUT_DIR, 'two_captions.sbv')) - vtt.save() - - self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'two_captions.vtt'))) - - with open(os.path.join(OUTPUT_DIR, 'two_captions.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.378 --> 00:00:11.378', - 'Caption text #1', - '', - '00:00:11.378 --> 00:00:12.305', - 'Caption text #2 (line 1)', - 'Caption text #2 (line 2)', - ] - - self.assertListEqual(lines, expected_lines) + with tempfile.TemporaryDirectory() as td: + with open(pathlib.Path(td) / 'two_captions.sbv', 'w') as f: + f.write( + (PATH_TO_SAMPLES / 'two_captions.sbv').read_text() + ) + + webvtt.from_sbv( + pathlib.Path(td) / 'two_captions.sbv' + ).save() + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'two_captions.vtt') + ) + self.assertEqual( + (pathlib.Path(td) / 'two_captions.vtt').read_text(), + textwrap.dedent(''' + WEBVTT + + 00:00:00.378 --> 00:00:11.378 + Caption text #1 + + 00:00:11.378 --> 00:00:12.305 + Caption text #2 (line 1) + Caption text #2 (line 2) + ''').strip() + ) def test_save_to_other_location(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(td) - webvtt.read(self._get_file('one_caption.vtt')).save(target_path) - self.assertTrue(os.path.exists(os.path.join(target_path, 'one_caption.vtt'))) + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption.vtt') + ) def test_save_specific_filename(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) - output_file = os.path.join(target_path, 'custom_name.vtt') + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save( + pathlib.Path(td) / 'one_caption_new.vtt' + ) - webvtt.read(self._get_file('one_caption.vtt')).save(output_file) - self.assertTrue(os.path.exists(output_file)) + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt') + ) def test_save_specific_filename_no_extension(self): - target_path = os.path.join(OUTPUT_DIR, 'test_folder') - os.makedirs(target_path) - output_file = os.path.join(target_path, 'custom_name') - - webvtt.read(self._get_file('one_caption.vtt')).save(output_file) - self.assertTrue(os.path.exists(os.path.join(target_path, 'custom_name.vtt'))) - - def test_caption_timestamp_update(self): - c = Caption('00:00:00.500', '00:00:07.000') - c.start = '00:00:01.750' - c.end = '00:00:08.250' - - self.assertEqual(c.start, '00:00:01.750') - self.assertEqual(c.end, '00:00:08.250') - - def test_caption_timestamp_format(self): - c = Caption('01:02:03.400', '02:03:04.500') - self.assertEqual(c.start, '01:02:03.400') - self.assertEqual(c.end, '02:03:04.500') - - c = Caption('02:03.400', '03:04.500') - self.assertEqual(c.start, '00:02:03.400') - self.assertEqual(c.end, '00:03:04.500') - - def test_caption_text(self): - c = Caption(text=['Caption line #1', 'Caption line #2']) - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_caption_receive_text(self): - c = Caption(text='Caption line #1\nCaption line #2') - - self.assertEqual( - len(c.lines), - 2 - ) - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_update_text(self): - c = Caption(text='Caption line #1') - c.text = 'Caption line #1 updated' - self.assertEqual( - c.text, - 'Caption line #1 updated' - ) - - def test_update_text_multiline(self): - c = Caption(text='Caption line #1') - c.text = 'Caption line #1\nCaption line #2' - - self.assertEqual( - len(c.lines), - 2 - ) - - self.assertEqual( - c.text, - 'Caption line #1\nCaption line #2' - ) - - def test_update_text_wrong_type(self): - c = Caption(text='Caption line #1') - - self.assertRaises( - AttributeError, - setattr, - c, - 'text', - 123 - ) - - def test_manipulate_lines(self): - c = Caption(text=['Caption line #1', 'Caption line #2']) - c.lines[0] = 'Caption line #1 updated' - self.assertEqual( - c.lines[0], - 'Caption line #1 updated' - ) - - def test_read_file_buffer(self): - with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: - vtt = webvtt.read_buffer(f) - self.assertIsInstance(vtt.captions, list) + with tempfile.TemporaryDirectory() as td: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save( + pathlib.Path(td) / 'one_caption_new' + ) + + self.assertTrue( + os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt') + ) + + def test_from_buffer(self): + with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: + self.assertIsInstance( + webvtt.from_buffer(f).captions, + list + ) + + def test_deprecated_read_buffer(self): + with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter('always') + vtt = webvtt.read_buffer(f) + + self.assertIsInstance(vtt.captions, list) + self.assertEqual( + 'Deprecated: use from_buffer instead.', + str(warns[-1].message) + ) def test_read_memory_buffer(self): - payload = '' - with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f: - payload = f.read() + buffer = io.StringIO( + (PATH_TO_SAMPLES / 'sample.vtt').read_text() + ) - buffer = io.StringIO(payload) - vtt = webvtt.read_buffer(buffer) - self.assertIsInstance(vtt.captions, list) + self.assertIsInstance( + webvtt.from_buffer(buffer).captions, + list + ) def test_read_memory_buffer_carriage_return(self): """https://github.com/glut23/webvtt-py/issues/29""" @@ -251,167 +272,917 @@ def test_read_memory_buffer_carriage_return(self): \r 00:00:11.890 --> 00:00:16.320\r Caption text #3\r - ''')) - vtt = webvtt.read_buffer(buffer) - self.assertEqual(len(vtt.captions), 3) + ''') + ) + + self.assertEqual( + len(webvtt.from_buffer(buffer).captions), + 3 + ) def test_read_malformed_buffer(self): - malformed_payloads = ['', 'MOCK MELFORMED CONTENT'] + malformed_payloads = ['', 'MOCK MALFORMED CONTENT'] for payload in malformed_payloads: buffer = io.StringIO(payload) with self.assertRaises(MalformedFileError): - webvtt.read_buffer(buffer) - + webvtt.from_buffer(buffer) def test_captions(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertIsInstance(vtt.captions, list) - - def test_captions_prevent_write(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertRaises( - AttributeError, - setattr, - vtt, - 'captions', - [] - ) + captions = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').captions + self.assertIsInstance( + captions, + list + ) + self.assertEqual(len(captions), 16) def test_sequence_iteration(self): - vtt = webvtt.read(self._get_file('sample.vtt')) + vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt') self.assertIsInstance(vtt[0], Caption) self.assertEqual(len(vtt), len(vtt.captions)) def test_save_no_filename(self): - vtt = webvtt.WebVTT() self.assertRaises( webvtt.errors.MissingFilenameError, - vtt.save - ) + webvtt.WebVTT().save + ) - def test_malformed_start_timestamp(self): - self.assertRaises( - webvtt.errors.MalformedCaptionError, - Caption, - '01:00' - ) + def test_save_with_path_to_dir_no_filename(self): + with tempfile.TemporaryDirectory() as td: + self.assertRaises( + webvtt.errors.MissingFilenameError, + webvtt.WebVTT().save, + td + ) def test_set_styles_from_text(self): - style = Style() - style.text = '::cue(b) {\n color: peachpuff;\n}' + style = Style('::cue(b) {\n color: peachpuff;\n}') self.assertListEqual( style.lines, ['::cue(b) {', ' color: peachpuff;', '}'] - ) - - def test_get_styles_as_text(self): - style = Style() - style.lines = ['::cue(b) {', ' color: peachpuff;', '}'] - self.assertEqual( - style.text, - '::cue(b) {color: peachpuff;}' - ) + ) def test_save_identifiers(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) - - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) - vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt')) - - with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - 'second caption', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2', - '', - '00:00:11.890 --> 00:00:16.320', - 'Caption text #3', - '', - '4', - '00:00:16.320 --> 00:00:21.580', - 'Caption text #4', - '', - '00:00:21.580 --> 00:00:23.880', - 'Caption text #5', - '', - '00:00:23.880 --> 00:00:27.280', - 'Caption text #6' - ] - - self.assertListEqual(lines, expected_lines) + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'using_identifiers.vtt' + ).save( + f.name + ) + + self.assertListEqual( + pathlib.Path(f.name).read_text().splitlines(), + [ + 'WEBVTT', + '', + '00:00:00.500 --> 00:00:07.000', + 'Caption text #1', + '', + 'second caption', + '00:00:07.000 --> 00:00:11.890', + 'Caption text #2', + '', + '00:00:11.890 --> 00:00:16.320', + 'Caption text #3', + '', + '4', + '00:00:16.320 --> 00:00:21.580', + 'Caption text #4', + '', + '00:00:21.580 --> 00:00:23.880', + 'Caption text #5', + '', + '00:00:23.880 --> 00:00:27.280', + 'Caption text #6' + ] + ) def test_save_updated_identifiers(self): - os.makedirs(OUTPUT_DIR) - copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR) - - vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt')) + vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt') vtt.captions[0].identifier = 'first caption' vtt.captions[1].identifier = None vtt.captions[3].identifier = '44' - last_caption = Caption('00:00:27.280', '00:00:29.200', 'Caption text #7') + last_caption = Caption(start='00:00:27.280', + end='00:00:29.200', + text='Caption text #7' + ) last_caption.identifier = 'last caption' vtt.captions.append(last_caption) - vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt')) - - with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f: - lines = [line.rstrip() for line in f.readlines()] - - expected_lines = [ - 'WEBVTT', - '', - 'first caption', - '00:00:00.500 --> 00:00:07.000', - 'Caption text #1', - '', - '00:00:07.000 --> 00:00:11.890', - 'Caption text #2', - '', - '00:00:11.890 --> 00:00:16.320', - 'Caption text #3', - '', - '44', - '00:00:16.320 --> 00:00:21.580', - 'Caption text #4', - '', - '00:00:21.580 --> 00:00:23.880', - 'Caption text #5', - '', - '00:00:23.880 --> 00:00:27.280', - 'Caption text #6', - '', - 'last caption', - '00:00:27.280 --> 00:00:29.200', - 'Caption text #7' - ] - - self.assertListEqual(lines, expected_lines) + + with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f: + vtt.save(f.name) + + self.assertListEqual( + pathlib.Path(f.name).read_text().splitlines(), + [ + 'WEBVTT', + '', + 'first caption', + '00:00:00.500 --> 00:00:07.000', + 'Caption text #1', + '', + '00:00:07.000 --> 00:00:11.890', + 'Caption text #2', + '', + '00:00:11.890 --> 00:00:16.320', + 'Caption text #3', + '', + '44', + '00:00:16.320 --> 00:00:21.580', + 'Caption text #4', + '', + '00:00:21.580 --> 00:00:23.880', + 'Caption text #5', + '', + '00:00:23.880 --> 00:00:27.280', + 'Caption text #6', + '', + 'last caption', + '00:00:27.280 --> 00:00:29.200', + 'Caption text #7' + ] + ) def test_content_formatting(self): """ Verify that content property returns the correctly formatted webvtt. """ captions = [ - Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']), - Caption('00:00:08.000', '00:00:15.000', ['Caption test line 3', 'Caption test line 4']), - ] - expected_content = textwrap.dedent("""\ - WEBVTT + Caption(start='00:00:00.500', + end='00:00:07.000', + text=['Caption test line 1', 'Caption test line 2'] + ), + Caption(start='00:00:08.000', + end='00:00:15.000', + text=['Caption test line 3', 'Caption test line 4'] + ), + ] - 00:00:00.500 --> 00:00:07.000 - Caption test line 1 - Caption test line 2 - - 00:00:08.000 --> 00:00:15.000 - Caption test line 3 - Caption test line 4 - """).strip() - vtt = webvtt.WebVTT(captions=captions) - self.assertEqual(expected_content, vtt.content) + self.assertEqual( + webvtt.WebVTT(captions=captions).content, + textwrap.dedent(""" + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption test line 1 + Caption test line 2 + + 00:00:08.000 --> 00:00:15.000 + Caption test line 3 + Caption test line 4 + """).strip() + ) + + def test_repr(self): + test_file = PATH_TO_SAMPLES / 'sample.vtt' + self.assertEqual( + repr(webvtt.read(test_file)), + f"" + ) + + def test_str(self): + self.assertEqual( + str(webvtt.read(PATH_TO_SAMPLES / 'sample.vtt')), + textwrap.dedent(""" + 00:00:00.500 00:00:07.000 Caption text #1 + 00:00:07.000 00:00:11.890 Caption text #2 + 00:00:11.890 00:00:16.320 Caption text #3 + 00:00:16.320 00:00:21.580 Caption text #4 + 00:00:21.580 00:00:23.880 Caption text #5 + 00:00:23.880 00:00:27.280 Caption text #6 + 00:00:27.280 00:00:30.280 Caption text #7 + 00:00:30.280 00:00:36.510 Caption text #8 + 00:00:36.510 00:00:38.870 Caption text #9 + 00:00:38.870 00:00:45.000 Caption text #10 + 00:00:45.000 00:00:47.000 Caption text #11 + 00:00:47.000 00:00:50.970 Caption text #12 + 00:00:50.970 00:00:54.440 Caption text #13 + 00:00:54.440 00:00:58.600 Caption text #14 + 00:00:58.600 00:01:01.350 Caption text #15 + 00:01:01.350 00:01:04.300 Caption text #16 + """).strip() + ) + + def test_parse_invalid_file(self): + self.assertRaises( + MalformedFileError, + webvtt.read, + PATH_TO_SAMPLES / 'invalid.vtt' + ) + + def test_file_not_found(self): + self.assertRaises( + FileNotFoundError, + webvtt.read, + 'nowhere' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').total_length, + 64 + ) + + def test_total_length_no_captions(self): + self.assertEqual( + webvtt.WebVTT().total_length, + 0 + ) + + def test_parse_empty_file(self): + self.assertRaises( + MalformedFileError, + webvtt.read, + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_parse_invalid_timeframe_line(self): + good_captions = len( + webvtt.read(PATH_TO_SAMPLES / 'invalid_timeframe.vtt').captions + ) + self.assertEqual(good_captions, 6) + + def test_parse_invalid_timeframe_in_cue_text(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'invalid_timeframe_in_cue_text.vtt' + ) + self.assertEqual(2, len(vtt.captions)) + self.assertEqual('Caption text #3', vtt.captions[1].text) + + def test_parse_get_caption_data(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt') + self.assertEqual(vtt.captions[0].start_in_seconds, 0) + self.assertEqual(vtt.captions[0].start, '00:00:00.500') + self.assertEqual(vtt.captions[0].end_in_seconds, 7) + self.assertEqual(vtt.captions[0].end, '00:00:07.000') + self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') + self.assertEqual(len(vtt.captions[0].lines), 1) + + def test_caption_without_timeframe(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_timeframe.vtt') + self.assertEqual(len(vtt.captions), 6) + + def test_caption_without_cue_text(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_caption_text.vtt') + self.assertEqual(len(vtt.captions), 4) + + def test_timestamps_format(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt') + self.assertEqual(vtt.captions[2].start, '00:00:11.890') + self.assertEqual(vtt.captions[2].end, '00:00:16.320') + + def test_parse_timestamp(self): + self.assertEqual( + Caption(start='02:03:11.890').start_in_seconds, + 7391 + ) + + def test_captions_attribute(self): + self.assertListEqual(webvtt.WebVTT().captions, []) + + def test_metadata_headers(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers.vtt') + self.assertEqual(len(vtt.captions), 2) + + def test_metadata_headers_multiline(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers_multiline.vtt') + self.assertEqual(len(vtt.captions), 2) + + def test_parse_identifiers(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt') + self.assertEqual(len(vtt.captions), 6) + + self.assertEqual(vtt.captions[1].identifier, 'second caption') + self.assertEqual(vtt.captions[2].identifier, None) + self.assertEqual(vtt.captions[3].identifier, '4') + + def test_parse_comments(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'comments.vtt') + self.assertEqual(len(vtt.captions), 3) + self.assertListEqual( + vtt.captions[0].lines, + ['- Ta en kopp varmt te.', + '- Det är inte varmt.'] + ) + self.assertListEqual( + vtt.captions[0].comments, + [] + ) + self.assertListEqual( + vtt.captions[1].comments, + [] + ) + self.assertEqual( + vtt.captions[2].text, + '- Ta en kopp' + ) + self.assertListEqual( + vtt.captions[2].comments, + ['This last line may not translate well.'] + ) + self.assertListEqual( + vtt.header_comments, + ['This translation was done by Kyle so that\n' + 'some friends can watch it with their parents.' + ] + ) + self.assertListEqual( + vtt.footer_comments, + ['end of file'] + ) + + def test_parse_styles(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'styles.vtt') + self.assertEqual(len(vtt.captions), 1) + self.assertEqual( + vtt.styles[0].text, + '::cue {\n background-image: linear-gradient(to bottom, ' + 'dimgray, lightgray);\n color: papayawhip;\n}' + ) + + def test_parse_styles_with_comments(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'styles_with_comments.vtt') + self.assertEqual(len(vtt.captions), 1) + self.assertEqual(len(vtt.styles), 2) + self.assertEqual( + vtt.styles[0].comments, + [] + ) + self.assertEqual( + vtt.styles[0].text, + '::cue {\n' + ' background-image: linear-gradient(to bottom, dimgray, ' + 'lightgray);\n' + ' color: papayawhip;\n' + '}' + ) + self.assertEqual( + vtt.styles[1].comments, + ['This is the second block of styles', + 'Multiline comment for the same\nsecond block of styles' + ] + ) + self.assertEqual( + vtt.styles[1].text, + '::cue(b) {\n' + ' color: peachpuff;\n' + '}' + ) + self.assertListEqual( + vtt.header_comments, + ['Sample of comments with styles'] + ) + self.assertListEqual( + vtt.footer_comments, + [] + ) + + def test_multiple_comments_everywhere(self): + vtt = webvtt.WebVTT.from_string(textwrap.dedent(""" + WEBVTT + + NOTE Test file + + NOTE this file is testing multiple comments + in different places + + STYLE + ::cue { + background-image: linear-gradient(to bottom, dimgray, lightgray); + color: papayawhip; + } + + NOTE this style uses nice color + + NOTE check how it looks before proceeding + + STYLE + ::cue(b) { + color: peachpuff; + } + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + NOTE next caption has two lines + + NOTE needs review + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 line 1 + Caption text #2 line 2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + + NOTE Copyright 2024 + + NOTE this is the end of the file + """).strip() + ) + + self.assertListEqual( + vtt.header_comments, + ['Test file', + 'this file is testing multiple comments\nin different places' + ] + ) + self.assertListEqual( + vtt.footer_comments, + ['Copyright 2024', + 'this is the end of the file' + ] + ) + self.assertListEqual( + vtt.captions[0].comments, + [] + ) + self.assertListEqual( + vtt.captions[1].comments, + ['next caption has two lines', + 'needs review' + ] + ) + self.assertListEqual( + vtt.captions[2].comments, + [] + ) + self.assertListEqual( + vtt.captions[3].comments, + [] + ) + self.assertListEqual( + vtt.styles[0].comments, + [] + ) + self.assertListEqual( + vtt.styles[1].comments, + ['this style uses nice color', + 'check how it looks before proceeding' + ] + ) + + def test_comments_in_new_file(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.header_comments.append('This is a header comment') + vtt.header_comments.append( + 'where we can see a\ntwo line comment' + ) + vtt.styles.append( + Style('::cue(b) {\n color: peachpuff;\n}') + ) + style = Style('::cue {\n color: papayawhip;\n}') + style.comments.append('Another style to test\nthe look and feel') + style.comments.append('Please check') + vtt.styles.append(style) + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption #1', + ) + ) + caption = Caption(start='00:00:07.000', + end='00:00:11.890', + text='Caption #2' + ) + caption.comments.append( + 'Second caption may be a bit off\nand needs checking' + ) + caption.comments.append('Confirm if it displays correctly') + vtt.captions.append(caption) + vtt.footer_comments.append('This is a footer comment') + vtt.footer_comments.append( + 'where we can also see a\ntwo line comment' + ) + + vtt.save(f.name) + self.assertEqual( + f.read(), + textwrap.dedent(''' + WEBVTT + + NOTE This is a header comment + + NOTE + where we can see a + two line comment + + STYLE + ::cue(b) { + color: peachpuff; + } + + NOTE + Another style to test + the look and feel + + NOTE Please check + + STYLE + ::cue { + color: papayawhip; + } + + 00:00:00.500 --> 00:00:07.000 + Caption #1 + + NOTE + Second caption may be a bit off + and needs checking + + NOTE Confirm if it displays correctly + + 00:00:07.000 --> 00:00:11.890 + Caption #2 + + NOTE This is a footer comment + + NOTE + where we can also see a + two line comment + ''').strip() + ) + + def test_clean_cue_tags(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'cue_tags.vtt') + self.assertEqual( + vtt.captions[1].text, + 'Like a big-a pizza pie' + ) + self.assertEqual( + vtt.captions[2].text, + 'That\'s amore' + ) + + def test_parse_captions_with_bom(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'captions_with_bom.vtt') + self.assertEqual(len(vtt.captions), 4) + + def test_empty_lines_are_not_included_in_result(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'netflix_chicas_del_cable.vtt') + self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,") + self.assertEqual( + vtt.captions[-2].text, + "Diez años no son suficientes\npara olvidarte..." + ) + + def test_can_parse_youtube_dl_files(self): + vtt = webvtt.read(PATH_TO_SAMPLES / 'youtube_dl.vtt') + self.assertEqual( + "this will happen is I'm telling\n ", + vtt.captions[2].text + ) + + +class TestParseSRT(unittest.TestCase): + + def test_parse_empty_file(self): + self.assertRaises( + webvtt.errors.MalformedFileError, + webvtt.from_srt, + # We reuse this file as it is empty and serves the purpose. + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_invalid_format(self): + for i in range(1, 5): + self.assertRaises( + MalformedFileError, + webvtt.from_srt, + PATH_TO_SAMPLES / f'invalid_format{i}.srt' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').total_length, + 23 + ) + + def test_parse_captions(self): + self.assertTrue( + webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').captions + ) + + def test_missing_timeframe_line(self): + self.assertEqual( + len(webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_timeframe.srt').captions + ), + 4 + ) + + def test_empty_caption_text(self): + self.assertTrue( + webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_caption_text.srt').captions + ) + + def test_empty_gets_removed(self): + captions = webvtt.from_srt( + PATH_TO_SAMPLES / 'missing_caption_text.srt' + ).captions + self.assertEqual(len(captions), 4) + + def test_invalid_timestamp(self): + self.assertEqual( + len(webvtt.from_srt( + PATH_TO_SAMPLES / 'invalid_timeframe.srt' + ).captions), + 4 + ) + + def test_timestamps_format(self): + vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt') + self.assertEqual(vtt.captions[2].start, '00:00:11.890') + self.assertEqual(vtt.captions[2].end, '00:00:16.320') + + def test_parse_get_caption_data(self): + vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'one_caption.srt') + self.assertEqual(vtt.captions[0].start_in_seconds, 0) + self.assertEqual(vtt.captions[0].start, '00:00:00.500') + self.assertEqual(vtt.captions[0].end_in_seconds, 7) + self.assertEqual(vtt.captions[0].end, '00:00:07.000') + self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') + self.assertEqual(len(vtt.captions[0].lines), 1) + + +class TestParseSBV(unittest.TestCase): + + def test_parse_empty_file(self): + self.assertRaises( + MalformedFileError, + webvtt.from_sbv, + # We reuse this file as it is empty and serves the purpose. + PATH_TO_SAMPLES / 'empty.vtt' + ) + + def test_invalid_format(self): + self.assertRaises( + MalformedFileError, + webvtt.from_sbv, + PATH_TO_SAMPLES / 'invalid_format.sbv' + ) + + def test_total_length(self): + self.assertEqual( + webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').total_length, + 16 + ) + + def test_parse_captions(self): + self.assertEqual( + len(webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').captions), + 5 + ) + + def test_missing_timeframe_line(self): + self.assertEqual( + len(webvtt.from_sbv( + PATH_TO_SAMPLES / 'missing_timeframe.sbv' + ).captions), + 4 + ) + + def test_missing_caption_text(self): + self.assertTrue( + webvtt.from_sbv( + PATH_TO_SAMPLES / 'missing_caption_text.sbv' + ).captions + ) + + def test_invalid_timestamp(self): + self.assertEqual( + len(webvtt.from_sbv( + PATH_TO_SAMPLES / 'invalid_timeframe.sbv' + ).captions), + 4 + ) + + def test_timestamps_format(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].start, '00:00:11.378') + self.assertEqual(vtt.captions[1].end, '00:00:12.305') + + def test_timestamps_in_seconds(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].start_in_seconds, 11) + self.assertEqual(vtt.captions[1].end_in_seconds, 12) + + def test_get_caption_text(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual(vtt.captions[1].text, 'Caption text #2') + + def test_get_caption_text_multiline(self): + vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv') + self.assertEqual( + vtt.captions[2].text, + 'Caption text #3 (line 1)\nCaption text #3 (line 2)' + ) + self.assertListEqual( + vtt.captions[2].lines, + ['Caption text #3 (line 1)', 'Caption text #3 (line 2)'] + ) + + def test_convert_from_srt_to_vtt_and_back_gives_same_file(self): + with tempfile.NamedTemporaryFile('w', suffix='.srt') as f: + webvtt.from_srt( + PATH_TO_SAMPLES / 'sample.srt' + ).save_as_srt(f.name) + + self.assertEqual( + pathlib.Path(PATH_TO_SAMPLES / 'sample.srt').read_text(), + pathlib.Path(f.name).read_text() + ) + + def test_save_file_with_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, add_bom=True) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_file_with_bom_keeps_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'captions_with_bom.vtt' + ).save(f.name) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''').strip() + ) + + def test_save_file_with_bom_removes_bom_if_requested(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'captions_with_bom.vtt' + ).save(f.name, add_bom=False) + self.assertEqual( + f.read(), + textwrap.dedent(f''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + + 00:00:07.000 --> 00:00:11.890 + Caption text #2 + + 00:00:11.890 --> 00:00:16.320 + Caption text #3 + + 00:00:16.320 --> 00:00:21.580 + Caption text #4 + ''').strip() + ) + + def test_save_file_with_encoding(self): + with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, + encoding='utf-32-le' + ) + self.assertEqual( + f.read().decode('utf-32-le'), + textwrap.dedent(''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_file_with_encoding_and_bom(self): + with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f: + webvtt.read( + PATH_TO_SAMPLES / 'one_caption.vtt' + ).save(f.name, + encoding='utf-32-le', + add_bom=True + ) + self.assertEqual( + f.read().decode('utf-32-le'), + textwrap.dedent(f''' + {CODEC_BOMS['utf-32-le'].decode('utf-32-le')}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_new_file_utf_8_default_encoding_no_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption text #1' + ) + ) + vtt.save(f.name) + self.assertEqual(vtt.encoding, 'utf-8') + self.assertEqual( + f.read(), + textwrap.dedent(f''' + WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_save_new_file_utf_8_default_encoding_with_bom(self): + with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f: + vtt = webvtt.WebVTT() + vtt.captions.append( + Caption(start='00:00:00.500', + end='00:00:07.000', + text='Caption text #1' + ) + ) + vtt.save(f.name, + add_bom=True + ) + self.assertEqual(vtt.encoding, 'utf-8') + self.assertEqual( + f.read(), + textwrap.dedent(f''' + {CODEC_BOMS['utf-8'].decode()}WEBVTT + + 00:00:00.500 --> 00:00:07.000 + Caption text #1 + ''').strip() + ) + + def test_iter_slice(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(start='00:00:11.000', + end='00:00:27.000' + ) + for expected_caption in (vtt.captions[2], + vtt.captions[3], + vtt.captions[4] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) + + def test_iter_slice_no_start_time(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(end='00:00:27.000') + for expected_caption in (vtt.captions[0], + vtt.captions[1], + vtt.captions[2], + vtt.captions[3], + vtt.captions[4] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) + + def test_iter_slice_no_end_time(self): + vtt = webvtt.read( + PATH_TO_SAMPLES / 'sample.vtt' + ) + slice_of_captions = vtt.iter_slice(start='00:00:47.000') + for expected_caption in (vtt.captions[11], + vtt.captions[12], + vtt.captions[13], + vtt.captions[14], + vtt.captions[15] + ): + self.assertIs(expected_caption, next(slice_of_captions)) + + with self.assertRaises(StopIteration): + next(slice_of_captions) diff --git a/tests/test_webvtt_parser.py b/tests/test_webvtt_parser.py deleted file mode 100644 index a26741a..0000000 --- a/tests/test_webvtt_parser.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from .generic import GenericParserTestCase - -import webvtt -from webvtt.parsers import WebVTTParser -from webvtt.structures import Caption -from webvtt.errors import MalformedFileError, MalformedCaptionError - - -class WebVTTParserTestCase(GenericParserTestCase): - - def test_webvtt_parse_invalid_file(self): - self.assertRaises( - MalformedFileError, - webvtt.read, - self._get_file('invalid.vtt') - ) - - def test_webvtt_captions_not_found(self): - self.assertRaises( - FileNotFoundError, - webvtt.read, - 'some_file' - ) - - def test_webvtt_total_length(self): - self.assertEqual( - webvtt.read(self._get_file('sample.vtt')).total_length, - 64 - ) - - def test_webvtt_total_length_no_parser(self): - self.assertEqual( - webvtt.WebVTT().total_length, - 0 - ) - - def test_webvtt__parse_captions(self): - self.assertTrue(webvtt.read(self._get_file('sample.vtt')).captions) - - def test_webvtt_parse_empty_file(self): - self.assertRaises( - MalformedFileError, - webvtt.read, - self._get_file('empty.vtt') - ) - - def test_webvtt_parse_get_captions(self): - self.assertEqual( - len(webvtt.read(self._get_file('sample.vtt')).captions), - 16 - ) - - def test_webvtt_parse_invalid_timeframe_line(self): - self.assertRaises( - MalformedCaptionError, - webvtt.read, - self._get_file('invalid_timeframe.vtt') - ) - - def test_webvtt_parse_invalid_timeframe_in_cue_text(self): - vtt = webvtt.read(self._get_file('invalid_timeframe_in_cue_text.vtt')) - self.assertEqual(4, len(vtt.captions)) - self.assertEqual('', vtt.captions[1].text) - - def test_webvtt_parse_get_caption_data(self): - vtt = webvtt.read(self._get_file('one_caption.vtt')) - self.assertEqual(vtt.captions[0].start_in_seconds, 0.5) - self.assertEqual(vtt.captions[0].start, '00:00:00.500') - self.assertEqual(vtt.captions[0].end_in_seconds, 7) - self.assertEqual(vtt.captions[0].end, '00:00:07.000') - self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1') - self.assertEqual(len(vtt.captions[0].lines), 1) - - def test_webvtt_caption_without_timeframe(self): - self.assertRaises( - MalformedCaptionError, - webvtt.read, - self._get_file('missing_timeframe.vtt') - ) - - def test_webvtt_caption_without_cue_text(self): - vtt = webvtt.read(self._get_file('missing_caption_text.vtt')) - self.assertEqual(len(vtt.captions), 5) - - def test_webvtt_timestamps_format(self): - vtt = webvtt.read(self._get_file('sample.vtt')) - self.assertEqual(vtt.captions[2].start, '00:00:11.890') - self.assertEqual(vtt.captions[2].end, '00:00:16.320') - - def test_parse_timestamp(self): - caption = Caption(start='02:03:11.890') - self.assertEqual( - caption.start_in_seconds, - 7391.89 - ) - - def test_captions_attribute(self): - self.assertListEqual([], webvtt.WebVTT().captions) - - def test_webvtt_timestamp_format(self): - self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00:00.000 --> 00:00:00.000')) - self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00.000 --> 00:00.000')) - - def test_metadata_headers(self): - vtt = webvtt.read(self._get_file('metadata_headers.vtt')) - self.assertEqual(len(vtt.captions), 2) - - def test_metadata_headers_multiline(self): - vtt = webvtt.read(self._get_file('metadata_headers_multiline.vtt')) - self.assertEqual(len(vtt.captions), 2) - - def test_parse_identifiers(self): - vtt = webvtt.read(self._get_file('using_identifiers.vtt')) - self.assertEqual(len(vtt.captions), 6) - - self.assertEqual(vtt.captions[1].identifier, 'second caption') - self.assertEqual(vtt.captions[2].identifier, None) - self.assertEqual(vtt.captions[3].identifier, '4') - - def test_parse_with_comments(self): - vtt = webvtt.read(self._get_file('comments.vtt')) - self.assertEqual(len(vtt.captions), 3) - self.assertListEqual( - vtt.captions[0].lines, - ['- Ta en kopp varmt te.', - '- Det är inte varmt.'] - ) - self.assertEqual( - vtt.captions[2].text, - '- Ta en kopp' - ) - - def test_parse_styles(self): - vtt = webvtt.read(self._get_file('styles.vtt')) - self.assertEqual(len(vtt.captions), 1) - self.assertEqual( - vtt.styles[0].text, - '::cue {background-image: linear-gradient(to bottom, dimgray, lightgray);color: papayawhip;}' - ) - - def test_clean_cue_tags(self): - vtt = webvtt.read(self._get_file('cue_tags.vtt')) - self.assertEqual( - vtt.captions[1].text, - 'Like a big-a pizza pie' - ) - self.assertEqual( - vtt.captions[2].text, - 'That\'s amore' - ) - - def test_parse_captions_with_bom(self): - vtt = webvtt.read(self._get_file('captions_with_bom.vtt')) - self.assertEqual(len(vtt.captions), 4) - - def test_empty_lines_are_not_included_in_result(self): - vtt = webvtt.read(self._get_file('netflix_chicas_del_cable.vtt')) - self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,") - self.assertEqual( - vtt.captions[-2].text, - "Diez años no son suficientes\npara olvidarte..." - ) - - def test_can_parse_youtube_dl_files(self): - vtt = webvtt.read(self._get_file('youtube_dl.vtt')) - self.assertEqual( - "this will happen is I'm telling", - vtt.captions[2].text - ) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 06fa4ff..0000000 --- a/tox.ini +++ /dev/null @@ -1,18 +0,0 @@ -[tox] -envlist = py34, py35, py36, py37, py38, py39 - -[travis] -python = - 3.9: py39 - 3.8: py38 - 3.7: py37 - 3.6: py36 - 3.5: py35 - 3.4: py34 - -[testenv] -setenv = - PYTHONPATH = {toxinidir} -deps = pytest -commands = - pytest diff --git a/webvtt/__init__.py b/webvtt/__init__.py index b6239f4..33fa037 100644 --- a/webvtt/__init__.py +++ b/webvtt/__init__.py @@ -1,15 +1,16 @@ -__version__ = '0.4.6' +"""Main webvtt package.""" -from .webvtt import * -from .segmenter import * -from .structures import * -from .errors import * +__version__ = '0.5.0' -__all__ = webvtt.__all__ + segmenter.__all__ + structures.__all__ + errors.__all__ +from .webvtt import WebVTT +from . import segmenter +from .models import Caption, Style # noqa + +__all__ = ['WebVTT', 'Caption', 'Style'] read = WebVTT.read read_buffer = WebVTT.read_buffer +from_buffer = WebVTT.from_buffer from_srt = WebVTT.from_srt from_sbv = WebVTT.from_sbv -list_formats = WebVTT.list_formats -segment = WebVTTSegmenter().segment +segment = segmenter.segment diff --git a/webvtt/cli.py b/webvtt/cli.py index ad8ebbf..a50e7d0 100644 --- a/webvtt/cli.py +++ b/webvtt/cli.py @@ -1,47 +1,59 @@ -""" -Usage: - webvtt segment [--target-duration=SECONDS] [--mpegts=OFFSET] [--output=] - webvtt -h | --help - webvtt --version - -Options: - -h --help Show this screen. - --version Show version. - --target-duration=SECONDS Target duration of each segment in seconds [default: 10]. - --mpegts=OFFSET Presentation timestamp value [default: 900000]. - --output= Output to directory [default: ./]. - -Examples: - webvtt segment captions.vtt --output destination/directory -""" - -from docopt import docopt - -from . import WebVTTSegmenter, __version__ - - -def main(): - """Main entry point for CLI commands.""" - options = docopt(__doc__, version=__version__) - if options['segment']: - segment( - options[''], - options['--output'], - options['--target-duration'], - options['--mpegts'], +"""CLI module.""" + +import argparse +import typing + +from . import segmenter + + +def main(argv: typing.Optional[typing.Sequence] = None): + """ + Segment WebVTT file from command line. + + :param argv: command line arguments + """ + arguments = argparse.ArgumentParser( + description='Segment WebVTT files.' + ) + arguments.add_argument( + 'command', + choices=['segment'], + help='command to perform' + ) + arguments.add_argument( + 'file', + metavar='PATH', + help='WebVTT file' + ) + arguments.add_argument( + '-o', '--output', + metavar='PATH', + help='output directory' + ) + arguments.add_argument( + '-d', '--target-duration', + metavar='NUMBER', + type=int, + default=segmenter.DEFAULT_SECONDS, + help='target duration of each segment in seconds, default: 10' + ) + arguments.add_argument( + '-m', '--mpegts', + metavar='NUMBER', + type=int, + default=segmenter.DEFAULT_MPEGTS, + help='presentation timestamp value, default: 900000' ) + args = arguments.parse_args(argv) -def segment(f, output, target_duration, mpegts): - """Segment command.""" - try: - target_duration = int(target_duration) - except ValueError: - exit('Error: Invalid target duration.') + segmenter.segment( + args.file, + args.output, + args.target_duration, + args.mpegts + ) - try: - mpegts = int(mpegts) - except ValueError: - exit('Error: Invalid MPEGTS value.') - WebVTTSegmenter().segment(f, output, target_duration, mpegts) \ No newline at end of file +if __name__ == '__main__': + main() # pragma: no cover diff --git a/webvtt/errors.py b/webvtt/errors.py index 9ee549b..f8628c9 100644 --- a/webvtt/errors.py +++ b/webvtt/errors.py @@ -1,18 +1,13 @@ - -__all__ = ['MalformedFileError', 'MalformedCaptionError', 'InvalidCaptionsError', 'MissingFilenameError'] +"""Errors module.""" class MalformedFileError(Exception): - """Error raised when the file is not well formatted""" + """File is not in the right format.""" class MalformedCaptionError(Exception): - """Error raised when a caption is not well formatted""" - - -class InvalidCaptionsError(Exception): - """Error raised when passing wrong captions to the segmenter""" + """Caption not in the right format.""" class MissingFilenameError(Exception): - """Error raised when saving a file without filename.""" + """Missing a filename when saving to disk.""" diff --git a/webvtt/models.py b/webvtt/models.py new file mode 100644 index 0000000..31be12f --- /dev/null +++ b/webvtt/models.py @@ -0,0 +1,222 @@ +"""Models module.""" + +import re +import typing + +from .errors import MalformedCaptionError + + +class Timestamp: + """Representation of a timestamp.""" + + PATTERN = re.compile(r'(?:(\d{1,2}):)?(\d{1,2}):(\d{1,2})\.(\d{3})') + + def __init__( + self, + hours: int = 0, + minutes: int = 0, + seconds: int = 0, + milliseconds: int = 0 + ): + """Initialize.""" + self.hours = hours + self.minutes = minutes + self.seconds = seconds + self.milliseconds = milliseconds + + def __str__(self): + """Return the string representation of the timestamp.""" + return ( + f'{self.hours:02d}:{self.minutes:02d}:{self.seconds:02d}' + f'.{self.milliseconds:03d}' + ) + + def to_tuple(self) -> typing.Tuple[int, int, int, int]: + """Return the timestamp in tuple form.""" + return self.hours, self.minutes, self.seconds, self.milliseconds + + def __repr__(self): + """Return the string representation of the caption.""" + return (f'<{self.__class__.__name__} ' + f'hours={self.hours} ' + f'minutes={self.minutes} ' + f'seconds={self.seconds} ' + f'milliseconds={self.milliseconds}>' + ) + + def __eq__(self, other): + """Compare equality with other object.""" + return self.to_tuple() == other.to_tuple() + + def __ne__(self, other): + """Compare a not equality with other object.""" + return self.to_tuple() != other.to_tuple() + + def __gt__(self, other): + """Compare greater than with other object.""" + return self.to_tuple() > other.to_tuple() + + def __lt__(self, other): + """Compare less than with other object.""" + return self.to_tuple() < other.to_tuple() + + def __ge__(self, other): + """Compare greater or equal with other object.""" + return self.to_tuple() >= other.to_tuple() + + def __le__(self, other): + """Compare less or equal with other object.""" + return self.to_tuple() <= other.to_tuple() + + @classmethod + def from_string(cls, value: str) -> 'Timestamp': + """Return a `Timestamp` instance from a string value.""" + if type(value) is not str: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + match = re.match(cls.PATTERN, value) + if not match: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + hours = int(match.group(1) or 0) + minutes = int(match.group(2)) + seconds = int(match.group(3)) + milliseconds = int(match.group(4)) + + if minutes > 59 or seconds > 59: + raise MalformedCaptionError(f'Invalid timestamp {value!r}') + + return cls(hours, minutes, seconds, milliseconds) + + def in_seconds(self) -> int: + """Return the timestamp in seconds.""" + return (self.hours * 3600 + + self.minutes * 60 + + self.seconds + ) + + +class Caption: + """Representation of a caption.""" + + CUE_TEXT_TAGS = re.compile('<.*?>') + + def __init__(self, + start: typing.Optional[str] = None, + end: typing.Optional[str] = None, + text: typing.Optional[typing.Union[str, + typing.Sequence[str] + ]] = None, + identifier: typing.Optional[str] = None + ): + """ + Initialize. + + :param start: start time of the caption + :param end: end time of the caption + :param text: the text of the caption + :param identifier: optional identifier + """ + text = text or [] + self.start = start or '00:00:00.000' + self.end = end or '00:00:00.000' + self.identifier = identifier + self.lines = (text.splitlines() + if isinstance(text, str) + else + list(text) + ) + self.comments: typing.List[str] = [] + + def __repr__(self): + """Return the string representation of the caption.""" + cleaned_text = self.text.replace('\n', '\\n') + return (f'<{self.__class__.__name__} ' + f'start={self.start!r} ' + f'end={self.end!r} ' + f'text={cleaned_text!r} ' + f'identifier={self.identifier!r}>' + ) + + def __str__(self): + """Return a readable representation of the caption.""" + cleaned_text = self.text.replace('\n', '\\n') + return f'{self.start} {self.end} {cleaned_text}' + + def __eq__(self, other): + """Compare equality with another object.""" + if not isinstance(other, type(self)): + return False + + return (self.start == other.start and + self.end == other.end and + self.raw_text == other.raw_text and + self.identifier == other.identifier + ) + + @property + def start(self): + """Return the start time of the caption.""" + return str(self.start_time) + + @start.setter + def start(self, value: str): + """Set the start time of the caption.""" + self.start_time = Timestamp.from_string(value) + + @property + def end(self): + """Return the end time of the caption.""" + return str(self.end_time) + + @end.setter + def end(self, value: str): + """Set the end time of the caption.""" + self.end_time = Timestamp.from_string(value) + + @property + def start_in_seconds(self) -> int: + """Return the start time of the caption in seconds.""" + return self.start_time.in_seconds() + + @property + def end_in_seconds(self): + """Return the end time of the caption in seconds.""" + return self.end_time.in_seconds() + + @property + def raw_text(self) -> str: + """Return the text of the caption (including cue tags).""" + return '\n'.join(self.lines) + + @property + def text(self) -> str: + """Return the text of the caption (without cue tags).""" + return re.sub(self.CUE_TEXT_TAGS, '', self.raw_text) + + @text.setter + def text(self, value: str): + """Set the text of the captions.""" + if not isinstance(value, str): + raise AttributeError( + f'String value expected but received {value}.' + ) + + self.lines = value.splitlines() + + +class Style: + """Representation of a style.""" + + def __init__(self, text: typing.Union[str, typing.List[str]]): + """Initialize. + + :param: text: the style text + """ + self.lines = text.splitlines() if isinstance(text, str) else text + self.comments: typing.List[str] = [] + + @property + def text(self): + """Return the text of the style.""" + return '\n'.join(self.lines) diff --git a/webvtt/parsers.py b/webvtt/parsers.py deleted file mode 100644 index 3a978ca..0000000 --- a/webvtt/parsers.py +++ /dev/null @@ -1,293 +0,0 @@ -import re -import os -import codecs - -from .errors import MalformedFileError, MalformedCaptionError -from .structures import Block, Style, Caption - - -class TextBasedParser(object): - """ - Parser for plain text caption files. - This is a generic class, do not use directly. - """ - - TIMEFRAME_LINE_PATTERN = '' - PARSER_OPTIONS = {} - - def __init__(self, parse_options=None): - self.captions = [] - self.parse_options = parse_options or {} - - def read(self, file): - """Reads the captions file.""" - content = self._get_content_from_file(file_path=file) - self._validate(content) - self._parse(content) - - return self - - def read_from_buffer(self, buffer): - content = self._read_content_lines(buffer) - self._validate(content) - self._parse(content) - - return self - - def _get_content_from_file(self, file_path): - encoding = self._read_file_encoding(file_path) - with open(file_path, encoding=encoding) as f: - return self._read_content_lines(f) - - def _read_file_encoding(self, file_path): - first_bytes = min(32, os.path.getsize(file_path)) - with open(file_path, 'rb') as f: - raw = f.read(first_bytes) - - if raw.startswith(codecs.BOM_UTF8): - return 'utf-8-sig' - else: - return 'utf-8' - - def _read_content_lines(self, file_obj): - - lines = [line.rstrip('\n\r') for line in file_obj.readlines()] - - if not lines: - raise MalformedFileError('The file is empty.') - - return lines - - def _read_content(self, file): - return self._get_content_from_file(file_path=file) - - def _parse_timeframe_line(self, line): - """Parse timeframe line and return start and end timestamps.""" - tf = self._validate_timeframe_line(line) - if not tf: - raise MalformedCaptionError('Invalid time format') - - return tf.group(1), tf.group(2) - - def _validate_timeframe_line(self, line): - return re.match(self.TIMEFRAME_LINE_PATTERN, line) - - def _is_timeframe_line(self, line): - """ - This method returns True if the line contains the timeframes. - To be implemented by child classes. - """ - raise NotImplementedError - - def _validate(self, lines): - """ - Validates the format of the parsed file. - To be implemented by child classes. - """ - raise NotImplementedError - - def _should_skip_line(self, line, index, caption): - """ - This method returns True for a line that should be skipped. - Implement in child classes if needed. - """ - return False - - def _parse(self, lines): - self.captions = [] - c = None - - for index, line in enumerate(lines): - if self._is_timeframe_line(line): - try: - start, end = self._parse_timeframe_line(line) - except MalformedCaptionError as e: - raise MalformedCaptionError('{} in line {}'.format(e, index + 1)) - c = Caption(start, end) - elif self._should_skip_line(line, index, c): # allow child classes to skip lines based on the content - continue - elif line: - if c is None: - raise MalformedCaptionError( - 'Caption missing timeframe in line {}.'.format(index + 1)) - else: - c.add_line(line) - else: - if c is None: - continue - if not c.lines: - if self.PARSER_OPTIONS.get('ignore_empty_captions', False): - c = None - continue - raise MalformedCaptionError('Caption missing text in line {}.'.format(index + 1)) - - self.captions.append(c) - c = None - - if c is not None and c.lines: - self.captions.append(c) - - -class SRTParser(TextBasedParser): - """ - SRT parser. - """ - - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})') - - PARSER_OPTIONS = { - 'ignore_empty_captions': True - } - - def _validate(self, lines): - if len(lines) < 2 or lines[0] != '1' or not self._validate_timeframe_line(lines[1]): - raise MalformedFileError('The file does not have a valid format.') - - def _is_timeframe_line(self, line): - return '-->' in line - - def _should_skip_line(self, line, index, caption): - return caption is None and line.isdigit() - - -class WebVTTParser(TextBasedParser): - """ - WebVTT parser. - """ - - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})') - COMMENT_PATTERN = re.compile(r'NOTE(?:\s.+|$)') - STYLE_PATTERN = re.compile(r'STYLE[ \t]*$') - - def __init__(self): - super().__init__() - self.styles = [] - - def _compute_blocks(self, lines): - blocks = [] - - for index, line in enumerate(lines, start=1): - if line: - if not blocks: - blocks.append(Block(index)) - if not blocks[-1].lines: - if not line.strip(): - continue - blocks[-1].line_number = index - blocks[-1].lines.append(line) - else: - blocks.append(Block(index)) - - # filter out empty blocks and skip signature - return list(filter(lambda x: x.lines, blocks))[1:] - - def _parse_cue_block(self, block): - caption = Caption() - cue_timings = None - additional_blocks = None - - for line_number, line in enumerate(block.lines): - if self._is_cue_timings_line(line): - if cue_timings is None: - try: - cue_timings = self._parse_timeframe_line(line) - except MalformedCaptionError as e: - raise MalformedCaptionError( - '{} in line {}'.format(e, block.line_number + line_number)) - else: - additional_blocks = self._compute_blocks( - ['WEBVTT', ''] + block.lines[line_number:] - ) - break - elif line_number == 0: - caption.identifier = line - else: - caption.add_line(line) - - caption.start = cue_timings[0] - caption.end = cue_timings[1] - return caption, additional_blocks - - def _parse(self, lines): - self.captions = [] - blocks = self._compute_blocks(lines) - self._parse_blocks(blocks) - - def _is_empty(self, block): - is_empty = True - - for line in block.lines: - if line.strip() != "": - is_empty = False - - return is_empty - - def _parse_blocks(self, blocks): - for block in blocks: - # skip empty blocks - if self._is_empty(block): - continue - - if self._is_cue_block(block): - caption, additional_blocks = self._parse_cue_block(block) - self.captions.append(caption) - - if additional_blocks: - self._parse_blocks(additional_blocks) - - elif self._is_comment_block(block): - continue - elif self._is_style_block(block): - if self.captions: - raise MalformedFileError( - 'Style block defined after the first cue in line {}.' - .format(block.line_number)) - style = Style() - style.lines = block.lines[1:] - self.styles.append(style) - else: - if len(block.lines) == 1: - raise MalformedCaptionError( - 'Standalone cue identifier in line {}.'.format(block.line_number)) - else: - raise MalformedCaptionError( - 'Missing timing cue in line {}.'.format(block.line_number+1)) - - def _validate(self, lines): - if not re.match('WEBVTT', lines[0]): - raise MalformedFileError('The file does not have a valid format') - - def _is_cue_timings_line(self, line): - return '-->' in line - - def _is_cue_block(self, block): - """Returns True if it is a cue block - (one of the two first lines being a cue timing line)""" - return any(map(self._is_cue_timings_line, block.lines[:2])) - - def _is_comment_block(self, block): - """Returns True if it is a comment block""" - return re.match(self.COMMENT_PATTERN, block.lines[0]) - - def _is_style_block(self, block): - """Returns True if it is a style block""" - return re.match(self.STYLE_PATTERN, block.lines[0]) - - -class SBVParser(TextBasedParser): - """ - YouTube SBV parser. - """ - - TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})') - - PARSER_OPTIONS = { - 'ignore_empty_captions': True - } - - def _validate(self, lines): - if not self._validate_timeframe_line(lines[0]): - raise MalformedFileError('The file does not have a valid format') - - def _is_timeframe_line(self, line): - return self._validate_timeframe_line(line) diff --git a/webvtt/sbv.py b/webvtt/sbv.py new file mode 100644 index 0000000..be9c1bf --- /dev/null +++ b/webvtt/sbv.py @@ -0,0 +1,117 @@ +"""SBV format module.""" + +import typing +import re + +from . import utils +from .models import Caption +from .errors import MalformedFileError + + +class SBVCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d{1,2}:\d{1,2}:\d{1,2}.\d{3}),(\d{1,2}:\d{1,2}:\d{1,2}.\d{3})' + ) + + def __init__( + self, + start: str, + end: str, + payload: typing.Sequence[str] + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + lines[1].strip() + ) + + @classmethod + def from_lines( + cls, + lines: typing.Sequence[str] + ) -> 'SBVCueBlock': + """ + Create a `SBVCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `SBVCueBlock` instance + """ + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) + assert match is not None + + payload = lines[1:] + + return cls(match.group(1), match.group(2), payload) + + +def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse SBV captions from lines of text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + if not _is_valid_content(lines): + raise MalformedFileError('Invalid format') + + return _parse_captions(lines) + + +def _is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid SBV content. + + :param lines: lines of text + :returns: true for a valid SBV content + """ + if len(lines) < 2: + return False + + first_block = next(utils.iter_blocks_of_lines(lines)) + return bool(first_block and SBVCueBlock.is_valid(first_block)) + + +def _parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse captions from the text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + captions = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + if not SBVCueBlock.is_valid(block_lines): + continue + + cue_block = SBVCueBlock.from_lines(block_lines) + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) + + return captions diff --git a/webvtt/segmenter.py b/webvtt/segmenter.py index 9378ad6..a159fd8 100644 --- a/webvtt/segmenter.py +++ b/webvtt/segmenter.py @@ -1,110 +1,121 @@ +"""Segmenter module.""" + +import typing import os +import pathlib from math import ceil, floor -from .errors import InvalidCaptionsError -from .webvtt import WebVTT -from .structures import Caption +from .webvtt import WebVTT, Caption + +DEFAULT_MPEGTS = 900000 +DEFAULT_SECONDS = 10 # default number of seconds per segment + + +def segment( + webvtt_path: str, + output: str, + seconds: int = DEFAULT_SECONDS, + mpegts: int = DEFAULT_MPEGTS + ): + """ + Segment a WebVTT captions file. + + :param webvtt_path: the path to the file + :param output: the path to the destination folder + :param seconds: the number of seconds for each segment + :param mpegts: value for the MPEG-TS + """ + captions = WebVTT.read(webvtt_path).captions + + output_folder = pathlib.Path(output) + os.makedirs(output_folder, exist_ok=True) + + segments = slice_segments(captions, seconds) + write_segments(output_folder, segments, mpegts) + write_manifest(output_folder, segments, seconds) -MPEGTS = 900000 -SECONDS = 10 # default number of seconds per segment -__all__ = ['WebVTTSegmenter'] +def slice_segments( + captions: typing.Sequence[Caption], + seconds: int + ) -> typing.List[typing.List[Caption]]: + """ + Slice segments of captions based on seconds per segment. + :param captions: the captions + :param seconds: seconds per segment + :returns: list of lists of `Caption` objects + """ + total_segments = ( + 0 + if not captions else + int(ceil(captions[-1].end_in_seconds / seconds)) + ) + + segments: typing.List[typing.List[Caption]] = [ + [] for _ in range(total_segments) + ] + + for c in captions: + segment_index_start = floor(c.start_in_seconds / seconds) + segments[segment_index_start].append(c) + + # Also include a caption in other segments based on the end time. + segment_index_end = floor(c.end_in_seconds / seconds) + if segment_index_end > segment_index_start: + for i in range(segment_index_start + 1, segment_index_end + 1): + segments[i].append(c) + + return segments + + +def write_segments( + output_folder: pathlib.Path, + segments: typing.Iterable[typing.Iterable[Caption]], + mpegts: int + ): + """ + Write the segments to the output folder. -class WebVTTSegmenter(object): + :param output_folder: folder where the segment files will be stored + :param segments: the segments of `Caption` objects + :param mpegts: value for the MPEG-TS """ - Provides segmentation of WebVTT captions for HTTP Live Streaming (HLS). + for index, segment in enumerate(segments): + segment_file = output_folder / f'fileSequence{index}.webvtt' + + with open(segment_file, 'w', encoding='utf-8') as f: + f.write('WEBVTT\n') + f.write(f'X-TIMESTAMP-MAP=MPEGTS:{mpegts},' + 'LOCAL:00:00:00.000\n' + ) + + for caption in segment: + f.write('\n{} --> {}\n'.format(caption.start, caption.end)) + f.writelines(f'{line}\n' for line in caption.lines) + + +def write_manifest( + output_folder: pathlib.Path, + segments: typing.Iterable[typing.Iterable[Caption]], + seconds: int + ): + """ + Write the manifest in the output folder. + + :param output_folder: folder where the manifest will be stored + :param segments: the segments of `Caption` objects + :param seconds: the seconds per segment """ - def __init__(self): - self._total_segments = 0 - self._output_folder = '' - self._seconds = 0 - self._mpegts = 0 - self._segments = [] - - def _validate_webvtt(self, webvtt): - # Validates that the captions is a list and all the captions are instances of Caption. - if not isinstance(webvtt, WebVTT): - return False - for c in webvtt.captions: - if not isinstance(c, Caption): - return False - return True - - def _slice_segments(self, captions): - self._segments = [[] for _ in range(self.total_segments)] - - for c in captions: - segment_index_start = floor(c.start_in_seconds / self.seconds) - self.segments[segment_index_start].append(c) - - # Also include a caption in other segments based on the end time. - segment_index_end = floor(c.end_in_seconds / self.seconds) - if segment_index_end > segment_index_start: - for i in range(segment_index_start + 1, segment_index_end + 1): - self.segments[i].append(c) - - def _write_segments(self): - for index in range(self.total_segments): - segment_file = os.path.join(self._output_folder, 'fileSequence{}.webvtt'.format(index)) - - with open(segment_file, 'w', encoding='utf-8') as f: - f.write('WEBVTT\n') - f.write('X-TIMESTAMP-MAP=MPEGTS:{},LOCAL:00:00:00.000\n'.format(self._mpegts)) - - for caption in self.segments[index]: - f.write('\n{} --> {}\n'.format(caption.start, caption.end)) - f.writelines(['{}\n'.format(l) for l in caption.lines]) - - def _write_manifest(self): - manifest_file = os.path.join(self._output_folder, 'prog_index.m3u8') - with open(manifest_file, 'w', encoding='utf-8') as f: - f.write('#EXTM3U\n') - f.write('#EXT-X-TARGETDURATION:{}\n'.format(self.seconds)) - f.write('#EXT-X-VERSION:3\n') - f.write('#EXT-X-PLAYLIST-TYPE:VOD\n') - - for i in range(self.total_segments): - f.write('#EXTINF:30.00000\n') - f.write('fileSequence{}.webvtt\n'.format(i)) - - f.write('#EXT-X-ENDLIST\n') - - def segment(self, webvtt, output='', seconds=SECONDS, mpegts=MPEGTS): - """Segments the captions based on a number of seconds.""" - if isinstance(webvtt, str): - # if a string is supplied we parse the file - captions = WebVTT().read(webvtt).captions - elif not self._validate_webvtt(webvtt): - raise InvalidCaptionsError('The captions provided are invalid') - else: - # we expect to have a webvtt object - captions = webvtt.captions - - self._total_segments = 0 if not captions else int(ceil(captions[-1].end_in_seconds / seconds)) - self._output_folder = output - self._seconds = seconds - self._mpegts = mpegts - - output_folder = os.path.join(os.getcwd(), output) - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - self._slice_segments(captions) - self._write_segments() - self._write_manifest() - - @property - def seconds(self): - """Returns the number of seconds used for segmenting captions.""" - return self._seconds - - @property - def total_segments(self): - """Returns the total of segments.""" - return self._total_segments - - @property - def segments(self): - """Return the list of segments.""" - return self._segments + manifest_file = output_folder / 'prog_index.m3u8' + with open(manifest_file, 'w', encoding='utf-8') as f: + f.write('#EXTM3U\n') + f.write(f'#EXT-X-TARGETDURATION:{seconds}\n') + f.write('#EXT-X-VERSION:3\n') + f.write('#EXT-X-PLAYLIST-TYPE:VOD\n') + + for index, _ in enumerate(segments): + f.write('#EXTINF:30.00000\n') + f.write(f'fileSequence{index}.webvtt\n') + + f.write('#EXT-X-ENDLIST\n') diff --git a/webvtt/srt.py b/webvtt/srt.py new file mode 100644 index 0000000..0e0637d --- /dev/null +++ b/webvtt/srt.py @@ -0,0 +1,148 @@ +"""SRT format module.""" + +import typing +import re + +from .models import Caption +from .errors import MalformedFileError +from . import utils + + +class SRTCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})' + ) + + def __init__( + self, + index: str, + start: str, + end: str, + payload: typing.Sequence[str] + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.index = index + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Sequence[str] + ) -> 'SRTCueBlock': + """ + Create a `SRTCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `SRTCueBlock` instance + """ + index = lines[0] + + match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) + assert match is not None + + payload = lines[2:] + + return cls(index, match.group(1), match.group(2), payload) + + +def parse(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse SRT captions from lines of text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + if not is_valid_content(lines): + raise MalformedFileError('Invalid format') + + return parse_captions(lines) + + +def is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid SBV content. + + :param lines: lines of text + :returns: true for a valid SBV content + """ + return bool( + len(lines) >= 3 and + lines[0].isdigit() and + '-->' in lines[1] and + lines[2].strip() + ) + + +def parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]: + """ + Parse captions from the text. + + :param lines: lines of text + :returns: list of `Caption` objects + """ + captions: typing.List[Caption] = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + if not SRTCueBlock.is_valid(block_lines): + continue + + cue_block = SRTCueBlock.from_lines(block_lines) + cue_block.start, cue_block.end = map( + lambda x: x.replace(',', '.'), (cue_block.start, cue_block.end)) + + captions.append(Caption(cue_block.start, + cue_block.end, + cue_block.payload + )) + + return captions + + +def write( + f: typing.IO[str], + captions: typing.Iterable[Caption] + ): + """ + Write captions to an output. + + :param f: file or file-like object + :param captions: Iterable of `Caption` objects + """ + output = [] + for index, caption in enumerate(captions, start=1): + output.extend([ + f'{index}', + '{} --> {}'.format(*map(lambda x: x.replace('.', ','), + (caption.start, caption.end)) + ), + *caption.lines, + '' + ]) + f.write('\n'.join(output).rstrip()) diff --git a/webvtt/structures.py b/webvtt/structures.py deleted file mode 100644 index c4a576d..0000000 --- a/webvtt/structures.py +++ /dev/null @@ -1,135 +0,0 @@ -import re - -from .errors import MalformedCaptionError - -TIMESTAMP_PATTERN = re.compile(r'(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})') - -__all__ = ['Caption'] - - -class Caption(object): - - CUE_TEXT_TAGS = re.compile('<.*?>') - - """ - Represents a caption. - """ - def __init__(self, start='00:00:00.000', end='00:00:00.000', text=None): - self.start = start - self.end = end - self.identifier = None - - # If lines is a string convert to a list - if text and isinstance(text, str): - text = text.splitlines() - - self._lines = text or [] - - def __repr__(self): - return '<%(cls)s start=%(start)s end=%(end)s text=%(text)s>' % { - 'cls': self.__class__.__name__, - 'start': self.start, - 'end': self.end, - 'text': self.text.replace('\n', '\\n') - } - - def __str__(self): - return '%(start)s %(end)s %(text)s' % { - 'start': self.start, - 'end': self.end, - 'text': self.text.replace('\n', '\\n') - } - - def add_line(self, line): - self.lines.append(line) - - def _to_seconds(self, hours, minutes, seconds, milliseconds): - return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 - - def _parse_timestamp(self, timestamp): - res = re.match(TIMESTAMP_PATTERN, timestamp) - if not res: - raise MalformedCaptionError('Invalid timestamp: {}'.format(timestamp)) - - values = list(map(lambda x: int(x) if x else 0, res.groups())) - return self._to_seconds(*values) - - def _to_timestamp(self, total_seconds): - hours = int(total_seconds / 3600) - minutes = int(total_seconds / 60 - hours * 60) - seconds = total_seconds - hours * 3600 - minutes * 60 - return '{:02d}:{:02d}:{:06.3f}'.format(hours, minutes, seconds) - - def _clean_cue_tags(self, text): - return re.sub(self.CUE_TEXT_TAGS, '', text) - - @property - def start_in_seconds(self): - return self._start - - @property - def end_in_seconds(self): - return self._end - - @property - def start(self): - return self._to_timestamp(self._start) - - @start.setter - def start(self, value): - self._start = self._parse_timestamp(value) - - @property - def end(self): - return self._to_timestamp(self._end) - - @end.setter - def end(self, value): - self._end = self._parse_timestamp(value) - - @property - def lines(self): - return self._lines - - @property - def text(self): - """Returns the captions lines as a text (without cue tags)""" - return self._clean_cue_tags(self.raw_text) - - @property - def raw_text(self): - """Returns the captions lines as a text (may include cue tags)""" - return '\n'.join(self.lines) - - @text.setter - def text(self, value): - if not isinstance(value, str): - raise AttributeError('String value expected but received {}.'.format(type(value))) - - self._lines = value.splitlines() - - -class GenericBlock(object): - """Generic class that defines a data structure holding an array of lines""" - def __init__(self): - self.lines = [] - - -class Block(GenericBlock): - def __init__(self, line_number): - super().__init__() - self.line_number = line_number - - -class Style(GenericBlock): - - @property - def text(self): - """Returns the style lines as a text""" - return ''.join(map(lambda x: x.strip(), self.lines)) - - @text.setter - def text(self, value): - if type(value) != str: - raise TypeError('The text value must be a string.') - self.lines = value.split('\n') diff --git a/webvtt/utils.py b/webvtt/utils.py new file mode 100644 index 0000000..28da8c8 --- /dev/null +++ b/webvtt/utils.py @@ -0,0 +1,104 @@ +"""Utils module.""" + +import typing +import codecs + +CODEC_BOMS = { + 'utf-8': codecs.BOM_UTF8, + 'utf-32-le': codecs.BOM_UTF32_LE, + 'utf-32-be': codecs.BOM_UTF32_BE, + 'utf-16-le': codecs.BOM_UTF16_LE, + 'utf-16-be': codecs.BOM_UTF16_BE +} + + +class FileWrapper: + """File handling functionality with built-in support for Byte OrderMark.""" + + def __init__( + self, + file_path: str, + mode: typing.Optional[str] = None, + encoding: typing.Optional[str] = None + ): + """ + Initialize. + + :param file_path: path to the file + :param mode: mode in which the file is opened + :param encoding: name of the encoding used to decode the file + """ + self.file_path = file_path + self.mode = mode or 'r' + self.bom_encoding = self.detect_bom_encoding(file_path) + self.encoding = (self.bom_encoding or + encoding or + 'utf-8' + ) + + @classmethod + def open( + cls, + file_path: str, + mode: typing.Optional[str] = None, + encoding: typing.Optional[str] = None + ) -> 'FileWrapper': + """ + Open a file. + + :param file_path: path to the file + :param mode: mode in which the file is opened + :param encoding: name of the encoding used to decode the file + """ + return cls(file_path, mode, encoding) + + def __enter__(self): + """Enter context.""" + self.file = open( + file=self.file_path, + mode=self.mode, + encoding=self.encoding + ) + if self.bom_encoding: + self.file.seek(len(CODEC_BOMS[self.bom_encoding])) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context.""" + self.file.close() + + @staticmethod + def detect_bom_encoding(file_path: str) -> typing.Optional[str]: + """ + Detect the encoding of a file based on the presence of the BOM. + + :param file_path: path to the file + :returns: the encoding if BOM is found or None. + """ + with open(file_path, mode='rb') as f: + first_bytes = f.read(4) + for encoding, bom in CODEC_BOMS.items(): + if first_bytes.startswith(bom): + return encoding + return None + + +def iter_blocks_of_lines( + lines: typing.Iterable[str] + ) -> typing.Generator[typing.List[str], None, None]: + """ + Iterate blocks of text. + + :param lines: lines of text. + """ + current_text_block = [] + + for line in lines: + if line.strip(): + current_text_block.append(line) + elif current_text_block: + yield current_text_block + current_text_block = [] + + if current_text_block: + yield current_text_block diff --git a/webvtt/vtt.py b/webvtt/vtt.py new file mode 100644 index 0000000..459ac3e --- /dev/null +++ b/webvtt/vtt.py @@ -0,0 +1,400 @@ +"""VTT format module.""" + +import re +import typing +from dataclasses import dataclass + +from .errors import MalformedFileError +from .models import Caption, Style +from . import utils + + +@dataclass +class ParserOutput: + """Output of parser.""" + + styles: typing.List[Style] + captions: typing.List[Caption] + header_comments: typing.List[str] + footer_comments: typing.List[str] + + @classmethod + def from_data( + cls, + data: typing.Mapping[str, typing.Any] + ) -> 'ParserOutput': + """ + Return a `ParserOutput` instance from the provided data. + + :param data: data from the parser + :returns: an instance of `ParserOutput` + """ + items = data.get('items', []) + return cls( + captions=[it for it in items if isinstance(it, Caption)], + styles=[it for it in items if isinstance(it, Style)], + header_comments=data.get('header_comments', []), + footer_comments=data.get('footer_comments', []) + ) + + +class WebVTTCueBlock: + """Representation of a cue timing block.""" + + CUE_TIMINGS_PATTERN = re.compile( + r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})' + ) + + def __init__( + self, + identifier, + start, + end, + payload + ): + """ + Initialize. + + :param start: start time + :param end: end time + :param payload: caption text + """ + self.identifier = identifier + self.start = start + self.end = end + self.payload = payload + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a cue time block. + + :param lines: the lines to be validated + :returns: true for a matching cue time block + """ + return bool( + ( + len(lines) >= 2 and + re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and + "-->" not in lines[1] + ) or + ( + len(lines) >= 3 and + "-->" not in lines[0] and + re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) and + "-->" not in lines[2] + ) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTCueBlock': + """ + Create a `WebVTTCueBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTCueBlock` instance + """ + identifier = None + start = None + end = None + payload = [] + + for line in lines: + timing_match = re.match(cls.CUE_TIMINGS_PATTERN, line) + if timing_match: + start = timing_match.group(1) + end = timing_match.group(2) + elif not start: + identifier = line + else: + payload.append(line) + + return cls(identifier, start, end, payload) + + @staticmethod + def format_lines(caption: Caption) -> typing.List[str]: + """ + Return the lines for a cue block. + + :param caption: the `Caption` instance + :returns: list of lines for a cue block + """ + return [ + '', + *(identifier for identifier in {caption.identifier} if identifier), + f'{caption.start} --> {caption.end}', + *caption.lines + ] + + +class WebVTTCommentBlock: + """Representation of a comment block.""" + + COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text: str): + """ + Initialize. + + :param text: comment text + """ + self.text = text + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a comment block. + + :param lines: the lines to be validated + :returns: true for a matching comment block + """ + return bool(lines and lines[0].startswith('NOTE')) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTCommentBlock': + """ + Create a `WebVTTCommentBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTCommentBlock` instance + """ + match = cls.COMMENT_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + @staticmethod + def format_lines(lines: str) -> typing.List[str]: + """ + Return the lines for a comment block. + + :param lines: comment lines + :returns: list of lines for a comment block + """ + list_of_lines = lines.split('\n') + + if len(list_of_lines) == 1: + return [f'NOTE {lines}'] + + return ['NOTE', *list_of_lines] + + +class WebVTTStyleBlock: + """Representation of a style block.""" + + STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL) + + def __init__(self, text: str): + """ + Initialize. + + :param text: style text + """ + self.text = text + + @classmethod + def is_valid( + cls, + lines: typing.Sequence[str] + ) -> bool: + """ + Validate the lines for a match of a style block. + + :param lines: the lines to be validated + :returns: true for a matching style block + """ + return (len(lines) >= 2 and + lines[0] == 'STYLE' and + not any(line.strip() == '' or '-->' in line for line in lines) + ) + + @classmethod + def from_lines( + cls, + lines: typing.Iterable[str] + ) -> 'WebVTTStyleBlock': + """ + Create a `WebVTTStyleBlock` from lines of text. + + :param lines: the lines of text + :returns: `WebVTTStyleBlock` instance + """ + match = cls.STYLE_PATTERN.match('\n'.join(lines)) + return cls(text=match.group(1).strip() if match else '') + + @staticmethod + def format_lines(lines: typing.List[str]) -> typing.List[str]: + """ + Return the lines for a style block. + + :param lines: style lines + :returns: list of lines for a style block + """ + return ['STYLE', *lines] + + +def parse( + lines: typing.Sequence[str] + ) -> ParserOutput: + """ + Parse VTT captions from lines of text. + + :param lines: lines of text + :returns: object `ParserOutput` with all parsed items + """ + if not is_valid_content(lines): + raise MalformedFileError('Invalid format') + + return parse_items(lines) + + +def is_valid_content(lines: typing.Sequence[str]) -> bool: + """ + Validate lines of text for valid VTT content. + + :param lines: lines of text + :returns: true for a valid VTT content + """ + return bool(lines and lines[0].startswith('WEBVTT')) + + +def parse_items( + lines: typing.Sequence[str] + ) -> ParserOutput: + """ + Parse items from the text. + + :param lines: lines of text + :returns: an object `ParserOutput` with all parsed items + """ + header_comments: typing.List[str] = [] + items: typing.List[typing.Union[Caption, Style]] = [] + comments: typing.List[WebVTTCommentBlock] = [] + + for block_lines in utils.iter_blocks_of_lines(lines): + item = parse_item(block_lines) + if item: + item.comments = [comment.text for comment in comments] + comments = [] + items.append(item) + elif WebVTTCommentBlock.is_valid(block_lines): + comments.append(WebVTTCommentBlock.from_lines(block_lines)) + + if items: + header_comments, items[0].comments = items[0].comments, header_comments + + return ParserOutput.from_data( + {'items': items, + 'header_comments': header_comments, + 'footer_comments': [comment.text for comment in comments] + } + ) + + +def parse_item( + lines: typing.Sequence[str] + ) -> typing.Union[Caption, Style, None]: + """ + Parse an item from lines of text. + + :param lines: lines of text + :returns: An item (Caption or Style) if found, otherwise None + """ + if WebVTTCueBlock.is_valid(lines): + cue_block = WebVTTCueBlock.from_lines(lines) + return Caption(cue_block.start, + cue_block.end, + cue_block.payload, + cue_block.identifier + ) + + if WebVTTStyleBlock.is_valid(lines): + return Style(WebVTTStyleBlock.from_lines(lines).text) + + return None + + +def write( + f: typing.IO[str], + captions: typing.Iterable[Caption], + styles: typing.Iterable[Style], + header_comments: typing.Iterable[str], + footer_comments: typing.Iterable[str] + ): + """ + Write captions to an output. + + :param f: file or file-like object + :param captions: Iterable of `Caption` objects + :param styles: Iterable of `Style` objects + :param header_comments: the comments for the header + :param footer_comments: the comments for the footer + """ + f.write( + to_str(captions, + styles, + header_comments, + footer_comments + ) + ) + + +def to_str( + captions: typing.Iterable[Caption], + styles: typing.Iterable[Style], + header_comments: typing.Iterable[str], + footer_comments: typing.Iterable[str] + ) -> str: + """ + Convert captions to a string with webvtt format. + + :param captions: the iterable of `Caption` objects + :param styles: the iterable of `Style` objects + :param header_comments: the comments for the header + :param footer_comments: the comments for the footer + :returns: String of the content in WebVTT format. + """ + output = ['WEBVTT'] + + for comment in header_comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + + for style in styles: + for comment in style.comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + output.extend([ + '', + *WebVTTStyleBlock.format_lines(style.lines) + ]) + + for caption in captions: + for comment in caption.comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + output.extend(WebVTTCueBlock.format_lines(caption)) + + for comment in footer_comments: + output.extend([ + '', + *WebVTTCommentBlock.format_lines(comment) + ]) + + return '\n'.join(output) diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index adec7c9..499b6df 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -1,143 +1,363 @@ +"""WebVTT module.""" + import os +import typing +import warnings -from .parsers import WebVTTParser, SRTParser, SBVParser -from .writers import WebVTTWriter, SRTWriter +from . import vtt, utils +from . import srt +from . import sbv +from .models import Caption, Style, Timestamp from .errors import MissingFilenameError -__all__ = ['WebVTT'] +DEFAULT_ENCODING = 'utf-8' -class WebVTT(object): +class WebVTT: """ Parse captions in WebVTT format and also from other formats like SRT. To read WebVTT: - WebVTT().read('captions.vtt') - - For other formats like SRT, use from_[format in lower case]: + WebVTT.read('captions.vtt') - WebVTT().from_srt('captions.srt') + For other formats: - A list of all supported formats is available calling list_formats(). + WebVTT.from_srt('captions.srt') + WebVTT.from_sbv('captions.sbv') """ - def __init__(self, file='', captions=None, styles=None): + def __init__( + self, + file: typing.Optional[str] = None, + captions: typing.Optional[typing.List[Caption]] = None, + styles: typing.Optional[typing.List[Style]] = None, + header_comments: typing.Optional[typing.List[str]] = None, + footer_comments: typing.Optional[typing.List[str]] = None, + ): + """ + Initialize. + + :param file: the path of the WebVTT file + :param captions: the list of captions + :param styles: the list of styles + :param header_comments: list of comments for the start of the file + :param footer_comments: list of comments for the bottom of the file + """ self.file = file - self._captions = captions or [] - self._styles = styles + self.captions = captions or [] + self.styles = styles or [] + self.header_comments = header_comments or [] + self.footer_comments = footer_comments or [] + self._has_bom = False + self.encoding = DEFAULT_ENCODING def __len__(self): - return len(self._captions) + """Return the number of captions.""" + return len(self.captions) def __getitem__(self, index): - return self._captions[index] + """Return a caption by index.""" + return self.captions[index] def __repr__(self): - return '<%(cls)s file=%(file)s>' % { - 'cls': self.__class__.__name__, - 'file': self.file - } + """Return the string representation of the WebVTT file.""" + return (f'<{self.__class__.__name__} file={self.file!r} ' + f'encoding={self.encoding!r}>' + ) def __str__(self): - return '\n'.join([str(c) for c in self._captions]) + """Return a readable representation of the WebVTT content.""" + return '\n'.join(str(c) for c in self.captions) @classmethod - def from_srt(cls, file): - """Reads captions from a file in SubRip format.""" - parser = SRTParser().read(file) - return cls(file=file, captions=parser.captions) + def read( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """ + Read a WebVTT captions file. - @classmethod - def from_sbv(cls, file): - """Reads captions from a file in YouTube SBV format.""" - parser = SBVParser().read(file) - return cls(file=file, captions=parser.captions) + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + instance = cls.from_buffer(fw.file) + if fw.bom_encoding: + instance.encoding = fw.bom_encoding + instance._has_bom = True + return instance @classmethod - def read(cls, file): - """Reads a WebVTT captions file.""" - parser = WebVTTParser().read(file) - return cls(file=file, captions=parser.captions, styles=parser.styles) + def read_buffer( + cls, + buffer: typing.Iterator[str] + ) -> 'WebVTT': + """ + Read WebVTT captions from a file-like object. + + This method is DEPRECATED. Use from_buffer instead. + + Such file-like object may be the return of an io.open call, + io.StringIO object, tempfile.TemporaryFile object, etc. + + :param buffer: the file-like object to read captions from + :returns: a `WebVTT` instance + """ + warnings.warn( + 'Deprecated: use from_buffer instead.', + DeprecationWarning + ) + return cls.from_buffer(buffer) @classmethod - def read_buffer(cls, buffer): - """Reads a WebVTT captions from a file-like object. + def from_buffer( + cls, + buffer: typing.Iterator[str] + ) -> 'WebVTT': + """ + Read WebVTT captions from a file-like object. + Such file-like object may be the return of an io.open call, - io.StringIO object, tempfile.TemporaryFile object, etc.""" - parser = WebVTTParser().read_from_buffer(buffer) - return cls(captions=parser.captions, styles=parser.styles) + io.StringIO object, tempfile.TemporaryFile object, etc. + + :param buffer: the file-like object to read captions from + :returns: a `WebVTT` instance + """ + output = vtt.parse(cls._get_lines(buffer)) + + return cls( + file=getattr(buffer, 'name', None), + captions=output.captions, + styles=output.styles, + header_comments=output.header_comments, + footer_comments=output.footer_comments + ) - def _get_output_file(self, output, extension='vtt'): - if not output: + @classmethod + def from_srt( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """ + Read captions from a file in SubRip format. + + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + return cls( + file=fw.file.name, + captions=srt.parse(cls._get_lines(fw.file)) + ) + + @classmethod + def from_sbv( + cls, + file: str, + encoding: typing.Optional[str] = None + ) -> 'WebVTT': + """ + Read captions from a file in YouTube SBV format. + + :param file: the file path + :param encoding: encoding of the file + :returns: a `WebVTT` instance + """ + with utils.FileWrapper.open(file, encoding=encoding) as fw: + return cls( + file=fw.file.name, + captions=sbv.parse(cls._get_lines(fw.file)), + ) + + @classmethod + def from_string(cls, string: str) -> 'WebVTT': + """ + Read captions from a string. + + :param string: the captions in a string + :returns: a `WebVTT` instance + """ + output = vtt.parse(cls._get_lines(string.splitlines())) + return cls( + captions=output.captions, + styles=output.styles, + header_comments=output.header_comments, + footer_comments=output.footer_comments + ) + + @staticmethod + def _get_lines(lines: typing.Iterable[str]) -> typing.List[str]: + """ + Return cleaned lines from an iterable of lines. + + :param lines: iterable of lines + :returns: a list of cleaned lines + """ + return [line.rstrip('\n\r') for line in lines] + + def _get_destination_file( + self, + destination_path: typing.Optional[str] = None, + extension: str = 'vtt' + ) -> str: + """ + Return the destination file based on the provided params. + + :param destination_path: optional destination path + :param extension: the extension of the file + :returns: the destination file + + :raises MissingFilenameError: if destination path cannot be determined + """ + if not destination_path and not self.file: + raise MissingFilenameError + + if not destination_path and self.file: + destination_path = ( + f'{os.path.splitext(self.file)[0]}.{extension}' + ) + + assert destination_path is not None + + target = os.path.join(os.getcwd(), destination_path) + if os.path.isdir(target): if not self.file: raise MissingFilenameError - # saving an original vtt file will overwrite the file - # and for files read from other formats will save as vtt - # with the same name and location - return os.path.splitext(self.file)[0] + '.' + extension - else: - target = os.path.join(os.getcwd(), output) - if os.path.isdir(target): - # if an output is provided and it is a directory - # the file will be saved in that location with the same name - filename = os.path.splitext(os.path.basename(self.file))[0] - return os.path.join(target, '{}.{}'.format(filename, extension)) - else: - if target[-3:].lower() != extension: - target += '.' + extension - # otherwise the file will be written in the specified location - return target - - def save(self, output=''): - """Save the document. - If no output is provided the file will be saved in the same location. Otherwise output - can determine a target directory or file. - """ - self.file = self._get_output_file(output) - with open(self.file, 'w', encoding='utf-8') as f: - self.write(f) - - def save_as_srt(self, output=''): - self.file = self._get_output_file(output, extension='srt') - with open(self.file, 'w', encoding='utf-8') as f: - self.write(f, format='srt') - - def write(self, f, format='vtt'): + + # store the file in specified directory + base_name = os.path.splitext(os.path.basename(self.file))[0] + new_filename = f'{base_name}.{extension}' + return os.path.join(target, new_filename) + + if target[-4:].lower() != f'.{extension}': + target = f'{target}.{extension}' + + # store the file in the specified full path + return target + + def save( + self, + output: typing.Optional[str] = None, + encoding: typing.Optional[str] = None, + add_bom: typing.Optional[bool] = None + ): + """ + Save the WebVTT captions to a file. + + :param output: destination path of the file + :param encoding: encoding of the file + :param add_bom: save the file with Byte Order Mark + + :raises MissingFilenameError: if output cannot be determined + """ + self.file = self._get_destination_file(output) + encoding = encoding or self.encoding + + if add_bom is None and self._has_bom: + add_bom = True + + with open(self.file, 'w', encoding=encoding) as f: + if add_bom and encoding in utils.CODEC_BOMS: + f.write(utils.CODEC_BOMS[encoding].decode(encoding)) + + vtt.write( + f, + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) + + def save_as_srt( + self, + output: typing.Optional[str] = None, + encoding: typing.Optional[str] = DEFAULT_ENCODING + ): + """ + Save the WebVTT captions to a file in SubRip format. + + :param output: destination path of the file + :param encoding: encoding of the file + + :raises MissingFilenameError: if output cannot be determined + """ + self.file = self._get_destination_file(output, extension='srt') + with open(self.file, 'w', encoding=encoding) as f: + srt.write(f, self.captions) + + def write( + self, + f: typing.IO[str], + format: str = 'vtt' + ): + """ + Save the WebVTT captions to a file-like object. + + :param f: destination file-like object + :param format: the format to use (`vtt` or `srt`) + + :raises MissingFilenameError: if output cannot be determined + """ if format == 'vtt': - WebVTTWriter().write(self._captions, f) - elif format == 'srt': - SRTWriter().write(self._captions, f) -# elif output_format == OutputFormat.SBV: -# SBVWriter().write(self._captions, f) + return vtt.write(f, + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) + if format == 'srt': + return srt.write(f, self.captions) - @staticmethod - def list_formats(): - """Provides a list of supported formats that this class can read from.""" - return ('WebVTT (.vtt)', 'SubRip (.srt)', 'YouTube SBV (.sbv)') + raise ValueError(f'Format {format} is not supported.') - @property - def captions(self): - """Returns the list of captions.""" - return self._captions + def iter_slice( + self, + start: typing.Optional[str] = None, + end: typing.Optional[str] = None + ) -> typing.Generator[Caption, None, None]: + """ + Iterate a slice of the captions based on a time range. + + :param start: start timestamp of the range + :param end: end timestamp of the range + :returns: generator of Captions + """ + start_time = Timestamp.from_string(start) if start else None + end_time = Timestamp.from_string(end) if end else None + + for caption in self.captions: + if ( + (not start_time or caption.start_time >= start_time) and + (not end_time or caption.end_time <= end_time) + ): + yield caption @property def total_length(self): """Returns the total length of the captions.""" - if not self._captions: + if not self.captions: return 0 - return int(self._captions[-1].end_in_seconds) - int(self._captions[0].start_in_seconds) - - @property - def styles(self): - return self._styles + return ( + self.captions[-1].end_in_seconds - + self.captions[0].start_in_seconds + ) @property - def content(self): + def content(self) -> str: """ - Return webvtt content with webvtt formatting. + Return the webvtt capions as string. This property is useful in cases where the webvtt content is needed - but no file saving on the system is required. + but no file-like destination is required. Storage in DB for instance. """ - return WebVTTWriter().webvtt_content(self._captions) + return vtt.to_str( + self.captions, + self.styles, + self.header_comments, + self.footer_comments + ) diff --git a/webvtt/writers.py b/webvtt/writers.py deleted file mode 100644 index 5ec551b..0000000 --- a/webvtt/writers.py +++ /dev/null @@ -1,41 +0,0 @@ - -class WebVTTWriter(object): - - def write(self, captions, f): - f.write(self.webvtt_content(captions)) - - def webvtt_content(self, captions): - """ - Return captions content with webvtt formatting. - """ - output = ["WEBVTT"] - for caption in captions: - output.append("") - if caption.identifier: - output.append(caption.identifier) - output.append('{} --> {}'.format(caption.start, caption.end)) - output.extend(caption.lines) - return '\n'.join(output) - - -class SRTWriter(object): - - def write(self, captions, f): - for line_number, caption in enumerate(captions, start=1): - f.write('{}\n'.format(line_number)) - f.write('{} --> {}\n'.format(self._to_srt_timestamp(caption.start_in_seconds), - self._to_srt_timestamp(caption.end_in_seconds))) - f.writelines(['{}\n'.format(l) for l in caption.lines]) - f.write('\n') - - def _to_srt_timestamp(self, total_seconds): - hours = int(total_seconds / 3600) - minutes = int(total_seconds / 60 - hours * 60) - seconds = int(total_seconds - hours * 3600 - minutes * 60) - milliseconds = round((total_seconds - seconds - hours * 3600 - minutes * 60)*1000) - - return '{:02d}:{:02d}:{:02d},{:03d}'.format(hours, minutes, seconds, milliseconds) - - -class SBVWriter(object): - pass