diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..54e5841
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,32 @@
+name: CI
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
+
+ steps:
+ - name: Checkout Repository
+ uses: actions/checkout@v2
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install tox
+ run: pip install tox
+
+ - name: Run tox
+ run: tox
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 9854868..0000000
--- a/.travis.yml
+++ /dev/null
@@ -1,12 +0,0 @@
-language: python
-python:
- - 3.9
- - 3.8
- - 3.7
- - 3.6
- - 3.5
- - 3.4
-
-install: pip install -U tox-travis
-
-script: tox
diff --git a/LICENSE b/LICENSE
index 3a702d9..16ad61a 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,3 +1,5 @@
+MIT License
+
Copyright (c) 2016 Alejandro Mendez
Permission is hereby granted, free of charge, to any person obtaining a copy
diff --git a/setup.cfg b/setup.cfg
index 5aef279..12ccc6a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,2 +1,43 @@
[metadata]
-description-file = README.rst
+description_file = README.rst
+
+[flake8]
+doctests = true
+radon-max-cc=10
+
+[tox:tox]
+envlist =
+ codestyle
+ py
+ coverage
+isolated_build = True
+
+[coverage:run]
+source = webvtt
+branch = true
+
+[coverage:report]
+fail_under = 100
+
+[testenv]
+deps =
+ coverage
+description = run the tests and provide coverage metrics
+commands =
+ coverage run -m unittest discover
+
+[testenv:codestyle]
+deps =
+ flake8
+ flake8-docstrings
+ radon
+ mypy
+commands =
+ flake8 webvtt setup.py
+ mypy webvtt
+
+[testenv:coverage]
+deps = coverage
+commands =
+ coverage html --fail-under=0
+ coverage report
diff --git a/setup.py b/setup.py
index 92384b3..6bae26d 100644
--- a/setup.py
+++ b/setup.py
@@ -1,43 +1,40 @@
-import io
-import re
-from setuptools import setup, find_packages
+"""webvtt-py setuptools configuration."""
-with io.open('README.rst', 'r', encoding='utf-8') as f:
- readme = f.read()
+import re
+import pathlib
-with io.open('webvtt/__init__.py', 'rt', encoding='utf-8') as f:
- version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)
+from setuptools import setup, find_packages
+version = (
+ re.search(
+ r'__version__ = \'(.*?)\'',
+ pathlib.Path('webvtt/__init__.py').read_text()
+ ).group(1)
+ )
setup(
name='webvtt-py',
version=version,
description='WebVTT reader, writer and segmenter',
- long_description=readme,
+ long_description=pathlib.Path('README.rst').read_text(),
author='Alejandro Mendez',
author_email='amendez23@gmail.com',
url='https://github.com/glut23/webvtt-py',
packages=find_packages('.', exclude=['tests']),
include_package_data=True,
- install_requires=[
- 'docopt'
- ],
entry_points={
'console_scripts': [
'webvtt=webvtt.cli:main'
]
},
license='MIT',
- python_requires='>=3.4',
+ python_requires='>=3.7',
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
diff --git a/tests/generic.py b/tests/generic.py
deleted file mode 100644
index 0f84ae3..0000000
--- a/tests/generic.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import os
-import unittest
-
-
-class GenericParserTestCase(unittest.TestCase):
-
- SUBTITLES_DIR = os.path.join(os.path.dirname(__file__), 'subtitles')
-
- def _get_file(self, filename):
- return os.path.join(self.SUBTITLES_DIR, filename)
diff --git a/tests/subtitles/captions_with_bom.vtt b/tests/samples/captions_with_bom.vtt
similarity index 100%
rename from tests/subtitles/captions_with_bom.vtt
rename to tests/samples/captions_with_bom.vtt
diff --git a/tests/subtitles/comments.vtt b/tests/samples/comments.vtt
similarity index 92%
rename from tests/subtitles/comments.vtt
rename to tests/samples/comments.vtt
index 3847e02..8cd3032 100644
--- a/tests/subtitles/comments.vtt
+++ b/tests/samples/comments.vtt
@@ -18,4 +18,6 @@ NOTE This last line may not translate well.
3
00:02:25.000 --> 00:02:30.000
-- Ta en kopp
\ No newline at end of file
+- Ta en kopp
+
+NOTE end of file
\ No newline at end of file
diff --git a/tests/subtitles/cue_tags.vtt b/tests/samples/cue_tags.vtt
similarity index 100%
rename from tests/subtitles/cue_tags.vtt
rename to tests/samples/cue_tags.vtt
diff --git a/tests/subtitles/empty.vtt b/tests/samples/empty.vtt
similarity index 100%
rename from tests/subtitles/empty.vtt
rename to tests/samples/empty.vtt
diff --git a/tests/subtitles/invalid.vtt b/tests/samples/invalid.vtt
similarity index 100%
rename from tests/subtitles/invalid.vtt
rename to tests/samples/invalid.vtt
diff --git a/tests/subtitles/invalid_format.sbv b/tests/samples/invalid_format.sbv
similarity index 100%
rename from tests/subtitles/invalid_format.sbv
rename to tests/samples/invalid_format.sbv
diff --git a/tests/subtitles/invalid_format1.srt b/tests/samples/invalid_format1.srt
similarity index 100%
rename from tests/subtitles/invalid_format1.srt
rename to tests/samples/invalid_format1.srt
diff --git a/tests/subtitles/invalid_format2.srt b/tests/samples/invalid_format2.srt
similarity index 100%
rename from tests/subtitles/invalid_format2.srt
rename to tests/samples/invalid_format2.srt
diff --git a/tests/subtitles/invalid_format3.srt b/tests/samples/invalid_format3.srt
similarity index 100%
rename from tests/subtitles/invalid_format3.srt
rename to tests/samples/invalid_format3.srt
diff --git a/tests/subtitles/invalid_format4.srt b/tests/samples/invalid_format4.srt
similarity index 100%
rename from tests/subtitles/invalid_format4.srt
rename to tests/samples/invalid_format4.srt
diff --git a/tests/subtitles/invalid_timeframe.sbv b/tests/samples/invalid_timeframe.sbv
similarity index 100%
rename from tests/subtitles/invalid_timeframe.sbv
rename to tests/samples/invalid_timeframe.sbv
diff --git a/tests/subtitles/invalid_timeframe.srt b/tests/samples/invalid_timeframe.srt
similarity index 100%
rename from tests/subtitles/invalid_timeframe.srt
rename to tests/samples/invalid_timeframe.srt
diff --git a/tests/subtitles/invalid_timeframe.vtt b/tests/samples/invalid_timeframe.vtt
similarity index 100%
rename from tests/subtitles/invalid_timeframe.vtt
rename to tests/samples/invalid_timeframe.vtt
diff --git a/tests/subtitles/invalid_timeframe_in_cue_text.vtt b/tests/samples/invalid_timeframe_in_cue_text.vtt
similarity index 100%
rename from tests/subtitles/invalid_timeframe_in_cue_text.vtt
rename to tests/samples/invalid_timeframe_in_cue_text.vtt
diff --git a/tests/subtitles/metadata_headers.vtt b/tests/samples/metadata_headers.vtt
similarity index 100%
rename from tests/subtitles/metadata_headers.vtt
rename to tests/samples/metadata_headers.vtt
diff --git a/tests/subtitles/metadata_headers_multiline.vtt b/tests/samples/metadata_headers_multiline.vtt
similarity index 100%
rename from tests/subtitles/metadata_headers_multiline.vtt
rename to tests/samples/metadata_headers_multiline.vtt
diff --git a/tests/subtitles/missing_caption_text.sbv b/tests/samples/missing_caption_text.sbv
similarity index 100%
rename from tests/subtitles/missing_caption_text.sbv
rename to tests/samples/missing_caption_text.sbv
diff --git a/tests/subtitles/missing_caption_text.srt b/tests/samples/missing_caption_text.srt
similarity index 100%
rename from tests/subtitles/missing_caption_text.srt
rename to tests/samples/missing_caption_text.srt
diff --git a/tests/subtitles/missing_caption_text.vtt b/tests/samples/missing_caption_text.vtt
similarity index 100%
rename from tests/subtitles/missing_caption_text.vtt
rename to tests/samples/missing_caption_text.vtt
diff --git a/tests/subtitles/missing_timeframe.sbv b/tests/samples/missing_timeframe.sbv
similarity index 100%
rename from tests/subtitles/missing_timeframe.sbv
rename to tests/samples/missing_timeframe.sbv
diff --git a/tests/subtitles/missing_timeframe.srt b/tests/samples/missing_timeframe.srt
similarity index 100%
rename from tests/subtitles/missing_timeframe.srt
rename to tests/samples/missing_timeframe.srt
diff --git a/tests/subtitles/missing_timeframe.vtt b/tests/samples/missing_timeframe.vtt
similarity index 100%
rename from tests/subtitles/missing_timeframe.vtt
rename to tests/samples/missing_timeframe.vtt
diff --git a/tests/subtitles/netflix_chicas_del_cable.vtt b/tests/samples/netflix_chicas_del_cable.vtt
similarity index 100%
rename from tests/subtitles/netflix_chicas_del_cable.vtt
rename to tests/samples/netflix_chicas_del_cable.vtt
diff --git a/tests/subtitles/no_captions.vtt b/tests/samples/no_captions.vtt
similarity index 100%
rename from tests/subtitles/no_captions.vtt
rename to tests/samples/no_captions.vtt
diff --git a/tests/subtitles/one_caption.srt b/tests/samples/one_caption.srt
similarity index 100%
rename from tests/subtitles/one_caption.srt
rename to tests/samples/one_caption.srt
diff --git a/tests/subtitles/one_caption.vtt b/tests/samples/one_caption.vtt
similarity index 100%
rename from tests/subtitles/one_caption.vtt
rename to tests/samples/one_caption.vtt
diff --git a/tests/subtitles/sample.sbv b/tests/samples/sample.sbv
similarity index 100%
rename from tests/subtitles/sample.sbv
rename to tests/samples/sample.sbv
diff --git a/tests/subtitles/sample.srt b/tests/samples/sample.srt
similarity index 100%
rename from tests/subtitles/sample.srt
rename to tests/samples/sample.srt
diff --git a/tests/subtitles/sample.vtt b/tests/samples/sample.vtt
similarity index 100%
rename from tests/subtitles/sample.vtt
rename to tests/samples/sample.vtt
diff --git a/tests/subtitles/styles.vtt b/tests/samples/styles.vtt
similarity index 100%
rename from tests/subtitles/styles.vtt
rename to tests/samples/styles.vtt
diff --git a/tests/samples/styles_with_comments.vtt b/tests/samples/styles_with_comments.vtt
new file mode 100644
index 0000000..4ae93ac
--- /dev/null
+++ b/tests/samples/styles_with_comments.vtt
@@ -0,0 +1,23 @@
+WEBVTT
+
+NOTE Sample of comments with styles
+
+STYLE
+::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+}
+
+NOTE This is the second block of styles
+
+NOTE
+Multiline comment for the same
+second block of styles
+
+STYLE
+::cue(b) {
+ color: peachpuff;
+}
+
+00:00:00.000 --> 00:00:10.000
+- Hello world.
\ No newline at end of file
diff --git a/tests/subtitles/two_captions.sbv b/tests/samples/two_captions.sbv
similarity index 100%
rename from tests/subtitles/two_captions.sbv
rename to tests/samples/two_captions.sbv
diff --git a/tests/subtitles/using_identifiers.vtt b/tests/samples/using_identifiers.vtt
similarity index 100%
rename from tests/subtitles/using_identifiers.vtt
rename to tests/samples/using_identifiers.vtt
diff --git a/tests/subtitles/youtube_dl.vtt b/tests/samples/youtube_dl.vtt
similarity index 100%
rename from tests/subtitles/youtube_dl.vtt
rename to tests/samples/youtube_dl.vtt
diff --git a/tests/test_cli.py b/tests/test_cli.py
new file mode 100644
index 0000000..ca42cf7
--- /dev/null
+++ b/tests/test_cli.py
@@ -0,0 +1,182 @@
+import unittest
+import tempfile
+import os
+import pathlib
+import textwrap
+
+from webvtt.cli import main
+
+
+class CLITestCase(unittest.TestCase):
+
+ def test_cli(self):
+ vtt_file = (
+ pathlib.Path(__file__).resolve().parent
+ / 'samples' / 'sample.vtt'
+ )
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+
+ main(['segment', str(vtt_file.resolve()), '-o', temp_dir])
+ _, dirs, files = next(os.walk(temp_dir))
+
+ self.assertEqual(len(dirs), 0)
+ self.assertEqual(len(files), 8)
+ for expected_file in ('prog_index.m3u8',
+ 'fileSequence0.webvtt',
+ 'fileSequence1.webvtt',
+ 'fileSequence2.webvtt',
+ 'fileSequence3.webvtt',
+ 'fileSequence4.webvtt',
+ 'fileSequence5.webvtt',
+ 'fileSequence6.webvtt',
+ ):
+ self.assertIn(expected_file, files)
+
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'prog_index.m3u8').read_text(),
+ textwrap.dedent(
+ '''
+ #EXTM3U
+ #EXT-X-TARGETDURATION:10
+ #EXT-X-VERSION:3
+ #EXT-X-PLAYLIST-TYPE:VOD
+ #EXTINF:30.00000
+ fileSequence0.webvtt
+ #EXTINF:30.00000
+ fileSequence1.webvtt
+ #EXTINF:30.00000
+ fileSequence2.webvtt
+ #EXTINF:30.00000
+ fileSequence3.webvtt
+ #EXTINF:30.00000
+ fileSequence4.webvtt
+ #EXTINF:30.00000
+ fileSequence5.webvtt
+ #EXTINF:30.00000
+ fileSequence6.webvtt
+ #EXT-X-ENDLIST
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence0.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence1.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence2.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+
+ 00:00:21.580 --> 00:00:23.880
+ Caption text #5
+
+ 00:00:23.880 --> 00:00:27.280
+ Caption text #6
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence3.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+
+ 00:00:30.280 --> 00:00:36.510
+ Caption text #8
+
+ 00:00:36.510 --> 00:00:38.870
+ Caption text #9
+
+ 00:00:38.870 --> 00:00:45.000
+ Caption text #10
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence4.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:38.870 --> 00:00:45.000
+ Caption text #10
+
+ 00:00:45.000 --> 00:00:47.000
+ Caption text #11
+
+ 00:00:47.000 --> 00:00:50.970
+ Caption text #12
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence5.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:47.000 --> 00:00:50.970
+ Caption text #12
+
+ 00:00:50.970 --> 00:00:54.440
+ Caption text #13
+
+ 00:00:54.440 --> 00:00:58.600
+ Caption text #14
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+ '''
+ ).lstrip())
+ self.assertEqual(
+ (pathlib.Path(temp_dir) / 'fileSequence6.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+
+ 00:01:01.350 --> 00:01:04.300
+ Caption text #16
+ '''
+ ).lstrip())
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000..2210e57
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,440 @@
+import unittest
+
+from webvtt.models import Timestamp, Caption, Style
+from webvtt.errors import MalformedCaptionError
+
+
+class TestTimestamp(unittest.TestCase):
+
+ def test_instantiation(self):
+ timestamp = Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ )
+ self.assertEqual(timestamp.hours, 1)
+ self.assertEqual(timestamp.minutes, 12)
+ self.assertEqual(timestamp.seconds, 23)
+ self.assertEqual(timestamp.milliseconds, 500)
+
+ def test_from_string(self):
+ timestamp = Timestamp.from_string('01:24:11.670')
+ self.assertEqual(timestamp.hours, 1)
+ self.assertEqual(timestamp.minutes, 24)
+ self.assertEqual(timestamp.seconds, 11)
+ self.assertEqual(timestamp.milliseconds, 670)
+
+ def test_from_string_single_digits(self):
+ timestamp = Timestamp.from_string('1:2:7.670')
+ self.assertEqual(timestamp.hours, 1)
+ self.assertEqual(timestamp.minutes, 2)
+ self.assertEqual(timestamp.seconds, 7)
+ self.assertEqual(timestamp.milliseconds, 670)
+
+ def test_from_string_missing_hours(self):
+ timestamp = Timestamp.from_string('24:11.670')
+ self.assertEqual(timestamp.hours, 0)
+ self.assertEqual(timestamp.minutes, 24)
+ self.assertEqual(timestamp.seconds, 11)
+ self.assertEqual(timestamp.milliseconds, 670)
+
+ def test_from_string_wrong_minutes(self):
+ with self.assertRaises(MalformedCaptionError):
+ Timestamp.from_string('01:76:11.670')
+
+ def test_from_string_wrong_seconds(self):
+ with self.assertRaises(MalformedCaptionError):
+ Timestamp.from_string('01:24:87.670')
+
+ def test_from_string_wrong_type(self):
+ with self.assertRaises(MalformedCaptionError):
+ Timestamp.from_string(1234)
+
+ def test_from_string_wrong_value(self):
+ with self.assertRaises(MalformedCaptionError):
+ Timestamp.from_string('01:24:11:670')
+
+ def test_to_tuple(self):
+ self.assertEqual(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ).to_tuple(),
+ (1, 12, 23, 500)
+ )
+
+ def test_equality(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) ==
+ Timestamp.from_string('01:12:23.500')
+ )
+
+ def test_not_equality(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) !=
+ Timestamp.from_string('01:12:23.600')
+ )
+
+ def test_greater_than(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=600
+ ) >
+ Timestamp.from_string('01:12:23.500')
+ )
+
+ def test_less_than(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) <
+ Timestamp.from_string('01:12:23.600')
+ )
+
+ def test_greater_or_equal_than(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) >=
+ Timestamp.from_string('01:12:23.500')
+ )
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=600
+ ) >=
+ Timestamp.from_string('01:12:23.500')
+ )
+
+ def test_less_or_equal_than(self):
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) <=
+ Timestamp.from_string('01:12:23.500')
+ )
+ self.assertTrue(
+ Timestamp(
+ hours=1, minutes=12, seconds=23, milliseconds=500
+ ) <=
+ Timestamp.from_string('01:12:23.600')
+ )
+
+ def test_repr(self):
+ timestamp = Timestamp(
+ hours=1,
+ minutes=12,
+ seconds=45,
+ milliseconds=320
+ )
+
+ self.assertEqual(
+ repr(timestamp),
+ ''
+ )
+
+
+class TestCaption(unittest.TestCase):
+
+ def test_instantiation(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+ self.assertEqual(caption.start, '00:00:07.000')
+ self.assertEqual(caption.end, '00:00:11.890')
+ self.assertEqual(caption.text, 'Hello test!')
+ self.assertEqual(caption.identifier, 'A test caption')
+
+ def test_timestamp_wrong_type(self):
+ with self.assertRaises(MalformedCaptionError):
+ Caption(
+ start=1234,
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ def test_identifier_is_optional(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ )
+ self.assertIsNone(caption.identifier)
+
+ def test_multi_lines(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!\nThis is the second line',
+ identifier='A test caption'
+ )
+ self.assertEqual(caption.text, 'Hello test!\nThis is the second line')
+ self.assertListEqual(
+ caption.lines,
+ ['Hello test!', 'This is the second line']
+ )
+
+ def test_multi_lines_accepts_list(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text=['Hello test!', 'This is the second line'],
+ identifier='A test caption'
+ )
+ self.assertEqual(caption.text, 'Hello test!\nThis is the second line')
+ self.assertListEqual(
+ caption.lines,
+ ['Hello test!', 'This is the second line']
+ )
+
+ def test_cuetags(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text=[
+ 'Hello test!',
+ 'This is the second line'
+ ],
+ identifier='A test caption'
+ )
+ self.assertEqual(caption.text, 'Hello test!\nThis is the second line')
+ self.assertEqual(
+ caption.raw_text,
+ 'Hello test!\n'
+ 'This is the second line'
+ )
+
+ def test_in_seconds(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text=['Hello test!', 'This is the second line'],
+ identifier='A test caption'
+ )
+ self.assertEqual(caption.start_in_seconds, 7)
+ self.assertEqual(caption.end_in_seconds, 11)
+
+ def test_wrong_start_timestamp(self):
+ self.assertRaises(
+ MalformedCaptionError,
+ Caption,
+ start='1234',
+ end='00:00:11.890',
+ text='Hello Test!'
+ )
+
+ def test_wrong_type_start_timestamp(self):
+ self.assertRaises(
+ MalformedCaptionError,
+ Caption,
+ start=1234,
+ end='00:00:11.890',
+ text='Hello Test!'
+ )
+
+ def test_wrong_end_timestamp(self):
+ self.assertRaises(
+ MalformedCaptionError,
+ Caption,
+ start='00:00:07.000',
+ end='1234',
+ text='Hello Test!'
+ )
+
+ def test_wrong_type_end_timestamp(self):
+ self.assertRaises(
+ MalformedCaptionError,
+ Caption,
+ start='00:00:07.000',
+ end=1234,
+ text='Hello Test!'
+ )
+
+ def test_equality(self):
+ caption1 = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ caption2 = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ self.assertTrue(caption1 == caption2)
+
+ caption1 = Caption(
+ start='00:00:02.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ caption2 = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ self.assertFalse(caption1 == caption2)
+
+ self.assertFalse(
+ Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ ) == 1234
+ )
+
+ def test_repr(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ self.assertEqual(
+ repr(caption),
+ ""
+ )
+
+ def test_str(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+
+ self.assertEqual(
+ str(caption),
+ '00:00:07.000 00:00:11.890 Hello test!'
+ )
+
+ def test_accept_comments(self):
+ caption = Caption(
+ start='00:00:07.000',
+ end='00:00:11.890',
+ text='Hello test!',
+ identifier='A test caption'
+ )
+ caption.comments.append('One comment')
+ caption.comments.append('Another comment')
+
+ self.assertListEqual(
+ caption.comments,
+ ['One comment', 'Another comment']
+ )
+
+ def test_timestamp_update(self):
+ c = Caption('00:00:00.500', '00:00:07.000')
+ c.start = '00:00:01.750'
+ c.end = '00:00:08.250'
+
+ self.assertEqual(c.start, '00:00:01.750')
+ self.assertEqual(c.end, '00:00:08.250')
+
+ def test_timestamp_format(self):
+ c = Caption('01:02:03.400', '02:03:04.500')
+ self.assertEqual(c.start, '01:02:03.400')
+ self.assertEqual(c.end, '02:03:04.500')
+
+ c = Caption('02:03.400', '03:04.500')
+ self.assertEqual(c.start, '00:02:03.400')
+ self.assertEqual(c.end, '00:03:04.500')
+
+ def test_update_text(self):
+ c = Caption(text='Caption line #1')
+ c.text = 'Caption line #1 updated'
+ self.assertEqual(
+ c.text,
+ 'Caption line #1 updated'
+ )
+
+ def test_update_text_multiline(self):
+ c = Caption(text='Caption line #1')
+ c.text = 'Caption line #1\nCaption line #2'
+
+ self.assertEqual(
+ len(c.lines),
+ 2
+ )
+
+ self.assertEqual(
+ c.text,
+ 'Caption line #1\nCaption line #2'
+ )
+
+ def test_update_text_wrong_type(self):
+ c = Caption(text='Caption line #1')
+
+ self.assertRaises(
+ AttributeError,
+ setattr,
+ c,
+ 'text',
+ 123
+ )
+
+ def test_manipulate_lines(self):
+ c = Caption(text=['Caption line #1', 'Caption line #2'])
+ c.lines[0] = 'Caption line #1 updated'
+ self.assertEqual(
+ c.lines[0],
+ 'Caption line #1 updated'
+ )
+
+ def test_malformed_start_timestamp(self):
+ self.assertRaises(
+ MalformedCaptionError,
+ Caption,
+ '01:00'
+ )
+
+
+class TestStyle(unittest.TestCase):
+
+ def test_instantiation(self):
+ style = Style(text='::cue(b) {\ncolor: peachpuff;\n}')
+ self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}')
+ self.assertListEqual(
+ style.lines,
+ ['::cue(b) {', 'color: peachpuff;', '}']
+ )
+
+ def test_text_accept_list_of_strings(self):
+ style = Style(text=['::cue(b) {', 'color: peachpuff;', '}'])
+ self.assertEqual(style.text, '::cue(b) {\ncolor: peachpuff;\n}')
+ self.assertListEqual(
+ style.lines,
+ ['::cue(b) {', 'color: peachpuff;', '}']
+ )
+
+ def test_accept_comments(self):
+ style = Style(text='::cue(b) {\ncolor: peachpuff;\n}')
+ style.comments.append('One comment')
+ style.comments.append('Another comment')
+
+ self.assertListEqual(
+ style.comments,
+ ['One comment', 'Another comment']
+ )
+
+ def test_get_text(self):
+ style = Style(['::cue(b) {', ' color: peachpuff;', '}'])
+ self.assertEqual(
+ style.text,
+ '::cue(b) {\n color: peachpuff;\n}'
+ )
diff --git a/tests/test_sbv.py b/tests/test_sbv.py
new file mode 100644
index 0000000..d78a9c3
--- /dev/null
+++ b/tests/test_sbv.py
@@ -0,0 +1,141 @@
+import unittest
+import textwrap
+
+from webvtt import sbv
+from webvtt.errors import MalformedFileError
+from webvtt.models import Caption
+
+
+class TestSBVCueBlock(unittest.TestCase):
+
+ def test_is_valid(self):
+ self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500,00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 0:0:0.500,0:0:7.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500,00:00:07.000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00.500,00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00.500,00:00:07.000
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 1
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(sbv.SBVCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500,00:00:07.000
+ ''').strip().split('\n'))
+ )
+
+ def test_from_lines(self):
+ cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent('''
+ 00:00:00.500,00:00:07.000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(
+ cue_block.start,
+ '00:00:00.500'
+ )
+ self.assertEqual(
+ cue_block.end,
+ '00:00:07.000'
+ )
+ self.assertEqual(
+ cue_block.payload,
+ ['Caption #1 line 1', 'Caption #1 line 2']
+ )
+
+ def test_from_lines_shorter_timestamps(self):
+ cue_block = sbv.SBVCueBlock.from_lines(textwrap.dedent('''
+ 0:1:2.500,0:1:03.800
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(
+ cue_block.start,
+ '0:1:2.500'
+ )
+ self.assertEqual(
+ cue_block.end,
+ '0:1:03.800'
+ )
+
+
+class TestSBVModule(unittest.TestCase):
+
+ def test_parse_invalid_format(self):
+ self.assertRaises(
+ MalformedFileError,
+ sbv.parse,
+ textwrap.dedent('''
+ 1
+ 00:00:00.500,00:00:07.000
+ Caption text #1
+
+ 00:00:07.000,00:00:11.890
+ Caption text #2
+ ''').strip().split('\n')
+ )
+
+ def test_parse_captions(self):
+ captions = sbv.parse(
+ textwrap.dedent('''
+ 00:00:00.500,00:00:07.000
+ Caption #1
+
+ 00:00:07.000,00:00:11.890
+ Caption #2 line 1
+ Caption #2 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(len(captions), 2)
+ self.assertIsInstance(captions[0], Caption)
+ self.assertIsInstance(captions[1], Caption)
+ self.assertEqual(
+ str(captions[0]),
+ '00:00:00.500 00:00:07.000 Caption #1'
+ )
+ self.assertEqual(
+ str(captions[1]),
+ r'00:00:07.000 00:00:11.890 Caption #2 line 1\n'
+ 'Caption #2 line 2'
+ )
diff --git a/tests/test_sbv_parser.py b/tests/test_sbv_parser.py
deleted file mode 100644
index 0a6f899..0000000
--- a/tests/test_sbv_parser.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import webvtt
-
-from .generic import GenericParserTestCase
-
-
-class SBVParserTestCase(GenericParserTestCase):
-
- def test_sbv_parse_empty_file(self):
- self.assertRaises(
- webvtt.errors.MalformedFileError,
- webvtt.from_sbv,
- self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose.
- )
-
- def test_sbv_invalid_format(self):
- self.assertRaises(
- webvtt.errors.MalformedFileError,
- webvtt.from_sbv,
- self._get_file('invalid_format.sbv')
- )
-
- def test_sbv_total_length(self):
- self.assertEqual(
- webvtt.from_sbv(self._get_file('sample.sbv')).total_length,
- 16
- )
-
- def test_sbv_parse_captions(self):
- self.assertEqual(
- len(webvtt.from_srt(self._get_file('sample.srt')).captions),
- 5
- )
-
- def test_sbv_missing_timeframe_line(self):
- self.assertRaises(
- webvtt.errors.MalformedCaptionError,
- webvtt.from_sbv,
- self._get_file('missing_timeframe.sbv')
- )
-
- def test_sbv_missing_caption_text(self):
- self.assertTrue(webvtt.from_sbv(self._get_file('missing_caption_text.sbv')).captions)
-
- def test_sbv_invalid_timestamp(self):
- self.assertRaises(
- webvtt.errors.MalformedCaptionError,
- webvtt.from_sbv,
- self._get_file('invalid_timeframe.sbv')
- )
-
- def test_sbv_timestamps_format(self):
- vtt = webvtt.from_sbv(self._get_file('sample.sbv'))
- self.assertEqual(vtt.captions[1].start, '00:00:11.378')
- self.assertEqual(vtt.captions[1].end, '00:00:12.305')
-
- def test_sbv_timestamps_in_seconds(self):
- vtt = webvtt.from_sbv(self._get_file('sample.sbv'))
- self.assertEqual(vtt.captions[1].start_in_seconds, 11.378)
- self.assertEqual(vtt.captions[1].end_in_seconds, 12.305)
-
- def test_sbv_get_caption_text(self):
- vtt = webvtt.from_sbv(self._get_file('sample.sbv'))
- self.assertEqual(vtt.captions[1].text, 'Caption text #2')
-
- def test_sbv_get_caption_text_multiline(self):
- vtt = webvtt.from_sbv(self._get_file('sample.sbv'))
- self.assertEqual(
- vtt.captions[2].text,
- 'Caption text #3 (line 1)\nCaption text #3 (line 2)'
- )
- self.assertListEqual(
- vtt.captions[2].lines,
- ['Caption text #3 (line 1)', 'Caption text #3 (line 2)']
- )
diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py
index d82a520..f6ed5b4 100644
--- a/tests/test_segmenter.py
+++ b/tests/test_segmenter.py
@@ -1,180 +1,342 @@
import os
import unittest
-from shutil import rmtree
+import tempfile
+import pathlib
+import textwrap
-from webvtt import WebVTTSegmenter, Caption
-from webvtt.errors import InvalidCaptionsError
-from webvtt import WebVTT
+from webvtt import segmenter
-BASE_DIR = os.path.dirname(__file__)
-SUBTITLES_DIR = os.path.join(BASE_DIR, 'subtitles')
-OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
+PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples'
-class WebVTTSegmenterTestCase(unittest.TestCase):
+class TestSegmenter(unittest.TestCase):
def setUp(self):
- self.segmenter = WebVTTSegmenter()
+ self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
- if os.path.exists(OUTPUT_DIR):
- rmtree(OUTPUT_DIR)
-
- def _parse_captions(self, filename):
- self.webvtt = WebVTT().read(os.path.join(SUBTITLES_DIR, filename))
-
- def test_invalid_captions(self):
- self.assertRaises(
- FileNotFoundError,
- self.segmenter.segment,
- 'text'
- )
-
- self.assertRaises(
- InvalidCaptionsError,
- self.segmenter.segment,
- 10
- )
-
- def test_single_invalid_caption(self):
- self.assertRaises(
- InvalidCaptionsError,
- self.segmenter.segment,
- [Caption(), Caption(), 'text', Caption()]
- )
-
- def test_total_segments(self):
- # segment with default 10 seconds
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR)
- self.assertEqual(self.segmenter.total_segments, 7)
-
- # segment with custom 30 seconds
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR, 30)
- self.assertEqual(self.segmenter.total_segments, 3)
-
- def test_output_folder_is_created(self):
- self.assertFalse(os.path.exists(OUTPUT_DIR))
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR)
- self.assertTrue(os.path.exists(OUTPUT_DIR))
-
- def test_segmentation_files_exist(self):
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR)
- for i in range(7):
- self.assertTrue(
- os.path.exists(os.path.join(OUTPUT_DIR, 'fileSequence{}.webvtt'.format(i)))
+ self.temp_dir.cleanup()
+
+ def test_segmentation_with_defaults(self):
+ segmenter.segment(PATH_TO_SAMPLES / 'sample.vtt', self.temp_dir.name)
+
+ _, dirs, files = next(os.walk(self.temp_dir.name))
+
+ self.assertEqual(len(dirs), 0)
+ self.assertEqual(len(files), 8)
+
+ for expected_file in ('prog_index.m3u8',
+ 'fileSequence0.webvtt',
+ 'fileSequence1.webvtt',
+ 'fileSequence2.webvtt',
+ 'fileSequence3.webvtt',
+ 'fileSequence4.webvtt',
+ 'fileSequence5.webvtt',
+ 'fileSequence6.webvtt',
+ ):
+ self.assertIn(expected_file, files)
+
+ output_path = pathlib.Path(self.temp_dir.name)
+
+ self.assertEqual(
+ (output_path / 'prog_index.m3u8').read_text(),
+ textwrap.dedent(
+ '''
+ #EXTM3U
+ #EXT-X-TARGETDURATION:10
+ #EXT-X-VERSION:3
+ #EXT-X-PLAYLIST-TYPE:VOD
+ #EXTINF:30.00000
+ fileSequence0.webvtt
+ #EXTINF:30.00000
+ fileSequence1.webvtt
+ #EXTINF:30.00000
+ fileSequence2.webvtt
+ #EXTINF:30.00000
+ fileSequence3.webvtt
+ #EXTINF:30.00000
+ fileSequence4.webvtt
+ #EXTINF:30.00000
+ fileSequence5.webvtt
+ #EXTINF:30.00000
+ fileSequence6.webvtt
+ #EXT-X-ENDLIST
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence0.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence1.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence2.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+
+ 00:00:21.580 --> 00:00:23.880
+ Caption text #5
+
+ 00:00:23.880 --> 00:00:27.280
+ Caption text #6
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence3.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+
+ 00:00:30.280 --> 00:00:36.510
+ Caption text #8
+
+ 00:00:36.510 --> 00:00:38.870
+ Caption text #9
+
+ 00:00:38.870 --> 00:00:45.000
+ Caption text #10
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence4.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:38.870 --> 00:00:45.000
+ Caption text #10
+
+ 00:00:45.000 --> 00:00:47.000
+ Caption text #11
+
+ 00:00:47.000 --> 00:00:50.970
+ Caption text #12
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence5.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:47.000 --> 00:00:50.970
+ Caption text #12
+
+ 00:00:50.970 --> 00:00:54.440
+ Caption text #13
+
+ 00:00:54.440 --> 00:00:58.600
+ Caption text #14
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence6.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+
+ 00:01:01.350 --> 00:01:04.300
+ Caption text #16
+ '''
+ ).lstrip()
+ )
+
+ def test_segmentation_with_custom_values(self):
+ segmenter.segment(
+ webvtt_path=PATH_TO_SAMPLES / 'sample.vtt',
+ output=self.temp_dir.name,
+ seconds=30,
+ mpegts=800000
+ )
+
+ _, dirs, files = next(os.walk(self.temp_dir.name))
+
+ self.assertEqual(len(dirs), 0)
+ self.assertEqual(len(files), 4)
+
+ for expected_file in ('prog_index.m3u8',
+ 'fileSequence0.webvtt',
+ 'fileSequence1.webvtt',
+ 'fileSequence2.webvtt',
+ ):
+ self.assertIn(expected_file, files)
+
+ output_path = pathlib.Path(self.temp_dir.name)
+
+ self.assertEqual(
+ (output_path / 'prog_index.m3u8').read_text(),
+ textwrap.dedent(
+ '''
+ #EXTM3U
+ #EXT-X-TARGETDURATION:30
+ #EXT-X-VERSION:3
+ #EXT-X-PLAYLIST-TYPE:VOD
+ #EXTINF:30.00000
+ fileSequence0.webvtt
+ #EXTINF:30.00000
+ fileSequence1.webvtt
+ #EXTINF:30.00000
+ fileSequence2.webvtt
+ #EXT-X-ENDLIST
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence0.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+
+ 00:00:21.580 --> 00:00:23.880
+ Caption text #5
+
+ 00:00:23.880 --> 00:00:27.280
+ Caption text #6
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence1.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000
+
+ 00:00:27.280 --> 00:00:30.280
+ Caption text #7
+
+ 00:00:30.280 --> 00:00:36.510
+ Caption text #8
+
+ 00:00:36.510 --> 00:00:38.870
+ Caption text #9
+
+ 00:00:38.870 --> 00:00:45.000
+ Caption text #10
+
+ 00:00:45.000 --> 00:00:47.000
+ Caption text #11
+
+ 00:00:47.000 --> 00:00:50.970
+ Caption text #12
+
+ 00:00:50.970 --> 00:00:54.440
+ Caption text #13
+
+ 00:00:54.440 --> 00:00:58.600
+ Caption text #14
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+ '''
+ ).lstrip()
+ )
+ self.assertEqual(
+ (output_path / 'fileSequence2.webvtt').read_text(),
+ textwrap.dedent(
+ '''
+ WEBVTT
+ X-TIMESTAMP-MAP=MPEGTS:800000,LOCAL:00:00:00.000
+
+ 00:00:58.600 --> 00:01:01.350
+ Caption text #15
+
+ 00:01:01.350 --> 00:01:04.300
+ Caption text #16
+ '''
+ ).lstrip()
)
- self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'prog_index.m3u8')))
-
- def test_segmentation(self):
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR)
-
- # segment 1 should have caption 1 and 2
- self.assertEqual(len(self.segmenter.segments[0]), 2)
- self.assertIn(self.webvtt.captions[0], self.segmenter.segments[0])
- self.assertIn(self.webvtt.captions[1], self.segmenter.segments[0])
- # segment 2 should have caption 2 again (overlap), 3 and 4
- self.assertEqual(len(self.segmenter.segments[1]), 3)
- self.assertIn(self.webvtt.captions[2], self.segmenter.segments[1])
- self.assertIn(self.webvtt.captions[3], self.segmenter.segments[1])
- # segment 3 should have caption 4 again (overlap), 5, 6 and 7
- self.assertEqual(len(self.segmenter.segments[2]), 4)
- self.assertIn(self.webvtt.captions[3], self.segmenter.segments[2])
- self.assertIn(self.webvtt.captions[4], self.segmenter.segments[2])
- self.assertIn(self.webvtt.captions[5], self.segmenter.segments[2])
- self.assertIn(self.webvtt.captions[6], self.segmenter.segments[2])
- # segment 4 should have caption 7 again (overlap), 8, 9 and 10
- self.assertEqual(len(self.segmenter.segments[3]), 4)
- self.assertIn(self.webvtt.captions[6], self.segmenter.segments[3])
- self.assertIn(self.webvtt.captions[7], self.segmenter.segments[3])
- self.assertIn(self.webvtt.captions[8], self.segmenter.segments[3])
- self.assertIn(self.webvtt.captions[9], self.segmenter.segments[3])
- # segment 5 should have caption 10 again (overlap), 11 and 12
- self.assertEqual(len(self.segmenter.segments[4]), 3)
- self.assertIn(self.webvtt.captions[9], self.segmenter.segments[4])
- self.assertIn(self.webvtt.captions[10], self.segmenter.segments[4])
- self.assertIn(self.webvtt.captions[11], self.segmenter.segments[4])
- # segment 6 should have caption 12 again (overlap), 13, 14 and 15
- self.assertEqual(len(self.segmenter.segments[5]), 4)
- self.assertIn(self.webvtt.captions[11], self.segmenter.segments[5])
- self.assertIn(self.webvtt.captions[12], self.segmenter.segments[5])
- self.assertIn(self.webvtt.captions[13], self.segmenter.segments[5])
- self.assertIn(self.webvtt.captions[14], self.segmenter.segments[5])
- # segment 7 should have caption 15 again (overlap) and 16
- self.assertEqual(len(self.segmenter.segments[6]), 2)
- self.assertIn(self.webvtt.captions[14], self.segmenter.segments[6])
- self.assertIn(self.webvtt.captions[15], self.segmenter.segments[6])
-
- def test_segment_content(self):
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10)
-
- with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- 'WEBVTT',
- 'X-TIMESTAMP-MAP=MPEGTS:900000,LOCAL:00:00:00.000',
- '',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- '',
- '00:00:07.000 --> 00:00:11.890',
- 'Caption text #2'
- ]
-
- self.assertListEqual(lines, expected_lines)
-
- def test_manifest_content(self):
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR, 10)
-
- with open(os.path.join(OUTPUT_DIR, 'prog_index.m3u8'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- '#EXTM3U',
- '#EXT-X-TARGETDURATION:{}'.format(self.segmenter.seconds),
- '#EXT-X-VERSION:3',
- '#EXT-X-PLAYLIST-TYPE:VOD',
- ]
-
- for i in range(7):
- expected_lines.extend([
- '#EXTINF:30.00000',
- 'fileSequence{}.webvtt'.format(i)
- ])
-
- expected_lines.append('#EXT-X-ENDLIST')
-
- for index, line in enumerate(expected_lines):
- self.assertEqual(lines[index], line)
-
- def test_customize_mpegts(self):
- self._parse_captions('sample.vtt')
- self.segmenter.segment(self.webvtt, OUTPUT_DIR, mpegts=800000)
-
- with open(os.path.join(OUTPUT_DIR, 'fileSequence0.webvtt'), 'r', encoding='utf-8') as f:
- lines = f.readlines()
- self.assertIn('MPEGTS:800000', lines[1])
-
- def test_segment_from_file(self):
- self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'sample.vtt'), OUTPUT_DIR),
- self.assertEqual(self.segmenter.total_segments, 7)
def test_segment_with_no_captions(self):
- self.segmenter.segment(os.path.join(SUBTITLES_DIR, 'no_captions.vtt'), OUTPUT_DIR),
- self.assertEqual(self.segmenter.total_segments, 0)
-
- def test_total_segments_readonly(self):
- self.assertRaises(
- AttributeError,
- setattr,
- WebVTTSegmenter(),
- 'total_segments',
- 5
- )
+ segmenter.segment(
+ webvtt_path=PATH_TO_SAMPLES / 'no_captions.vtt',
+ output=self.temp_dir.name
+ )
+
+ _, dirs, files = next(os.walk(self.temp_dir.name))
+
+ self.assertEqual(len(dirs), 0)
+ self.assertEqual(len(files), 1)
+ self.assertIn('prog_index.m3u8', files)
+
+ self.assertEqual(
+ (pathlib.Path(self.temp_dir.name) / 'prog_index.m3u8').read_text(),
+ textwrap.dedent(
+ '''
+ #EXTM3U
+ #EXT-X-TARGETDURATION:10
+ #EXT-X-VERSION:3
+ #EXT-X-PLAYLIST-TYPE:VOD
+ #EXT-X-ENDLIST
+ '''
+ ).lstrip()
+ )
diff --git a/tests/test_srt.py b/tests/test_srt.py
index eed186d..e6a2dd4 100644
--- a/tests/test_srt.py
+++ b/tests/test_srt.py
@@ -1,35 +1,160 @@
-import os
import unittest
-from shutil import rmtree, copy
+import io
+import textwrap
-import webvtt
+from webvtt import srt
+from webvtt.errors import MalformedFileError
+from webvtt.models import Caption
-from .generic import GenericParserTestCase
+class TestSRTCueBlock(unittest.TestCase):
-BASE_DIR = os.path.dirname(__file__)
-OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
+ def test_is_valid(self):
+ self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+ self.assertTrue(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n'))
+ )
-class SRTCaptionsTestCase(GenericParserTestCase):
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1
+ ''').strip().split('\n'))
+ )
- def setUp(self):
- os.makedirs(OUTPUT_DIR)
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
- def tearDown(self):
- if os.path.exists(OUTPUT_DIR):
- rmtree(OUTPUT_DIR)
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ ''').strip().split('\n'))
+ )
- def test_convert_from_srt_to_vtt_and_back_gives_same_file(self):
- copy(self._get_file('sample.srt'), OUTPUT_DIR)
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 1
+ Caption #1
+ ''').strip().split('\n'))
+ )
- vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'sample.srt'))
- vtt.save_as_srt(os.path.join(OUTPUT_DIR, 'sample_converted.srt'))
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ Caption #1
+ ''').strip().split('\n'))
+ )
- with open(os.path.join(OUTPUT_DIR, 'sample.srt'), 'r', encoding='utf-8') as f:
- original = f.read()
+ self.assertFalse(srt.SRTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00,500 --> 00:00:07,000
+ ''').strip().split('\n'))
+ )
- with open(os.path.join(OUTPUT_DIR, 'sample_converted.srt'), 'r', encoding='utf-8') as f:
- converted = f.read()
+ def test_from_lines(self):
+ cue_block = srt.SRTCueBlock.from_lines(textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(cue_block.index, '1')
+ self.assertEqual(
+ cue_block.start,
+ '00:00:00,500'
+ )
+ self.assertEqual(
+ cue_block.end,
+ '00:00:07,000'
+ )
+ self.assertEqual(
+ cue_block.payload,
+ ['Caption #1 line 1', 'Caption #1 line 2']
+ )
- self.assertEqual(original.strip(), converted.strip())
+
+class TestSRTModule(unittest.TestCase):
+
+ def test_parse_invalid_format(self):
+ self.assertRaises(
+ MalformedFileError,
+ srt.parse,
+ textwrap.dedent('''
+ 00:00:00,500 --> 00:00:07,000
+ Caption text #1
+
+ 00:00:07,000 --> 00:00:11,890
+ Caption text #2
+ ''').strip().split('\n')
+ )
+
+ def test_parse_captions(self):
+ captions = srt.parse(
+ textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1
+
+ 2
+ 00:00:07,000 --> 00:00:11,890
+ Caption #2 line 1
+ Caption #2 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(len(captions), 2)
+ self.assertIsInstance(captions[0], Caption)
+ self.assertIsInstance(captions[1], Caption)
+ self.assertEqual(
+ str(captions[0]),
+ '00:00:00.500 00:00:07.000 Caption #1'
+ )
+ self.assertEqual(
+ str(captions[1]),
+ r'00:00:07.000 00:00:11.890 Caption #2 line 1\n'
+ 'Caption #2 line 2'
+ )
+
+ def test_write(self):
+ out = io.StringIO()
+ captions = [
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption #1'
+ ),
+ Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['Caption #2 line 1',
+ 'Caption #2 line 2'
+ ]
+ )
+ ]
+ captions[0].comments.append('Comment for the first caption')
+ captions[1].comments.append('Comment for the second caption')
+
+ srt.write(out, captions)
+
+ out.seek(0)
+
+ self.assertEqual(
+ out.read(),
+ textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption #1
+
+ 2
+ 00:00:07,000 --> 00:00:11,890
+ Caption #2 line 1
+ Caption #2 line 2
+ ''').strip()
+ )
diff --git a/tests/test_srt_parser.py b/tests/test_srt_parser.py
deleted file mode 100644
index 7885ec5..0000000
--- a/tests/test_srt_parser.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import webvtt
-
-from .generic import GenericParserTestCase
-
-
-class SRTParserTestCase(GenericParserTestCase):
-
- def test_srt_parse_empty_file(self):
- self.assertRaises(
- webvtt.errors.MalformedFileError,
- webvtt.from_srt,
- self._get_file('empty.vtt') # We reuse this file as it is empty and serves the purpose.
- )
-
- def test_srt_invalid_format(self):
- for i in range(1, 5):
- self.assertRaises(
- webvtt.errors.MalformedFileError,
- webvtt.from_srt,
- self._get_file('invalid_format{}.srt'.format(i))
- )
-
- def test_srt_total_length(self):
- self.assertEqual(
- webvtt.from_srt(self._get_file('sample.srt')).total_length,
- 23
- )
-
- def test_srt_parse_captions(self):
- self.assertTrue(webvtt.from_srt(self._get_file('sample.srt')).captions)
-
- def test_srt_missing_timeframe_line(self):
- self.assertRaises(
- webvtt.errors.MalformedCaptionError,
- webvtt.from_srt,
- self._get_file('missing_timeframe.srt')
- )
-
- def test_srt_empty_caption_text(self):
- self.assertTrue(webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions)
-
- def test_srt_empty_gets_removed(self):
- captions = webvtt.from_srt(self._get_file('missing_caption_text.srt')).captions
- self.assertEqual(len(captions), 4)
-
- def test_srt_invalid_timestamp(self):
- self.assertRaises(
- webvtt.errors.MalformedCaptionError,
- webvtt.from_srt,
- self._get_file('invalid_timeframe.srt')
- )
-
- def test_srt_timestamps_format(self):
- vtt = webvtt.from_srt(self._get_file('sample.srt'))
- self.assertEqual(vtt.captions[2].start, '00:00:11.890')
- self.assertEqual(vtt.captions[2].end, '00:00:16.320')
-
- def test_srt_parse_get_caption_data(self):
- vtt = webvtt.from_srt(self._get_file('one_caption.srt'))
- self.assertEqual(vtt.captions[0].start_in_seconds, 0.5)
- self.assertEqual(vtt.captions[0].start, '00:00:00.500')
- self.assertEqual(vtt.captions[0].end_in_seconds, 7)
- self.assertEqual(vtt.captions[0].end, '00:00:07.000')
- self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1')
- self.assertEqual(len(vtt.captions[0].lines), 1)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..23fc1a2
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,26 @@
+import unittest
+import tempfile
+
+from webvtt.utils import FileWrapper, CODEC_BOMS
+
+
+class TestUtils(unittest.TestCase):
+
+ def test_open_file(self):
+ with tempfile.NamedTemporaryFile('w') as f:
+ f.write('Hello test!')
+ f.flush()
+ with FileWrapper.open(f.name) as fw:
+ self.assertEqual(fw.file.read(), 'Hello test!')
+ self.assertEqual(fw.file.encoding, 'utf-8')
+
+ def test_open_file_with_bom(self):
+ for encoding, bom in CODEC_BOMS.items():
+ with tempfile.NamedTemporaryFile('wb') as f:
+ f.write(bom)
+ f.write('Hello test'.encode(encoding))
+ f.flush()
+
+ with FileWrapper.open(f.name) as fw:
+ self.assertEqual(fw.file.read(), 'Hello test')
+ self.assertEqual(fw.file.encoding, encoding)
diff --git a/tests/test_vtt.py b/tests/test_vtt.py
new file mode 100644
index 0000000..214b71c
--- /dev/null
+++ b/tests/test_vtt.py
@@ -0,0 +1,556 @@
+import unittest
+import textwrap
+import io
+
+from webvtt import vtt
+from webvtt.errors import MalformedFileError
+from webvtt.models import Caption, Style
+
+
+class TestWebVTTCueBlock(unittest.TestCase):
+
+ def test_is_valid(self):
+ self.assertTrue(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ identifier
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ identifier
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500 00:00:07.000
+ Caption #1 line 1
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTCueBlock.is_valid(textwrap.dedent('''
+ 00:00:00.500 --> 00:00:07.000
+ ''').strip().split('\n'))
+ )
+
+ def test_from_lines(self):
+ cue_block = vtt.WebVTTCueBlock.from_lines(textwrap.dedent('''
+ identifier
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1 line 1
+ Caption #1 line 2
+ ''').strip().split('\n')
+ )
+ self.assertEqual(cue_block.identifier, 'identifier')
+ self.assertEqual(cue_block.start, '00:00:00.500')
+ self.assertEqual(cue_block.end, '00:00:07.000')
+ self.assertListEqual(
+ cue_block.payload,
+ ['Caption #1 line 1', 'Caption #1 line 2']
+ )
+
+
+class TestWebVTTCommentBlock(unittest.TestCase):
+
+ def test_is_valid(self):
+ self.assertTrue(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ NOTE This is a one line comment
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ NOTE
+ This is a another one line comment
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ NOTE
+ This is a multi-line comment
+ taking two lines
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ This is not a comment
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ # This is not a comment
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTCommentBlock.is_valid(textwrap.dedent('''
+ // This is not a comment
+ ''').strip().split('\n'))
+ )
+
+ def test_from_lines(self):
+ comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent('''
+ NOTE This is a one line comment
+ ''').strip().split('\n')
+ )
+ self.assertEqual(comment.text, 'This is a one line comment')
+
+ comment = vtt.WebVTTCommentBlock.from_lines(textwrap.dedent('''
+ NOTE
+ This is a multi-line comment
+ taking two lines
+ ''').strip().split('\n')
+ )
+ self.assertEqual(
+ comment.text,
+ 'This is a multi-line comment\ntaking two lines'
+ )
+
+
+class TestWebVTTStyleBlock(unittest.TestCase):
+
+ def test_is_valid(self):
+ self.assertTrue(
+ vtt.WebVTTStyleBlock.is_valid(textwrap.dedent('''
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ''').strip().split('\n'))
+ )
+
+ self.assertTrue(
+ vtt.WebVTTStyleBlock.is_valid(textwrap.dedent('''
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ::cue(b) {
+ color: peachpuff;
+ }
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTStyleBlock.is_valid(textwrap.dedent('''
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+
+ ::cue(b) {
+ color: peachpuff;
+ }
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTStyleBlock.is_valid(textwrap.dedent('''
+ STYLE
+ ::cue--> {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ''').strip().split('\n'))
+ )
+
+ self.assertFalse(
+ vtt.WebVTTStyleBlock.is_valid(textwrap.dedent('''
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ''').strip().split('\n'))
+ )
+
+ def test_from_lines(self):
+ style = vtt.WebVTTStyleBlock.from_lines(textwrap.dedent('''
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ::cue(b) {
+ color: peachpuff;
+ }
+ ''').strip().split('\n')
+ )
+ self.assertEqual(
+ style.text,
+ textwrap.dedent('''
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+ ::cue(b) {
+ color: peachpuff;
+ }
+ ''').strip()
+ )
+
+
+class TestVTTModule(unittest.TestCase):
+
+ def test_parse_invalid_format(self):
+ self.assertRaises(
+ MalformedFileError,
+ vtt.parse,
+ textwrap.dedent('''
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+ ''').strip().split('\n')
+ )
+
+ def test_parse_captions(self):
+ output = vtt.parse(
+ textwrap.dedent('''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2 line 1
+ Caption text #2 line 2
+ ''').strip().split('\n')
+ )
+ captions = output.captions
+ styles = output.styles
+ self.assertEqual(len(captions), 2)
+ self.assertEqual(len(styles), 0)
+ self.assertIsInstance(captions[0], Caption)
+ self.assertIsInstance(captions[1], Caption)
+ self.assertEqual(
+ str(captions[0]),
+ '00:00:00.500 00:00:07.000 Caption text #1'
+ )
+ self.assertEqual(
+ str(captions[1]),
+ r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n'
+ 'Caption text #2 line 2'
+ )
+
+ def test_parse_styles(self):
+ output = vtt.parse(
+ textwrap.dedent('''
+ WEBVTT
+
+ STYLE
+ ::cue {
+ color: white;
+ }
+
+ STYLE
+ ::cue(.important) {
+ color: red;
+ font-weight: bold;
+ }
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip().split('\n')
+ )
+ captions = output.captions
+ styles = output.styles
+ self.assertEqual(len(captions), 1)
+ self.assertEqual(len(styles), 2)
+ self.assertIsInstance(styles[0], Style)
+ self.assertIsInstance(styles[1], Style)
+ self.assertEqual(
+ str(styles[0].text),
+ textwrap.dedent('''
+ ::cue {
+ color: white;
+ }
+ ''').strip()
+ )
+ self.assertEqual(
+ str(styles[1].text),
+ textwrap.dedent('''
+ ::cue(.important) {
+ color: red;
+ font-weight: bold;
+ }
+ ''').strip(),
+ )
+
+ def test_parse_content(self):
+ output = vtt.parse(
+ textwrap.dedent('''
+ WEBVTT
+
+ NOTE This is a testing sample
+
+ NOTE We can see two header comments, a style
+ comment and finally a footer comments
+
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray);
+ color: papayawhip;
+ }
+
+ NOTE the following style needs review
+
+ STYLE
+ ::cue {
+ color: white;
+ }
+
+ NOTE Comment for the first caption
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ NOTE
+ Comment for the second caption
+ that is very long
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2 line 1
+ Caption text #2 line 2
+
+ NOTE Copyright 2024
+
+ NOTE end of file
+ ''').strip().split('\n')
+ )
+ captions = output.captions
+ styles = output.styles
+ self.assertEqual(len(captions), 2)
+ self.assertEqual(len(styles), 2)
+ self.assertIsInstance(captions[0], Caption)
+ self.assertIsInstance(captions[1], Caption)
+ self.assertIsInstance(styles[0], Style)
+ self.assertIsInstance(styles[1], Style)
+ self.assertEqual(
+ str(captions[0]),
+ '00:00:00.500 00:00:07.000 Caption text #1'
+ )
+ self.assertEqual(
+ str(captions[1]),
+ r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n'
+ 'Caption text #2 line 2'
+ )
+ self.assertEqual(
+ str(styles[0].text),
+ textwrap.dedent('''
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray);
+ color: papayawhip;
+ }
+ ''').strip()
+ )
+ self.assertEqual(
+ str(styles[1].text),
+ textwrap.dedent('''
+ ::cue {
+ color: white;
+ }
+ ''').strip()
+ )
+ self.assertEqual(
+ styles[0].comments,
+ []
+ )
+ self.assertEqual(
+ styles[1].comments,
+ ['the following style needs review']
+ )
+ self.assertEqual(
+ captions[0].comments,
+ ['Comment for the first caption']
+ )
+ self.assertEqual(
+ captions[1].comments,
+ ['Comment for the second caption\nthat is very long']
+ )
+ self.assertListEqual(output.header_comments,
+ ['This is a testing sample',
+ 'We can see two header comments, a style\n'
+ 'comment and finally a footer comments'
+ ]
+ )
+ self.assertListEqual(output.footer_comments,
+ ['Copyright 2024',
+ 'end of file'
+ ]
+ )
+
+ def test_write(self):
+ out = io.StringIO()
+ captions = [
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption #1'
+ ),
+ Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['Caption #2 line 1',
+ 'Caption #2 line 2'
+ ]
+ )
+ ]
+ styles = [
+ Style('::cue(b) {\n color: peachpuff;\n}'),
+ Style('::cue {\n color: papayawhip;\n}')
+ ]
+ captions[0].comments.append('Comment for the first caption')
+ captions[1].comments.append('Comment for the second caption')
+ styles[1].comments.append(
+ 'Comment for the second style\nwith two lines'
+ )
+ header_comments = ['header comment', 'begin of the file']
+ footer_comments = ['footer comment', 'end of file']
+
+ vtt.write(
+ out,
+ captions,
+ styles,
+ header_comments,
+ footer_comments
+ )
+
+ out.seek(0)
+
+ self.assertEqual(
+ out.read(),
+ textwrap.dedent('''
+ WEBVTT
+
+ NOTE header comment
+
+ NOTE begin of the file
+
+ STYLE
+ ::cue(b) {
+ color: peachpuff;
+ }
+
+ NOTE
+ Comment for the second style
+ with two lines
+
+ STYLE
+ ::cue {
+ color: papayawhip;
+ }
+
+ NOTE Comment for the first caption
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+
+ NOTE Comment for the second caption
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption #2 line 1
+ Caption #2 line 2
+
+ NOTE footer comment
+
+ NOTE end of file
+ ''').strip()
+ )
+
+ def test_to_str(self):
+ captions = [
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption #1'
+ ),
+ Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['Caption #2 line 1',
+ 'Caption #2 line 2'
+ ]
+ )
+ ]
+ styles = [
+ Style('::cue(b) {\n color: peachpuff;\n}'),
+ Style('::cue {\n color: papayawhip;\n}')
+ ]
+ captions[0].comments.append('Comment for the first caption')
+ captions[1].comments.append('Comment for the second caption')
+ styles[1].comments.append(
+ 'Comment for the second style\nwith two lines'
+ )
+ header_comments = ['header comment', 'begin of the file']
+ footer_comments = ['footer comment', 'end of file']
+
+ self.assertEqual(
+ vtt.to_str(
+ captions,
+ styles,
+ header_comments,
+ footer_comments
+ ),
+ textwrap.dedent('''
+ WEBVTT
+
+ NOTE header comment
+
+ NOTE begin of the file
+
+ STYLE
+ ::cue(b) {
+ color: peachpuff;
+ }
+
+ NOTE
+ Comment for the second style
+ with two lines
+
+ STYLE
+ ::cue {
+ color: papayawhip;
+ }
+
+ NOTE Comment for the first caption
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+
+ NOTE Comment for the second caption
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption #2 line 1
+ Caption #2 line 2
+
+ NOTE footer comment
+
+ NOTE end of file
+ ''').strip()
+ )
diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py
index 9da3238..3111e2f 100644
--- a/tests/test_webvtt.py
+++ b/tests/test_webvtt.py
@@ -1,242 +1,263 @@
+import unittest
import os
import io
import textwrap
-from shutil import rmtree, copy
+import warnings
+import tempfile
+import pathlib
import webvtt
-from webvtt.structures import Caption, Style
-from .generic import GenericParserTestCase
+from webvtt.models import Caption, Style
+from webvtt.utils import CODEC_BOMS
from webvtt.errors import MalformedFileError
+PATH_TO_SAMPLES = pathlib.Path(__file__).resolve().parent / 'samples'
-BASE_DIR = os.path.dirname(__file__)
-OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
+class TestWebVTT(unittest.TestCase):
-class WebVTTTestCase(GenericParserTestCase):
+ def test_from_string(self):
+ vtt = webvtt.WebVTT.from_string(textwrap.dedent("""
+ WEBVTT
- def tearDown(self):
- if os.path.exists(OUTPUT_DIR):
- rmtree(OUTPUT_DIR)
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
- def test_create_caption(self):
- caption = Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2'])
- self.assertEqual(caption.start, '00:00:00.500')
- self.assertEqual(caption.start_in_seconds, 0.5)
- self.assertEqual(caption.end, '00:00:07.000')
- self.assertEqual(caption.end_in_seconds, 7)
- self.assertEqual(caption.lines, ['Caption test line 1', 'Caption test line 2'])
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2 line 1
+ Caption text #2 line 2
- def test_write_captions(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('one_caption.vtt'), OUTPUT_DIR)
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+ """).strip()
+ )
+ self.assertEqual(len(vtt), 4)
+ self.assertEqual(
+ str(vtt[0]),
+ '00:00:00.500 00:00:07.000 Caption text #1'
+ )
+ self.assertEqual(
+ str(vtt[1]),
+ r'00:00:07.000 00:00:11.890 Caption text #2 line 1\n'
+ 'Caption text #2 line 2'
+ )
+ self.assertEqual(
+ str(vtt[2]),
+ '00:00:11.890 00:00:16.320 Caption text #3'
+ )
+ self.assertEqual(
+ str(vtt[3]),
+ '00:00:16.320 00:00:21.580 Caption text #4'
+ )
+ def test_write_captions(self):
out = io.StringIO()
- vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt'))
- new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2'])
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt')
+ new_caption = Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['New caption text line1',
+ 'New caption text line2'
+ ]
+ )
vtt.captions.append(new_caption)
vtt.write(out)
out.seek(0)
- lines = [line.rstrip() for line in out.readlines()]
- expected_lines = [
- 'WEBVTT',
- '',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- '',
- '00:00:07.000 --> 00:00:11.890',
- 'New caption text line1',
- 'New caption text line2'
- ]
+ self.assertEqual(
+ out.read(),
+ textwrap.dedent('''
+ WEBVTT
- self.assertListEqual(lines, expected_lines)
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
- def test_save_captions(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('one_caption.vtt'), OUTPUT_DIR)
+ 00:00:07.000 --> 00:00:11.890
+ New caption text line1
+ New caption text line2
+ ''').strip()
+ )
- vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'one_caption.vtt'))
- new_caption = Caption('00:00:07.000', '00:00:11.890', ['New caption text line1', 'New caption text line2'])
+ def test_write_captions_in_srt(self):
+ out = io.StringIO()
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt')
+ new_caption = Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['New caption text line1',
+ 'New caption text line2'
+ ]
+ )
vtt.captions.append(new_caption)
- vtt.save()
-
- with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
+ vtt.write(out, format='srt')
- expected_lines = [
- 'WEBVTT',
- '',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- '',
- '00:00:07.000 --> 00:00:11.890',
- 'New caption text line1',
- 'New caption text line2'
- ]
+ out.seek(0)
+ self.assertEqual(
+ out.read(),
+ textwrap.dedent('''
+ 1
+ 00:00:00,500 --> 00:00:07,000
+ Caption text #1
+
+ 2
+ 00:00:07,000 --> 00:00:11,890
+ New caption text line1
+ New caption text line2
+ ''').strip()
+ )
+
+ def test_write_captions_in_unsupported_format(self):
+ self.assertRaises(
+ ValueError,
+ webvtt.WebVTT().write,
+ io.StringIO(),
+ format='ttt'
+ )
- self.assertListEqual(lines, expected_lines)
+ def test_save_captions(self):
+ with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f:
+ f.write((PATH_TO_SAMPLES / 'one_caption.vtt').read_text())
+ f.flush()
+
+ vtt = webvtt.read(f.name)
+ new_caption = Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text=['New caption text line1',
+ 'New caption text line2'
+ ]
+ )
+ vtt.captions.append(new_caption)
+ vtt.save()
+ f.flush()
+
+ self.assertEqual(
+ pathlib.Path(f.name).read_text(),
+ textwrap.dedent('''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ New caption text line1
+ New caption text line2
+ ''').strip()
+ )
def test_srt_conversion(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('one_caption.srt'), OUTPUT_DIR)
-
- vtt = webvtt.from_srt(os.path.join(OUTPUT_DIR, 'one_caption.srt'))
- vtt.save()
-
- self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'one_caption.vtt')))
-
- with open(os.path.join(OUTPUT_DIR, 'one_caption.vtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- 'WEBVTT',
- '',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- ]
-
- self.assertListEqual(lines, expected_lines)
+ with tempfile.TemporaryDirectory() as td:
+ with open(pathlib.Path(td) / 'one_caption.srt', 'w') as f:
+ f.write((PATH_TO_SAMPLES / 'one_caption.srt').read_text())
+
+ webvtt.from_srt(
+ pathlib.Path(td) / 'one_caption.srt'
+ ).save()
+
+ self.assertTrue(
+ os.path.exists(pathlib.Path(td) / 'one_caption.vtt')
+ )
+ self.assertEqual(
+ (pathlib.Path(td) / 'one_caption.vtt').read_text(),
+ textwrap.dedent('''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
def test_sbv_conversion(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('two_captions.sbv'), OUTPUT_DIR)
-
- vtt = webvtt.from_sbv(os.path.join(OUTPUT_DIR, 'two_captions.sbv'))
- vtt.save()
-
- self.assertTrue(os.path.exists(os.path.join(OUTPUT_DIR, 'two_captions.vtt')))
-
- with open(os.path.join(OUTPUT_DIR, 'two_captions.vtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- 'WEBVTT',
- '',
- '00:00:00.378 --> 00:00:11.378',
- 'Caption text #1',
- '',
- '00:00:11.378 --> 00:00:12.305',
- 'Caption text #2 (line 1)',
- 'Caption text #2 (line 2)',
- ]
-
- self.assertListEqual(lines, expected_lines)
+ with tempfile.TemporaryDirectory() as td:
+ with open(pathlib.Path(td) / 'two_captions.sbv', 'w') as f:
+ f.write(
+ (PATH_TO_SAMPLES / 'two_captions.sbv').read_text()
+ )
+
+ webvtt.from_sbv(
+ pathlib.Path(td) / 'two_captions.sbv'
+ ).save()
+
+ self.assertTrue(
+ os.path.exists(pathlib.Path(td) / 'two_captions.vtt')
+ )
+ self.assertEqual(
+ (pathlib.Path(td) / 'two_captions.vtt').read_text(),
+ textwrap.dedent('''
+ WEBVTT
+
+ 00:00:00.378 --> 00:00:11.378
+ Caption text #1
+
+ 00:00:11.378 --> 00:00:12.305
+ Caption text #2 (line 1)
+ Caption text #2 (line 2)
+ ''').strip()
+ )
def test_save_to_other_location(self):
- target_path = os.path.join(OUTPUT_DIR, 'test_folder')
- os.makedirs(target_path)
+ with tempfile.TemporaryDirectory() as td:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(td)
- webvtt.read(self._get_file('one_caption.vtt')).save(target_path)
- self.assertTrue(os.path.exists(os.path.join(target_path, 'one_caption.vtt')))
+ self.assertTrue(
+ os.path.exists(pathlib.Path(td) / 'one_caption.vtt')
+ )
def test_save_specific_filename(self):
- target_path = os.path.join(OUTPUT_DIR, 'test_folder')
- os.makedirs(target_path)
- output_file = os.path.join(target_path, 'custom_name.vtt')
+ with tempfile.TemporaryDirectory() as td:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(
+ pathlib.Path(td) / 'one_caption_new.vtt'
+ )
- webvtt.read(self._get_file('one_caption.vtt')).save(output_file)
- self.assertTrue(os.path.exists(output_file))
+ self.assertTrue(
+ os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt')
+ )
def test_save_specific_filename_no_extension(self):
- target_path = os.path.join(OUTPUT_DIR, 'test_folder')
- os.makedirs(target_path)
- output_file = os.path.join(target_path, 'custom_name')
-
- webvtt.read(self._get_file('one_caption.vtt')).save(output_file)
- self.assertTrue(os.path.exists(os.path.join(target_path, 'custom_name.vtt')))
-
- def test_caption_timestamp_update(self):
- c = Caption('00:00:00.500', '00:00:07.000')
- c.start = '00:00:01.750'
- c.end = '00:00:08.250'
-
- self.assertEqual(c.start, '00:00:01.750')
- self.assertEqual(c.end, '00:00:08.250')
-
- def test_caption_timestamp_format(self):
- c = Caption('01:02:03.400', '02:03:04.500')
- self.assertEqual(c.start, '01:02:03.400')
- self.assertEqual(c.end, '02:03:04.500')
-
- c = Caption('02:03.400', '03:04.500')
- self.assertEqual(c.start, '00:02:03.400')
- self.assertEqual(c.end, '00:03:04.500')
-
- def test_caption_text(self):
- c = Caption(text=['Caption line #1', 'Caption line #2'])
- self.assertEqual(
- c.text,
- 'Caption line #1\nCaption line #2'
- )
-
- def test_caption_receive_text(self):
- c = Caption(text='Caption line #1\nCaption line #2')
-
- self.assertEqual(
- len(c.lines),
- 2
- )
- self.assertEqual(
- c.text,
- 'Caption line #1\nCaption line #2'
- )
-
- def test_update_text(self):
- c = Caption(text='Caption line #1')
- c.text = 'Caption line #1 updated'
- self.assertEqual(
- c.text,
- 'Caption line #1 updated'
- )
-
- def test_update_text_multiline(self):
- c = Caption(text='Caption line #1')
- c.text = 'Caption line #1\nCaption line #2'
-
- self.assertEqual(
- len(c.lines),
- 2
- )
-
- self.assertEqual(
- c.text,
- 'Caption line #1\nCaption line #2'
- )
-
- def test_update_text_wrong_type(self):
- c = Caption(text='Caption line #1')
-
- self.assertRaises(
- AttributeError,
- setattr,
- c,
- 'text',
- 123
- )
-
- def test_manipulate_lines(self):
- c = Caption(text=['Caption line #1', 'Caption line #2'])
- c.lines[0] = 'Caption line #1 updated'
- self.assertEqual(
- c.lines[0],
- 'Caption line #1 updated'
- )
-
- def test_read_file_buffer(self):
- with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f:
- vtt = webvtt.read_buffer(f)
- self.assertIsInstance(vtt.captions, list)
+ with tempfile.TemporaryDirectory() as td:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(
+ pathlib.Path(td) / 'one_caption_new'
+ )
+
+ self.assertTrue(
+ os.path.exists(pathlib.Path(td) / 'one_caption_new.vtt')
+ )
+
+ def test_from_buffer(self):
+ with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f:
+ self.assertIsInstance(
+ webvtt.from_buffer(f).captions,
+ list
+ )
+
+ def test_deprecated_read_buffer(self):
+ with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f:
+ with warnings.catch_warnings(record=True) as warns:
+ warnings.simplefilter('always')
+ vtt = webvtt.read_buffer(f)
+
+ self.assertIsInstance(vtt.captions, list)
+ self.assertEqual(
+ 'Deprecated: use from_buffer instead.',
+ str(warns[-1].message)
+ )
def test_read_memory_buffer(self):
- payload = ''
- with open(self._get_file('sample.vtt'), 'r', encoding='utf-8') as f:
- payload = f.read()
+ buffer = io.StringIO(
+ (PATH_TO_SAMPLES / 'sample.vtt').read_text()
+ )
- buffer = io.StringIO(payload)
- vtt = webvtt.read_buffer(buffer)
- self.assertIsInstance(vtt.captions, list)
+ self.assertIsInstance(
+ webvtt.from_buffer(buffer).captions,
+ list
+ )
def test_read_memory_buffer_carriage_return(self):
"""https://github.com/glut23/webvtt-py/issues/29"""
@@ -251,167 +272,917 @@ def test_read_memory_buffer_carriage_return(self):
\r
00:00:11.890 --> 00:00:16.320\r
Caption text #3\r
- '''))
- vtt = webvtt.read_buffer(buffer)
- self.assertEqual(len(vtt.captions), 3)
+ ''')
+ )
+
+ self.assertEqual(
+ len(webvtt.from_buffer(buffer).captions),
+ 3
+ )
def test_read_malformed_buffer(self):
- malformed_payloads = ['', 'MOCK MELFORMED CONTENT']
+ malformed_payloads = ['', 'MOCK MALFORMED CONTENT']
for payload in malformed_payloads:
buffer = io.StringIO(payload)
with self.assertRaises(MalformedFileError):
- webvtt.read_buffer(buffer)
-
+ webvtt.from_buffer(buffer)
def test_captions(self):
- vtt = webvtt.read(self._get_file('sample.vtt'))
- self.assertIsInstance(vtt.captions, list)
-
- def test_captions_prevent_write(self):
- vtt = webvtt.read(self._get_file('sample.vtt'))
- self.assertRaises(
- AttributeError,
- setattr,
- vtt,
- 'captions',
- []
- )
+ captions = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').captions
+ self.assertIsInstance(
+ captions,
+ list
+ )
+ self.assertEqual(len(captions), 16)
def test_sequence_iteration(self):
- vtt = webvtt.read(self._get_file('sample.vtt'))
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt')
self.assertIsInstance(vtt[0], Caption)
self.assertEqual(len(vtt), len(vtt.captions))
def test_save_no_filename(self):
- vtt = webvtt.WebVTT()
self.assertRaises(
webvtt.errors.MissingFilenameError,
- vtt.save
- )
+ webvtt.WebVTT().save
+ )
- def test_malformed_start_timestamp(self):
- self.assertRaises(
- webvtt.errors.MalformedCaptionError,
- Caption,
- '01:00'
- )
+ def test_save_with_path_to_dir_no_filename(self):
+ with tempfile.TemporaryDirectory() as td:
+ self.assertRaises(
+ webvtt.errors.MissingFilenameError,
+ webvtt.WebVTT().save,
+ td
+ )
def test_set_styles_from_text(self):
- style = Style()
- style.text = '::cue(b) {\n color: peachpuff;\n}'
+ style = Style('::cue(b) {\n color: peachpuff;\n}')
self.assertListEqual(
style.lines,
['::cue(b) {', ' color: peachpuff;', '}']
- )
-
- def test_get_styles_as_text(self):
- style = Style()
- style.lines = ['::cue(b) {', ' color: peachpuff;', '}']
- self.assertEqual(
- style.text,
- '::cue(b) {color: peachpuff;}'
- )
+ )
def test_save_identifiers(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR)
-
- vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt'))
- vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'))
-
- with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- 'WEBVTT',
- '',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- '',
- 'second caption',
- '00:00:07.000 --> 00:00:11.890',
- 'Caption text #2',
- '',
- '00:00:11.890 --> 00:00:16.320',
- 'Caption text #3',
- '',
- '4',
- '00:00:16.320 --> 00:00:21.580',
- 'Caption text #4',
- '',
- '00:00:21.580 --> 00:00:23.880',
- 'Caption text #5',
- '',
- '00:00:23.880 --> 00:00:27.280',
- 'Caption text #6'
- ]
-
- self.assertListEqual(lines, expected_lines)
+ with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'using_identifiers.vtt'
+ ).save(
+ f.name
+ )
+
+ self.assertListEqual(
+ pathlib.Path(f.name).read_text().splitlines(),
+ [
+ 'WEBVTT',
+ '',
+ '00:00:00.500 --> 00:00:07.000',
+ 'Caption text #1',
+ '',
+ 'second caption',
+ '00:00:07.000 --> 00:00:11.890',
+ 'Caption text #2',
+ '',
+ '00:00:11.890 --> 00:00:16.320',
+ 'Caption text #3',
+ '',
+ '4',
+ '00:00:16.320 --> 00:00:21.580',
+ 'Caption text #4',
+ '',
+ '00:00:21.580 --> 00:00:23.880',
+ 'Caption text #5',
+ '',
+ '00:00:23.880 --> 00:00:27.280',
+ 'Caption text #6'
+ ]
+ )
def test_save_updated_identifiers(self):
- os.makedirs(OUTPUT_DIR)
- copy(self._get_file('using_identifiers.vtt'), OUTPUT_DIR)
-
- vtt = webvtt.read(os.path.join(OUTPUT_DIR, 'using_identifiers.vtt'))
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt')
vtt.captions[0].identifier = 'first caption'
vtt.captions[1].identifier = None
vtt.captions[3].identifier = '44'
- last_caption = Caption('00:00:27.280', '00:00:29.200', 'Caption text #7')
+ last_caption = Caption(start='00:00:27.280',
+ end='00:00:29.200',
+ text='Caption text #7'
+ )
last_caption.identifier = 'last caption'
vtt.captions.append(last_caption)
- vtt.save(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'))
-
- with open(os.path.join(OUTPUT_DIR, 'new_using_identifiers.vtt'), 'r', encoding='utf-8') as f:
- lines = [line.rstrip() for line in f.readlines()]
-
- expected_lines = [
- 'WEBVTT',
- '',
- 'first caption',
- '00:00:00.500 --> 00:00:07.000',
- 'Caption text #1',
- '',
- '00:00:07.000 --> 00:00:11.890',
- 'Caption text #2',
- '',
- '00:00:11.890 --> 00:00:16.320',
- 'Caption text #3',
- '',
- '44',
- '00:00:16.320 --> 00:00:21.580',
- 'Caption text #4',
- '',
- '00:00:21.580 --> 00:00:23.880',
- 'Caption text #5',
- '',
- '00:00:23.880 --> 00:00:27.280',
- 'Caption text #6',
- '',
- 'last caption',
- '00:00:27.280 --> 00:00:29.200',
- 'Caption text #7'
- ]
-
- self.assertListEqual(lines, expected_lines)
+
+ with tempfile.NamedTemporaryFile('w', suffix='.vtt') as f:
+ vtt.save(f.name)
+
+ self.assertListEqual(
+ pathlib.Path(f.name).read_text().splitlines(),
+ [
+ 'WEBVTT',
+ '',
+ 'first caption',
+ '00:00:00.500 --> 00:00:07.000',
+ 'Caption text #1',
+ '',
+ '00:00:07.000 --> 00:00:11.890',
+ 'Caption text #2',
+ '',
+ '00:00:11.890 --> 00:00:16.320',
+ 'Caption text #3',
+ '',
+ '44',
+ '00:00:16.320 --> 00:00:21.580',
+ 'Caption text #4',
+ '',
+ '00:00:21.580 --> 00:00:23.880',
+ 'Caption text #5',
+ '',
+ '00:00:23.880 --> 00:00:27.280',
+ 'Caption text #6',
+ '',
+ 'last caption',
+ '00:00:27.280 --> 00:00:29.200',
+ 'Caption text #7'
+ ]
+ )
def test_content_formatting(self):
"""
Verify that content property returns the correctly formatted webvtt.
"""
captions = [
- Caption('00:00:00.500', '00:00:07.000', ['Caption test line 1', 'Caption test line 2']),
- Caption('00:00:08.000', '00:00:15.000', ['Caption test line 3', 'Caption test line 4']),
- ]
- expected_content = textwrap.dedent("""\
- WEBVTT
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text=['Caption test line 1', 'Caption test line 2']
+ ),
+ Caption(start='00:00:08.000',
+ end='00:00:15.000',
+ text=['Caption test line 3', 'Caption test line 4']
+ ),
+ ]
- 00:00:00.500 --> 00:00:07.000
- Caption test line 1
- Caption test line 2
-
- 00:00:08.000 --> 00:00:15.000
- Caption test line 3
- Caption test line 4
- """).strip()
- vtt = webvtt.WebVTT(captions=captions)
- self.assertEqual(expected_content, vtt.content)
+ self.assertEqual(
+ webvtt.WebVTT(captions=captions).content,
+ textwrap.dedent("""
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption test line 1
+ Caption test line 2
+
+ 00:00:08.000 --> 00:00:15.000
+ Caption test line 3
+ Caption test line 4
+ """).strip()
+ )
+
+ def test_repr(self):
+ test_file = PATH_TO_SAMPLES / 'sample.vtt'
+ self.assertEqual(
+ repr(webvtt.read(test_file)),
+ f""
+ )
+
+ def test_str(self):
+ self.assertEqual(
+ str(webvtt.read(PATH_TO_SAMPLES / 'sample.vtt')),
+ textwrap.dedent("""
+ 00:00:00.500 00:00:07.000 Caption text #1
+ 00:00:07.000 00:00:11.890 Caption text #2
+ 00:00:11.890 00:00:16.320 Caption text #3
+ 00:00:16.320 00:00:21.580 Caption text #4
+ 00:00:21.580 00:00:23.880 Caption text #5
+ 00:00:23.880 00:00:27.280 Caption text #6
+ 00:00:27.280 00:00:30.280 Caption text #7
+ 00:00:30.280 00:00:36.510 Caption text #8
+ 00:00:36.510 00:00:38.870 Caption text #9
+ 00:00:38.870 00:00:45.000 Caption text #10
+ 00:00:45.000 00:00:47.000 Caption text #11
+ 00:00:47.000 00:00:50.970 Caption text #12
+ 00:00:50.970 00:00:54.440 Caption text #13
+ 00:00:54.440 00:00:58.600 Caption text #14
+ 00:00:58.600 00:01:01.350 Caption text #15
+ 00:01:01.350 00:01:04.300 Caption text #16
+ """).strip()
+ )
+
+ def test_parse_invalid_file(self):
+ self.assertRaises(
+ MalformedFileError,
+ webvtt.read,
+ PATH_TO_SAMPLES / 'invalid.vtt'
+ )
+
+ def test_file_not_found(self):
+ self.assertRaises(
+ FileNotFoundError,
+ webvtt.read,
+ 'nowhere'
+ )
+
+ def test_total_length(self):
+ self.assertEqual(
+ webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').total_length,
+ 64
+ )
+
+ def test_total_length_no_captions(self):
+ self.assertEqual(
+ webvtt.WebVTT().total_length,
+ 0
+ )
+
+ def test_parse_empty_file(self):
+ self.assertRaises(
+ MalformedFileError,
+ webvtt.read,
+ PATH_TO_SAMPLES / 'empty.vtt'
+ )
+
+ def test_parse_invalid_timeframe_line(self):
+ good_captions = len(
+ webvtt.read(PATH_TO_SAMPLES / 'invalid_timeframe.vtt').captions
+ )
+ self.assertEqual(good_captions, 6)
+
+ def test_parse_invalid_timeframe_in_cue_text(self):
+ vtt = webvtt.read(
+ PATH_TO_SAMPLES / 'invalid_timeframe_in_cue_text.vtt'
+ )
+ self.assertEqual(2, len(vtt.captions))
+ self.assertEqual('Caption text #3', vtt.captions[1].text)
+
+ def test_parse_get_caption_data(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'one_caption.vtt')
+ self.assertEqual(vtt.captions[0].start_in_seconds, 0)
+ self.assertEqual(vtt.captions[0].start, '00:00:00.500')
+ self.assertEqual(vtt.captions[0].end_in_seconds, 7)
+ self.assertEqual(vtt.captions[0].end, '00:00:07.000')
+ self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1')
+ self.assertEqual(len(vtt.captions[0].lines), 1)
+
+ def test_caption_without_timeframe(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_timeframe.vtt')
+ self.assertEqual(len(vtt.captions), 6)
+
+ def test_caption_without_cue_text(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'missing_caption_text.vtt')
+ self.assertEqual(len(vtt.captions), 4)
+
+ def test_timestamps_format(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt')
+ self.assertEqual(vtt.captions[2].start, '00:00:11.890')
+ self.assertEqual(vtt.captions[2].end, '00:00:16.320')
+
+ def test_parse_timestamp(self):
+ self.assertEqual(
+ Caption(start='02:03:11.890').start_in_seconds,
+ 7391
+ )
+
+ def test_captions_attribute(self):
+ self.assertListEqual(webvtt.WebVTT().captions, [])
+
+ def test_metadata_headers(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers.vtt')
+ self.assertEqual(len(vtt.captions), 2)
+
+ def test_metadata_headers_multiline(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'metadata_headers_multiline.vtt')
+ self.assertEqual(len(vtt.captions), 2)
+
+ def test_parse_identifiers(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'using_identifiers.vtt')
+ self.assertEqual(len(vtt.captions), 6)
+
+ self.assertEqual(vtt.captions[1].identifier, 'second caption')
+ self.assertEqual(vtt.captions[2].identifier, None)
+ self.assertEqual(vtt.captions[3].identifier, '4')
+
+ def test_parse_comments(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'comments.vtt')
+ self.assertEqual(len(vtt.captions), 3)
+ self.assertListEqual(
+ vtt.captions[0].lines,
+ ['- Ta en kopp varmt te.',
+ '- Det är inte varmt.']
+ )
+ self.assertListEqual(
+ vtt.captions[0].comments,
+ []
+ )
+ self.assertListEqual(
+ vtt.captions[1].comments,
+ []
+ )
+ self.assertEqual(
+ vtt.captions[2].text,
+ '- Ta en kopp'
+ )
+ self.assertListEqual(
+ vtt.captions[2].comments,
+ ['This last line may not translate well.']
+ )
+ self.assertListEqual(
+ vtt.header_comments,
+ ['This translation was done by Kyle so that\n'
+ 'some friends can watch it with their parents.'
+ ]
+ )
+ self.assertListEqual(
+ vtt.footer_comments,
+ ['end of file']
+ )
+
+ def test_parse_styles(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'styles.vtt')
+ self.assertEqual(len(vtt.captions), 1)
+ self.assertEqual(
+ vtt.styles[0].text,
+ '::cue {\n background-image: linear-gradient(to bottom, '
+ 'dimgray, lightgray);\n color: papayawhip;\n}'
+ )
+
+ def test_parse_styles_with_comments(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'styles_with_comments.vtt')
+ self.assertEqual(len(vtt.captions), 1)
+ self.assertEqual(len(vtt.styles), 2)
+ self.assertEqual(
+ vtt.styles[0].comments,
+ []
+ )
+ self.assertEqual(
+ vtt.styles[0].text,
+ '::cue {\n'
+ ' background-image: linear-gradient(to bottom, dimgray, '
+ 'lightgray);\n'
+ ' color: papayawhip;\n'
+ '}'
+ )
+ self.assertEqual(
+ vtt.styles[1].comments,
+ ['This is the second block of styles',
+ 'Multiline comment for the same\nsecond block of styles'
+ ]
+ )
+ self.assertEqual(
+ vtt.styles[1].text,
+ '::cue(b) {\n'
+ ' color: peachpuff;\n'
+ '}'
+ )
+ self.assertListEqual(
+ vtt.header_comments,
+ ['Sample of comments with styles']
+ )
+ self.assertListEqual(
+ vtt.footer_comments,
+ []
+ )
+
+ def test_multiple_comments_everywhere(self):
+ vtt = webvtt.WebVTT.from_string(textwrap.dedent("""
+ WEBVTT
+
+ NOTE Test file
+
+ NOTE this file is testing multiple comments
+ in different places
+
+ STYLE
+ ::cue {
+ background-image: linear-gradient(to bottom, dimgray, lightgray);
+ color: papayawhip;
+ }
+
+ NOTE this style uses nice color
+
+ NOTE check how it looks before proceeding
+
+ STYLE
+ ::cue(b) {
+ color: peachpuff;
+ }
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ NOTE next caption has two lines
+
+ NOTE needs review
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2 line 1
+ Caption text #2 line 2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+
+ NOTE Copyright 2024
+
+ NOTE this is the end of the file
+ """).strip()
+ )
+
+ self.assertListEqual(
+ vtt.header_comments,
+ ['Test file',
+ 'this file is testing multiple comments\nin different places'
+ ]
+ )
+ self.assertListEqual(
+ vtt.footer_comments,
+ ['Copyright 2024',
+ 'this is the end of the file'
+ ]
+ )
+ self.assertListEqual(
+ vtt.captions[0].comments,
+ []
+ )
+ self.assertListEqual(
+ vtt.captions[1].comments,
+ ['next caption has two lines',
+ 'needs review'
+ ]
+ )
+ self.assertListEqual(
+ vtt.captions[2].comments,
+ []
+ )
+ self.assertListEqual(
+ vtt.captions[3].comments,
+ []
+ )
+ self.assertListEqual(
+ vtt.styles[0].comments,
+ []
+ )
+ self.assertListEqual(
+ vtt.styles[1].comments,
+ ['this style uses nice color',
+ 'check how it looks before proceeding'
+ ]
+ )
+
+ def test_comments_in_new_file(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ vtt = webvtt.WebVTT()
+ vtt.header_comments.append('This is a header comment')
+ vtt.header_comments.append(
+ 'where we can see a\ntwo line comment'
+ )
+ vtt.styles.append(
+ Style('::cue(b) {\n color: peachpuff;\n}')
+ )
+ style = Style('::cue {\n color: papayawhip;\n}')
+ style.comments.append('Another style to test\nthe look and feel')
+ style.comments.append('Please check')
+ vtt.styles.append(style)
+ vtt.captions.append(
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption #1',
+ )
+ )
+ caption = Caption(start='00:00:07.000',
+ end='00:00:11.890',
+ text='Caption #2'
+ )
+ caption.comments.append(
+ 'Second caption may be a bit off\nand needs checking'
+ )
+ caption.comments.append('Confirm if it displays correctly')
+ vtt.captions.append(caption)
+ vtt.footer_comments.append('This is a footer comment')
+ vtt.footer_comments.append(
+ 'where we can also see a\ntwo line comment'
+ )
+
+ vtt.save(f.name)
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent('''
+ WEBVTT
+
+ NOTE This is a header comment
+
+ NOTE
+ where we can see a
+ two line comment
+
+ STYLE
+ ::cue(b) {
+ color: peachpuff;
+ }
+
+ NOTE
+ Another style to test
+ the look and feel
+
+ NOTE Please check
+
+ STYLE
+ ::cue {
+ color: papayawhip;
+ }
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption #1
+
+ NOTE
+ Second caption may be a bit off
+ and needs checking
+
+ NOTE Confirm if it displays correctly
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption #2
+
+ NOTE This is a footer comment
+
+ NOTE
+ where we can also see a
+ two line comment
+ ''').strip()
+ )
+
+ def test_clean_cue_tags(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'cue_tags.vtt')
+ self.assertEqual(
+ vtt.captions[1].text,
+ 'Like a big-a pizza pie'
+ )
+ self.assertEqual(
+ vtt.captions[2].text,
+ 'That\'s amore'
+ )
+
+ def test_parse_captions_with_bom(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'captions_with_bom.vtt')
+ self.assertEqual(len(vtt.captions), 4)
+
+ def test_empty_lines_are_not_included_in_result(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'netflix_chicas_del_cable.vtt')
+ self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,")
+ self.assertEqual(
+ vtt.captions[-2].text,
+ "Diez años no son suficientes\npara olvidarte..."
+ )
+
+ def test_can_parse_youtube_dl_files(self):
+ vtt = webvtt.read(PATH_TO_SAMPLES / 'youtube_dl.vtt')
+ self.assertEqual(
+ "this will happen is I'm telling\n ",
+ vtt.captions[2].text
+ )
+
+
+class TestParseSRT(unittest.TestCase):
+
+ def test_parse_empty_file(self):
+ self.assertRaises(
+ webvtt.errors.MalformedFileError,
+ webvtt.from_srt,
+ # We reuse this file as it is empty and serves the purpose.
+ PATH_TO_SAMPLES / 'empty.vtt'
+ )
+
+ def test_invalid_format(self):
+ for i in range(1, 5):
+ self.assertRaises(
+ MalformedFileError,
+ webvtt.from_srt,
+ PATH_TO_SAMPLES / f'invalid_format{i}.srt'
+ )
+
+ def test_total_length(self):
+ self.assertEqual(
+ webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').total_length,
+ 23
+ )
+
+ def test_parse_captions(self):
+ self.assertTrue(
+ webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt').captions
+ )
+
+ def test_missing_timeframe_line(self):
+ self.assertEqual(
+ len(webvtt.from_srt(
+ PATH_TO_SAMPLES / 'missing_timeframe.srt').captions
+ ),
+ 4
+ )
+
+ def test_empty_caption_text(self):
+ self.assertTrue(
+ webvtt.from_srt(
+ PATH_TO_SAMPLES / 'missing_caption_text.srt').captions
+ )
+
+ def test_empty_gets_removed(self):
+ captions = webvtt.from_srt(
+ PATH_TO_SAMPLES / 'missing_caption_text.srt'
+ ).captions
+ self.assertEqual(len(captions), 4)
+
+ def test_invalid_timestamp(self):
+ self.assertEqual(
+ len(webvtt.from_srt(
+ PATH_TO_SAMPLES / 'invalid_timeframe.srt'
+ ).captions),
+ 4
+ )
+
+ def test_timestamps_format(self):
+ vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'sample.srt')
+ self.assertEqual(vtt.captions[2].start, '00:00:11.890')
+ self.assertEqual(vtt.captions[2].end, '00:00:16.320')
+
+ def test_parse_get_caption_data(self):
+ vtt = webvtt.from_srt(PATH_TO_SAMPLES / 'one_caption.srt')
+ self.assertEqual(vtt.captions[0].start_in_seconds, 0)
+ self.assertEqual(vtt.captions[0].start, '00:00:00.500')
+ self.assertEqual(vtt.captions[0].end_in_seconds, 7)
+ self.assertEqual(vtt.captions[0].end, '00:00:07.000')
+ self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1')
+ self.assertEqual(len(vtt.captions[0].lines), 1)
+
+
+class TestParseSBV(unittest.TestCase):
+
+ def test_parse_empty_file(self):
+ self.assertRaises(
+ MalformedFileError,
+ webvtt.from_sbv,
+ # We reuse this file as it is empty and serves the purpose.
+ PATH_TO_SAMPLES / 'empty.vtt'
+ )
+
+ def test_invalid_format(self):
+ self.assertRaises(
+ MalformedFileError,
+ webvtt.from_sbv,
+ PATH_TO_SAMPLES / 'invalid_format.sbv'
+ )
+
+ def test_total_length(self):
+ self.assertEqual(
+ webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').total_length,
+ 16
+ )
+
+ def test_parse_captions(self):
+ self.assertEqual(
+ len(webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv').captions),
+ 5
+ )
+
+ def test_missing_timeframe_line(self):
+ self.assertEqual(
+ len(webvtt.from_sbv(
+ PATH_TO_SAMPLES / 'missing_timeframe.sbv'
+ ).captions),
+ 4
+ )
+
+ def test_missing_caption_text(self):
+ self.assertTrue(
+ webvtt.from_sbv(
+ PATH_TO_SAMPLES / 'missing_caption_text.sbv'
+ ).captions
+ )
+
+ def test_invalid_timestamp(self):
+ self.assertEqual(
+ len(webvtt.from_sbv(
+ PATH_TO_SAMPLES / 'invalid_timeframe.sbv'
+ ).captions),
+ 4
+ )
+
+ def test_timestamps_format(self):
+ vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv')
+ self.assertEqual(vtt.captions[1].start, '00:00:11.378')
+ self.assertEqual(vtt.captions[1].end, '00:00:12.305')
+
+ def test_timestamps_in_seconds(self):
+ vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv')
+ self.assertEqual(vtt.captions[1].start_in_seconds, 11)
+ self.assertEqual(vtt.captions[1].end_in_seconds, 12)
+
+ def test_get_caption_text(self):
+ vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv')
+ self.assertEqual(vtt.captions[1].text, 'Caption text #2')
+
+ def test_get_caption_text_multiline(self):
+ vtt = webvtt.from_sbv(PATH_TO_SAMPLES / 'sample.sbv')
+ self.assertEqual(
+ vtt.captions[2].text,
+ 'Caption text #3 (line 1)\nCaption text #3 (line 2)'
+ )
+ self.assertListEqual(
+ vtt.captions[2].lines,
+ ['Caption text #3 (line 1)', 'Caption text #3 (line 2)']
+ )
+
+ def test_convert_from_srt_to_vtt_and_back_gives_same_file(self):
+ with tempfile.NamedTemporaryFile('w', suffix='.srt') as f:
+ webvtt.from_srt(
+ PATH_TO_SAMPLES / 'sample.srt'
+ ).save_as_srt(f.name)
+
+ self.assertEqual(
+ pathlib.Path(PATH_TO_SAMPLES / 'sample.srt').read_text(),
+ pathlib.Path(f.name).read_text()
+ )
+
+ def test_save_file_with_bom(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(f.name, add_bom=True)
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent(f'''
+ {CODEC_BOMS['utf-8'].decode()}WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
+
+ def test_save_file_with_bom_keeps_bom(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'captions_with_bom.vtt'
+ ).save(f.name)
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent(f'''
+ {CODEC_BOMS['utf-8'].decode()}WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+ ''').strip()
+ )
+
+ def test_save_file_with_bom_removes_bom_if_requested(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'captions_with_bom.vtt'
+ ).save(f.name, add_bom=False)
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent(f'''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+
+ 00:00:07.000 --> 00:00:11.890
+ Caption text #2
+
+ 00:00:11.890 --> 00:00:16.320
+ Caption text #3
+
+ 00:00:16.320 --> 00:00:21.580
+ Caption text #4
+ ''').strip()
+ )
+
+ def test_save_file_with_encoding(self):
+ with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(f.name,
+ encoding='utf-32-le'
+ )
+ self.assertEqual(
+ f.read().decode('utf-32-le'),
+ textwrap.dedent('''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
+
+ def test_save_file_with_encoding_and_bom(self):
+ with tempfile.NamedTemporaryFile('rb', suffix='.vtt') as f:
+ webvtt.read(
+ PATH_TO_SAMPLES / 'one_caption.vtt'
+ ).save(f.name,
+ encoding='utf-32-le',
+ add_bom=True
+ )
+ self.assertEqual(
+ f.read().decode('utf-32-le'),
+ textwrap.dedent(f'''
+ {CODEC_BOMS['utf-32-le'].decode('utf-32-le')}WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
+
+ def test_save_new_file_utf_8_default_encoding_no_bom(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ vtt = webvtt.WebVTT()
+ vtt.captions.append(
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption text #1'
+ )
+ )
+ vtt.save(f.name)
+ self.assertEqual(vtt.encoding, 'utf-8')
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent(f'''
+ WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
+
+ def test_save_new_file_utf_8_default_encoding_with_bom(self):
+ with tempfile.NamedTemporaryFile('r', suffix='.vtt') as f:
+ vtt = webvtt.WebVTT()
+ vtt.captions.append(
+ Caption(start='00:00:00.500',
+ end='00:00:07.000',
+ text='Caption text #1'
+ )
+ )
+ vtt.save(f.name,
+ add_bom=True
+ )
+ self.assertEqual(vtt.encoding, 'utf-8')
+ self.assertEqual(
+ f.read(),
+ textwrap.dedent(f'''
+ {CODEC_BOMS['utf-8'].decode()}WEBVTT
+
+ 00:00:00.500 --> 00:00:07.000
+ Caption text #1
+ ''').strip()
+ )
+
+ def test_iter_slice(self):
+ vtt = webvtt.read(
+ PATH_TO_SAMPLES / 'sample.vtt'
+ )
+ slice_of_captions = vtt.iter_slice(start='00:00:11.000',
+ end='00:00:27.000'
+ )
+ for expected_caption in (vtt.captions[2],
+ vtt.captions[3],
+ vtt.captions[4]
+ ):
+ self.assertIs(expected_caption, next(slice_of_captions))
+
+ with self.assertRaises(StopIteration):
+ next(slice_of_captions)
+
+ def test_iter_slice_no_start_time(self):
+ vtt = webvtt.read(
+ PATH_TO_SAMPLES / 'sample.vtt'
+ )
+ slice_of_captions = vtt.iter_slice(end='00:00:27.000')
+ for expected_caption in (vtt.captions[0],
+ vtt.captions[1],
+ vtt.captions[2],
+ vtt.captions[3],
+ vtt.captions[4]
+ ):
+ self.assertIs(expected_caption, next(slice_of_captions))
+
+ with self.assertRaises(StopIteration):
+ next(slice_of_captions)
+
+ def test_iter_slice_no_end_time(self):
+ vtt = webvtt.read(
+ PATH_TO_SAMPLES / 'sample.vtt'
+ )
+ slice_of_captions = vtt.iter_slice(start='00:00:47.000')
+ for expected_caption in (vtt.captions[11],
+ vtt.captions[12],
+ vtt.captions[13],
+ vtt.captions[14],
+ vtt.captions[15]
+ ):
+ self.assertIs(expected_caption, next(slice_of_captions))
+
+ with self.assertRaises(StopIteration):
+ next(slice_of_captions)
diff --git a/tests/test_webvtt_parser.py b/tests/test_webvtt_parser.py
deleted file mode 100644
index a26741a..0000000
--- a/tests/test_webvtt_parser.py
+++ /dev/null
@@ -1,172 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from .generic import GenericParserTestCase
-
-import webvtt
-from webvtt.parsers import WebVTTParser
-from webvtt.structures import Caption
-from webvtt.errors import MalformedFileError, MalformedCaptionError
-
-
-class WebVTTParserTestCase(GenericParserTestCase):
-
- def test_webvtt_parse_invalid_file(self):
- self.assertRaises(
- MalformedFileError,
- webvtt.read,
- self._get_file('invalid.vtt')
- )
-
- def test_webvtt_captions_not_found(self):
- self.assertRaises(
- FileNotFoundError,
- webvtt.read,
- 'some_file'
- )
-
- def test_webvtt_total_length(self):
- self.assertEqual(
- webvtt.read(self._get_file('sample.vtt')).total_length,
- 64
- )
-
- def test_webvtt_total_length_no_parser(self):
- self.assertEqual(
- webvtt.WebVTT().total_length,
- 0
- )
-
- def test_webvtt__parse_captions(self):
- self.assertTrue(webvtt.read(self._get_file('sample.vtt')).captions)
-
- def test_webvtt_parse_empty_file(self):
- self.assertRaises(
- MalformedFileError,
- webvtt.read,
- self._get_file('empty.vtt')
- )
-
- def test_webvtt_parse_get_captions(self):
- self.assertEqual(
- len(webvtt.read(self._get_file('sample.vtt')).captions),
- 16
- )
-
- def test_webvtt_parse_invalid_timeframe_line(self):
- self.assertRaises(
- MalformedCaptionError,
- webvtt.read,
- self._get_file('invalid_timeframe.vtt')
- )
-
- def test_webvtt_parse_invalid_timeframe_in_cue_text(self):
- vtt = webvtt.read(self._get_file('invalid_timeframe_in_cue_text.vtt'))
- self.assertEqual(4, len(vtt.captions))
- self.assertEqual('', vtt.captions[1].text)
-
- def test_webvtt_parse_get_caption_data(self):
- vtt = webvtt.read(self._get_file('one_caption.vtt'))
- self.assertEqual(vtt.captions[0].start_in_seconds, 0.5)
- self.assertEqual(vtt.captions[0].start, '00:00:00.500')
- self.assertEqual(vtt.captions[0].end_in_seconds, 7)
- self.assertEqual(vtt.captions[0].end, '00:00:07.000')
- self.assertEqual(vtt.captions[0].lines[0], 'Caption text #1')
- self.assertEqual(len(vtt.captions[0].lines), 1)
-
- def test_webvtt_caption_without_timeframe(self):
- self.assertRaises(
- MalformedCaptionError,
- webvtt.read,
- self._get_file('missing_timeframe.vtt')
- )
-
- def test_webvtt_caption_without_cue_text(self):
- vtt = webvtt.read(self._get_file('missing_caption_text.vtt'))
- self.assertEqual(len(vtt.captions), 5)
-
- def test_webvtt_timestamps_format(self):
- vtt = webvtt.read(self._get_file('sample.vtt'))
- self.assertEqual(vtt.captions[2].start, '00:00:11.890')
- self.assertEqual(vtt.captions[2].end, '00:00:16.320')
-
- def test_parse_timestamp(self):
- caption = Caption(start='02:03:11.890')
- self.assertEqual(
- caption.start_in_seconds,
- 7391.89
- )
-
- def test_captions_attribute(self):
- self.assertListEqual([], webvtt.WebVTT().captions)
-
- def test_webvtt_timestamp_format(self):
- self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00:00.000 --> 00:00:00.000'))
- self.assertTrue(WebVTTParser()._validate_timeframe_line('00:00.000 --> 00:00.000'))
-
- def test_metadata_headers(self):
- vtt = webvtt.read(self._get_file('metadata_headers.vtt'))
- self.assertEqual(len(vtt.captions), 2)
-
- def test_metadata_headers_multiline(self):
- vtt = webvtt.read(self._get_file('metadata_headers_multiline.vtt'))
- self.assertEqual(len(vtt.captions), 2)
-
- def test_parse_identifiers(self):
- vtt = webvtt.read(self._get_file('using_identifiers.vtt'))
- self.assertEqual(len(vtt.captions), 6)
-
- self.assertEqual(vtt.captions[1].identifier, 'second caption')
- self.assertEqual(vtt.captions[2].identifier, None)
- self.assertEqual(vtt.captions[3].identifier, '4')
-
- def test_parse_with_comments(self):
- vtt = webvtt.read(self._get_file('comments.vtt'))
- self.assertEqual(len(vtt.captions), 3)
- self.assertListEqual(
- vtt.captions[0].lines,
- ['- Ta en kopp varmt te.',
- '- Det är inte varmt.']
- )
- self.assertEqual(
- vtt.captions[2].text,
- '- Ta en kopp'
- )
-
- def test_parse_styles(self):
- vtt = webvtt.read(self._get_file('styles.vtt'))
- self.assertEqual(len(vtt.captions), 1)
- self.assertEqual(
- vtt.styles[0].text,
- '::cue {background-image: linear-gradient(to bottom, dimgray, lightgray);color: papayawhip;}'
- )
-
- def test_clean_cue_tags(self):
- vtt = webvtt.read(self._get_file('cue_tags.vtt'))
- self.assertEqual(
- vtt.captions[1].text,
- 'Like a big-a pizza pie'
- )
- self.assertEqual(
- vtt.captions[2].text,
- 'That\'s amore'
- )
-
- def test_parse_captions_with_bom(self):
- vtt = webvtt.read(self._get_file('captions_with_bom.vtt'))
- self.assertEqual(len(vtt.captions), 4)
-
- def test_empty_lines_are_not_included_in_result(self):
- vtt = webvtt.read(self._get_file('netflix_chicas_del_cable.vtt'))
- self.assertEqual(vtt.captions[0].text, "[Alba] En 1928,")
- self.assertEqual(
- vtt.captions[-2].text,
- "Diez años no son suficientes\npara olvidarte..."
- )
-
- def test_can_parse_youtube_dl_files(self):
- vtt = webvtt.read(self._get_file('youtube_dl.vtt'))
- self.assertEqual(
- "this will happen is I'm telling",
- vtt.captions[2].text
- )
diff --git a/tox.ini b/tox.ini
deleted file mode 100644
index 06fa4ff..0000000
--- a/tox.ini
+++ /dev/null
@@ -1,18 +0,0 @@
-[tox]
-envlist = py34, py35, py36, py37, py38, py39
-
-[travis]
-python =
- 3.9: py39
- 3.8: py38
- 3.7: py37
- 3.6: py36
- 3.5: py35
- 3.4: py34
-
-[testenv]
-setenv =
- PYTHONPATH = {toxinidir}
-deps = pytest
-commands =
- pytest
diff --git a/webvtt/__init__.py b/webvtt/__init__.py
index b6239f4..33fa037 100644
--- a/webvtt/__init__.py
+++ b/webvtt/__init__.py
@@ -1,15 +1,16 @@
-__version__ = '0.4.6'
+"""Main webvtt package."""
-from .webvtt import *
-from .segmenter import *
-from .structures import *
-from .errors import *
+__version__ = '0.5.0'
-__all__ = webvtt.__all__ + segmenter.__all__ + structures.__all__ + errors.__all__
+from .webvtt import WebVTT
+from . import segmenter
+from .models import Caption, Style # noqa
+
+__all__ = ['WebVTT', 'Caption', 'Style']
read = WebVTT.read
read_buffer = WebVTT.read_buffer
+from_buffer = WebVTT.from_buffer
from_srt = WebVTT.from_srt
from_sbv = WebVTT.from_sbv
-list_formats = WebVTT.list_formats
-segment = WebVTTSegmenter().segment
+segment = segmenter.segment
diff --git a/webvtt/cli.py b/webvtt/cli.py
index ad8ebbf..a50e7d0 100644
--- a/webvtt/cli.py
+++ b/webvtt/cli.py
@@ -1,47 +1,59 @@
-"""
-Usage:
- webvtt segment [--target-duration=SECONDS] [--mpegts=OFFSET] [--output=]
- webvtt -h | --help
- webvtt --version
-
-Options:
- -h --help Show this screen.
- --version Show version.
- --target-duration=SECONDS Target duration of each segment in seconds [default: 10].
- --mpegts=OFFSET Presentation timestamp value [default: 900000].
- --output= Output to directory [default: ./].
-
-Examples:
- webvtt segment captions.vtt --output destination/directory
-"""
-
-from docopt import docopt
-
-from . import WebVTTSegmenter, __version__
-
-
-def main():
- """Main entry point for CLI commands."""
- options = docopt(__doc__, version=__version__)
- if options['segment']:
- segment(
- options[''],
- options['--output'],
- options['--target-duration'],
- options['--mpegts'],
+"""CLI module."""
+
+import argparse
+import typing
+
+from . import segmenter
+
+
+def main(argv: typing.Optional[typing.Sequence] = None):
+ """
+ Segment WebVTT file from command line.
+
+ :param argv: command line arguments
+ """
+ arguments = argparse.ArgumentParser(
+ description='Segment WebVTT files.'
+ )
+ arguments.add_argument(
+ 'command',
+ choices=['segment'],
+ help='command to perform'
+ )
+ arguments.add_argument(
+ 'file',
+ metavar='PATH',
+ help='WebVTT file'
+ )
+ arguments.add_argument(
+ '-o', '--output',
+ metavar='PATH',
+ help='output directory'
+ )
+ arguments.add_argument(
+ '-d', '--target-duration',
+ metavar='NUMBER',
+ type=int,
+ default=segmenter.DEFAULT_SECONDS,
+ help='target duration of each segment in seconds, default: 10'
+ )
+ arguments.add_argument(
+ '-m', '--mpegts',
+ metavar='NUMBER',
+ type=int,
+ default=segmenter.DEFAULT_MPEGTS,
+ help='presentation timestamp value, default: 900000'
)
+ args = arguments.parse_args(argv)
-def segment(f, output, target_duration, mpegts):
- """Segment command."""
- try:
- target_duration = int(target_duration)
- except ValueError:
- exit('Error: Invalid target duration.')
+ segmenter.segment(
+ args.file,
+ args.output,
+ args.target_duration,
+ args.mpegts
+ )
- try:
- mpegts = int(mpegts)
- except ValueError:
- exit('Error: Invalid MPEGTS value.')
- WebVTTSegmenter().segment(f, output, target_duration, mpegts)
\ No newline at end of file
+if __name__ == '__main__':
+ main() # pragma: no cover
diff --git a/webvtt/errors.py b/webvtt/errors.py
index 9ee549b..f8628c9 100644
--- a/webvtt/errors.py
+++ b/webvtt/errors.py
@@ -1,18 +1,13 @@
-
-__all__ = ['MalformedFileError', 'MalformedCaptionError', 'InvalidCaptionsError', 'MissingFilenameError']
+"""Errors module."""
class MalformedFileError(Exception):
- """Error raised when the file is not well formatted"""
+ """File is not in the right format."""
class MalformedCaptionError(Exception):
- """Error raised when a caption is not well formatted"""
-
-
-class InvalidCaptionsError(Exception):
- """Error raised when passing wrong captions to the segmenter"""
+ """Caption not in the right format."""
class MissingFilenameError(Exception):
- """Error raised when saving a file without filename."""
+ """Missing a filename when saving to disk."""
diff --git a/webvtt/models.py b/webvtt/models.py
new file mode 100644
index 0000000..31be12f
--- /dev/null
+++ b/webvtt/models.py
@@ -0,0 +1,222 @@
+"""Models module."""
+
+import re
+import typing
+
+from .errors import MalformedCaptionError
+
+
+class Timestamp:
+ """Representation of a timestamp."""
+
+ PATTERN = re.compile(r'(?:(\d{1,2}):)?(\d{1,2}):(\d{1,2})\.(\d{3})')
+
+ def __init__(
+ self,
+ hours: int = 0,
+ minutes: int = 0,
+ seconds: int = 0,
+ milliseconds: int = 0
+ ):
+ """Initialize."""
+ self.hours = hours
+ self.minutes = minutes
+ self.seconds = seconds
+ self.milliseconds = milliseconds
+
+ def __str__(self):
+ """Return the string representation of the timestamp."""
+ return (
+ f'{self.hours:02d}:{self.minutes:02d}:{self.seconds:02d}'
+ f'.{self.milliseconds:03d}'
+ )
+
+ def to_tuple(self) -> typing.Tuple[int, int, int, int]:
+ """Return the timestamp in tuple form."""
+ return self.hours, self.minutes, self.seconds, self.milliseconds
+
+ def __repr__(self):
+ """Return the string representation of the caption."""
+ return (f'<{self.__class__.__name__} '
+ f'hours={self.hours} '
+ f'minutes={self.minutes} '
+ f'seconds={self.seconds} '
+ f'milliseconds={self.milliseconds}>'
+ )
+
+ def __eq__(self, other):
+ """Compare equality with other object."""
+ return self.to_tuple() == other.to_tuple()
+
+ def __ne__(self, other):
+ """Compare a not equality with other object."""
+ return self.to_tuple() != other.to_tuple()
+
+ def __gt__(self, other):
+ """Compare greater than with other object."""
+ return self.to_tuple() > other.to_tuple()
+
+ def __lt__(self, other):
+ """Compare less than with other object."""
+ return self.to_tuple() < other.to_tuple()
+
+ def __ge__(self, other):
+ """Compare greater or equal with other object."""
+ return self.to_tuple() >= other.to_tuple()
+
+ def __le__(self, other):
+ """Compare less or equal with other object."""
+ return self.to_tuple() <= other.to_tuple()
+
+ @classmethod
+ def from_string(cls, value: str) -> 'Timestamp':
+ """Return a `Timestamp` instance from a string value."""
+ if type(value) is not str:
+ raise MalformedCaptionError(f'Invalid timestamp {value!r}')
+
+ match = re.match(cls.PATTERN, value)
+ if not match:
+ raise MalformedCaptionError(f'Invalid timestamp {value!r}')
+
+ hours = int(match.group(1) or 0)
+ minutes = int(match.group(2))
+ seconds = int(match.group(3))
+ milliseconds = int(match.group(4))
+
+ if minutes > 59 or seconds > 59:
+ raise MalformedCaptionError(f'Invalid timestamp {value!r}')
+
+ return cls(hours, minutes, seconds, milliseconds)
+
+ def in_seconds(self) -> int:
+ """Return the timestamp in seconds."""
+ return (self.hours * 3600 +
+ self.minutes * 60 +
+ self.seconds
+ )
+
+
+class Caption:
+ """Representation of a caption."""
+
+ CUE_TEXT_TAGS = re.compile('<.*?>')
+
+ def __init__(self,
+ start: typing.Optional[str] = None,
+ end: typing.Optional[str] = None,
+ text: typing.Optional[typing.Union[str,
+ typing.Sequence[str]
+ ]] = None,
+ identifier: typing.Optional[str] = None
+ ):
+ """
+ Initialize.
+
+ :param start: start time of the caption
+ :param end: end time of the caption
+ :param text: the text of the caption
+ :param identifier: optional identifier
+ """
+ text = text or []
+ self.start = start or '00:00:00.000'
+ self.end = end or '00:00:00.000'
+ self.identifier = identifier
+ self.lines = (text.splitlines()
+ if isinstance(text, str)
+ else
+ list(text)
+ )
+ self.comments: typing.List[str] = []
+
+ def __repr__(self):
+ """Return the string representation of the caption."""
+ cleaned_text = self.text.replace('\n', '\\n')
+ return (f'<{self.__class__.__name__} '
+ f'start={self.start!r} '
+ f'end={self.end!r} '
+ f'text={cleaned_text!r} '
+ f'identifier={self.identifier!r}>'
+ )
+
+ def __str__(self):
+ """Return a readable representation of the caption."""
+ cleaned_text = self.text.replace('\n', '\\n')
+ return f'{self.start} {self.end} {cleaned_text}'
+
+ def __eq__(self, other):
+ """Compare equality with another object."""
+ if not isinstance(other, type(self)):
+ return False
+
+ return (self.start == other.start and
+ self.end == other.end and
+ self.raw_text == other.raw_text and
+ self.identifier == other.identifier
+ )
+
+ @property
+ def start(self):
+ """Return the start time of the caption."""
+ return str(self.start_time)
+
+ @start.setter
+ def start(self, value: str):
+ """Set the start time of the caption."""
+ self.start_time = Timestamp.from_string(value)
+
+ @property
+ def end(self):
+ """Return the end time of the caption."""
+ return str(self.end_time)
+
+ @end.setter
+ def end(self, value: str):
+ """Set the end time of the caption."""
+ self.end_time = Timestamp.from_string(value)
+
+ @property
+ def start_in_seconds(self) -> int:
+ """Return the start time of the caption in seconds."""
+ return self.start_time.in_seconds()
+
+ @property
+ def end_in_seconds(self):
+ """Return the end time of the caption in seconds."""
+ return self.end_time.in_seconds()
+
+ @property
+ def raw_text(self) -> str:
+ """Return the text of the caption (including cue tags)."""
+ return '\n'.join(self.lines)
+
+ @property
+ def text(self) -> str:
+ """Return the text of the caption (without cue tags)."""
+ return re.sub(self.CUE_TEXT_TAGS, '', self.raw_text)
+
+ @text.setter
+ def text(self, value: str):
+ """Set the text of the captions."""
+ if not isinstance(value, str):
+ raise AttributeError(
+ f'String value expected but received {value}.'
+ )
+
+ self.lines = value.splitlines()
+
+
+class Style:
+ """Representation of a style."""
+
+ def __init__(self, text: typing.Union[str, typing.List[str]]):
+ """Initialize.
+
+ :param: text: the style text
+ """
+ self.lines = text.splitlines() if isinstance(text, str) else text
+ self.comments: typing.List[str] = []
+
+ @property
+ def text(self):
+ """Return the text of the style."""
+ return '\n'.join(self.lines)
diff --git a/webvtt/parsers.py b/webvtt/parsers.py
deleted file mode 100644
index 3a978ca..0000000
--- a/webvtt/parsers.py
+++ /dev/null
@@ -1,293 +0,0 @@
-import re
-import os
-import codecs
-
-from .errors import MalformedFileError, MalformedCaptionError
-from .structures import Block, Style, Caption
-
-
-class TextBasedParser(object):
- """
- Parser for plain text caption files.
- This is a generic class, do not use directly.
- """
-
- TIMEFRAME_LINE_PATTERN = ''
- PARSER_OPTIONS = {}
-
- def __init__(self, parse_options=None):
- self.captions = []
- self.parse_options = parse_options or {}
-
- def read(self, file):
- """Reads the captions file."""
- content = self._get_content_from_file(file_path=file)
- self._validate(content)
- self._parse(content)
-
- return self
-
- def read_from_buffer(self, buffer):
- content = self._read_content_lines(buffer)
- self._validate(content)
- self._parse(content)
-
- return self
-
- def _get_content_from_file(self, file_path):
- encoding = self._read_file_encoding(file_path)
- with open(file_path, encoding=encoding) as f:
- return self._read_content_lines(f)
-
- def _read_file_encoding(self, file_path):
- first_bytes = min(32, os.path.getsize(file_path))
- with open(file_path, 'rb') as f:
- raw = f.read(first_bytes)
-
- if raw.startswith(codecs.BOM_UTF8):
- return 'utf-8-sig'
- else:
- return 'utf-8'
-
- def _read_content_lines(self, file_obj):
-
- lines = [line.rstrip('\n\r') for line in file_obj.readlines()]
-
- if not lines:
- raise MalformedFileError('The file is empty.')
-
- return lines
-
- def _read_content(self, file):
- return self._get_content_from_file(file_path=file)
-
- def _parse_timeframe_line(self, line):
- """Parse timeframe line and return start and end timestamps."""
- tf = self._validate_timeframe_line(line)
- if not tf:
- raise MalformedCaptionError('Invalid time format')
-
- return tf.group(1), tf.group(2)
-
- def _validate_timeframe_line(self, line):
- return re.match(self.TIMEFRAME_LINE_PATTERN, line)
-
- def _is_timeframe_line(self, line):
- """
- This method returns True if the line contains the timeframes.
- To be implemented by child classes.
- """
- raise NotImplementedError
-
- def _validate(self, lines):
- """
- Validates the format of the parsed file.
- To be implemented by child classes.
- """
- raise NotImplementedError
-
- def _should_skip_line(self, line, index, caption):
- """
- This method returns True for a line that should be skipped.
- Implement in child classes if needed.
- """
- return False
-
- def _parse(self, lines):
- self.captions = []
- c = None
-
- for index, line in enumerate(lines):
- if self._is_timeframe_line(line):
- try:
- start, end = self._parse_timeframe_line(line)
- except MalformedCaptionError as e:
- raise MalformedCaptionError('{} in line {}'.format(e, index + 1))
- c = Caption(start, end)
- elif self._should_skip_line(line, index, c): # allow child classes to skip lines based on the content
- continue
- elif line:
- if c is None:
- raise MalformedCaptionError(
- 'Caption missing timeframe in line {}.'.format(index + 1))
- else:
- c.add_line(line)
- else:
- if c is None:
- continue
- if not c.lines:
- if self.PARSER_OPTIONS.get('ignore_empty_captions', False):
- c = None
- continue
- raise MalformedCaptionError('Caption missing text in line {}.'.format(index + 1))
-
- self.captions.append(c)
- c = None
-
- if c is not None and c.lines:
- self.captions.append(c)
-
-
-class SRTParser(TextBasedParser):
- """
- SRT parser.
- """
-
- TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})')
-
- PARSER_OPTIONS = {
- 'ignore_empty_captions': True
- }
-
- def _validate(self, lines):
- if len(lines) < 2 or lines[0] != '1' or not self._validate_timeframe_line(lines[1]):
- raise MalformedFileError('The file does not have a valid format.')
-
- def _is_timeframe_line(self, line):
- return '-->' in line
-
- def _should_skip_line(self, line, index, caption):
- return caption is None and line.isdigit()
-
-
-class WebVTTParser(TextBasedParser):
- """
- WebVTT parser.
- """
-
- TIMEFRAME_LINE_PATTERN = re.compile(r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})')
- COMMENT_PATTERN = re.compile(r'NOTE(?:\s.+|$)')
- STYLE_PATTERN = re.compile(r'STYLE[ \t]*$')
-
- def __init__(self):
- super().__init__()
- self.styles = []
-
- def _compute_blocks(self, lines):
- blocks = []
-
- for index, line in enumerate(lines, start=1):
- if line:
- if not blocks:
- blocks.append(Block(index))
- if not blocks[-1].lines:
- if not line.strip():
- continue
- blocks[-1].line_number = index
- blocks[-1].lines.append(line)
- else:
- blocks.append(Block(index))
-
- # filter out empty blocks and skip signature
- return list(filter(lambda x: x.lines, blocks))[1:]
-
- def _parse_cue_block(self, block):
- caption = Caption()
- cue_timings = None
- additional_blocks = None
-
- for line_number, line in enumerate(block.lines):
- if self._is_cue_timings_line(line):
- if cue_timings is None:
- try:
- cue_timings = self._parse_timeframe_line(line)
- except MalformedCaptionError as e:
- raise MalformedCaptionError(
- '{} in line {}'.format(e, block.line_number + line_number))
- else:
- additional_blocks = self._compute_blocks(
- ['WEBVTT', ''] + block.lines[line_number:]
- )
- break
- elif line_number == 0:
- caption.identifier = line
- else:
- caption.add_line(line)
-
- caption.start = cue_timings[0]
- caption.end = cue_timings[1]
- return caption, additional_blocks
-
- def _parse(self, lines):
- self.captions = []
- blocks = self._compute_blocks(lines)
- self._parse_blocks(blocks)
-
- def _is_empty(self, block):
- is_empty = True
-
- for line in block.lines:
- if line.strip() != "":
- is_empty = False
-
- return is_empty
-
- def _parse_blocks(self, blocks):
- for block in blocks:
- # skip empty blocks
- if self._is_empty(block):
- continue
-
- if self._is_cue_block(block):
- caption, additional_blocks = self._parse_cue_block(block)
- self.captions.append(caption)
-
- if additional_blocks:
- self._parse_blocks(additional_blocks)
-
- elif self._is_comment_block(block):
- continue
- elif self._is_style_block(block):
- if self.captions:
- raise MalformedFileError(
- 'Style block defined after the first cue in line {}.'
- .format(block.line_number))
- style = Style()
- style.lines = block.lines[1:]
- self.styles.append(style)
- else:
- if len(block.lines) == 1:
- raise MalformedCaptionError(
- 'Standalone cue identifier in line {}.'.format(block.line_number))
- else:
- raise MalformedCaptionError(
- 'Missing timing cue in line {}.'.format(block.line_number+1))
-
- def _validate(self, lines):
- if not re.match('WEBVTT', lines[0]):
- raise MalformedFileError('The file does not have a valid format')
-
- def _is_cue_timings_line(self, line):
- return '-->' in line
-
- def _is_cue_block(self, block):
- """Returns True if it is a cue block
- (one of the two first lines being a cue timing line)"""
- return any(map(self._is_cue_timings_line, block.lines[:2]))
-
- def _is_comment_block(self, block):
- """Returns True if it is a comment block"""
- return re.match(self.COMMENT_PATTERN, block.lines[0])
-
- def _is_style_block(self, block):
- """Returns True if it is a style block"""
- return re.match(self.STYLE_PATTERN, block.lines[0])
-
-
-class SBVParser(TextBasedParser):
- """
- YouTube SBV parser.
- """
-
- TIMEFRAME_LINE_PATTERN = re.compile(r'\s*(\d+:\d{2}:\d{2}.\d{3}),(\d+:\d{2}:\d{2}.\d{3})')
-
- PARSER_OPTIONS = {
- 'ignore_empty_captions': True
- }
-
- def _validate(self, lines):
- if not self._validate_timeframe_line(lines[0]):
- raise MalformedFileError('The file does not have a valid format')
-
- def _is_timeframe_line(self, line):
- return self._validate_timeframe_line(line)
diff --git a/webvtt/sbv.py b/webvtt/sbv.py
new file mode 100644
index 0000000..be9c1bf
--- /dev/null
+++ b/webvtt/sbv.py
@@ -0,0 +1,117 @@
+"""SBV format module."""
+
+import typing
+import re
+
+from . import utils
+from .models import Caption
+from .errors import MalformedFileError
+
+
+class SBVCueBlock:
+ """Representation of a cue timing block."""
+
+ CUE_TIMINGS_PATTERN = re.compile(
+ r'\s*(\d{1,2}:\d{1,2}:\d{1,2}.\d{3}),(\d{1,2}:\d{1,2}:\d{1,2}.\d{3})'
+ )
+
+ def __init__(
+ self,
+ start: str,
+ end: str,
+ payload: typing.Sequence[str]
+ ):
+ """
+ Initialize.
+
+ :param start: start time
+ :param end: end time
+ :param payload: caption text
+ """
+ self.start = start
+ self.end = end
+ self.payload = payload
+
+ @classmethod
+ def is_valid(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> bool:
+ """
+ Validate the lines for a match of a cue time block.
+
+ :param lines: the lines to be validated
+ :returns: true for a matching cue time block
+ """
+ return bool(
+ len(lines) >= 2 and
+ re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and
+ lines[1].strip()
+ )
+
+ @classmethod
+ def from_lines(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> 'SBVCueBlock':
+ """
+ Create a `SBVCueBlock` from lines of text.
+
+ :param lines: the lines of text
+ :returns: `SBVCueBlock` instance
+ """
+ match = re.match(cls.CUE_TIMINGS_PATTERN, lines[0])
+ assert match is not None
+
+ payload = lines[1:]
+
+ return cls(match.group(1), match.group(2), payload)
+
+
+def parse(lines: typing.Sequence[str]) -> typing.List[Caption]:
+ """
+ Parse SBV captions from lines of text.
+
+ :param lines: lines of text
+ :returns: list of `Caption` objects
+ """
+ if not _is_valid_content(lines):
+ raise MalformedFileError('Invalid format')
+
+ return _parse_captions(lines)
+
+
+def _is_valid_content(lines: typing.Sequence[str]) -> bool:
+ """
+ Validate lines of text for valid SBV content.
+
+ :param lines: lines of text
+ :returns: true for a valid SBV content
+ """
+ if len(lines) < 2:
+ return False
+
+ first_block = next(utils.iter_blocks_of_lines(lines))
+ return bool(first_block and SBVCueBlock.is_valid(first_block))
+
+
+def _parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]:
+ """
+ Parse captions from the text.
+
+ :param lines: lines of text
+ :returns: list of `Caption` objects
+ """
+ captions = []
+
+ for block_lines in utils.iter_blocks_of_lines(lines):
+ if not SBVCueBlock.is_valid(block_lines):
+ continue
+
+ cue_block = SBVCueBlock.from_lines(block_lines)
+ captions.append(Caption(cue_block.start,
+ cue_block.end,
+ cue_block.payload
+ ))
+
+ return captions
diff --git a/webvtt/segmenter.py b/webvtt/segmenter.py
index 9378ad6..a159fd8 100644
--- a/webvtt/segmenter.py
+++ b/webvtt/segmenter.py
@@ -1,110 +1,121 @@
+"""Segmenter module."""
+
+import typing
import os
+import pathlib
from math import ceil, floor
-from .errors import InvalidCaptionsError
-from .webvtt import WebVTT
-from .structures import Caption
+from .webvtt import WebVTT, Caption
+
+DEFAULT_MPEGTS = 900000
+DEFAULT_SECONDS = 10 # default number of seconds per segment
+
+
+def segment(
+ webvtt_path: str,
+ output: str,
+ seconds: int = DEFAULT_SECONDS,
+ mpegts: int = DEFAULT_MPEGTS
+ ):
+ """
+ Segment a WebVTT captions file.
+
+ :param webvtt_path: the path to the file
+ :param output: the path to the destination folder
+ :param seconds: the number of seconds for each segment
+ :param mpegts: value for the MPEG-TS
+ """
+ captions = WebVTT.read(webvtt_path).captions
+
+ output_folder = pathlib.Path(output)
+ os.makedirs(output_folder, exist_ok=True)
+
+ segments = slice_segments(captions, seconds)
+ write_segments(output_folder, segments, mpegts)
+ write_manifest(output_folder, segments, seconds)
-MPEGTS = 900000
-SECONDS = 10 # default number of seconds per segment
-__all__ = ['WebVTTSegmenter']
+def slice_segments(
+ captions: typing.Sequence[Caption],
+ seconds: int
+ ) -> typing.List[typing.List[Caption]]:
+ """
+ Slice segments of captions based on seconds per segment.
+ :param captions: the captions
+ :param seconds: seconds per segment
+ :returns: list of lists of `Caption` objects
+ """
+ total_segments = (
+ 0
+ if not captions else
+ int(ceil(captions[-1].end_in_seconds / seconds))
+ )
+
+ segments: typing.List[typing.List[Caption]] = [
+ [] for _ in range(total_segments)
+ ]
+
+ for c in captions:
+ segment_index_start = floor(c.start_in_seconds / seconds)
+ segments[segment_index_start].append(c)
+
+ # Also include a caption in other segments based on the end time.
+ segment_index_end = floor(c.end_in_seconds / seconds)
+ if segment_index_end > segment_index_start:
+ for i in range(segment_index_start + 1, segment_index_end + 1):
+ segments[i].append(c)
+
+ return segments
+
+
+def write_segments(
+ output_folder: pathlib.Path,
+ segments: typing.Iterable[typing.Iterable[Caption]],
+ mpegts: int
+ ):
+ """
+ Write the segments to the output folder.
-class WebVTTSegmenter(object):
+ :param output_folder: folder where the segment files will be stored
+ :param segments: the segments of `Caption` objects
+ :param mpegts: value for the MPEG-TS
"""
- Provides segmentation of WebVTT captions for HTTP Live Streaming (HLS).
+ for index, segment in enumerate(segments):
+ segment_file = output_folder / f'fileSequence{index}.webvtt'
+
+ with open(segment_file, 'w', encoding='utf-8') as f:
+ f.write('WEBVTT\n')
+ f.write(f'X-TIMESTAMP-MAP=MPEGTS:{mpegts},'
+ 'LOCAL:00:00:00.000\n'
+ )
+
+ for caption in segment:
+ f.write('\n{} --> {}\n'.format(caption.start, caption.end))
+ f.writelines(f'{line}\n' for line in caption.lines)
+
+
+def write_manifest(
+ output_folder: pathlib.Path,
+ segments: typing.Iterable[typing.Iterable[Caption]],
+ seconds: int
+ ):
+ """
+ Write the manifest in the output folder.
+
+ :param output_folder: folder where the manifest will be stored
+ :param segments: the segments of `Caption` objects
+ :param seconds: the seconds per segment
"""
- def __init__(self):
- self._total_segments = 0
- self._output_folder = ''
- self._seconds = 0
- self._mpegts = 0
- self._segments = []
-
- def _validate_webvtt(self, webvtt):
- # Validates that the captions is a list and all the captions are instances of Caption.
- if not isinstance(webvtt, WebVTT):
- return False
- for c in webvtt.captions:
- if not isinstance(c, Caption):
- return False
- return True
-
- def _slice_segments(self, captions):
- self._segments = [[] for _ in range(self.total_segments)]
-
- for c in captions:
- segment_index_start = floor(c.start_in_seconds / self.seconds)
- self.segments[segment_index_start].append(c)
-
- # Also include a caption in other segments based on the end time.
- segment_index_end = floor(c.end_in_seconds / self.seconds)
- if segment_index_end > segment_index_start:
- for i in range(segment_index_start + 1, segment_index_end + 1):
- self.segments[i].append(c)
-
- def _write_segments(self):
- for index in range(self.total_segments):
- segment_file = os.path.join(self._output_folder, 'fileSequence{}.webvtt'.format(index))
-
- with open(segment_file, 'w', encoding='utf-8') as f:
- f.write('WEBVTT\n')
- f.write('X-TIMESTAMP-MAP=MPEGTS:{},LOCAL:00:00:00.000\n'.format(self._mpegts))
-
- for caption in self.segments[index]:
- f.write('\n{} --> {}\n'.format(caption.start, caption.end))
- f.writelines(['{}\n'.format(l) for l in caption.lines])
-
- def _write_manifest(self):
- manifest_file = os.path.join(self._output_folder, 'prog_index.m3u8')
- with open(manifest_file, 'w', encoding='utf-8') as f:
- f.write('#EXTM3U\n')
- f.write('#EXT-X-TARGETDURATION:{}\n'.format(self.seconds))
- f.write('#EXT-X-VERSION:3\n')
- f.write('#EXT-X-PLAYLIST-TYPE:VOD\n')
-
- for i in range(self.total_segments):
- f.write('#EXTINF:30.00000\n')
- f.write('fileSequence{}.webvtt\n'.format(i))
-
- f.write('#EXT-X-ENDLIST\n')
-
- def segment(self, webvtt, output='', seconds=SECONDS, mpegts=MPEGTS):
- """Segments the captions based on a number of seconds."""
- if isinstance(webvtt, str):
- # if a string is supplied we parse the file
- captions = WebVTT().read(webvtt).captions
- elif not self._validate_webvtt(webvtt):
- raise InvalidCaptionsError('The captions provided are invalid')
- else:
- # we expect to have a webvtt object
- captions = webvtt.captions
-
- self._total_segments = 0 if not captions else int(ceil(captions[-1].end_in_seconds / seconds))
- self._output_folder = output
- self._seconds = seconds
- self._mpegts = mpegts
-
- output_folder = os.path.join(os.getcwd(), output)
- if not os.path.exists(output_folder):
- os.makedirs(output_folder)
-
- self._slice_segments(captions)
- self._write_segments()
- self._write_manifest()
-
- @property
- def seconds(self):
- """Returns the number of seconds used for segmenting captions."""
- return self._seconds
-
- @property
- def total_segments(self):
- """Returns the total of segments."""
- return self._total_segments
-
- @property
- def segments(self):
- """Return the list of segments."""
- return self._segments
+ manifest_file = output_folder / 'prog_index.m3u8'
+ with open(manifest_file, 'w', encoding='utf-8') as f:
+ f.write('#EXTM3U\n')
+ f.write(f'#EXT-X-TARGETDURATION:{seconds}\n')
+ f.write('#EXT-X-VERSION:3\n')
+ f.write('#EXT-X-PLAYLIST-TYPE:VOD\n')
+
+ for index, _ in enumerate(segments):
+ f.write('#EXTINF:30.00000\n')
+ f.write(f'fileSequence{index}.webvtt\n')
+
+ f.write('#EXT-X-ENDLIST\n')
diff --git a/webvtt/srt.py b/webvtt/srt.py
new file mode 100644
index 0000000..0e0637d
--- /dev/null
+++ b/webvtt/srt.py
@@ -0,0 +1,148 @@
+"""SRT format module."""
+
+import typing
+import re
+
+from .models import Caption
+from .errors import MalformedFileError
+from . import utils
+
+
+class SRTCueBlock:
+ """Representation of a cue timing block."""
+
+ CUE_TIMINGS_PATTERN = re.compile(
+ r'\s*(\d+:\d{2}:\d{2},\d{3})\s*-->\s*(\d+:\d{2}:\d{2},\d{3})'
+ )
+
+ def __init__(
+ self,
+ index: str,
+ start: str,
+ end: str,
+ payload: typing.Sequence[str]
+ ):
+ """
+ Initialize.
+
+ :param start: start time
+ :param end: end time
+ :param payload: caption text
+ """
+ self.index = index
+ self.start = start
+ self.end = end
+ self.payload = payload
+
+ @classmethod
+ def is_valid(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> bool:
+ """
+ Validate the lines for a match of a cue time block.
+
+ :param lines: the lines to be validated
+ :returns: true for a matching cue time block
+ """
+ return bool(
+ len(lines) >= 3 and
+ lines[0].isdigit() and
+ re.match(cls.CUE_TIMINGS_PATTERN, lines[1])
+ )
+
+ @classmethod
+ def from_lines(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> 'SRTCueBlock':
+ """
+ Create a `SRTCueBlock` from lines of text.
+
+ :param lines: the lines of text
+ :returns: `SRTCueBlock` instance
+ """
+ index = lines[0]
+
+ match = re.match(cls.CUE_TIMINGS_PATTERN, lines[1])
+ assert match is not None
+
+ payload = lines[2:]
+
+ return cls(index, match.group(1), match.group(2), payload)
+
+
+def parse(lines: typing.Sequence[str]) -> typing.List[Caption]:
+ """
+ Parse SRT captions from lines of text.
+
+ :param lines: lines of text
+ :returns: list of `Caption` objects
+ """
+ if not is_valid_content(lines):
+ raise MalformedFileError('Invalid format')
+
+ return parse_captions(lines)
+
+
+def is_valid_content(lines: typing.Sequence[str]) -> bool:
+ """
+ Validate lines of text for valid SBV content.
+
+ :param lines: lines of text
+ :returns: true for a valid SBV content
+ """
+ return bool(
+ len(lines) >= 3 and
+ lines[0].isdigit() and
+ '-->' in lines[1] and
+ lines[2].strip()
+ )
+
+
+def parse_captions(lines: typing.Sequence[str]) -> typing.List[Caption]:
+ """
+ Parse captions from the text.
+
+ :param lines: lines of text
+ :returns: list of `Caption` objects
+ """
+ captions: typing.List[Caption] = []
+
+ for block_lines in utils.iter_blocks_of_lines(lines):
+ if not SRTCueBlock.is_valid(block_lines):
+ continue
+
+ cue_block = SRTCueBlock.from_lines(block_lines)
+ cue_block.start, cue_block.end = map(
+ lambda x: x.replace(',', '.'), (cue_block.start, cue_block.end))
+
+ captions.append(Caption(cue_block.start,
+ cue_block.end,
+ cue_block.payload
+ ))
+
+ return captions
+
+
+def write(
+ f: typing.IO[str],
+ captions: typing.Iterable[Caption]
+ ):
+ """
+ Write captions to an output.
+
+ :param f: file or file-like object
+ :param captions: Iterable of `Caption` objects
+ """
+ output = []
+ for index, caption in enumerate(captions, start=1):
+ output.extend([
+ f'{index}',
+ '{} --> {}'.format(*map(lambda x: x.replace('.', ','),
+ (caption.start, caption.end))
+ ),
+ *caption.lines,
+ ''
+ ])
+ f.write('\n'.join(output).rstrip())
diff --git a/webvtt/structures.py b/webvtt/structures.py
deleted file mode 100644
index c4a576d..0000000
--- a/webvtt/structures.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import re
-
-from .errors import MalformedCaptionError
-
-TIMESTAMP_PATTERN = re.compile(r'(\d+)?:?(\d{2}):(\d{2})[.,](\d{3})')
-
-__all__ = ['Caption']
-
-
-class Caption(object):
-
- CUE_TEXT_TAGS = re.compile('<.*?>')
-
- """
- Represents a caption.
- """
- def __init__(self, start='00:00:00.000', end='00:00:00.000', text=None):
- self.start = start
- self.end = end
- self.identifier = None
-
- # If lines is a string convert to a list
- if text and isinstance(text, str):
- text = text.splitlines()
-
- self._lines = text or []
-
- def __repr__(self):
- return '<%(cls)s start=%(start)s end=%(end)s text=%(text)s>' % {
- 'cls': self.__class__.__name__,
- 'start': self.start,
- 'end': self.end,
- 'text': self.text.replace('\n', '\\n')
- }
-
- def __str__(self):
- return '%(start)s %(end)s %(text)s' % {
- 'start': self.start,
- 'end': self.end,
- 'text': self.text.replace('\n', '\\n')
- }
-
- def add_line(self, line):
- self.lines.append(line)
-
- def _to_seconds(self, hours, minutes, seconds, milliseconds):
- return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
-
- def _parse_timestamp(self, timestamp):
- res = re.match(TIMESTAMP_PATTERN, timestamp)
- if not res:
- raise MalformedCaptionError('Invalid timestamp: {}'.format(timestamp))
-
- values = list(map(lambda x: int(x) if x else 0, res.groups()))
- return self._to_seconds(*values)
-
- def _to_timestamp(self, total_seconds):
- hours = int(total_seconds / 3600)
- minutes = int(total_seconds / 60 - hours * 60)
- seconds = total_seconds - hours * 3600 - minutes * 60
- return '{:02d}:{:02d}:{:06.3f}'.format(hours, minutes, seconds)
-
- def _clean_cue_tags(self, text):
- return re.sub(self.CUE_TEXT_TAGS, '', text)
-
- @property
- def start_in_seconds(self):
- return self._start
-
- @property
- def end_in_seconds(self):
- return self._end
-
- @property
- def start(self):
- return self._to_timestamp(self._start)
-
- @start.setter
- def start(self, value):
- self._start = self._parse_timestamp(value)
-
- @property
- def end(self):
- return self._to_timestamp(self._end)
-
- @end.setter
- def end(self, value):
- self._end = self._parse_timestamp(value)
-
- @property
- def lines(self):
- return self._lines
-
- @property
- def text(self):
- """Returns the captions lines as a text (without cue tags)"""
- return self._clean_cue_tags(self.raw_text)
-
- @property
- def raw_text(self):
- """Returns the captions lines as a text (may include cue tags)"""
- return '\n'.join(self.lines)
-
- @text.setter
- def text(self, value):
- if not isinstance(value, str):
- raise AttributeError('String value expected but received {}.'.format(type(value)))
-
- self._lines = value.splitlines()
-
-
-class GenericBlock(object):
- """Generic class that defines a data structure holding an array of lines"""
- def __init__(self):
- self.lines = []
-
-
-class Block(GenericBlock):
- def __init__(self, line_number):
- super().__init__()
- self.line_number = line_number
-
-
-class Style(GenericBlock):
-
- @property
- def text(self):
- """Returns the style lines as a text"""
- return ''.join(map(lambda x: x.strip(), self.lines))
-
- @text.setter
- def text(self, value):
- if type(value) != str:
- raise TypeError('The text value must be a string.')
- self.lines = value.split('\n')
diff --git a/webvtt/utils.py b/webvtt/utils.py
new file mode 100644
index 0000000..28da8c8
--- /dev/null
+++ b/webvtt/utils.py
@@ -0,0 +1,104 @@
+"""Utils module."""
+
+import typing
+import codecs
+
+CODEC_BOMS = {
+ 'utf-8': codecs.BOM_UTF8,
+ 'utf-32-le': codecs.BOM_UTF32_LE,
+ 'utf-32-be': codecs.BOM_UTF32_BE,
+ 'utf-16-le': codecs.BOM_UTF16_LE,
+ 'utf-16-be': codecs.BOM_UTF16_BE
+}
+
+
+class FileWrapper:
+ """File handling functionality with built-in support for Byte OrderMark."""
+
+ def __init__(
+ self,
+ file_path: str,
+ mode: typing.Optional[str] = None,
+ encoding: typing.Optional[str] = None
+ ):
+ """
+ Initialize.
+
+ :param file_path: path to the file
+ :param mode: mode in which the file is opened
+ :param encoding: name of the encoding used to decode the file
+ """
+ self.file_path = file_path
+ self.mode = mode or 'r'
+ self.bom_encoding = self.detect_bom_encoding(file_path)
+ self.encoding = (self.bom_encoding or
+ encoding or
+ 'utf-8'
+ )
+
+ @classmethod
+ def open(
+ cls,
+ file_path: str,
+ mode: typing.Optional[str] = None,
+ encoding: typing.Optional[str] = None
+ ) -> 'FileWrapper':
+ """
+ Open a file.
+
+ :param file_path: path to the file
+ :param mode: mode in which the file is opened
+ :param encoding: name of the encoding used to decode the file
+ """
+ return cls(file_path, mode, encoding)
+
+ def __enter__(self):
+ """Enter context."""
+ self.file = open(
+ file=self.file_path,
+ mode=self.mode,
+ encoding=self.encoding
+ )
+ if self.bom_encoding:
+ self.file.seek(len(CODEC_BOMS[self.bom_encoding]))
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Exit context."""
+ self.file.close()
+
+ @staticmethod
+ def detect_bom_encoding(file_path: str) -> typing.Optional[str]:
+ """
+ Detect the encoding of a file based on the presence of the BOM.
+
+ :param file_path: path to the file
+ :returns: the encoding if BOM is found or None.
+ """
+ with open(file_path, mode='rb') as f:
+ first_bytes = f.read(4)
+ for encoding, bom in CODEC_BOMS.items():
+ if first_bytes.startswith(bom):
+ return encoding
+ return None
+
+
+def iter_blocks_of_lines(
+ lines: typing.Iterable[str]
+ ) -> typing.Generator[typing.List[str], None, None]:
+ """
+ Iterate blocks of text.
+
+ :param lines: lines of text.
+ """
+ current_text_block = []
+
+ for line in lines:
+ if line.strip():
+ current_text_block.append(line)
+ elif current_text_block:
+ yield current_text_block
+ current_text_block = []
+
+ if current_text_block:
+ yield current_text_block
diff --git a/webvtt/vtt.py b/webvtt/vtt.py
new file mode 100644
index 0000000..459ac3e
--- /dev/null
+++ b/webvtt/vtt.py
@@ -0,0 +1,400 @@
+"""VTT format module."""
+
+import re
+import typing
+from dataclasses import dataclass
+
+from .errors import MalformedFileError
+from .models import Caption, Style
+from . import utils
+
+
+@dataclass
+class ParserOutput:
+ """Output of parser."""
+
+ styles: typing.List[Style]
+ captions: typing.List[Caption]
+ header_comments: typing.List[str]
+ footer_comments: typing.List[str]
+
+ @classmethod
+ def from_data(
+ cls,
+ data: typing.Mapping[str, typing.Any]
+ ) -> 'ParserOutput':
+ """
+ Return a `ParserOutput` instance from the provided data.
+
+ :param data: data from the parser
+ :returns: an instance of `ParserOutput`
+ """
+ items = data.get('items', [])
+ return cls(
+ captions=[it for it in items if isinstance(it, Caption)],
+ styles=[it for it in items if isinstance(it, Style)],
+ header_comments=data.get('header_comments', []),
+ footer_comments=data.get('footer_comments', [])
+ )
+
+
+class WebVTTCueBlock:
+ """Representation of a cue timing block."""
+
+ CUE_TIMINGS_PATTERN = re.compile(
+ r'\s*((?:\d+:)?\d{2}:\d{2}.\d{3})\s*-->\s*((?:\d+:)?\d{2}:\d{2}.\d{3})'
+ )
+
+ def __init__(
+ self,
+ identifier,
+ start,
+ end,
+ payload
+ ):
+ """
+ Initialize.
+
+ :param start: start time
+ :param end: end time
+ :param payload: caption text
+ """
+ self.identifier = identifier
+ self.start = start
+ self.end = end
+ self.payload = payload
+
+ @classmethod
+ def is_valid(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> bool:
+ """
+ Validate the lines for a match of a cue time block.
+
+ :param lines: the lines to be validated
+ :returns: true for a matching cue time block
+ """
+ return bool(
+ (
+ len(lines) >= 2 and
+ re.match(cls.CUE_TIMINGS_PATTERN, lines[0]) and
+ "-->" not in lines[1]
+ ) or
+ (
+ len(lines) >= 3 and
+ "-->" not in lines[0] and
+ re.match(cls.CUE_TIMINGS_PATTERN, lines[1]) and
+ "-->" not in lines[2]
+ )
+ )
+
+ @classmethod
+ def from_lines(
+ cls,
+ lines: typing.Iterable[str]
+ ) -> 'WebVTTCueBlock':
+ """
+ Create a `WebVTTCueBlock` from lines of text.
+
+ :param lines: the lines of text
+ :returns: `WebVTTCueBlock` instance
+ """
+ identifier = None
+ start = None
+ end = None
+ payload = []
+
+ for line in lines:
+ timing_match = re.match(cls.CUE_TIMINGS_PATTERN, line)
+ if timing_match:
+ start = timing_match.group(1)
+ end = timing_match.group(2)
+ elif not start:
+ identifier = line
+ else:
+ payload.append(line)
+
+ return cls(identifier, start, end, payload)
+
+ @staticmethod
+ def format_lines(caption: Caption) -> typing.List[str]:
+ """
+ Return the lines for a cue block.
+
+ :param caption: the `Caption` instance
+ :returns: list of lines for a cue block
+ """
+ return [
+ '',
+ *(identifier for identifier in {caption.identifier} if identifier),
+ f'{caption.start} --> {caption.end}',
+ *caption.lines
+ ]
+
+
+class WebVTTCommentBlock:
+ """Representation of a comment block."""
+
+ COMMENT_PATTERN = re.compile(r'NOTE\s(.*?)\Z', re.DOTALL)
+
+ def __init__(self, text: str):
+ """
+ Initialize.
+
+ :param text: comment text
+ """
+ self.text = text
+
+ @classmethod
+ def is_valid(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> bool:
+ """
+ Validate the lines for a match of a comment block.
+
+ :param lines: the lines to be validated
+ :returns: true for a matching comment block
+ """
+ return bool(lines and lines[0].startswith('NOTE'))
+
+ @classmethod
+ def from_lines(
+ cls,
+ lines: typing.Iterable[str]
+ ) -> 'WebVTTCommentBlock':
+ """
+ Create a `WebVTTCommentBlock` from lines of text.
+
+ :param lines: the lines of text
+ :returns: `WebVTTCommentBlock` instance
+ """
+ match = cls.COMMENT_PATTERN.match('\n'.join(lines))
+ return cls(text=match.group(1).strip() if match else '')
+
+ @staticmethod
+ def format_lines(lines: str) -> typing.List[str]:
+ """
+ Return the lines for a comment block.
+
+ :param lines: comment lines
+ :returns: list of lines for a comment block
+ """
+ list_of_lines = lines.split('\n')
+
+ if len(list_of_lines) == 1:
+ return [f'NOTE {lines}']
+
+ return ['NOTE', *list_of_lines]
+
+
+class WebVTTStyleBlock:
+ """Representation of a style block."""
+
+ STYLE_PATTERN = re.compile(r'STYLE\s(.*?)\Z', re.DOTALL)
+
+ def __init__(self, text: str):
+ """
+ Initialize.
+
+ :param text: style text
+ """
+ self.text = text
+
+ @classmethod
+ def is_valid(
+ cls,
+ lines: typing.Sequence[str]
+ ) -> bool:
+ """
+ Validate the lines for a match of a style block.
+
+ :param lines: the lines to be validated
+ :returns: true for a matching style block
+ """
+ return (len(lines) >= 2 and
+ lines[0] == 'STYLE' and
+ not any(line.strip() == '' or '-->' in line for line in lines)
+ )
+
+ @classmethod
+ def from_lines(
+ cls,
+ lines: typing.Iterable[str]
+ ) -> 'WebVTTStyleBlock':
+ """
+ Create a `WebVTTStyleBlock` from lines of text.
+
+ :param lines: the lines of text
+ :returns: `WebVTTStyleBlock` instance
+ """
+ match = cls.STYLE_PATTERN.match('\n'.join(lines))
+ return cls(text=match.group(1).strip() if match else '')
+
+ @staticmethod
+ def format_lines(lines: typing.List[str]) -> typing.List[str]:
+ """
+ Return the lines for a style block.
+
+ :param lines: style lines
+ :returns: list of lines for a style block
+ """
+ return ['STYLE', *lines]
+
+
+def parse(
+ lines: typing.Sequence[str]
+ ) -> ParserOutput:
+ """
+ Parse VTT captions from lines of text.
+
+ :param lines: lines of text
+ :returns: object `ParserOutput` with all parsed items
+ """
+ if not is_valid_content(lines):
+ raise MalformedFileError('Invalid format')
+
+ return parse_items(lines)
+
+
+def is_valid_content(lines: typing.Sequence[str]) -> bool:
+ """
+ Validate lines of text for valid VTT content.
+
+ :param lines: lines of text
+ :returns: true for a valid VTT content
+ """
+ return bool(lines and lines[0].startswith('WEBVTT'))
+
+
+def parse_items(
+ lines: typing.Sequence[str]
+ ) -> ParserOutput:
+ """
+ Parse items from the text.
+
+ :param lines: lines of text
+ :returns: an object `ParserOutput` with all parsed items
+ """
+ header_comments: typing.List[str] = []
+ items: typing.List[typing.Union[Caption, Style]] = []
+ comments: typing.List[WebVTTCommentBlock] = []
+
+ for block_lines in utils.iter_blocks_of_lines(lines):
+ item = parse_item(block_lines)
+ if item:
+ item.comments = [comment.text for comment in comments]
+ comments = []
+ items.append(item)
+ elif WebVTTCommentBlock.is_valid(block_lines):
+ comments.append(WebVTTCommentBlock.from_lines(block_lines))
+
+ if items:
+ header_comments, items[0].comments = items[0].comments, header_comments
+
+ return ParserOutput.from_data(
+ {'items': items,
+ 'header_comments': header_comments,
+ 'footer_comments': [comment.text for comment in comments]
+ }
+ )
+
+
+def parse_item(
+ lines: typing.Sequence[str]
+ ) -> typing.Union[Caption, Style, None]:
+ """
+ Parse an item from lines of text.
+
+ :param lines: lines of text
+ :returns: An item (Caption or Style) if found, otherwise None
+ """
+ if WebVTTCueBlock.is_valid(lines):
+ cue_block = WebVTTCueBlock.from_lines(lines)
+ return Caption(cue_block.start,
+ cue_block.end,
+ cue_block.payload,
+ cue_block.identifier
+ )
+
+ if WebVTTStyleBlock.is_valid(lines):
+ return Style(WebVTTStyleBlock.from_lines(lines).text)
+
+ return None
+
+
+def write(
+ f: typing.IO[str],
+ captions: typing.Iterable[Caption],
+ styles: typing.Iterable[Style],
+ header_comments: typing.Iterable[str],
+ footer_comments: typing.Iterable[str]
+ ):
+ """
+ Write captions to an output.
+
+ :param f: file or file-like object
+ :param captions: Iterable of `Caption` objects
+ :param styles: Iterable of `Style` objects
+ :param header_comments: the comments for the header
+ :param footer_comments: the comments for the footer
+ """
+ f.write(
+ to_str(captions,
+ styles,
+ header_comments,
+ footer_comments
+ )
+ )
+
+
+def to_str(
+ captions: typing.Iterable[Caption],
+ styles: typing.Iterable[Style],
+ header_comments: typing.Iterable[str],
+ footer_comments: typing.Iterable[str]
+ ) -> str:
+ """
+ Convert captions to a string with webvtt format.
+
+ :param captions: the iterable of `Caption` objects
+ :param styles: the iterable of `Style` objects
+ :param header_comments: the comments for the header
+ :param footer_comments: the comments for the footer
+ :returns: String of the content in WebVTT format.
+ """
+ output = ['WEBVTT']
+
+ for comment in header_comments:
+ output.extend([
+ '',
+ *WebVTTCommentBlock.format_lines(comment)
+ ])
+
+ for style in styles:
+ for comment in style.comments:
+ output.extend([
+ '',
+ *WebVTTCommentBlock.format_lines(comment)
+ ])
+ output.extend([
+ '',
+ *WebVTTStyleBlock.format_lines(style.lines)
+ ])
+
+ for caption in captions:
+ for comment in caption.comments:
+ output.extend([
+ '',
+ *WebVTTCommentBlock.format_lines(comment)
+ ])
+ output.extend(WebVTTCueBlock.format_lines(caption))
+
+ for comment in footer_comments:
+ output.extend([
+ '',
+ *WebVTTCommentBlock.format_lines(comment)
+ ])
+
+ return '\n'.join(output)
diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py
index adec7c9..499b6df 100644
--- a/webvtt/webvtt.py
+++ b/webvtt/webvtt.py
@@ -1,143 +1,363 @@
+"""WebVTT module."""
+
import os
+import typing
+import warnings
-from .parsers import WebVTTParser, SRTParser, SBVParser
-from .writers import WebVTTWriter, SRTWriter
+from . import vtt, utils
+from . import srt
+from . import sbv
+from .models import Caption, Style, Timestamp
from .errors import MissingFilenameError
-__all__ = ['WebVTT']
+DEFAULT_ENCODING = 'utf-8'
-class WebVTT(object):
+class WebVTT:
"""
Parse captions in WebVTT format and also from other formats like SRT.
To read WebVTT:
- WebVTT().read('captions.vtt')
-
- For other formats like SRT, use from_[format in lower case]:
+ WebVTT.read('captions.vtt')
- WebVTT().from_srt('captions.srt')
+ For other formats:
- A list of all supported formats is available calling list_formats().
+ WebVTT.from_srt('captions.srt')
+ WebVTT.from_sbv('captions.sbv')
"""
- def __init__(self, file='', captions=None, styles=None):
+ def __init__(
+ self,
+ file: typing.Optional[str] = None,
+ captions: typing.Optional[typing.List[Caption]] = None,
+ styles: typing.Optional[typing.List[Style]] = None,
+ header_comments: typing.Optional[typing.List[str]] = None,
+ footer_comments: typing.Optional[typing.List[str]] = None,
+ ):
+ """
+ Initialize.
+
+ :param file: the path of the WebVTT file
+ :param captions: the list of captions
+ :param styles: the list of styles
+ :param header_comments: list of comments for the start of the file
+ :param footer_comments: list of comments for the bottom of the file
+ """
self.file = file
- self._captions = captions or []
- self._styles = styles
+ self.captions = captions or []
+ self.styles = styles or []
+ self.header_comments = header_comments or []
+ self.footer_comments = footer_comments or []
+ self._has_bom = False
+ self.encoding = DEFAULT_ENCODING
def __len__(self):
- return len(self._captions)
+ """Return the number of captions."""
+ return len(self.captions)
def __getitem__(self, index):
- return self._captions[index]
+ """Return a caption by index."""
+ return self.captions[index]
def __repr__(self):
- return '<%(cls)s file=%(file)s>' % {
- 'cls': self.__class__.__name__,
- 'file': self.file
- }
+ """Return the string representation of the WebVTT file."""
+ return (f'<{self.__class__.__name__} file={self.file!r} '
+ f'encoding={self.encoding!r}>'
+ )
def __str__(self):
- return '\n'.join([str(c) for c in self._captions])
+ """Return a readable representation of the WebVTT content."""
+ return '\n'.join(str(c) for c in self.captions)
@classmethod
- def from_srt(cls, file):
- """Reads captions from a file in SubRip format."""
- parser = SRTParser().read(file)
- return cls(file=file, captions=parser.captions)
+ def read(
+ cls,
+ file: str,
+ encoding: typing.Optional[str] = None
+ ) -> 'WebVTT':
+ """
+ Read a WebVTT captions file.
- @classmethod
- def from_sbv(cls, file):
- """Reads captions from a file in YouTube SBV format."""
- parser = SBVParser().read(file)
- return cls(file=file, captions=parser.captions)
+ :param file: the file path
+ :param encoding: encoding of the file
+ :returns: a `WebVTT` instance
+ """
+ with utils.FileWrapper.open(file, encoding=encoding) as fw:
+ instance = cls.from_buffer(fw.file)
+ if fw.bom_encoding:
+ instance.encoding = fw.bom_encoding
+ instance._has_bom = True
+ return instance
@classmethod
- def read(cls, file):
- """Reads a WebVTT captions file."""
- parser = WebVTTParser().read(file)
- return cls(file=file, captions=parser.captions, styles=parser.styles)
+ def read_buffer(
+ cls,
+ buffer: typing.Iterator[str]
+ ) -> 'WebVTT':
+ """
+ Read WebVTT captions from a file-like object.
+
+ This method is DEPRECATED. Use from_buffer instead.
+
+ Such file-like object may be the return of an io.open call,
+ io.StringIO object, tempfile.TemporaryFile object, etc.
+
+ :param buffer: the file-like object to read captions from
+ :returns: a `WebVTT` instance
+ """
+ warnings.warn(
+ 'Deprecated: use from_buffer instead.',
+ DeprecationWarning
+ )
+ return cls.from_buffer(buffer)
@classmethod
- def read_buffer(cls, buffer):
- """Reads a WebVTT captions from a file-like object.
+ def from_buffer(
+ cls,
+ buffer: typing.Iterator[str]
+ ) -> 'WebVTT':
+ """
+ Read WebVTT captions from a file-like object.
+
Such file-like object may be the return of an io.open call,
- io.StringIO object, tempfile.TemporaryFile object, etc."""
- parser = WebVTTParser().read_from_buffer(buffer)
- return cls(captions=parser.captions, styles=parser.styles)
+ io.StringIO object, tempfile.TemporaryFile object, etc.
+
+ :param buffer: the file-like object to read captions from
+ :returns: a `WebVTT` instance
+ """
+ output = vtt.parse(cls._get_lines(buffer))
+
+ return cls(
+ file=getattr(buffer, 'name', None),
+ captions=output.captions,
+ styles=output.styles,
+ header_comments=output.header_comments,
+ footer_comments=output.footer_comments
+ )
- def _get_output_file(self, output, extension='vtt'):
- if not output:
+ @classmethod
+ def from_srt(
+ cls,
+ file: str,
+ encoding: typing.Optional[str] = None
+ ) -> 'WebVTT':
+ """
+ Read captions from a file in SubRip format.
+
+ :param file: the file path
+ :param encoding: encoding of the file
+ :returns: a `WebVTT` instance
+ """
+ with utils.FileWrapper.open(file, encoding=encoding) as fw:
+ return cls(
+ file=fw.file.name,
+ captions=srt.parse(cls._get_lines(fw.file))
+ )
+
+ @classmethod
+ def from_sbv(
+ cls,
+ file: str,
+ encoding: typing.Optional[str] = None
+ ) -> 'WebVTT':
+ """
+ Read captions from a file in YouTube SBV format.
+
+ :param file: the file path
+ :param encoding: encoding of the file
+ :returns: a `WebVTT` instance
+ """
+ with utils.FileWrapper.open(file, encoding=encoding) as fw:
+ return cls(
+ file=fw.file.name,
+ captions=sbv.parse(cls._get_lines(fw.file)),
+ )
+
+ @classmethod
+ def from_string(cls, string: str) -> 'WebVTT':
+ """
+ Read captions from a string.
+
+ :param string: the captions in a string
+ :returns: a `WebVTT` instance
+ """
+ output = vtt.parse(cls._get_lines(string.splitlines()))
+ return cls(
+ captions=output.captions,
+ styles=output.styles,
+ header_comments=output.header_comments,
+ footer_comments=output.footer_comments
+ )
+
+ @staticmethod
+ def _get_lines(lines: typing.Iterable[str]) -> typing.List[str]:
+ """
+ Return cleaned lines from an iterable of lines.
+
+ :param lines: iterable of lines
+ :returns: a list of cleaned lines
+ """
+ return [line.rstrip('\n\r') for line in lines]
+
+ def _get_destination_file(
+ self,
+ destination_path: typing.Optional[str] = None,
+ extension: str = 'vtt'
+ ) -> str:
+ """
+ Return the destination file based on the provided params.
+
+ :param destination_path: optional destination path
+ :param extension: the extension of the file
+ :returns: the destination file
+
+ :raises MissingFilenameError: if destination path cannot be determined
+ """
+ if not destination_path and not self.file:
+ raise MissingFilenameError
+
+ if not destination_path and self.file:
+ destination_path = (
+ f'{os.path.splitext(self.file)[0]}.{extension}'
+ )
+
+ assert destination_path is not None
+
+ target = os.path.join(os.getcwd(), destination_path)
+ if os.path.isdir(target):
if not self.file:
raise MissingFilenameError
- # saving an original vtt file will overwrite the file
- # and for files read from other formats will save as vtt
- # with the same name and location
- return os.path.splitext(self.file)[0] + '.' + extension
- else:
- target = os.path.join(os.getcwd(), output)
- if os.path.isdir(target):
- # if an output is provided and it is a directory
- # the file will be saved in that location with the same name
- filename = os.path.splitext(os.path.basename(self.file))[0]
- return os.path.join(target, '{}.{}'.format(filename, extension))
- else:
- if target[-3:].lower() != extension:
- target += '.' + extension
- # otherwise the file will be written in the specified location
- return target
-
- def save(self, output=''):
- """Save the document.
- If no output is provided the file will be saved in the same location. Otherwise output
- can determine a target directory or file.
- """
- self.file = self._get_output_file(output)
- with open(self.file, 'w', encoding='utf-8') as f:
- self.write(f)
-
- def save_as_srt(self, output=''):
- self.file = self._get_output_file(output, extension='srt')
- with open(self.file, 'w', encoding='utf-8') as f:
- self.write(f, format='srt')
-
- def write(self, f, format='vtt'):
+
+ # store the file in specified directory
+ base_name = os.path.splitext(os.path.basename(self.file))[0]
+ new_filename = f'{base_name}.{extension}'
+ return os.path.join(target, new_filename)
+
+ if target[-4:].lower() != f'.{extension}':
+ target = f'{target}.{extension}'
+
+ # store the file in the specified full path
+ return target
+
+ def save(
+ self,
+ output: typing.Optional[str] = None,
+ encoding: typing.Optional[str] = None,
+ add_bom: typing.Optional[bool] = None
+ ):
+ """
+ Save the WebVTT captions to a file.
+
+ :param output: destination path of the file
+ :param encoding: encoding of the file
+ :param add_bom: save the file with Byte Order Mark
+
+ :raises MissingFilenameError: if output cannot be determined
+ """
+ self.file = self._get_destination_file(output)
+ encoding = encoding or self.encoding
+
+ if add_bom is None and self._has_bom:
+ add_bom = True
+
+ with open(self.file, 'w', encoding=encoding) as f:
+ if add_bom and encoding in utils.CODEC_BOMS:
+ f.write(utils.CODEC_BOMS[encoding].decode(encoding))
+
+ vtt.write(
+ f,
+ self.captions,
+ self.styles,
+ self.header_comments,
+ self.footer_comments
+ )
+
+ def save_as_srt(
+ self,
+ output: typing.Optional[str] = None,
+ encoding: typing.Optional[str] = DEFAULT_ENCODING
+ ):
+ """
+ Save the WebVTT captions to a file in SubRip format.
+
+ :param output: destination path of the file
+ :param encoding: encoding of the file
+
+ :raises MissingFilenameError: if output cannot be determined
+ """
+ self.file = self._get_destination_file(output, extension='srt')
+ with open(self.file, 'w', encoding=encoding) as f:
+ srt.write(f, self.captions)
+
+ def write(
+ self,
+ f: typing.IO[str],
+ format: str = 'vtt'
+ ):
+ """
+ Save the WebVTT captions to a file-like object.
+
+ :param f: destination file-like object
+ :param format: the format to use (`vtt` or `srt`)
+
+ :raises MissingFilenameError: if output cannot be determined
+ """
if format == 'vtt':
- WebVTTWriter().write(self._captions, f)
- elif format == 'srt':
- SRTWriter().write(self._captions, f)
-# elif output_format == OutputFormat.SBV:
-# SBVWriter().write(self._captions, f)
+ return vtt.write(f,
+ self.captions,
+ self.styles,
+ self.header_comments,
+ self.footer_comments
+ )
+ if format == 'srt':
+ return srt.write(f, self.captions)
- @staticmethod
- def list_formats():
- """Provides a list of supported formats that this class can read from."""
- return ('WebVTT (.vtt)', 'SubRip (.srt)', 'YouTube SBV (.sbv)')
+ raise ValueError(f'Format {format} is not supported.')
- @property
- def captions(self):
- """Returns the list of captions."""
- return self._captions
+ def iter_slice(
+ self,
+ start: typing.Optional[str] = None,
+ end: typing.Optional[str] = None
+ ) -> typing.Generator[Caption, None, None]:
+ """
+ Iterate a slice of the captions based on a time range.
+
+ :param start: start timestamp of the range
+ :param end: end timestamp of the range
+ :returns: generator of Captions
+ """
+ start_time = Timestamp.from_string(start) if start else None
+ end_time = Timestamp.from_string(end) if end else None
+
+ for caption in self.captions:
+ if (
+ (not start_time or caption.start_time >= start_time) and
+ (not end_time or caption.end_time <= end_time)
+ ):
+ yield caption
@property
def total_length(self):
"""Returns the total length of the captions."""
- if not self._captions:
+ if not self.captions:
return 0
- return int(self._captions[-1].end_in_seconds) - int(self._captions[0].start_in_seconds)
-
- @property
- def styles(self):
- return self._styles
+ return (
+ self.captions[-1].end_in_seconds -
+ self.captions[0].start_in_seconds
+ )
@property
- def content(self):
+ def content(self) -> str:
"""
- Return webvtt content with webvtt formatting.
+ Return the webvtt capions as string.
This property is useful in cases where the webvtt content is needed
- but no file saving on the system is required.
+ but no file-like destination is required. Storage in DB for instance.
"""
- return WebVTTWriter().webvtt_content(self._captions)
+ return vtt.to_str(
+ self.captions,
+ self.styles,
+ self.header_comments,
+ self.footer_comments
+ )
diff --git a/webvtt/writers.py b/webvtt/writers.py
deleted file mode 100644
index 5ec551b..0000000
--- a/webvtt/writers.py
+++ /dev/null
@@ -1,41 +0,0 @@
-
-class WebVTTWriter(object):
-
- def write(self, captions, f):
- f.write(self.webvtt_content(captions))
-
- def webvtt_content(self, captions):
- """
- Return captions content with webvtt formatting.
- """
- output = ["WEBVTT"]
- for caption in captions:
- output.append("")
- if caption.identifier:
- output.append(caption.identifier)
- output.append('{} --> {}'.format(caption.start, caption.end))
- output.extend(caption.lines)
- return '\n'.join(output)
-
-
-class SRTWriter(object):
-
- def write(self, captions, f):
- for line_number, caption in enumerate(captions, start=1):
- f.write('{}\n'.format(line_number))
- f.write('{} --> {}\n'.format(self._to_srt_timestamp(caption.start_in_seconds),
- self._to_srt_timestamp(caption.end_in_seconds)))
- f.writelines(['{}\n'.format(l) for l in caption.lines])
- f.write('\n')
-
- def _to_srt_timestamp(self, total_seconds):
- hours = int(total_seconds / 3600)
- minutes = int(total_seconds / 60 - hours * 60)
- seconds = int(total_seconds - hours * 3600 - minutes * 60)
- milliseconds = round((total_seconds - seconds - hours * 3600 - minutes * 60)*1000)
-
- return '{:02d}:{:02d}:{:02d},{:03d}'.format(hours, minutes, seconds, milliseconds)
-
-
-class SBVWriter(object):
- pass