From 72270376f347835d48d28c3486e694a8c6fa2aed Mon Sep 17 00:00:00 2001 From: miro Date: Wed, 23 Oct 2024 15:34:38 +0100 Subject: [PATCH] feat:b64 previously it was alowed to "upload" audio as b64 and the transcription was directly injected for utterance handling this commit adds support for returning the transcript instead of injecting it --- ovos_dinkum_listener/plugins.py | 21 +++++++---- ovos_dinkum_listener/service.py | 62 +++++++++++++++++++++++++-------- 2 files changed, 61 insertions(+), 22 deletions(-) diff --git a/ovos_dinkum_listener/plugins.py b/ovos_dinkum_listener/plugins.py index 3ef83bd..23d0272 100644 --- a/ovos_dinkum_listener/plugins.py +++ b/ovos_dinkum_listener/plugins.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, List, Tuple +from typing import Any, Dict, Optional, List, Tuple, Union from ovos_config.config import Configuration from ovos_plugin_manager.stt import OVOSSTTFactory @@ -56,18 +56,25 @@ def create_streaming_thread(self): return FakeStreamThread(self.queue, self.lang, self.engine, sample_rate, sample_width) - def transcribe(self, audio: Optional = None, + def transcribe(self, audio: Optional[Union[bytes, AudioData]] = None, lang: Optional[str] = None) -> List[Tuple[str, float]]: """transcribe audio data to a list of possible transcriptions and respective confidences""" # plugins expect AudioData objects - audiod = AudioData(audio or self.stream.buffer.read(), - sample_rate=self.stream.sample_rate, - sample_width=self.stream.sample_width) - transcripts = self.engine.transcribe(audiod, lang) if audio is None: + audiod = AudioData(self.stream.buffer.read(), + sample_rate=self.stream.sample_rate, + sample_width=self.stream.sample_width) self.stream.buffer.clear() - return transcripts + elif isinstance(audio, bytes): + audiod = AudioData(audio, + sample_rate=self.stream.sample_rate, + sample_width=self.stream.sample_width) + elif isinstance(audio, AudioData): + audiod = audio + else: + raise ValueError(f"'audio' must be 'bytes' or 'AudioData', got '{type(audio)}'") + return self.engine.transcribe(audiod, lang) def load_stt_module(config: Dict[str, Any] = None) -> StreamingSTT: diff --git a/ovos_dinkum_listener/service.py b/ovos_dinkum_listener/service.py index a85c7ec..cfe685f 100644 --- a/ovos_dinkum_listener/service.py +++ b/ovos_dinkum_listener/service.py @@ -12,18 +12,18 @@ import base64 import json import subprocess -import time import wave +from distutils.spawn import find_executable from enum import Enum from hashlib import md5 from os.path import dirname from pathlib import Path from tempfile import NamedTemporaryFile from threading import Thread, RLock, Event -from typing import List, Tuple +from typing import List, Tuple, Optional import speech_recognition as sr -from distutils.spawn import find_executable +import time from ovos_bus_client import MessageBusClient from ovos_bus_client.message import Message from ovos_bus_client.session import SessionManager @@ -31,6 +31,8 @@ from ovos_config.locations import get_xdg_data_save_path from ovos_plugin_manager.microphone import OVOSMicrophoneFactory from ovos_plugin_manager.stt import get_stt_lang_configs, get_stt_supported_langs, get_stt_module_configs +from ovos_plugin_manager.templates.stt import STT +from ovos_plugin_manager.templates.vad import VADEngine from ovos_plugin_manager.utils.tts_cache import hash_sentence from ovos_plugin_manager.vad import OVOSVADFactory from ovos_plugin_manager.vad import get_vad_configs @@ -38,11 +40,12 @@ from ovos_utils.log import LOG, log_deprecation from ovos_utils.process_utils import ProcessStatus, StatusCallbackMap, ProcessState +from ovos_dinkum_listener._util import _TemplateFilenameFormatter from ovos_dinkum_listener.plugins import load_stt_module, load_fallback_stt from ovos_dinkum_listener.transformers import AudioTransformersService from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop, ListeningMode, ListeningState from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer -from ovos_dinkum_listener._util import _TemplateFilenameFormatter + try: from ovos_backend_client.api import DatasetApi except ImportError: @@ -150,7 +153,12 @@ class OVOSDinkumVoiceService(Thread): def __init__(self, on_ready=on_ready, on_error=on_error, on_stopping=on_stopping, on_alive=on_alive, on_started=on_started, watchdog=lambda: None, mic=None, - bus=None, validate_source=True, *args, **kwargs): + bus=None, validate_source=True, + stt: Optional[STT] = None, + fallback_stt: Optional[STT] = None, + vad: Optional[VADEngine] = None, + disable_fallback: bool = False, + *args, **kwargs): """ watchdog: (callable) function to call periodically indicating operational status. @@ -186,9 +194,14 @@ def __init__(self, on_ready=on_ready, on_error=on_error, self.mic = mic or OVOSMicrophoneFactory.create(microphone_config) self.hotwords = HotwordContainer(self.bus) - self.vad = OVOSVADFactory.create() - self.stt = load_stt_module() - self.fallback_stt = load_fallback_stt() + self.vad = vad or OVOSVADFactory.create() + self.stt = stt or load_stt_module() + self.disable_fallback = disable_fallback + self.disable_reload = stt is None + if disable_fallback: + self.fallback_stt = None + else: + self.fallback_stt = fallback_stt or load_fallback_stt() self.transformers = AudioTransformersService(self.bus, self.config) self._load_lock = RLock() @@ -374,6 +387,7 @@ def register_event_handlers(self): self.bus.on('recognizer_loop:sleep', self._handle_sleep) self.bus.on('recognizer_loop:wake_up', self._handle_wake_up) + self.bus.on('recognizer_loop:b64_transcribe', self._handle_b64_transcribe) self.bus.on('recognizer_loop:b64_audio', self._handle_b64_audio) self.bus.on('recognizer_loop:record_stop', self._handle_stop_recording) self.bus.on('recognizer_loop:state.set', self._handle_change_state) @@ -671,15 +685,15 @@ def __normtranscripts(self, transcripts: List[Tuple[str, float]]) -> List[str]: ] hallucinations = self.config.get("hallucination_list", default_hallucinations) \ if self.config.get("filter_hallucinations", True) else [] - utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts] + utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts if u[0]] filtered_hutts = [u for u in utts if u and u.lower() not in hallucinations] - hutts = [u for u in utts if u and u not in filtered_hutts] + hutts = [u for u in utts if u not in filtered_hutts] if hutts: LOG.debug(f"Filtered hallucinations: {hutts}") return filtered_hutts def _stt_text(self, transcripts: List[Tuple[str, float]], stt_context: dict): - utts = self.__normtranscripts(transcripts) + utts = self.__normtranscripts(transcripts) if transcripts else [] LOG.debug(f"STT: {utts}") if utts: lang = stt_context.get("lang") or Configuration().get("lang", "en-us") @@ -922,8 +936,25 @@ def _handle_sound_played(self, message: Message): if self.voice_loop.state == ListeningState.CONFIRMATION: self.voice_loop.state = ListeningState.BEFORE_COMMAND + def _handle_b64_transcribe(self, message: Message): + """ transcribe base64 encoded audio and return result via message""" + LOG.debug("Handling Base64 STT request") + b64audio = message.data["audio"] + lang = message.data.get("lang", self.voice_loop.stt.lang) + + wav_data = base64.b64decode(b64audio) + + self.voice_loop.stt.stream_start() + audio = bytes2audiodata(wav_data) + utterances = self.voice_loop.stt.transcribe(audio, lang) + self.voice_loop.stt.stream_stop() + + LOG.debug(f"transcripts: {utterances}") + self.bus.emit(message.response({"transcriptions": utterances, "lang": lang})) + def _handle_b64_audio(self, message: Message): - """ transcribe base64 encoded audio """ + """ transcribe base64 encoded audio and inject result into bus""" + LOG.debug("Handling Base64 Incoming Audio") b64audio = message.data["audio"] lang = message.data.get("lang", self.voice_loop.stt.lang) @@ -1055,7 +1086,7 @@ def reload_configuration(self): Configuration object reports a change """ if self._config_hash() == self._applied_config_hash: - LOG.info(f"No relevant configuration changed") + LOG.debug(f"No relevant configuration changed") return LOG.info("Reloading changed configuration") if not self._load_lock.acquire(timeout=30): @@ -1071,7 +1102,7 @@ def reload_configuration(self): # Configuration changed, update status and reload self.status.set_alive() - if new_hash['stt'] != self._applied_config_hash['stt']: + if not self.disable_reload and new_hash['stt'] != self._applied_config_hash['stt']: LOG.info(f"Reloading STT") if self.stt: LOG.debug(f"old={self.stt.__class__}: {self.stt.config}") @@ -1083,7 +1114,8 @@ def reload_configuration(self): if self.stt: LOG.debug(f"new={self.stt.__class__}: {self.stt.config}") - if new_hash['fallback'] != self._applied_config_hash['fallback']: + if not self.disable_reload and not self.disable_fallback and new_hash['fallback'] != \ + self._applied_config_hash['fallback']: LOG.info(f"Reloading Fallback STT") if self.fallback_stt: LOG.debug(f"old={self.fallback_stt.__class__}: "