-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtranscriber.py
178 lines (156 loc) · 6.56 KB
/
transcriber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import io
from typing import List, Tuple
import pyaudio
import wave
from openai import OpenAI
import asyncio
from threading import Thread, Event
import webrtcvad
from utils import transcription_concat
import tempfile
FORMAT = pyaudio.paInt16 # Audio format
CHANNELS = 1 # Mono audio
RATE = 16000 # Sample rate
CHUNK_DURATION_MS = 30 # Frame duration in milliseconds
CHUNK = int(RATE * CHUNK_DURATION_MS / 1000)
MIN_TRANSCRIPTION_SIZE_MS = int(
os.getenv('UTTERTYPE_MIN_TRANSCRIPTION_SIZE_MS', 1500) # Minimum duration of speech to send to API in case of silence
)
class AudioTranscriber:
def __init__(self):
self.audio = pyaudio.PyAudio()
self.recording_finished = Event() # Threading event to end recording
self.recording_finished.set() # Initialize as finished
self.frames = []
self.audio_duration = 0
self.rolling_transcriptions: List[Tuple[int, str]] = [] # (idx, transcription)
self.rolling_requests: List[Thread] = [] # list of pending requests
self.event_loop = asyncio.get_event_loop()
self.vad = webrtcvad.Vad(1) # Voice Activity Detector, mode can be 0 to 3
self.transcriptions = asyncio.Queue()
def start_recording(self):
"""Start recording audio from the microphone."""
# Start a new recording in the background, do not block
def _record():
self.recording_finished = Event()
stream = self.audio.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
)
intermediate_trancriptions_idx = 0
while (
not self.recording_finished.is_set()
): # Keep recording until interrupted
data = stream.read(CHUNK)
self.audio_duration += CHUNK_DURATION_MS
is_speech = self.vad.is_speech(data, RATE)
current_audio_duration = len(self.frames) * CHUNK_DURATION_MS
if (
not is_speech
and current_audio_duration >= MIN_TRANSCRIPTION_SIZE_MS
): # silence
rolling_request = Thread(
target=self._intermediate_transcription,
args=(
intermediate_trancriptions_idx,
self._frames_to_wav(),
),
)
self.frames = []
self.rolling_requests.append(rolling_request)
rolling_request.start()
intermediate_trancriptions_idx += 1
self.frames.append(data)
# start recording in a new non-blocking thread
Thread(target=_record).start()
def stop_recording(self):
"""Stop the recording and reset variables"""
self.recording_finished.set()
self._finish_transcription()
self.frames = []
self.audio_duration = 0
self.rolling_requests = []
self.rolling_transcriptions = []
def _intermediate_transcription(self, idx, audio):
intermediate_transcription = self.transcribe_audio(audio)
self.rolling_transcriptions.append((idx, intermediate_transcription))
def _finish_transcription(self):
transcription = self.transcribe_audio(
self._frames_to_wav()
) # Last transcription
for request in self.rolling_requests: # Wait for rolling requests
request.join()
self.rolling_transcriptions.append(
(len(self.rolling_transcriptions), transcription)
)
sorted(self.rolling_transcriptions, key=lambda x: x[0]) # Sort by idx
transcriptions = [
t[1] for t in self.rolling_transcriptions
] # Get ordered transcriptions
self.event_loop.call_soon_threadsafe( # Put final combined result in finished queue
self.transcriptions.put_nowait,
(transcription_concat(transcriptions), self.audio_duration),
)
def _frames_to_wav(self):
buffer = io.BytesIO()
buffer.name = "tmp.wav"
wf = wave.open(buffer, "wb")
wf.setnchannels(CHANNELS)
wf.setsampwidth(self.audio.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b"".join(self.frames))
wf.close()
return buffer
def transcribe_audio(self, audio: io.BytesIO) -> str:
raise NotImplementedError("Please use a subclass of AudioTranscriber")
async def get_transcriptions(self):
"""
Asynchronously get transcriptions from the queue.
Returns (transcription string, audio duration in ms).
"""
while True:
transcription = await self.transcriptions.get()
yield transcription
self.transcriptions.task_done()
class WhisperAPITranscriber(AudioTranscriber):
def __init__(self, base_url, model_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.client = OpenAI(base_url=base_url)
@staticmethod
def create(*args, **kwargs):
base_url = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
model_name = os.getenv('OPENAI_MODEL_NAME', 'whisper-1')
return WhisperAPITranscriber(base_url, model_name)
def transcribe_audio(self, audio: io.BytesIO) -> str:
try:
transcription = self.client.audio.transcriptions.create(
model=self.model_name,
file=audio,
response_format="text",
language="en",
prompt="The following is normal speech or technical speech from an engineer.",
)
return transcription
except Exception as e:
print(f"Encountered Error: {e}")
return ""
class WhisperLocalMLXTranscriber(AudioTranscriber):
def __init__(self, model_type="distil-medium.en", *args, **kwargs):
super().__init__(*args, **kwargs)
from lightning_whisper_mlx import LightningWhisperMLX
self.model = LightningWhisperMLX(model_type)
def transcribe_audio(self, audio: io.BytesIO) -> str:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
tmpfile.write(audio.getvalue())
transcription = self.model.transcribe(tmpfile.name)["text"]
os.unlink(tmpfile.name)
return transcription
except Exception as e:
print(f"Encountered Error: {e}")
return ""