Skip to content

Commit

Permalink
Merge pull request #13 from Pikurrot/save-srt
Browse files Browse the repository at this point in the history
Allow saving as .srt for subtitles
  • Loading branch information
Pikurrot authored Feb 10, 2024
2 parents 13135b7 + 0017ef5 commit dee7b70
Showing 1 changed file with 60 additions and 7 deletions.
67 changes: 60 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def save_alignments_to_json(alignment_dict, save_dir, name="alignments.json"):
json.dump(alignment_dict, f, indent=4)
return json_path

def save_alignments_to_srt(subtitles_dict, save_dir, name="subtitles.srt"):
srt_path = os.path.join(save_dir, name)
print(f"Saving subtitles to {srt_path}...")
with open(srt_path, "w", encoding="utf-8") as f:
for sub in subtitles_dict:
f.write(f"{sub['number']}\n{sub['start']} --> {sub['end']}\n{sub['text']}\n\n")
return srt_path

def load_and_save_audio(audio_path, micro_audio, save_audio, save_dir):
if micro_audio:
print("Saving micro audio...")
Expand Down Expand Up @@ -114,6 +122,42 @@ def format_alignments(alignments):
formatted_transcription.append(formatted_line)
return "\n\n".join(formatted_transcription)

def alignments2subtitles(subtitles, max_line_length=40):
def sec2timesrt(sec):
# Convert seconds to HH:MM:SS,mmm format
hours, remainder = divmod(sec, 3600)
minutes, seconds = divmod(remainder, 60)
milliseconds = int((seconds - int(seconds)) * 1000)
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02},{milliseconds:03}"

def split_text(text, max_line_length):
# Split text into multiple lines based on max line length
lines = []
while text:
if len(text) <= max_line_length:
lines.append(text)
break
else:
# Find the nearest space before max_line_length
split_at = text.rfind(" ", 0, max_line_length + 1)
if split_at == -1: # No spaces found, force split
split_at = max_line_length
lines.append(text[:split_at])
text = text[split_at+1:] # Skip the space
return "\n".join(lines)

converted_subtitles = []
for index, sub in enumerate(subtitles, start=1):
converted_sub = {
"number": index,
"start": sec2timesrt(sub["start"]),
"end": sec2timesrt(sub["end"]),
"text": split_text(sub["text"], max_line_length)
}
converted_subtitles.append(converted_sub)

return converted_subtitles

def release_whisper():
global g_model, g_params
del g_model
Expand Down Expand Up @@ -168,7 +212,8 @@ def transcribe_whisperx(
save_transcription: bool,
save_alignments: bool,
save_in_subfolder: bool,
preserve_name: bool
preserve_name: bool,
alignments_format: str
) -> Tuple[str, str, str, str]:

print("Inputs received. Starting...")
Expand Down Expand Up @@ -205,7 +250,8 @@ def transcribe_custom(
save_transcription: bool,
save_alignments: bool,
save_in_subfolder: bool,
preserve_name: bool
preserve_name: bool,
alignments_format: str
) -> Tuple[str, str, str, str]:

