Skip to content

Commit

Permalink
Extend from_buffer support #32
Browse files Browse the repository at this point in the history
  • Loading branch information
glut23 committed May 30, 2024
1 parent a7dee05 commit 63fc0d5
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ History
0.5.1 [Unreleased]
------------------

* Added voice span support (#55)
* Extended from_buffer support to allow BytesIO and also other format conversions (#32)
* Fixed save SRT to not include cue tags, thanks to `@lilaboc <https://github.com/lilaboc>`_ (#56)
* Fixed saved caption to include a line break after the last caption as per standard (#49)

0.5.0 (15-05-2024)
------------------
Expand Down
35 changes: 35 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Reading WebVTT caption files
print(caption.start) # start timestamp in text format
print(caption.end) # end timestamp in text format
print(caption.text) # caption text
print(caption.voice) # voice span if present
# you can also iterate over the lines of a particular caption
for line in vtt[0].lines:
Expand Down Expand Up @@ -48,6 +49,40 @@ Reading WebVTT caption files from a file-like object
print(caption.text)
Reading WebVTT caption files from a BytesIO object
--------------------------------------------------

.. code-block:: python
import webvtt
from io import BytesIO
with open('captions.vtt', 'rb') as f:
buffer = BytesIO(f.read())
for caption in webvtt.from_buffer(buffer):
print(caption.start)
print(caption.end)
print(caption.text)
Reading caption files in other formats from a BytesIO object
------------------------------------------------------------

.. code-block:: python
import webvtt
from io import BytesIO
with open('captions.srt', 'rb') as f:
buffer = BytesIO(f.read())
# formats supported: vtt, srt, sbv
for caption in webvtt.from_buffer(buffer, format='srt'):
print(caption.start)
print(caption.end)
print(caption.text)
Reading WebVTT captions from a string
-------------------------------------

Expand Down
155 changes: 151 additions & 4 deletions tests/test_webvtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,29 @@ def test_save_specific_filename_no_extension(self):

def test_from_buffer(self):
with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f:
self.assertIsInstance(
webvtt.from_buffer(f).captions,
list
)
vtt = webvtt.from_buffer(f)
self.assertEqual(len(vtt), 16)
self.assertEqual(
str(vtt),
textwrap.dedent('''
00:00:00.500 00:00:07.000 Caption text #1
00:00:07.000 00:00:11.890 Caption text #2
00:00:11.890 00:00:16.320 Caption text #3
00:00:16.320 00:00:21.580 Caption text #4
00:00:21.580 00:00:23.880 Caption text #5
00:00:23.880 00:00:27.280 Caption text #6
00:00:27.280 00:00:30.280 Caption text #7
00:00:30.280 00:00:36.510 Caption text #8
00:00:36.510 00:00:38.870 Caption text #9
00:00:38.870 00:00:45.000 Caption text #10
00:00:45.000 00:00:47.000 Caption text #11
00:00:47.000 00:00:50.970 Caption text #12
00:00:50.970 00:00:54.440 Caption text #13
00:00:54.440 00:00:58.600 Caption text #14
00:00:58.600 00:01:01.350 Caption text #15
00:01:01.350 00:01:04.300 Caption text #16
'''
).strip())

def test_deprecated_read_buffer(self):
with open(PATH_TO_SAMPLES / 'sample.vtt', 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -311,6 +330,134 @@ def test_read_malformed_buffer(self):
with self.assertRaises(MalformedFileError):
webvtt.from_buffer(buffer)

def test_read_buffer_for_vtt_content(self):
buffer = io.StringIO(textwrap.dedent('''\
WEBVTT\r
\r
00:00:00.500 --> 00:00:07.000\r
Caption text #1\r
\r
00:00:07.000 --> 00:00:11.890\r
Caption text #2\r
\r
00:00:11.890 --> 00:00:16.320\r
Caption text #3\r
''')
)
vtt = webvtt.from_buffer(buffer, format='vtt')
self.assertEqual(len(vtt), 3)

self.assertEqual(
str(vtt[0]),
'00:00:00.500 00:00:07.000 Caption text #1'
)
self.assertEqual(
str(vtt[1]),
'00:00:07.000 00:00:11.890 Caption text #2'
)
self.assertEqual(
str(vtt[2]),
'00:00:11.890 00:00:16.320 Caption text #3'
)

def test_read_buffer_for_srt_content(self):
buffer = io.StringIO(textwrap.dedent('''\
0\r
00:00:00,500 --> 00:00:07,000\r
Caption text #1\r
\r
1\r
00:00:07,000 --> 00:00:11,890\r
Caption text #2\r
\r
2\r
00:00:11,890 --> 00:00:16,320\r
Caption text #3\r
''')
)
vtt = webvtt.from_buffer(buffer, format='srt')
self.assertEqual(len(vtt), 3)

self.assertEqual(
str(vtt[0]),
'00:00:00.500 00:00:07.000 Caption text #1'
)
self.assertEqual(
str(vtt[1]),
'00:00:07.000 00:00:11.890 Caption text #2'
)
self.assertEqual(
str(vtt[2]),
'00:00:11.890 00:00:16.320 Caption text #3'
)

def test_read_buffer_for_sbv_content(self):
buffer = io.StringIO(textwrap.dedent('''\
00:00:00.500,00:00:07.000\r
Caption text #1\r
\r
00:00:07.000,00:00:11.890\r
Caption text #2\r
\r
00:00:11.890,00:00:16.320\r
Caption text #3\r
''')
)
vtt = webvtt.from_buffer(buffer, format='sbv')
self.assertEqual(len(vtt), 3)

self.assertEqual(
str(vtt[0]),
'00:00:00.500 00:00:07.000 Caption text #1'
)
self.assertEqual(
str(vtt[1]),
'00:00:07.000 00:00:11.890 Caption text #2'
)
self.assertEqual(
str(vtt[2]),
'00:00:11.890 00:00:16.320 Caption text #3'
)

def test_read_buffer_unsupported_format(self):
self.assertRaises(
ValueError,
webvtt.from_buffer,
io.StringIO(),
format='ttt'
)

def test_read_bytesio_buffer_for_srt_content(self):
buffer = io.BytesIO(textwrap.dedent('''\
0\r
00:00:00,500 --> 00:00:07,000\r
Caption text #1\r
\r
1\r
00:00:07,000 --> 00:00:11,890\r
Caption text #2\r
\r
2\r
00:00:11,890 --> 00:00:16,320\r
Caption text #3\r
''').encode('utf-8')
)
vtt = webvtt.from_buffer(buffer, format='srt')
self.assertEqual(len(vtt), 3)

self.assertEqual(
str(vtt[0]),
'00:00:00.500 00:00:07.000 Caption text #1'
)
self.assertEqual(
str(vtt[1]),
'00:00:07.000 00:00:11.890 Caption text #2'
)
self.assertEqual(
str(vtt[2]),
'00:00:11.890 00:00:16.320 Caption text #3'
)

def test_captions(self):
captions = webvtt.read(PATH_TO_SAMPLES / 'sample.vtt').captions
self.assertIsInstance(
Expand Down
39 changes: 30 additions & 9 deletions webvtt/webvtt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""WebVTT module."""

import os
import io
import typing
import warnings
from functools import partial

from . import vtt, utils
from . import srt
Expand Down Expand Up @@ -115,7 +117,8 @@ def read_buffer(
@classmethod
def from_buffer(
cls,
buffer: typing.Iterator[str]
buffer: typing.Union[typing.Iterable[str], io.BytesIO],
format: str = 'vtt'
) -> 'WebVTT':
"""
Read WebVTT captions from a file-like object.
Expand All @@ -124,17 +127,35 @@ def from_buffer(
io.StringIO object, tempfile.TemporaryFile object, etc.
:param buffer: the file-like object to read captions from
:param format: the format of the data (vtt, srt or sbv)
:returns: a `WebVTT` instance
"""
output = vtt.parse(cls._get_lines(buffer))
if isinstance(buffer, io.BytesIO):
buffer = (line.decode('utf-8') for line in buffer)

return cls(
file=getattr(buffer, 'name', None),
captions=output.captions,
styles=output.styles,
header_comments=output.header_comments,
footer_comments=output.footer_comments
)
_cls = partial(cls, file=getattr(buffer, 'name', None))

if format == 'vtt':
output = vtt.parse(cls._get_lines(buffer))

return _cls(
captions=output.captions,
styles=output.styles,
header_comments=output.header_comments,
footer_comments=output.footer_comments
)

if format == 'srt':
return _cls(
captions=srt.parse(cls._get_lines(buffer))
)

if format == 'sbv':
return _cls(
captions=sbv.parse(cls._get_lines(buffer))
)

raise ValueError(f'Format {format} is not supported.')

@classmethod
def from_srt(
Expand Down

0 comments on commit 63fc0d5

Please sign in to comment.