From 63fc0d526280424b36777b294862cb246ae23455 Mon Sep 17 00:00:00 2001 From: Alejandro Mendez Date: Thu, 30 May 2024 13:26:13 +0200 Subject: [PATCH] Extend from_buffer support #32 --- CHANGELOG.rst | 3 + docs/source/usage.rst | 35 ++++++++++ tests/test_webvtt.py | 155 ++++++++++++++++++++++++++++++++++++++++-- webvtt/webvtt.py | 39 ++++++++--- 4 files changed, 219 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5d4b573..6ed8a4d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,10 @@ History 0.5.1 [Unreleased] ------------------ +* Added voice span support (#55) +* Extended from_buffer support to allow BytesIO and also other format conversions (#32) * Fixed save SRT to not include cue tags, thanks to `@lilaboc `_ (#56) +* Fixed saved caption to include a line break after the last caption as per standard (#49) 0.5.0 (15-05-2024) ------------------ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 58ad3f8..08c9911 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -13,6 +13,7 @@ Reading WebVTT caption files print(caption.start) # start timestamp in text format print(caption.end) # end timestamp in text format print(caption.text) # caption text + print(caption.voice) # voice span if present # you can also iterate over the lines of a particular caption for line in vtt[0].lines: @@ -48,6 +49,40 @@ Reading WebVTT caption files from a file-like object print(caption.text) +Reading WebVTT caption files from a BytesIO object +-------------------------------------------------- + +.. code-block:: python + + import webvtt + from io import BytesIO + + with open('captions.vtt', 'rb') as f: + buffer = BytesIO(f.read()) + + for caption in webvtt.from_buffer(buffer): + print(caption.start) + print(caption.end) + print(caption.text) + + +Reading caption files in other formats from a BytesIO object +------------------------------------------------------------ + +.. code-block:: python + + import webvtt + from io import BytesIO + + with open('captions.srt', 'rb') as f: + buffer = BytesIO(f.read()) + + # formats supported: vtt, srt, sbv + for caption in webvtt.from_buffer(buffer, format='srt'): + print(caption.start) + print(caption.end) + print(caption.text) + Reading WebVTT captions from a string ------------------------------------- diff --git a/tests/test_webvtt.py b/tests/test_webvtt.py index 94ae6c5..2126d78 100644 --- a/tests/test_webvtt.py +++ b/tests/test_webvtt.py @@ -256,10 +256,29 @@ def test_save_specific_filename_no_extension(self): def test_from_buffer(self): with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: - self.assertIsInstance( - webvtt.from_buffer(f).captions, - list - ) + vtt = webvtt.from_buffer(f) + self.assertEqual(len(vtt), 16) + self.assertEqual( + str(vtt), + textwrap.dedent(''' + 00:00:00.500 00:00:07.000 Caption text #1 + 00:00:07.000 00:00:11.890 Caption text #2 + 00:00:11.890 00:00:16.320 Caption text #3 + 00:00:16.320 00:00:21.580 Caption text #4 + 00:00:21.580 00:00:23.880 Caption text #5 + 00:00:23.880 00:00:27.280 Caption text #6 + 00:00:27.280 00:00:30.280 Caption text #7 + 00:00:30.280 00:00:36.510 Caption text #8 + 00:00:36.510 00:00:38.870 Caption text #9 + 00:00:38.870 00:00:45.000 Caption text #10 + 00:00:45.000 00:00:47.000 Caption text #11 + 00:00:47.000 00:00:50.970 Caption text #12 + 00:00:50.970 00:00:54.440 Caption text #13 + 00:00:54.440 00:00:58.600 Caption text #14 + 00:00:58.600 00:01:01.350 Caption text #15 + 00:01:01.350 00:01:04.300 Caption text #16 + ''' + ).strip()) def test_deprecated_read_buffer(self): with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f: @@ -311,6 +330,134 @@ def test_read_malformed_buffer(self): with self.assertRaises(MalformedFileError): webvtt.from_buffer(buffer) + def test_read_buffer_for_vtt_content(self): + buffer = io.StringIO(textwrap.dedent('''\ + WEBVTT\r + \r + 00:00:00.500 --> 00:00:07.000\r + Caption text #1\r + \r + 00:00:07.000 --> 00:00:11.890\r + Caption text #2\r + \r + 00:00:11.890 --> 00:00:16.320\r + Caption text #3\r + ''') + ) + vtt = webvtt.from_buffer(buffer, format='vtt') + self.assertEqual(len(vtt), 3) + + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + '00:00:07.000 00:00:11.890 Caption text #2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + + def test_read_buffer_for_srt_content(self): + buffer = io.StringIO(textwrap.dedent('''\ + 0\r + 00:00:00,500 --> 00:00:07,000\r + Caption text #1\r + \r + 1\r + 00:00:07,000 --> 00:00:11,890\r + Caption text #2\r + \r + 2\r + 00:00:11,890 --> 00:00:16,320\r + Caption text #3\r + ''') + ) + vtt = webvtt.from_buffer(buffer, format='srt') + self.assertEqual(len(vtt), 3) + + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + '00:00:07.000 00:00:11.890 Caption text #2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + + def test_read_buffer_for_sbv_content(self): + buffer = io.StringIO(textwrap.dedent('''\ + 00:00:00.500,00:00:07.000\r + Caption text #1\r + \r + 00:00:07.000,00:00:11.890\r + Caption text #2\r + \r + 00:00:11.890,00:00:16.320\r + Caption text #3\r + ''') + ) + vtt = webvtt.from_buffer(buffer, format='sbv') + self.assertEqual(len(vtt), 3) + + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + '00:00:07.000 00:00:11.890 Caption text #2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + + def test_read_buffer_unsupported_format(self): + self.assertRaises( + ValueError, + webvtt.from_buffer, + io.StringIO(), + format='ttt' + ) + + def test_read_bytesio_buffer_for_srt_content(self): + buffer = io.BytesIO(textwrap.dedent('''\ + 0\r + 00:00:00,500 --> 00:00:07,000\r + Caption text #1\r + \r + 1\r + 00:00:07,000 --> 00:00:11,890\r + Caption text #2\r + \r + 2\r + 00:00:11,890 --> 00:00:16,320\r + Caption text #3\r + ''').encode('utf-8') + ) + vtt = webvtt.from_buffer(buffer, format='srt') + self.assertEqual(len(vtt), 3) + + self.assertEqual( + str(vtt[0]), + '00:00:00.500 00:00:07.000 Caption text #1' + ) + self.assertEqual( + str(vtt[1]), + '00:00:07.000 00:00:11.890 Caption text #2' + ) + self.assertEqual( + str(vtt[2]), + '00:00:11.890 00:00:16.320 Caption text #3' + ) + def test_captions(self): captions = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').captions self.assertIsInstance( diff --git a/webvtt/webvtt.py b/webvtt/webvtt.py index 499b6df..4383148 100644 --- a/webvtt/webvtt.py +++ b/webvtt/webvtt.py @@ -1,8 +1,10 @@ """WebVTT module.""" import os +import io import typing import warnings +from functools import partial from . import vtt, utils from . import srt @@ -115,7 +117,8 @@ def read_buffer( @classmethod def from_buffer( cls, - buffer: typing.Iterator[str] + buffer: typing.Union[typing.Iterable[str], io.BytesIO], + format: str = 'vtt' ) -> 'WebVTT': """ Read WebVTT captions from a file-like object. @@ -124,17 +127,35 @@ def from_buffer( io.StringIO object, tempfile.TemporaryFile object, etc. :param buffer: the file-like object to read captions from + :param format: the format of the data (vtt, srt or sbv) :returns: a `WebVTT` instance """ - output = vtt.parse(cls._get_lines(buffer)) + if isinstance(buffer, io.BytesIO): + buffer = (line.decode('utf-8') for line in buffer) - return cls( - file=getattr(buffer, 'name', None), - captions=output.captions, - styles=output.styles, - header_comments=output.header_comments, - footer_comments=output.footer_comments - ) + _cls = partial(cls, file=getattr(buffer, 'name', None)) + + if format == 'vtt': + output = vtt.parse(cls._get_lines(buffer)) + + return _cls( + captions=output.captions, + styles=output.styles, + header_comments=output.header_comments, + footer_comments=output.footer_comments + ) + + if format == 'srt': + return _cls( + captions=srt.parse(cls._get_lines(buffer)) + ) + + if format == 'sbv': + return _cls( + captions=sbv.parse(cls._get_lines(buffer)) + ) + + raise ValueError(f'Format {format} is not supported.') @classmethod def from_srt(