print("Inputs received. Starting...")
Expand Down Expand Up @@ -282,12 +328,17 @@ def _transcribe() -> Tuple[str, str]:
aligned_result = whisperx.align(result["segments"], g_model_a, g_model_a_metadata, audio, g_params["device"], return_char_alignments=False)
time_align = time.time() - time_align
if g_params["save_alignments"]:
align_format = g_params["alignments_format"].lower()
if g_params["preserve_name"]:
audio_name = os.path.basename(g_params["audio_path"]).split(".")[0]
save_name = f"{audio_name}_timestamps.json"
save_name = f"{audio_name}_timestamps." + align_format
else:
save_name = "_timestamps.json"
save_alignments_to_json(aligned_result, save_dir, save_name)
save_name = "timestamps." + align_format
if align_format == "json":
save_alignments_to_json(aligned_result, save_dir, save_name)
elif align_format == "srt":
subtitles = alignments2subtitles(aligned_result["segments"], max_line_length=50)
save_alignments_to_srt(subtitles, save_dir, save_name)
if g_params["release_memory"]:
release_align()
print("Done!")
Expand Down Expand Up @@ -343,6 +394,7 @@ def _transcribe() -> Tuple[str, str]:
save_root = gr.Textbox(label="Save Path", placeholder="outputs", lines=1)
save_in_subfolder = gr.Checkbox(value=True, label="Save in Subfolder", info="Save files in a subfolder \"YYYY-MM-DD/NNNN/\" in the \"Save Path\" folder. CAUTION: if unchecked, files may be overwritten.")
preserve_name = gr.Checkbox(value=False, label="Preserve Name", info="Preserve the original name of the audio file when saving. E.g. \"<audio_name>_transcription.txt\". Only works for uploaded audio.")
alignments_format = gr.Radio(["JSON", "SRT"], value="JSON", label="Alignments Format", interactive=True)
gr.Markdown("""### Optimizations""")
compute_type_select = gr.Radio(["int8", "float16", "float32"], value = "int8", label="Compute Type", info="int8 is fastest and requires less memory. float32 is more accurate (Your device may not support some data types). "+release_whisper_message)
batch_size_slider = gr.Slider(1, 128, value = 1, step=1, label="Batch Size", info="Larger batch sizes may be faster but require more memory.")
Expand Down Expand Up @@ -378,6 +430,7 @@ def _transcribe() -> Tuple[str, str]:
save_root2 = gr.Textbox(label="Save Path", placeholder="outputs", lines=1)
save_in_subfolder2 = gr.Checkbox(value=True, label="Save in Subfolder", info="Save files in a subfolder \"YYYY-MM-DD/NNNN/\" in the \"Save Path\" folder. CAUTION: if unchecked, files may be overwritten.")
preserve_name2 = gr.Checkbox(value=False, label="Preserve Name", info="Preserve the original name of the audio file when saving. E.g. \"<audio_name>_transcription.txt\". Only works for uploaded audio.")
alignments_format2 = gr.Radio(["JSON", "SRT"], value="JSON", label="Alignments Format", interactive=True)
gr.Markdown("""### Optimizations""")
compute_type_select2 = gr.Radio(["float16", "float32"], value = "float16", label="Compute Type", info="float16 is faster and requires less memory. float32 is more accurate (Your device may not support some data types). "+release_whisper_message)
batch_size_slider2 = gr.Slider(1, 128, value = 1, step=1, label="Batch Size", info="Larger batch sizes may be faster but require more memory.")
Expand All @@ -395,11 +448,11 @@ def _transcribe() -> Tuple[str, str]:


submit_button.click(transcribe_whisperx,
inputs=[model_select, audio_upload, audio_record, device_select, batch_size_slider, compute_type_select, language_select, chunk_size_slider, beam_size_slider, release_memory_checkbox, save_root, save_audio, save_transcription, save_alignments, save_in_subfolder, preserve_name],
inputs=[model_select, audio_upload, audio_record, device_select, batch_size_slider, compute_type_select, language_select, chunk_size_slider, beam_size_slider, release_memory_checkbox, save_root, save_audio, save_transcription, save_alignments, save_in_subfolder, preserve_name, alignments_format],
outputs=[transcription_output, alignments_output, time_transcribe, time_align])

submit_button2.click(transcribe_custom,
inputs=[model_select2, audio_upload2, audio_record2, device_select2, batch_size_slider2, compute_type_select2, language_select2, chunk_size_slider2, beam_size_slider2, release_memory_checkbox2, save_root2, save_audio2, save_transcription2, save_alignments2, save_in_subfolder2, preserve_name2],
inputs=[model_select2, audio_upload2, audio_record2, device_select2, batch_size_slider2, compute_type_select2, language_select2, chunk_size_slider2, beam_size_slider2, release_memory_checkbox2, save_root2, save_audio2, save_transcription2, save_alignments2, save_in_subfolder2, preserve_name2, alignments_format2],
outputs=[transcription_output2, alignments_output2, time_transcribe2, time_align2])

release_memory_button.click(release_memory_models)
Expand Down

0 comments on commit dee7b70

Please sign in to comment.