diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst index 438ecc11..8025ac35 100644 --- a/docs/source/changelog/changelog_3.0.rst +++ b/docs/source/changelog/changelog_3.0.rst @@ -5,6 +5,24 @@ 3.0 Changelog ************* +3.1.0 +----- + +- Fixed a bug where cutoffs were not properly modelled +- Added additional filter on create subset to not include utterances with cutoffs in smaller subsets +- Added the ability to specify HMM topologies for phones +- Fixed issues caused by validators not cleaning up temporary files and databases +- Added support for default and nonnative dictionaries generated from other dictionaries +- Restricted initial training rounds to exclude default and nonnative dictionaries +- Changed clustering of phones to not mix silence and non-silence phones +- Optimized textgrid export +- Added better memory management for collecting alignments + +3.0.8 +----- + +- Fixed a compatibility issue with models trained under version 1.0 and earlier + 3.0.7 ----- diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index f814ebb0..96fcc4df 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -316,7 +316,7 @@ def db_engine(self) -> sqlalchemy.engine.Engine: self._db_engine = self.construct_engine() return self._db_engine - def get_next_primary_key(self, database_table: MfaSqlBase): + def get_next_primary_key(self, database_table): with self.session() as session: pk = session.query(sqlalchemy.func.max(database_table.id)).scalar() if not pk: @@ -634,7 +634,8 @@ def parse_args( unknown_dict[name] = val for name, param_type in param_types.items(): if (name.endswith("_directory") and name != "audio_directory") or ( - name.endswith("_path") and name not in {"rules_path", "phone_groups_path"} + name.endswith("_path") + and name not in {"rules_path", "phone_groups_path", "topology_path"} ): continue if args is not None and name in args and args[name] is not None: diff --git a/montreal_forced_aligner/acoustic_modeling/lda.py b/montreal_forced_aligner/acoustic_modeling/lda.py index 279b1c60..fa9c8c09 100644 --- a/montreal_forced_aligner/acoustic_modeling/lda.py +++ b/montreal_forced_aligner/acoustic_modeling/lda.py @@ -97,6 +97,8 @@ def _run(self): ] for dict_id in job.dictionary_ids: ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue lda_logger.debug(f"Processing {ali_path}") feat_path = job.construct_path( job.corpus.current_subset_directory, "feats", "scp", dictionary_id=dict_id @@ -164,6 +166,8 @@ def _run(self) -> typing.Generator[int]: ] for dict_id in job.dictionary_ids: ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue lda_logger.debug(f"Processing {ali_path}") feature_archive = job.construct_feature_archive(self.working_directory, dict_id) alignment_archive = AlignmentArchive(ali_path) diff --git a/montreal_forced_aligner/acoustic_modeling/monophone.py b/montreal_forced_aligner/acoustic_modeling/monophone.py index 84fff727..06d68b72 100644 --- a/montreal_forced_aligner/acoustic_modeling/monophone.py +++ b/montreal_forced_aligner/acoustic_modeling/monophone.py @@ -75,7 +75,7 @@ def _run(self): num_error = 0 tot_like = 0.0 tot_t = 0.0 - for d in job.dictionaries: + for d in job.training_dictionaries: dict_id = d.id train_logger.debug(f"Aligning for dictionary {d.name} ({d.id})") train_logger.debug(f"Aligning with model: {self.model_path}") @@ -302,14 +302,22 @@ def _trainer_initialization(self) -> None: tree_path = self.working_directory.joinpath("tree") init_log_path = self.working_log_directory.joinpath("init.log") job = self.jobs[0] - dict_id = job.dictionary_ids[0] - feature_archive = job.construct_feature_archive(self.working_directory, dict_id) feats = [] with kalpy_logger("kalpy.train", init_log_path) as train_logger: - for i, (_, mat) in enumerate(feature_archive): - if i > 10: + dict_index = 0 + while len(feats) < 10: + try: + dict_id = job.dictionary_ids[dict_index] + except IndexError: break - feats.append(mat) + feature_archive = job.construct_feature_archive(self.working_directory, dict_id) + for i, (_, mat) in enumerate(feature_archive): + if i > 10: + break + feats.append(mat) + dict_index += 1 + if not feats: + raise Exception("Could not initialize monophone model due to lack of features") shared_phones = self.worker.shared_phones_set_symbols() topo = read_topology(self.worker.topo_path) gmm_init_mono(topo, feats, shared_phones, str(self.model_path), str(tree_path)) diff --git a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py index 3650c4f0..0903b3cd 100644 --- a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py +++ b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py @@ -312,8 +312,12 @@ def setup(self): previous_directory = self.previous_aligner.working_directory for j in self.jobs: for p in j.construct_path_dictionary(previous_directory, "ali", "ark").values(): + if not p.exists(): + continue shutil.copy(p, wf.working_directory.joinpath(p.name)) for p in j.construct_path_dictionary(previous_directory, "words", "ark").values(): + if not p.exists(): + continue shutil.copy(p, wf.working_directory.joinpath(p.name)) for f in ["final.mdl", "final.alimdl", "lda.mat", "tree"]: p = previous_directory.joinpath(f) @@ -384,6 +388,12 @@ def train_pronunciation_probabilities(self) -> None: ) with mfa_open(silence_info_path, "r") as f: data = json.load(f) + for k, v in data.items(): + if v is None: + if "correction" in k: + data[k] = 1.0 + else: + data[k] = 0.5 if self.silence_probabilities: d.silence_probability = data["silence_probability"] d.initial_silence_probability = data["initial_silence_probability"] diff --git a/montreal_forced_aligner/acoustic_modeling/sat.py b/montreal_forced_aligner/acoustic_modeling/sat.py index 73cb97eb..3da3d2b0 100644 --- a/montreal_forced_aligner/acoustic_modeling/sat.py +++ b/montreal_forced_aligner/acoustic_modeling/sat.py @@ -81,13 +81,15 @@ def _run(self): .filter(Job.id == self.job_name) .first() ) - for d in job.dictionaries: + for d in job.training_dictionaries: train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})") train_logger.debug(f"Accumulating stats for model: {self.model_path}") dict_id = d.id accumulator = TwoFeatsStatsAccumulator(self.model_path) ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue fmllr_path = job.construct_path( job.corpus.current_subset_directory, "trans", "scp", dict_id ) diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index 6808738b..c663e0d3 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -100,6 +100,8 @@ def _run(self) -> typing.Generator[typing.Tuple[int, str]]: transition_model, acoustic_model = read_gmm_model(self.model_path) for dict_id in job.dictionary_ids: ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue transition_accs = DoubleVector(transition_model.NumTransitionIds() + 1) alignment_archive = AlignmentArchive(ali_path) for alignment in alignment_archive: @@ -523,6 +525,8 @@ def quality_check_subset(self): self.working_directory, "temp_ali", "ark" ) for dict_id, ali_path in ali_paths.items(): + if not ali_path.exists(): + continue new_path = temp_ali_paths[dict_id] write_specifier = generate_write_specifier(new_path) writer = Int32VectorWriter(write_specifier) @@ -577,15 +581,20 @@ def train(self) -> None: self.current_acoustic_model = AcousticModel( previous.exported_model_path, self.working_directory ) - self.align() - with self.session() as session: - session.query(WordInterval).delete() - session.query(PhoneInterval).delete() - session.commit() - self.collect_alignments() - self.analyze_alignments() - if self.current_subset != 0: - self.quality_check_subset() + if ( + not self.current_workflow.done + or not self.current_workflow.working_directory.exists() + ): + logger.debug(f"Skipping {self.current_aligner.identifier} alignments") + self.align() + with self.session() as session: + session.query(WordInterval).delete() + session.query(PhoneInterval).delete() + session.commit() + self.collect_alignments() + self.analyze_alignments() + if self.current_subset != 0: + self.quality_check_subset() self.set_current_workflow(trainer.identifier) if trainer.identifier.startswith("pronunciation_probabilities"): @@ -721,7 +730,6 @@ def align_options(self) -> MetaDict: options = self.current_aligner.align_options else: options = super().align_options - options["boost_silence"] = max(1.25, options["boost_silence"]) return options def align(self) -> None: diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py index 8c146554..fce20bfb 100644 --- a/montreal_forced_aligner/acoustic_modeling/triphone.py +++ b/montreal_forced_aligner/acoustic_modeling/triphone.py @@ -94,10 +94,12 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: train_logger.debug(f"Previous model path: {self.align_model_path}") train_logger.debug(f"Model path: {self.model_path}") train_logger.debug(f"Tree path: {self.tree_path}") - for d in job.dictionaries: + for d in job.training_dictionaries: dict_id = d.id train_logger.debug(f"Converting alignments for {d.name}") ali_path = self.ali_paths[dict_id] + if not ali_path.exists(): + continue new_ali_path = self.new_ali_paths[dict_id] train_logger.debug(f"Old alignments: {ali_path}") train_logger.debug(f"New alignments: {new_ali_path}") @@ -159,12 +161,14 @@ def _run(self): .filter(Phone.phone_type.in_([PhoneType.silence, PhoneType.oov])) .order_by(Phone.mapping_id) ] - for d in job.dictionaries: + for d in job.training_dictionaries: train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})") train_logger.debug(f"Accumulating stats for model: {self.model_path}") dict_id = d.id feature_archive = job.construct_feature_archive(self.working_directory, dict_id) ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue train_logger.debug("Feature Archive information:") train_logger.debug(f"File: {feature_archive.file_name}") train_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}") @@ -397,8 +401,29 @@ def _setup_tree(self, init_from_previous=False, initial_mix_up=True) -> None: train_logger.debug(f"Phone sets: {phone_sets}") questions = automatically_obtain_questions(tree_stats, phone_sets, [1], 1) train_logger.debug(f"Automatically obtained {len(questions)} questions") - for v in self.worker.extra_questions_mapping.values(): - questions.append(sorted([self.phone_mapping[x] for x in v])) + train_logger.debug("Automatic questions:") + for q_set in questions: + train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set])) + + # Remove questions containing silence and other phones + train_logger.debug("Filtering the following sets for containing silence phone:") + silence_phone_id = self.phone_mapping[self.optional_silence_phone] + silence_sets = [ + x for x in questions if silence_phone_id in x and x != [silence_phone_id] + ] + for q_set in silence_sets: + train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set])) + questions = [ + x for x in questions if silence_phone_id not in x or x == [silence_phone_id] + ] + + extra_questions = self.worker.extra_questions_mapping + if extra_questions: + train_logger.debug(f"Adding {len(extra_questions)} questions") + train_logger.debug("Extra questions:") + for v in self.worker.extra_questions_mapping.values(): + questions.append(sorted([self.phone_mapping[x] for x in v])) + train_logger.debug(", ".join(v)) train_logger.debug(f"{len(questions)} total questions") build_tree( diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index a17baf5f..f060bc3f 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -6,6 +6,7 @@ import functools import io import logging +import math import multiprocessing as mp import os import re @@ -77,9 +78,7 @@ mfa_open, ) from montreal_forced_aligner.textgrid import ( - construct_output_path, - construct_output_tiers, - export_textgrid, + construct_textgrid_output, output_textgrid_writing_errors, ) from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function @@ -424,18 +423,26 @@ def compute_pronunciation_probabilities(self): """ begin = time.time() - dictionary_counters = { - dict_id: PronunciationProbabilityCounter() - for dict_id in self.dictionary_lookup.values() - } + with self.session() as session: + dictionary_counters = { + dict_id: PronunciationProbabilityCounter() + for dict_id, in session.query(Dictionary.id).filter(Dictionary.name != "default") + } logger.info("Generating pronunciations...") arguments = self.generate_pronunciations_arguments() for result in run_kaldi_function( GeneratePronunciationsFunction, arguments, total_count=self.num_current_utterances ): - dict_id, utterance_counter = result - dictionary_counters[dict_id].add_counts(utterance_counter) - + try: + dict_id, utterance_counter = result + dictionary_counters[dict_id].add_counts(utterance_counter) + except Exception: + import sys + import traceback + + exc_type, exc_value, exc_traceback = sys.exc_info() + print("\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback))) + raise initial_key = ("", "") final_key = ("", "") lambda_2 = 2 @@ -450,16 +457,27 @@ def compute_pronunciation_probabilities(self): ) as log_file, self.session() as session: session.query(Pronunciation).update({"count": 0}) session.commit() - dictionaries = session.query(Dictionary.id) + dictionaries = session.query(Dictionary) dictionary_mappings = [] - for (d_id,) in dictionaries: + for d in dictionaries: + d_id = d.id + if d_id not in dictionary_counters: + continue counter = dictionary_counters[d_id] - log_file.write(f"For {d_id}:\n") + log_file.write(f"For {d.name}:\n") words = ( session.query(Word.word) .filter(Word.dictionary_id == d_id) - .filter(Word.word_type != WordType.silence) - .filter(Word.count > 0) + .filter( + sqlalchemy.or_( + sqlalchemy.and_( + Word.word_type.in_(WordType.speech_types()), Word.count > 0 + ), + Word.word.in_( + [d.cutoff_word, d.oov_word, d.laughter_word, d.bracketed_word] + ), + ) + ) ) pronunciations = ( session.query( @@ -470,8 +488,16 @@ def compute_pronunciation_probabilities(self): ) .join(Pronunciation.word) .filter(Word.dictionary_id == d_id) - .filter(Word.word_type != WordType.silence) - .filter(Word.count > 0) + .filter( + sqlalchemy.or_( + sqlalchemy.and_( + Word.word_type.in_(WordType.speech_types()), Word.count > 0 + ), + Word.word.in_( + [d.cutoff_word, d.oov_word, d.laughter_word, d.bracketed_word] + ), + ) + ) ) pron_mapping = {} pronunciations = [ @@ -479,6 +505,7 @@ def compute_pronunciation_probabilities(self): for w, p, p_id, generated in pronunciations if not generated or p in counter.word_pronunciation_counts[w] ] + pronunciations.append((d.cutoff_word, "cutoff_model", None)) for w, p, p_id in pronunciations: pron_mapping[(w, p)] = {"id": p_id} if w in {initial_key[0], final_key[0], self.silence_word}: @@ -546,8 +573,28 @@ def compute_pronunciation_probabilities(self): (non_silence_count + lambda_3) / (bar_count_non_silence_wp[(w, p)] + lambda_3) ) - session.bulk_update_mappings(Pronunciation, pron_mapping.values()) + cutoff_model = pron_mapping.pop((d.cutoff_word, "cutoff_model")) + bulk_update(session, Pronunciation, list(pron_mapping.values())) session.flush() + cutoff_not_model = pron_mapping[(d.cutoff_word, "spn")] + cutoff_query = ( + session.query(Pronunciation.id, Pronunciation.pronunciation) + .join(Pronunciation.word) + .filter(Word.word != d.cutoff_word) + .filter(Word.word.like(f"{d.cutoff_word[:-1]}%")) + ) + cutoff_mappings = [] + for pron_id, pron in cutoff_query: + if pron == d.cutoff_word: + data = dict(cutoff_not_model) + else: + data = dict(cutoff_model) + data["id"] = pron_id + cutoff_mappings.append(data) + if cutoff_mappings: + bulk_update(session, Pronunciation, cutoff_mappings) + session.flush() + initial_silence_count = counter.silence_before_counts[initial_key] + ( silence_probability * lambda_2 ) @@ -915,6 +962,11 @@ def collect_alignments(self) -> None: cursor.copy_from(phone_buf, PhoneInterval.__tablename__, sep=",", null="") phone_buf.truncate(0) phone_buf.seek(0) + conn.commit() + cursor.close() + conn.close() + conn = self.db_engine.raw_connection() + cursor = conn.cursor() if config.USE_POSTGRES: if word_buf.tell() != 0: @@ -1129,21 +1181,31 @@ def export_textgrids( self.collect_alignments() begin = time.time() error_dict = {} - - with self.session() as session: - files = ( - session.query( - File.id, - File.name, - File.relative_path, - SoundFile.duration, - TextFile.text_file_path, - ) - .join(File.sound_file) - .join(File.text_file) - ).all() with tqdm(total=self.num_files, disable=config.QUIET) as pbar: if config.USE_MP and config.NUM_JOBS > 1: + with self.session() as session: + files_per_job = math.ceil(self.num_files / len(self.jobs)) + file_batches = [{}] + query = ( + session.query( + File.id, + File.name, + File.relative_path, + SoundFile.duration, + TextFile.text_file_path, + ) + .join(File.sound_file) + .join(File.text_file) + ) + for file_id, file_name, relative_path, file_duration, text_file_path in query: + if len(file_batches[-1]) >= files_per_job: + file_batches.append({}) + file_batches[-1][file_id] = ( + file_name, + relative_path, + file_duration, + text_file_path, + ) stopped = mp.Event() finished_adding = mp.Event() @@ -1167,8 +1229,8 @@ def export_textgrids( export_proc.start() export_procs.append(export_proc) try: - for args in files: - for_write_queue.put(args) + for batch in file_batches: + for_write_queue.put(batch) time.sleep(1) finished_adding.set() while True: @@ -1197,30 +1259,37 @@ def export_textgrids( else: logger.debug("Not using multiprocessing for TextGrid export") - for file_id, name, relative_path, duration, text_file_path in files: - output_path = construct_output_path( - name, - relative_path, - self.export_output_directory, - text_file_path, - output_format, - ) - - data = construct_output_tiers( - session, - file_id, - workflow, - config.CLEANUP_TEXTGRIDS, - self.clitic_marker, - include_original_text, - ) - export_textgrid( - data, - output_path, - duration, - self.export_frame_shift, - output_format=output_format, + with self.session() as session: + file_batch = {} + query = ( + session.query( + File.id, + File.name, + File.relative_path, + SoundFile.duration, + TextFile.text_file_path, + ) + .join(File.sound_file) + .join(File.text_file) ) + for file_id, file_name, relative_path, file_duration, text_file_path in query: + file_batch[file_id] = ( + file_name, + relative_path, + file_duration, + text_file_path, + ) + for _ in construct_textgrid_output( + session, + file_batch, + workflow, + config.CLEANUP_TEXTGRIDS, + self.clitic_marker, + self.export_output_directory, + self.export_frame_shift, + output_format, + include_original_text, + ): pbar.update(1) if error_dict: @@ -1321,7 +1390,7 @@ def evaluate_alignments( Directory to save results, if not specified, it will be saved in the log directory comparison_source: :class:`~montreal_forced_aligner.data.WorkflowType` Workflow to compare to the reference intervals, defaults to :attr:`~montreal_forced_aligner.data.WorkflowType.alignment` - comparison_source: :class:`~montreal_forced_aligner.data.WorkflowType` + reference_source: :class:`~montreal_forced_aligner.data.WorkflowType` Workflow to use as the reference intervals, defaults to :attr:`~montreal_forced_aligner.data.WorkflowType.reference` """ diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index 22f4a6be..2b963385 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -63,11 +63,7 @@ ) from montreal_forced_aligner.exceptions import AlignmentCollectionError, AlignmentExportError from montreal_forced_aligner.helper import mfa_open, split_phone_position -from montreal_forced_aligner.textgrid import ( - construct_output_path, - construct_output_tiers, - export_textgrid, -) +from montreal_forced_aligner.textgrid import construct_textgrid_output from montreal_forced_aligner.utils import thread_logger if TYPE_CHECKING: @@ -418,7 +414,7 @@ def _run(self): text_column = Utterance.normalized_character_text else: text_column = Utterance.normalized_text - for d in job.dictionaries: + for d in job.training_dictionaries: begin = time.time() if self.lexicon_compilers and d.id in self.lexicon_compilers: lexicon = self.lexicon_compilers[d.id] @@ -477,7 +473,7 @@ def __init__(self, args: AccStatsArguments): self.working_directory = args.working_directory self.model_path = args.model_path - def _run(self) -> typing.Generator[typing.Tuple[int, int]]: + def _run(self) -> None: """Run the function""" with self.session() as session, thread_logger( "kalpy.train", self.log_path, job_name=self.job_name @@ -488,7 +484,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: .filter(Job.id == self.job_name) .first() ) - for d in job.dictionaries: + for d in job.training_dictionaries: train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})") train_logger.debug(f"Accumulating stats for model: {self.model_path}") dict_id = d.id @@ -496,6 +492,8 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: feature_archive = job.construct_feature_archive(self.working_directory, dict_id) ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue alignment_archive = AlignmentArchive(ali_path) train_logger.debug("Feature Archive information:") train_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}") @@ -565,7 +563,7 @@ def _run(self) -> None: **align_options, ) aligner.boost_silence(boost_silence, silence_phones) - for d in job.dictionaries: + for d in job.training_dictionaries: align_logger.debug(f"Aligning for dictionary {d.name} ({d.id})") align_logger.debug(f"Aligning with model: {aligner.acoustic_model_path}") dict_id = d.id @@ -1114,21 +1112,10 @@ def _run(self) -> None: silence_words = session.query(Word.word).filter(Word.word_type == WordType.silence) self.silence_words.update(x for x, in silence_words) - for d in job.dictionaries: + for d in job.training_dictionaries: ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id) if not os.path.exists(ali_path): continue - - utts = ( - session.query(Utterance.id, Utterance.normalized_text) - .join(Utterance.speaker) - .filter(Utterance.job_id == self.job_name) - .filter(Speaker.dictionary_id == d.id) - ) - utterance_texts = {} - for u_id, text in utts: - utterance_texts[u_id] = text - if self.lexicon_compilers and d.id in self.lexicon_compilers: lexicon_compiler = self.lexicon_compilers[d.id] else: @@ -1142,11 +1129,15 @@ def _run(self) -> None: ) utterance = int(alignment.utterance_id.split("-")[-1]) ctm = lexicon_compiler.phones_to_pronunciations(alignment.words, intervals) - word_pronunciations = [(x.label, x.pronunciation) for x in ctm.word_intervals] - # word_pronunciations = [ - # x if x[1] != OOV_PHONE else (OOV_WORD, OOV_PHONE) - # for x in word_pronunciations - # ] + word_pronunciations = [] + for wi in ctm.word_intervals: + label = wi.label + pronunciation = wi.pronunciation + if label.startswith(d.cutoff_word[:-1]): + label = d.cutoff_word + if pronunciation != d.oov_phone: + pronunciation = "cutoff_model" + word_pronunciations.append((label, pronunciation)) if self.for_g2p: phones = [] for i, x in enumerate(word_pronunciations): @@ -1256,6 +1247,8 @@ def _run(self) -> None: lexicon_compiler = d.lexicon_compiler if self.transcription: lat_path = job.construct_path(workflow.working_directory, "lat", "ark", d.id) + if not lat_path.exists(): + continue transcription_archive = TranscriptionArchive( lat_path, acoustic_scale=self.score_options["acoustic_scale"] @@ -1304,6 +1297,8 @@ def _run(self) -> None: self.callback((utterance, d.id, ctm)) else: ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id) + if not ali_path.exists(): + continue words_path = job.construct_path( workflow.working_directory, "words", "ark", d.id ) @@ -1505,13 +1500,7 @@ def run(self) -> None: ) while True: try: - ( - file_id, - name, - relative_path, - duration, - text_file_path, - ) = self.for_write_queue.get(timeout=1) + (file_batch) = self.for_write_queue.get(timeout=1) except Empty: if self.finished_adding.is_set(): self.finished_processing.set() @@ -1521,25 +1510,18 @@ def run(self) -> None: if self.stopped.is_set(): continue try: - output_path = construct_output_path( - name, - relative_path, - self.output_directory, - text_file_path, - self.output_format, - ) - data = construct_output_tiers( + for output_path in construct_textgrid_output( session, - file_id, + file_batch, workflow, self.cleanup_textgrids, self.clitic_marker, + self.output_directory, + self.export_frame_shift, + self.output_format, self.include_original_text, - ) - export_textgrid( - data, output_path, duration, self.export_frame_shift, self.output_format - ) - self.return_queue.put(1) + ): + self.return_queue.put(1) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() self.return_queue.put( diff --git a/montreal_forced_aligner/command_line/train_acoustic_model.py b/montreal_forced_aligner/command_line/train_acoustic_model.py index e040a2c8..f16930ca 100644 --- a/montreal_forced_aligner/command_line/train_acoustic_model.py +++ b/montreal_forced_aligner/command_line/train_acoustic_model.py @@ -81,6 +81,13 @@ "https://github.com/MontrealCorpusTools/mfa-models/tree/main/config/acoustic/rules for examples.", type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), ) +@click.option( + "--topology_path", + "topology_path", + help="Path to yaml file defining topologies. See " + "https://github.com/MontrealCorpusTools/mfa-models/tree/main/config/acoustic/topologies for examples.", + type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), +) @click.option( "--output_format", help="Format for aligned output files (default is long_textgrid).", diff --git a/montreal_forced_aligner/command_line/utils.py b/montreal_forced_aligner/command_line/utils.py index b02024f9..01d684ad 100644 --- a/montreal_forced_aligner/command_line/utils.py +++ b/montreal_forced_aligner/command_line/utils.py @@ -21,7 +21,6 @@ FileArgumentNotFoundError, ModelExtensionError, ModelTypeNotSupportedError, - NoDefaultSpeakerDictionaryError, PretrainedModelNotFoundError, ) from montreal_forced_aligner.helper import mfa_open @@ -145,7 +144,7 @@ def common_options(f: typing.Callable) -> typing.Callable: return functools.reduce(lambda x, opt: opt(x), options, f) -def validate_model_arg(name: str, model_type: str) -> Path: +def validate_model_arg(name: str, model_type: str) -> typing.Union[Path, str]: """ Validate pretrained model name argument @@ -181,6 +180,8 @@ def validate_model_arg(name: str, model_type: str) -> Path: model_class = MODEL_TYPES[model_type] if name in available_models: name = model_class.get_pretrained_path(name) + elif model_type == "dictionary" and str(name).lower() in {"default", "nonnative"}: + return name else: if isinstance(name, str): name = Path(name) @@ -193,8 +194,6 @@ def validate_model_arg(name: str, model_type: str) -> Path: paths = sorted(set(data.values())) for path in paths: validate_model_arg(path, "dictionary") - if "default" not in data: - raise click.BadParameter(str(NoDefaultSpeakerDictionaryError())) else: if name.exists(): if name.suffix: @@ -410,7 +409,7 @@ def check_server() -> None: logger.debug(f"pg_ctl stdout: {stdout}") logger.debug(f"pg_ctl stderr: {stderr}") if "no server running" in stdout: - raise DatabaseError() + raise DatabaseError(f"stdout: {stdout}\nstderr: {stderr}") def start_server() -> None: diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py index 43eb3e8c..b65fdc4e 100644 --- a/montreal_forced_aligner/command_line/validate.py +++ b/montreal_forced_aligner/command_line/validate.py @@ -139,6 +139,7 @@ def validate_corpus_cli(context, **kwargs) -> None: validator.dirty = True raise finally: + config.FINAL_CLEAN = True validator.cleanup() @@ -196,4 +197,5 @@ def validate_dictionary_cli(*args, **kwargs) -> None: validator.dirty = True raise finally: + config.FINAL_CLEAN = True validator.cleanup() diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index 30a9e573..0044d449 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -1111,6 +1111,7 @@ def load_corpus(self) -> None: self.write_lexicon_information() logger.debug(f"Wrote lexicon information in {time.time() - begin:.3f} seconds") else: + self.load_phone_topologies() self.load_phone_groups() self.load_lexicon_compilers() diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index 48198d76..33ae6aa9 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -4,6 +4,7 @@ import collections import logging import os +import re import threading import time import typing @@ -1182,38 +1183,68 @@ def create_subset(self, subset: int) -> None: Number of utterances to include in subset """ logger.info(f"Creating subset directory with {subset} utterances...") - multiword_pattern = r"\s\S+\s" + if hasattr(self, "cutoff_word") and hasattr(self, "brackets"): + initial_brackets = re.escape("".join(x[0] for x in self.brackets)) + final_brackets = re.escape("".join(x[1] for x in self.brackets)) + cutoff_identifier = re.sub( + rf"[{initial_brackets}{final_brackets}]", "", self.cutoff_word + ) + cutoff_pattern = f"[{initial_brackets}]({cutoff_identifier}|hes)" + else: + cutoff_pattern = "<(cutoff|hes)" + + def add_filters(query): + multiword_pattern = r"\s\S+\s" + filtered = ( + query.filter( + Utterance.normalized_text.op("~")(multiword_pattern) + if config.USE_POSTGRES + else Utterance.normalized_text.regexp_match(multiword_pattern) + ) + .filter(Utterance.ignored == False) # noqa + .filter( + sqlalchemy.or_( + Utterance.duration_deviation == None, # noqa + Utterance.duration_deviation < 10, + ) + ) + ) + if subset <= 25000: + filtered = filtered.filter( + sqlalchemy.not_( + Utterance.normalized_text.op("~")(cutoff_pattern) + if config.USE_POSTGRES + else Utterance.normalized_text.regexp_match(cutoff_pattern) + ) + ) + + return filtered + with self.session() as session: begin = time.time() session.query(Utterance).filter(Utterance.in_subset == True).update( # noqa {Utterance.in_subset: False} ) session.commit() - dictionary_lookup = {k: v for k, v in session.query(Dictionary.name, Dictionary.id)} + dictionary_query = session.query(Dictionary.name, Dictionary.id).filter( + Dictionary.name != "default" + ) + if subset <= 25000: + dictionary_query = dictionary_query.filter(Dictionary.name != "nonnative") + dictionary_lookup = {k: v for k, v in dictionary_query} num_dictionaries = len(dictionary_lookup) if num_dictionaries > 1: subsets_per_dictionary = {} utts_per_dictionary = {} subsetted = 0 for dict_name, dict_id in dictionary_lookup.items(): - num_utts = ( + base_query = ( session.query(Utterance) .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, - ) - ) # noqa - .count() + .filter(Speaker.dictionary_id == dict_id) # noqa ) + base_query = add_filters(base_query) + num_utts = base_query.count() utts_per_dictionary[dict_name] = num_utts if num_utts < int(subset / num_dictionaries): subsets_per_dictionary[dict_name] = num_utts @@ -1240,42 +1271,22 @@ def create_subset(self, subset: int) -> None: larger_subset_num = int(subset_per_dictionary * 10) speaker_ids = None average_duration = ( - session.query(sqlalchemy.func.avg(Utterance.duration)) - .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) + add_filters( + session.query(sqlalchemy.func.avg(Utterance.duration)) + .join(Utterance.speaker) + .filter(Speaker.dictionary_id == dict_id) ) - .filter(Utterance.ignored == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, - ) - ) # noqa ).first()[0] for utt_count_cutoff in [30, 15, 5]: sq = ( - session.query( - Speaker.id.label("speaker_id"), - sqlalchemy.func.count(Utterance.id).label("utt_count"), - ) - .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, + add_filters( + session.query( + Speaker.id.label("speaker_id"), + sqlalchemy.func.count(Utterance.id).label("utt_count"), ) - ) # noqa + .join(Utterance.speaker) + .filter(Speaker.dictionary_id == dict_id) + ) .filter(Utterance.duration <= average_duration) .group_by(Speaker.id.label("speaker_id")) .subquery() @@ -1286,32 +1297,21 @@ def create_subset(self, subset: int) -> None: ) ).first()[0] if total_speaker_utterances >= subset_per_dictionary: - speaker_ids = ( + speaker_ids = [ x for x, in session.query(sq.c.speaker_id).filter( sq.c.utt_count >= utt_count_cutoff ) - ) + ] break if num_utts > larger_subset_num: larger_subset_query = ( session.query(Utterance.id) .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, - ) - ) # noqa + .filter(Speaker.dictionary_id == dict_id) # noqa ) - if speaker_ids is not None: + larger_subset_query = add_filters(larger_subset_query) + if speaker_ids: larger_subset_query = larger_subset_query.filter( Speaker.id.in_(speaker_ids) ) @@ -1333,7 +1333,8 @@ def create_subset(self, subset: int) -> None: ) session.execute(query) - # Remove speakers with less than 5 utterances from subset, can't estimate speaker transforms well for low utterance counts + # Remove speakers with less than 5 utterances from subset, + # can't estimate speaker transforms well for low utterance counts sq = ( session.query( Utterance.speaker_id.label("speaker_id"), @@ -1343,9 +1344,9 @@ def create_subset(self, subset: int) -> None: .group_by(Utterance.speaker_id.label("speaker_id")) .subquery() ) - speaker_ids = ( + speaker_ids = [ x for x, in session.query(sq.c.speaker_id).filter(sq.c.utt_count < 5) - ) + ] session.query(Utterance).filter( Utterance.speaker_id.in_(speaker_ids) ).update({Utterance.in_subset: False}) @@ -1355,21 +1356,10 @@ def create_subset(self, subset: int) -> None: larger_subset_query = ( session.query(Utterance.id) .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, - ) - ) # noqa + .filter(Speaker.dictionary_id == dict_id) # noqa ) - if speaker_ids is not None: + larger_subset_query = add_filters(larger_subset_query) + if speaker_ids: larger_subset_query = larger_subset_query.filter( Speaker.id.in_(speaker_ids) ) @@ -1431,27 +1421,15 @@ def create_subset(self, subset: int) -> None: ).first()[0] remaining = subset_per_dictionary - total_speaker_utterances if remaining > 0: - speaker_ids = (x for x, in session.query(sq.c.speaker_id)) + speaker_ids = [x for x, in session.query(sq.c.speaker_id)] larger_subset_query = ( session.query(Utterance.id) .join(Utterance.speaker) - .filter(Speaker.dictionary_id == dict_id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa - .filter(Utterance.in_subset == False) # noqa - .filter( - sqlalchemy.or_( - Utterance.duration_deviation == None, # noqa - Utterance.duration_deviation < 10, - ) - ) # noqa + .filter(Speaker.dictionary_id == dict_id) # noqa ) - if speaker_ids is not None: + larger_subset_query = add_filters(larger_subset_query) + if speaker_ids: larger_subset_query = larger_subset_query.filter( Speaker.id.in_(speaker_ids) ) @@ -1478,13 +1456,7 @@ def create_subset(self, subset: int) -> None: if subset < self.num_utterances: # Get all shorter utterances that are not one word long larger_subset_query = ( - session.query(Utterance.id) - .filter( - Utterance.normalized_text.op("~")(multiword_pattern) - if config.USE_POSTGRES - else Utterance.normalized_text.regexp_match(multiword_pattern) - ) - .filter(Utterance.ignored == False) # noqa + add_filters(session.query(Utterance.id)) .order_by(Utterance.duration) .limit(larger_subset_num) ) diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py index 8bcd6c5c..335d809a 100644 --- a/montreal_forced_aligner/corpus/features.py +++ b/montreal_forced_aligner/corpus/features.py @@ -511,6 +511,8 @@ def _run(self) -> None: **self.fmllr_options, ) ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id) + if not ali_path.exists(): + continue fmllr_logger.debug(f"Alignment path: {ali_path}") alignment_archive = AlignmentArchive(ali_path) temp_trans_path = job.construct_path( @@ -594,6 +596,7 @@ def __init__( self, feature_type: str = "mfcc", use_energy: bool = True, + raw_energy: bool = False, frame_shift: int = 10, frame_length: int = 25, snip_edges: bool = False, @@ -644,6 +647,7 @@ def __init__( # MFCC options self.use_energy = use_energy + self.raw_energy = raw_energy self.low_frequency = low_frequency self.high_frequency = high_frequency self.sample_frequency = sample_frequency @@ -783,6 +787,7 @@ def mfcc_options(self) -> MetaDict: else: options = { "use_energy": self.use_energy, + "raw_energy": self.raw_energy, "dither": self.dither, "energy_floor": self.energy_floor, "num_coefficients": self.num_coefficients, diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index 9e1c2675..374d03bf 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -112,7 +112,7 @@ def full_load_utterance(session: sqlalchemy.orm.Session, utterance_id: int): def bulk_update( session: sqlalchemy.orm.Session, table: MfaSqlBase, - values: typing.List[typing.Dict[str, typing.Any]], + values: typing.Collection[typing.Dict[str, typing.Any]], id_field=None, ) -> None: """ @@ -353,7 +353,7 @@ class Dictionary(MfaSqlBase): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(50), nullable=False) - path = Column(PathType, unique=True) + path = Column(PathType) rules_applied = Column(Boolean, default=False) phone_set_type = Column(Enum(PhoneSetType), nullable=True) root_temp_directory = Column(PathType, nullable=True) @@ -1114,7 +1114,7 @@ def save( save_transcription: bool Flag for whether the hypothesized transcription text should be saved instead of the default text """ - from montreal_forced_aligner.alignment.multiprocessing import construct_output_path + from montreal_forced_aligner.textgrid import construct_output_path utterance_count = len(self.utterances) if output_format is None: # Saving directly @@ -2024,6 +2024,14 @@ def __str__(self): def has_dictionaries(self) -> bool: return len(self.dictionaries) > 0 + @property + def training_dictionaries(self) -> typing.List[int]: + if self.corpus.current_subset == 0: + return self.dictionaries + if self.corpus.current_subset <= 25000: + return [x for x in self.dictionaries if x.name not in {"default", "nonnative"}] + return [x for x in self.dictionaries if x.name not in {"default"}] + @property def dictionary_ids(self) -> typing.List[int]: return [x.id for x in self.dictionaries] diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py index 65f61ded..3d13bc1a 100644 --- a/montreal_forced_aligner/dictionary/mixins.py +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -6,7 +6,7 @@ import os import re import typing -from collections import Counter +from collections import Counter, defaultdict from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -27,7 +27,7 @@ DEFAULT_QUOTE_MARKERS = list("“„\"”〝〟″「」『』‚ʻʿ‘′'") DEFAULT_CLITIC_MARKERS = list("'’‘") -DEFAULT_COMPOUND_MARKERS = list("-/") +DEFAULT_COMPOUND_MARKERS = list("-‑/") DEFAULT_BRACKETS = [("<", ">"), ("[", "]"), ("{", "}"), ("(", ")"), ("<", ">")] __all__ = ["DictionaryMixin", "TemporaryDictionaryMixin"] @@ -190,6 +190,7 @@ def __init__( self.bracket_sanitize_regex = None self.use_cutoff_model = use_cutoff_model self._phone_groups = {} + self._topologies = {} @property def tokenizer(self): @@ -491,6 +492,12 @@ def kaldi_non_silence_phones(self) -> List[str]: @property def phone_groups(self) -> typing.Dict[str, typing.List[str]]: + if ( + not self._phone_groups + and getattr(self, "phone_group_path", None) + and hasattr(self, "load_phone_groups") + ): + self.load_phone_groups() if not self._phone_groups: for p in sorted(self.non_silence_phones): base_phone = self.get_base_phone(p) @@ -665,8 +672,14 @@ def _write_topo(self) -> None: """ Write the topo file to the temporary directory """ - + if ( + not self._topologies + and getattr(self, "topology_path", None) + and hasattr(self, "load_phone_topologies") + ): + self.load_phone_topologies() sil_transp = 1 / (self.num_silence_states - 1) + topo_groups = defaultdict(set) silence_lines = [ "", @@ -694,9 +707,16 @@ def _write_topo(self) -> None: silence_topo_string = "\n".join(silence_lines) topo_sections = [silence_topo_string] - topo_phones = self._get_grouped_phones() - for phone_list in topo_phones.values(): + for k, v in self._topologies.items(): + min_states = v.get("min_states", 1) + max_states = v.get("max_states", self.num_non_silence_states) + topo_groups[(min_states, max_states)].add(k) + for phone in self.non_silence_phones: + if phone not in self._topologies: + topo_groups[(1, self.num_non_silence_states)].add(phone) + + for (min_states, max_states), phone_list in topo_groups.items(): if not phone_list: continue non_silence_lines = [ @@ -707,16 +727,18 @@ def _write_topo(self) -> None: ), "", ] - # num_states = state_mapping[phone_type] - num_states = self.num_non_silence_states + num_states = max_states for i in range(num_states): if i == 0: # Initial non_silence state - transition_probability = 1 / self.num_non_silence_states - transition_string = " ".join( - f" {x} {transition_probability}" - for x in range(1, self.num_non_silence_states + 1) - ) + if min_states == max_states: + transition_string = f" {i} 0.5 {i+1} 0.5" + else: + transition_probability = 1 / max_states + transition_string = " ".join( + f" {x} {transition_probability}" + for x in range(min_states, max_states + 1) + ) non_silence_lines.append( f" {i} {i} {transition_string} " ) diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index 7dd9addb..23b3b14b 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -38,7 +38,11 @@ bulk_update, ) from montreal_forced_aligner.dictionary.mixins import TemporaryDictionaryMixin -from montreal_forced_aligner.exceptions import DictionaryError, DictionaryFileError +from montreal_forced_aligner.exceptions import ( + DictionaryError, + DictionaryFileError, + PhoneGroupTopologyMismatchError, +) from montreal_forced_aligner.helper import comma_join, mfa_open, split_phone_position from montreal_forced_aligner.models import DictionaryModel from montreal_forced_aligner.utils import parse_dictionary_file @@ -86,6 +90,7 @@ def __init__( dictionary_path: typing.Union[str, Path] = None, rules_path: typing.Union[str, Path] = None, phone_groups_path: typing.Union[str, Path] = None, + topology_path: typing.Union[str, Path] = None, **kwargs, ): super().__init__(**kwargs) @@ -109,8 +114,11 @@ def __init__( rules_path = Path(rules_path) if isinstance(phone_groups_path, str): phone_groups_path = Path(phone_groups_path) + if isinstance(topology_path, str): + topology_path = Path(topology_path) self.rules_path = rules_path self.phone_groups_path = phone_groups_path + self.topology_path = topology_path self.lexicon_compilers: typing.Dict[int, typing.Union[LexiconCompiler, G2PCompiler]] = {} self._tokenizers = {} @@ -165,6 +173,25 @@ def load_phone_groups(self) -> None: self._phone_groups[k] = sorted( set(x for x in v if x in self.non_silence_phones) ) + errors = [] + for phone_group in self._phone_groups.values(): + topos = set() + for phone in phone_group: + min_states = 1 + max_states = self.num_non_silence_states + if phone in self._topologies: + min_states = self._topologies[phone].get("min_states", 1) + max_states = self._topologies[phone].get( + "max_states", self.num_non_silence_states + ) + topos.add((min_states, max_states)) + if len(topos) > 1: + errors.append((phone_group, topos)) + if errors: + raise PhoneGroupTopologyMismatchError( + errors, self.phone_groups_path, self.topology_path + ) + found_phones = set() for phones in self._phone_groups.values(): found_phones.update(phones) @@ -177,6 +204,33 @@ def load_phone_groups(self) -> None: else: logger.debug("All phones were included in phone groups") + def load_phone_topologies(self) -> None: + """ + Load phone topologies from the dictionary's groups file path + """ + if self.topology_path is not None and self.topology_path.exists(): + with mfa_open(self.topology_path) as f: + self._topologies = { + k: v + for k, v in yaml.load(f, Loader=yaml.Loader).items() + if k in self.non_silence_phones + } + found_phones = set(self._topologies.keys()) + missing_phones = self.non_silence_phones - found_phones + if missing_phones: + logger.debug( + f"The following phones will use the default topology (min states = 1, max states = {self.num_non_silence_states}): " + f"{comma_join(sorted(missing_phones))}" + ) + else: + logger.debug("All phones were included in topology config") + logger.debug("The following phones will use custom topologies: ") + for k, v in self._topologies.items(): + min_states = self._topologies[k].get("min_states", 1) + max_states = self._topologies[k].get("max_states", self.num_non_silence_states) + logger.debug(f"{k}: min states = {min_states}, max states = {max_states}") + assert min_states <= max_states + @property def speaker_mapping(self) -> typing.Dict[str, int]: """Mapping of speakers to dictionaries""" @@ -349,6 +403,7 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: ) if default: self._default_dictionary_id = dictionary_id + word_primary_key = 1 pronunciation_primary_key = 1 word_objs = [] @@ -357,6 +412,7 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: phone_counts = collections.Counter() graphemes = set(self.clitic_markers + self.compound_markers) clitic_cleanup_regex = None + has_nonnative_speakers = False if len(self.clitic_markers) >= 1: other_clitic_markers = self.clitic_markers[1:] if other_clitic_markers: @@ -370,10 +426,54 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: dictionary_model, speakers, ) in self.dictionary_model.load_dictionary_paths().values(): + if dictionary_model == "nonnative": + if not has_nonnative_speakers: + dialect = "nonnative" + if dialect not in dialect_id_cache: + dialect_obj = Dialect(name=dialect) + session.add(dialect_obj) + session.flush() + dialect_id_cache[dialect] = dialect_obj.id + dictionary = Dictionary( + name=dialect, + dialect_id=dialect_id_cache[dialect], + path="", + phone_set_type=self.phone_set_type, + root_temp_directory=self.dictionary_output_directory, + position_dependent_phones=self.position_dependent_phones, + clitic_marker=self.clitic_marker + if self.clitic_marker is not None + else "", + default=False, + use_g2p=False, + max_disambiguation_symbol=0, + silence_word=self.silence_word, + oov_word=self.oov_word, + bracketed_word=self.bracketed_word, + cutoff_word=self.cutoff_word, + laughter_word=self.laughter_word, + optional_silence_phone=self.optional_silence_phone, + oov_phone=self.oov_phone, + ) + session.add(dictionary) + session.flush() + dictionary_id_cache[dialect] = dictionary.id + for speaker in speakers: + if speaker not in self._speaker_ids: + speaker_objs.append( + { + "id": self._current_speaker_index, + "name": speaker, + "dictionary_id": dictionary_id_cache[dialect], + } + ) + self._speaker_ids[speaker] = self._current_speaker_index + self._current_speaker_index += 1 + has_nonnative_speakers = True + continue if dictionary_model.path not in dictionary_id_cache and not self.use_g2p: word_cache = {} pronunciation_cache = set() - subsequences = set() if self.phone_set_type not in auto_set: if ( self.phone_set_type != dictionary_model.phone_set_type @@ -543,10 +643,6 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: pronunciation_primary_key += 1 pronunciation_cache.add((word, pron_string)) phone_counts.update(pron) - pron = pron[:-1] - while pron: - subsequences.add(tuple(pron)) - pron = pron[:-1] for w, wt in special_words.items(): if w in specials_found: @@ -650,13 +746,219 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: conn.execute(sqlalchemy.insert(Grapheme.__table__), grapheme_objs) session.commit() + self.create_default_dictionary() + if has_nonnative_speakers: + self.create_nonnative_dictionary() if pron_objs: self.calculate_disambiguation() self.load_phonological_rules() self.calculate_phone_mapping() + self.load_phone_topologies() self.load_phone_groups() return graphemes, phone_counts + def create_default_dictionary(self): + if self._default_dictionary_id is not None: # Already a default dictionary + return + with self.session() as session: + dictionary = session.query(Dictionary).filter_by(default=True).first() + if dictionary is not None: # Already a default dictionary + self._default_dictionary_id = dictionary.id + return + word_objs = [] + pron_objs = [] + dialect = session.query(Dialect).filter_by(name="unknown").first() + if dialect is None: + dialect = Dialect(name="unknown") + session.add(dialect) + session.flush() + dictionary = Dictionary( + name="default", + dialect_id=dialect.id, + path="", + phone_set_type=self.phone_set_type, + root_temp_directory=self.dictionary_output_directory, + position_dependent_phones=self.position_dependent_phones, + clitic_marker=self.clitic_marker if self.clitic_marker is not None else "", + default=True, + use_g2p=False, + max_disambiguation_symbol=0, + silence_word=self.silence_word, + oov_word=self.oov_word, + bracketed_word=self.bracketed_word, + cutoff_word=self.cutoff_word, + laughter_word=self.laughter_word, + optional_silence_phone=self.optional_silence_phone, + oov_phone=self.oov_phone, + ) + session.add(dictionary) + session.commit() + + special_words = { + self.silence_word: WordType.silence, + self.oov_word: WordType.oov, + self.bracketed_word: WordType.bracketed, + self.cutoff_word: WordType.cutoff, + self.laughter_word: WordType.laughter, + } + self._default_dictionary_id = dictionary.id + self.dictionary_lookup[dictionary.name] = dictionary.id + self._words_mappings[dictionary.id] = {} + word_primary_key = self.get_next_primary_key(Word) + pronunciation_primary_key = self.get_next_primary_key(Pronunciation) + current_index = 0 + word_cache = {} + for w, w_type in special_words.items(): + if w_type is WordType.silence: + pron = self.optional_silence_phone + else: + pron = self.oov_phone + word_objs.append( + { + "id": word_primary_key, + "mapping_id": current_index, + "word": w, + "word_type": w_type, + "dictionary_id": dictionary.id, + } + ) + self._words_mappings[dictionary.id][w] = current_index + current_index += 1 + + pron_objs.append( + { + "id": pronunciation_primary_key, + "pronunciation": pron, + "probability": 1.0, + "disambiguation": None, + "silence_after_probability": None, + "silence_before_correction": None, + "non_silence_before_correction": None, + "word_id": word_primary_key, + } + ) + word_primary_key += 1 + pronunciation_primary_key += 1 + + query = ( + session.query( + Word.word, + Word.word_type, + Pronunciation.pronunciation, + ) + .join(Pronunciation.word) + .filter(~Word.word.in_(special_words.keys())) + .distinct() + .order_by(Word.word, Pronunciation.pronunciation) + ) + for word, word_type, pronunciation in query: + if word not in word_cache: + word_objs.append( + { + "id": word_primary_key, + "mapping_id": current_index, + "word": word, + "word_type": word_type, + "dictionary_id": dictionary.id, + } + ) + self._words_mappings[dictionary.id][word] = current_index + current_index += 1 + word_cache[word] = word_primary_key + word_primary_key += 1 + pron_objs.append( + { + "id": pronunciation_primary_key, + "pronunciation": pronunciation, + "probability": None, + "disambiguation": None, + "silence_after_probability": None, + "silence_before_correction": None, + "non_silence_before_correction": None, + "word_id": word_cache[word], + } + ) + pronunciation_primary_key += 1 + for s in ["#0", "", ""]: + word_objs.append( + { + "id": word_primary_key, + "word": s, + "dictionary_id": dictionary.id, + "mapping_id": current_index, + "word_type": WordType.disambiguation, + } + ) + self._words_mappings[dictionary.id][s] = current_index + word_primary_key += 1 + current_index += 1 + with self.session() as session: + with session.bind.begin() as conn: + if word_objs: + conn.execute(sqlalchemy.insert(Word.__table__), word_objs) + if pron_objs: + conn.execute(sqlalchemy.insert(Pronunciation.__table__), pron_objs) + session.commit() + + def create_nonnative_dictionary(self): + with self.session() as session: + dictionary = session.query(Dictionary).filter_by(name="nonnative").first() + word_objs = [] + pron_objs = [] + + self.dictionary_lookup[dictionary.name] = dictionary.id + self._words_mappings[dictionary.id] = {} + word_primary_key = self.get_next_primary_key(Word) + pronunciation_primary_key = self.get_next_primary_key(Pronunciation) + word_cache = {} + query = ( + session.query( + Word.word, + Word.mapping_id, + Word.word_type, + Pronunciation.pronunciation, + ) + .join(Word.pronunciations, isouter=True) + .filter(Word.dictionary_id == self._default_dictionary_id) + .distinct() + .order_by(Word.mapping_id) + ) + for word, mapping_id, word_type, pronunciation in query: + if word not in word_cache: + word_objs.append( + { + "id": word_primary_key, + "mapping_id": mapping_id, + "word": word, + "word_type": word_type, + "dictionary_id": dictionary.id, + } + ) + self._words_mappings[dictionary.id][word] = mapping_id + word_cache[word] = word_primary_key + word_primary_key += 1 + if pronunciation is not None: + pron_objs.append( + { + "id": pronunciation_primary_key, + "pronunciation": pronunciation, + "probability": None, + "disambiguation": None, + "silence_after_probability": None, + "silence_before_correction": None, + "non_silence_before_correction": None, + "word_id": word_cache[word], + } + ) + pronunciation_primary_key += 1 + with self.session() as session: + with session.bind.begin() as conn: + if word_objs: + conn.execute(sqlalchemy.insert(Word.__table__), word_objs) + if pron_objs: + conn.execute(sqlalchemy.insert(Pronunciation.__table__), pron_objs) + session.commit() + def calculate_disambiguation(self) -> None: """Calculate the number of disambiguation symbols necessary for the dictionary""" with self.session() as session: diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index aab12922..f82dea66 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -306,6 +306,34 @@ def __init__(self, missing_phones: Collection[str]): self.message_lines.append(comma_join(missing_phones)) +class PhoneGroupTopologyMismatchError(DictionaryError): + """ + Exception class for when a dictionary receives a new phone + + Parameters + ---------- + error_topologies: List[Tuple[List[str], List[Tuple[int, int]]]] + Phones that are not in the acoustic model + """ + + def __init__( + self, + error_topologies: typing.List[ + typing.Tuple[typing.List[str], typing.List[typing.Tuple[int, int]]] + ], + phone_groups_path: typing.Union[Path, str], + topologies_path: typing.Union[Path, str], + ): + super().__init__("There were multiple topologies found for phones in the same group: ") + for k, v in error_topologies: + v = [f"(min_states = {x[0]}, max_states = {x[1]})" for x in v] + self.message_lines.append(f"{comma_join(k)}: {comma_join(v)}") + self.message_lines.append( + f"Please update {phone_groups_path} or {topologies_path} to " + f"ensure that phone groups have a single set of minimum and maximum states" + ) + + class NoDefaultSpeakerDictionaryError(DictionaryError): """ Exception class for errors in creating MultispeakerDictionary objects diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 124d3659..087a7bf9 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -647,7 +647,7 @@ def meta(self) -> MetaDict: self._meta["other_noise_phone"] = "sp" if "phone_set_type" not in self._meta: self._meta["phone_set_type"] = "UNKNOWN" - if "language" not in self._meta: + if "language" not in self._meta or self._meta["version"] <= "3.0": self._meta["language"] = "unknown" self._meta["phones"] = set(self._meta.get("phones", [])) if ( @@ -1589,7 +1589,9 @@ def load_dictionary_paths(self) -> Dict[str, Tuple[DictionaryModel, typing.Set[s data = yaml.load(f, Loader=yaml.Loader) for speaker, path in data.items(): if path not in mapping: - mapping[path] = (DictionaryModel(path), set()) + if path != "nonnative": + path = DictionaryModel(path) + mapping[path] = (path, set()) mapping[path][1].add(speaker) else: mapping[str(self.path)] = (self, {"default"}) diff --git a/montreal_forced_aligner/textgrid.py b/montreal_forced_aligner/textgrid.py index 140d63f6..f804c5bb 100644 --- a/montreal_forced_aligner/textgrid.py +++ b/montreal_forced_aligner/textgrid.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import Dict, List +import sqlalchemy from praatio import textgrid as tgio from praatio.data_classes.interval_tier import Interval from praatio.utilities import utils as tgio_utils @@ -40,7 +41,7 @@ __all__ = [ "process_ctm_line", "export_textgrid", - "construct_output_tiers", + "construct_textgrid_output", "construct_output_path", "output_textgrid_writing_errors", ] @@ -275,87 +276,150 @@ def parse_aligned_textgrid( return data -def construct_output_tiers( +def construct_textgrid_output( session: Session, - file_id: int, + file_batch: typing.Dict[int, typing.Tuple], workflow: CorpusWorkflow, cleanup_textgrids: bool, clitic_marker: str, - include_original_text: bool, -) -> Dict[str, Dict[str, List[CtmInterval]]]: - """ - Construct aligned output tiers for a file - - Parameters - ---------- - session: Session - SqlAlchemy session - file_id: int - Integer ID for the file - - Returns - ------- - Dict[str, Dict[str,List[CtmInterval]]] - Aligned tiers - """ - data = {} - phone_intervals = ( - session.query(PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name) + output_directory: Path, + frame_shift: float, + output_format: str = TextgridFormats.SHORT_TEXTGRID, + include_original_text: bool = False, +): + phone_interval_query = ( + sqlalchemy.select( + PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name, Utterance.file_id + ) + .execution_options(yield_per=1000) .join(PhoneInterval.phone) .join(PhoneInterval.utterance) .join(Utterance.speaker) - .filter(Utterance.file_id == file_id) .filter(PhoneInterval.workflow_id == workflow.id) .filter(PhoneInterval.duration > 0) - .order_by(PhoneInterval.begin) + .filter(Utterance.file_id.in_(list(file_batch.keys()))) ) - word_intervals = ( - session.query(WordInterval.begin, WordInterval.end, Word.word, Speaker.name) + word_interval_query = ( + sqlalchemy.select( + WordInterval.begin, WordInterval.end, Word.word, Speaker.name, Utterance.file_id + ) + .execution_options(yield_per=1000) .join(WordInterval.word) .join(WordInterval.utterance) .join(Utterance.speaker) - .filter(Utterance.file_id == file_id) .filter(WordInterval.workflow_id == workflow.id) .filter(WordInterval.duration > 0) - .order_by(WordInterval.begin) + .filter(Utterance.file_id.in_(list(file_batch.keys()))) ) if cleanup_textgrids: - phone_intervals = phone_intervals.filter(Phone.phone_type != PhoneType.silence) - word_intervals = word_intervals.filter(Word.word_type != WordType.silence) - - for w_begin, w_end, w, speaker_name in word_intervals: - if speaker_name not in data: - data[speaker_name] = {"words": [], "phones": []} - if include_original_text: - data[speaker_name]["utterances"] = [] - if ( - cleanup_textgrids - and data[speaker_name]["words"] - and w_begin - data[speaker_name]["words"][-1].end < 0.02 - and clitic_marker - and ( - data[speaker_name]["words"][-1].label.endswith(clitic_marker) - or w.startswith(clitic_marker) - ) - ): - data[speaker_name]["words"][-1].end = w_end - data[speaker_name]["words"][-1].label += w - - else: - data[speaker_name]["words"].append(CtmInterval(w_begin, w_end, w)) - + phone_interval_query = phone_interval_query.filter(Phone.phone_type != PhoneType.silence) + word_interval_query = word_interval_query.filter(Word.word_type != WordType.silence) + phone_intervals = session.execute( + phone_interval_query.order_by(Utterance.file_id, PhoneInterval.begin) + ) + word_intervals = session.execute( + word_interval_query.order_by(Utterance.file_id, WordInterval.begin) + ) + utterances = None if include_original_text: - utterances = ( - session.query(Utterance.begin, Utterance.end, Utterance.text, Speaker.name) + utterances = session.execute( + sqlalchemy.select( + Utterance.begin, Utterance.end, Utterance.text, Speaker.name, Utterance.file_id + ) + .execution_options(yield_per=1000) .join(Utterance.speaker) - .filter(Utterance.file_id == file_id) + .filter(Utterance.file_id.in_(list(file_batch.keys()))) + .order_by(Utterance.file_id) ) - for utt_begin, utt_end, utt_text, speaker_name in utterances: - data[speaker_name]["utterances"].append(CtmInterval(utt_begin, utt_end, utt_text)) - - for p_begin, p_end, phone, speaker_name in phone_intervals: - data[speaker_name]["phones"].append(CtmInterval(p_begin, p_end, phone)) - return data + pi_current_file_id = None + wi_current_file_id = None + u_current_file_id = None + word_data = [] + phone_data = [] + utterance_data = [] + + def process_phone_data(): + for beg, end, p, speaker_name in phone_data: + if speaker_name not in data: + data[speaker_name] = {"words": [], "phones": []} + if include_original_text: + data[speaker_name]["utterances"] = [] + data[speaker_name]["phones"].append(CtmInterval(beg, end, p)) + + def process_word_data(): + for beg, end, w, speaker_name in word_data: + if ( + cleanup_textgrids + and data[speaker_name]["words"] + and beg - data[speaker_name]["words"][-1].end < 0.02 + and clitic_marker + and ( + data[speaker_name]["words"][-1].label.endswith(clitic_marker) + or w.startswith(clitic_marker) + ) + ): + data[speaker_name]["words"][-1].end = end + data[speaker_name]["words"][-1].label += w + else: + data[speaker_name]["words"].append(CtmInterval(beg, end, w)) + + def process_utterance_data(): + for beg, end, u, speaker_name in utterance_data: + data[speaker_name]["utterances"].append(CtmInterval(beg, end, u)) + + while True: + data = {} + for pi_begin, pi_end, phone, pi_speaker_name, pi_file_id in phone_intervals: + if pi_current_file_id is None: + pi_current_file_id = pi_file_id + if pi_file_id != pi_current_file_id: + process_phone_data() + phone_data = [(pi_begin, pi_end, phone, pi_speaker_name)] + current_file_id = pi_current_file_id + pi_current_file_id = pi_file_id + break + phone_data.append((pi_begin, pi_end, phone, pi_speaker_name)) + else: + if phone_data: + process_phone_data() + current_file_id = pi_current_file_id + else: + break + for wi_begin, wi_end, word, wi_speaker_name, wi_file_id in word_intervals: + if wi_current_file_id is None: + wi_current_file_id = wi_file_id + if wi_file_id != wi_current_file_id: + process_word_data() + word_data = [(wi_begin, wi_end, word, wi_speaker_name)] + wi_current_file_id = wi_file_id + break + word_data.append((wi_begin, wi_end, word, wi_speaker_name)) + else: + if word_data: + process_word_data() + if include_original_text: + for u_begin, u_end, text, u_speaker_name, u_file_id in utterances: + if u_current_file_id is None: + u_current_file_id = u_file_id + if u_file_id != u_current_file_id: + process_utterance_data() + utterance_data = [(u_begin, u_end, text, u_speaker_name)] + u_current_file_id = u_file_id + break + utterance_data.append((u_begin, u_end, text, u_speaker_name)) + else: + if utterance_data: + process_utterance_data() + + file_name, relative_path, file_duration, text_file_path = file_batch[current_file_id] + output_path = construct_output_path( + file_name, relative_path, output_directory, text_file_path, output_format + ) + export_textgrid(data, output_path, file_duration, frame_shift, output_format) + yield output_path + word_data = [] + phone_data = [] + utterance_data = [] def construct_output_path( diff --git a/montreal_forced_aligner/tokenization/japanese.py b/montreal_forced_aligner/tokenization/japanese.py index beae7db2..551e9393 100644 --- a/montreal_forced_aligner/tokenization/japanese.py +++ b/montreal_forced_aligner/tokenization/japanese.py @@ -16,7 +16,7 @@ class JapaneseTokenizer: def __init__(self, ignore_case: bool = True): self.ignore_case = ignore_case resource_dir = pathlib.Path(__file__).parent.joinpath("resources") - config_path = resource_dir.joinpath("japanese", "sudachi_config.json") + config_path = str(resource_dir.joinpath("japanese", "sudachi_config.json")) try: self.tokenizer = sudachipy.Dictionary(dict="full", config_path=config_path).create( mode=sudachipy.SplitMode.B diff --git a/montreal_forced_aligner/tokenization/simple.py b/montreal_forced_aligner/tokenization/simple.py index a7bd355d..74abd9a6 100644 --- a/montreal_forced_aligner/tokenization/simple.py +++ b/montreal_forced_aligner/tokenization/simple.py @@ -129,6 +129,7 @@ def __init__( initial_clitic_regex: typing.Optional[re.Pattern], final_clitic_regex: typing.Optional[re.Pattern], compound_regex: typing.Optional[re.Pattern], + cutoff_regex: typing.Optional[re.Pattern], non_speech_regexes: typing.Dict[str, re.Pattern], oov_word: typing.Optional[str] = None, grapheme_set: typing.Optional[typing.Collection[str]] = None, @@ -136,6 +137,7 @@ def __init__( self.word_table = word_table self.clitic_marker = clitic_marker self.compound_regex = compound_regex + self.cutoff_regex = cutoff_regex self.oov_word = oov_word self.specials_set = {self.oov_word, "", ""} if not grapheme_set: @@ -172,6 +174,8 @@ def to_str(self, normalized_text: str) -> str: if self.word_table and self.word_table.member(normalized_text): return normalized_text for word, regex in self.non_speech_regexes.items(): + if self.cutoff_regex.match(normalized_text): + return normalized_text if regex.match(normalized_text): return word return normalized_text @@ -195,7 +199,8 @@ def split_clitics( """ split = [] if self.compound_regex is not None: - s = self.compound_regex.split(item) + s = [x for x in self.compound_regex.split(item) if x] + else: s = [item] if self.word_table is None: @@ -292,6 +297,8 @@ def __call__( """ if self.word_table and self.word_table.member(item): return [item] + if self.cutoff_regex.match(item): + return item for regex in self.non_speech_regexes.values(): if regex.match(item): return [item] @@ -325,7 +332,10 @@ def __init__( self.laughter_word = laughter_word self.oov_word = oov_word self.bracketed_word = bracketed_word - self.cutoff_word = cutoff_word + + initial_brackets = re.escape("".join(x[0] for x in self.brackets)) + final_brackets = re.escape("".join(x[1] for x in self.brackets)) + self.cutoff_identifier = re.sub(rf"[{initial_brackets}{final_brackets}]", "", cutoff_word) self.ignore_case = ignore_case self.use_g2p = use_g2p self.clitic_set = set() @@ -372,6 +382,7 @@ def __init__( self.initial_clitic_regex, self.final_clitic_regex, self.compound_regex, + self.cutoff_regex, self.non_speech_regexes, self.oov_word, self.grapheme_set, @@ -396,15 +407,19 @@ def _compile_regexes(self) -> None: if "-" in self.compound_markers: extra = "-" compound_markers = [x for x in compound_markers if x != "-"] - self.compound_regex = re.compile(rf"(?<=\w)[{extra}{''.join(compound_markers)}](?=\w)") + self.compound_regex = re.compile( + rf"(?<=\w)[{extra}{''.join(compound_markers)}](?:$|(?=\w))" + ) if self.brackets: - left_brackets = [x[0] for x in self.brackets] - right_brackets = [x[1] for x in self.brackets] - self.bracket_regex = re.compile( - rf"[{re.escape(''.join(left_brackets))}].*?[{re.escape(''.join(right_brackets))}]+" + left_brackets = re.escape("".join(x[0] for x in self.brackets)) + right_brackets = re.escape("".join(x[1] for x in self.brackets)) + self.cutoff_regex = re.compile( + rf"^[{left_brackets}]({self.cutoff_identifier}|hes(itation)?)([-_](?P[^{right_brackets}]+))?[{right_brackets}]$", + flags=re.IGNORECASE, ) + self.bracket_regex = re.compile(rf"[{left_brackets}].*?[{right_brackets}]+") self.laughter_regex = re.compile( - rf"[{re.escape(''.join(left_brackets))}](laugh(ing|ter)?|lachen|lg)[{re.escape(''.join(right_brackets))}]+", + rf"[{left_brackets}](laugh(ing|ter)?|lachen|lg)[{right_brackets}]+", flags=re.IGNORECASE, ) all_punctuation = set() diff --git a/tests/conftest.py b/tests/conftest.py index e1d0d53c..9e4a62f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -996,11 +996,11 @@ def multispeaker_dictionary_config_path(generated_dir, basic_dict_path, english_ @pytest.fixture(scope="session") -def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_mfa_dictionary): +def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_mfa_reduced_dict): path = generated_dir.joinpath("test_multispeaker_mfa_dictionary.yaml") with mfa_open(path, "w") as f: yaml.dump( - {"default": english_us_mfa_dictionary, "speaker": english_uk_mfa_dictionary}, + {"default": english_us_mfa_reduced_dict, "speaker": english_us_mfa_reduced_dict}, f, Dumper=yaml.Dumper, allow_unicode=True, @@ -1008,6 +1008,30 @@ def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_m return path +@pytest.fixture(scope="session") +def english_mfa_phone_groups_path(config_directory): + path = config_directory.joinpath("acoustic", "english_mfa_phone_groups.yaml") + return path + + +@pytest.fixture(scope="session") +def english_mfa_rules_path(config_directory): + path = config_directory.joinpath("acoustic", "english_mfa_rules.yaml") + return path + + +@pytest.fixture(scope="session") +def english_mfa_topology_path(config_directory): + path = config_directory.joinpath("acoustic", "english_mfa_topology.yaml") + return path + + +@pytest.fixture(scope="session") +def bad_topology_path(config_directory): + path = config_directory.joinpath("acoustic", "bad_topology.yaml") + return path + + @pytest.fixture(scope="session") def test_align_config(): return {"beam": 100, "retry_beam": 400} diff --git a/tests/data/configs/acoustic/bad_topology.yaml b/tests/data/configs/acoustic/bad_topology.yaml new file mode 100644 index 00000000..048869a9 --- /dev/null +++ b/tests/data/configs/acoustic/bad_topology.yaml @@ -0,0 +1,4 @@ +b: + max_states: 3 +bʲ: + max_states: 1 diff --git a/tests/data/configs/acoustic/english_mfa_phone_groups.yaml b/tests/data/configs/acoustic/english_mfa_phone_groups.yaml new file mode 100644 index 00000000..52a5e88b --- /dev/null +++ b/tests/data/configs/acoustic/english_mfa_phone_groups.yaml @@ -0,0 +1,170 @@ +- + - p + - pʷ + - pʰ + - pʲ +- + - kp +- + - b + - bʲ +- + - ɡb +- + - f + - fʷ + - fʲ +- + - v + - vʷ + - vʲ +- + - θ +- + - t̪ +- + - ð +- + - d̪ +- + - t + - tʷ + - tʰ + - tʲ +- + - ʈ + - ʈʲ + - ʈʷ +- + - ʔ +- + - d + - dʲ +- + - ɖ + - ɖʲ +- + - ɾ + - ɾʲ +- + - tʃ +- + - dʒ +- + - ʃ +- + - ʒ +- + - s +- + - z +- + - ɹ +- + - m +- + - mʲ +- + - m̩ +- + - ɱ +- + - n +- + - n̩ +- + - ɲ +- + - ɾ̃ +- + - ŋ +- + - l +- + - ɫ +- + - ɫ̩ +- + - ʎ +- + - ɟ + - ɟʷ +- + - ɡ + - ɡʷ +- + - c + - cʷ + - cʰ +- + - k + - kʷ + - kʰ +- + - ç +- + - h +- + - ɐ +- + - ə +- + - ɜː + - ɜ +- + - ɝ +- + - ɚ +- + - ʊ +- + - ɪ +- + - ɑ + - ɑː +- + - ɒ + - ɒː +- + - ɔ +- + - aː + - a +- + - æ +- + - aj +- + - aw +- + - i + - iː +- + - j +- + - ɛː + - ɛ +- + - e + - eː +- + - ej +- + - ʉ + - ʉː +- + - uː + - u +- + - w +- + - ʋ +- + - ɔj +- + - ow +- + - əw +- + - o + - oː diff --git a/tests/data/configs/acoustic/english_mfa_rules.yaml b/tests/data/configs/acoustic/english_mfa_rules.yaml new file mode 100644 index 00000000..9b8379f0 --- /dev/null +++ b/tests/data/configs/acoustic/english_mfa_rules.yaml @@ -0,0 +1,383 @@ +dialects: + nonnative: + - following_context: '' + preceding_context: '' + replacement: z + segment: 'ð' + - following_context: '' + preceding_context: '' + replacement: s + segment: 'θ' + nigeria: + - following_context: $ + preceding_context: '' + replacement: p s + segment: 'b z' + - following_context: $ + preceding_context: '' + replacement: t s + segment: 'd z' + - following_context: $ + preceding_context: '' + replacement: k s + segment: 'ɡ z' + - following_context: $ + preceding_context: '' + replacement: s + segment: 'z' + - following_context: '' + preceding_context: '^' + replacement: '' + segment: 'ç' + - following_context: '' + preceding_context: '^' + replacement: '' + segment: 'h' + - following_context: '$' + preceding_context: 'ŋ' + replacement: '' + segment: 'ɡ' + uk: + - following_context: '' + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ʔ + segment: 't[ʲʷ]?' + - following_context: ʉː? + preceding_context: '' + replacement: tʃ + segment: 'tʲ' + - following_context: ʉː? + preceding_context: '' + replacement: dʒ + segment: 'dʲ' + - following_context: '' + preceding_context: '^' + replacement: '' + segment: 'ç' + - following_context: '' # mitten glottalized syllabic + preceding_context: '' + replacement: ʔ n̩ + segment: t ə n + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic n + preceding_context: '' + replacement: n̩ + segment: ə n + - following_context: $ # syllabic n + preceding_context: '' + replacement: n̩ + segment: ə n + - following_context: $ # syllabic m + preceding_context: '' + replacement: m̩ + segment: ə m + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic m + preceding_context: '' + replacement: m̩ + segment: ə m + - following_context: $ # syllabic l + preceding_context: '' + replacement: ɫ̩ + segment: ə ɫ + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic l + preceding_context: '' + replacement: ɫ̩ + segment: ə ɫ + india: + - following_context: '' + preceding_context: '' + replacement: ɾ + segment: 'ɹ' + - following_context: '' + preceding_context: '' + replacement: a + segment: 'ɒ' + - following_context: '' + preceding_context: '' + replacement: a + segment: 'ɑ' + - following_context: '' + preceding_context: '' + replacement: aː + segment: 'ɒː' + - following_context: '' + preceding_context: '' + replacement: aː + segment: 'ɑː' + - following_context: '' + preceding_context: '' + replacement: f + segment: 'pʰ' + - following_context: '' + preceding_context: '' + replacement: dʒ + segment: 'z' + - following_context: '' + preceding_context: '' + replacement: dʒ + segment: 'ʒ' + - following_context: '' + preceding_context: '' + replacement: z + segment: 'ʒ' + - following_context: '' + preceding_context: '' + replacement: ʃ + segment: 'ʒ' + us: + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ɾ + segment: '[td]' + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: ɹ + replacement: ɾ + segment: '[td]' + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: ɫ + replacement: ɾ + segment: '[td]' + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ɾʲ + segment: '[td]ʲ' + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: ɹ + replacement: ɾʲ + segment: '[td]ʲ' + - following_context: (ɪ|ə|ɚ|i) # t/d flapping + preceding_context: '(ɫ|ɫ̩)' + replacement: ɾʲ + segment: '[td]ʲ' + - following_context: (ɪ|ə|ɚ|i) # nasal flapping + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ɾ̃ + segment: (ɲ|n) + - following_context: (ɪ|ə|ɚ|i) # nasal flapping + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ɾ̃ + segment: (ɲ|n) [td][ʲʷ]? + - following_context: $ # caught-cot merger + preceding_context: '' + replacement: ɑː + segment: ɒː + - following_context: '[^ɹ]' # caught-cot merger + preceding_context: '' + replacement: ɑː + segment: ɒː + - following_context: $ # caught-cot merger + preceding_context: '' + replacement: ɑ + segment: ɒ + - following_context: '[^ɹ]' # caught-cot merger + preceding_context: '' + replacement: ɑ + segment: ɒ + - following_context: $ # t/d flapping + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: ɾ + segment: d + - following_context: '' # mitten glottalized syllabic + preceding_context: '' + replacement: ʔ n̩ + segment: t ə n + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic n + preceding_context: '' + replacement: n̩ + segment: ə n + - following_context: $ # syllabic n + preceding_context: '' + replacement: n̩ + segment: ə n + - following_context: $ # syllabic m + preceding_context: '' + replacement: m̩ + segment: ə m + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic m + preceding_context: '' + replacement: m̩ + segment: ə m + - following_context: $ # syllabic l + preceding_context: '' + replacement: ɫ̩ + segment: ə ɫ + - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic l + preceding_context: '' + replacement: ɫ̩ + segment: ə ɫ +rules: +- following_context: .*(ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + preceding_context: ^ + replacement: '' + segment: ə +- following_context: $ + preceding_context: '[sʃn]' + replacement: '' + segment: '[tʈ]' +- following_context: $ + preceding_context: '[zʒn]' + replacement: '' + segment: '[dɖ]' +- following_context: '' + preceding_context: 'n' + replacement: '' + segment: '[dɖ]' +- following_context: ə|ɚ + preceding_context: '' + replacement: j + segment: i +- following_context: ə|ɚ + preceding_context: '' + replacement: w + segment: '[ʉu]' +- following_context: '$' + preceding_context: '' + replacement: t̪ + segment: '[tʈ] θ' +- following_context: '' + preceding_context: '' + replacement: t̪ + segment: 'θ [tʈ]' +- following_context: '$' + preceding_context: '' + replacement: d̪ + segment: 'ð [ɖd]' +- following_context: $ # t deletion + preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː? + replacement: '' + segment: '[tʈ]' +- following_context: '' + preceding_context: '' + replacement: d̪ + segment: '[ɖd] ð' +- following_context: '' + preceding_context: '' + replacement: d̪ + segment: ð +- following_context: '' + preceding_context: '' + replacement: t̪ + segment: θ +- following_context: '[tʈpkcsʃf][ʲʷ]?' + preceding_context: '' + replacement: s + segment: z +- following_context: '[tʈpkcsʃf][ʲʷ]?' + preceding_context: '' + replacement: t + segment: d +- following_context: '[tʈpkcsʃf][ʲʷ]?' + preceding_context: '' + replacement: p + segment: b +- following_context: '[tpkcsʃf][ʲʷ]?' + preceding_context: '' + replacement: k + segment: ɡ +- following_context: '[tʈpkcsʃf][ʲʷ]?' + preceding_context: '' + replacement: tʃ + segment: dʒ +- following_context: '[vʋf]ʲ?' + preceding_context: '' + replacement: ɱ + segment: m +- following_context: '[vʋf]ʲ?' + preceding_context: '' + replacement: ɱ + segment: 'n' +- following_context: '[tʈdɖ]$' + preceding_context: m + replacement: '' + segment: '[pb]' +- following_context: '[sz]$' + preceding_context: 'n' + replacement: '' + segment: '[tʈdɖ]' +- following_context: '[s]$' + preceding_context: '' + replacement: k + segment: '[sʃ] k' +- following_context: '[s]$' + preceding_context: '' + replacement: '' + segment: '[sʃ] t' +- following_context: $ # ask metathesis + preceding_context: '' + replacement: k s + segment: s k +- following_context: '[dɖʈtcɟɡk][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: p +- following_context: '[pbcɟɡk][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: '[tʈ]' +- following_context: '[pbcɟɡk][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: d +- following_context: '[dɖtʈcɟɡk][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: b +- following_context: '[dɖtʈpb][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: k +- following_context: '[dɖtʈpb][ʲʷ]?' + preceding_context: '' + replacement: '' + segment: ɡ +- following_context: ([tʈpkc][ʲʷ]?)? ɹ + preceding_context: '' + replacement: ʃ + segment: s +- following_context: 'ɹ' + preceding_context: '' + replacement: tʃ + segment: '[tʈ][ʲʷ]?' +- following_context: '' + preceding_context: '' + replacement: tʃ + segment: '[tʈ][ʲʷ]? ɹ' +- following_context: '' + preceding_context: '' + replacement: dʒ + segment: '[dɖ][ʲʷ]? ɹ' +- following_context: 'ɹ' + preceding_context: '' + replacement: dʒ + segment: '[dɖ][ʲʷ]?' +- following_context: ə n + preceding_context: '' + replacement: ʔ + segment: '[tʈ]' +- following_context: $ # ing/in' variation + preceding_context: ɪ + replacement: 'n' + segment: ŋ +- following_context: z$ # ing/in' variation + preceding_context: ɪ + replacement: 'n' + segment: ŋ +- following_context: '' # schwa deletion + preceding_context: '' + replacement: l ə + segment: 'ə l ə' +- following_context: '' # schwa deletion + preceding_context: '' + replacement: ʎ ə + segment: 'ə ʎ ə' +- following_context: '' # schwa deletion + preceding_context: '' + replacement: n ə + segment: 'ə n ə' +- following_context: '' # schwa deletion + preceding_context: '' + replacement: m ə + segment: 'ə m ə' +- following_context: '' # schwa deletion + preceding_context: '' + replacement: ɹ ə + segment: 'ə ɹ ə' diff --git a/tests/data/configs/acoustic/english_mfa_topology.yaml b/tests/data/configs/acoustic/english_mfa_topology.yaml new file mode 100644 index 00000000..18af2846 --- /dev/null +++ b/tests/data/configs/acoustic/english_mfa_topology.yaml @@ -0,0 +1,46 @@ +ɾ: + max_states: 1 + min_states: 1 +ɾʲ: + max_states: 1 + min_states: 1 +ɾ̃: + max_states: 1 + min_states: 1 +ʔ: + max_states: 1 + min_states: 1 +ə: + max_states: 3 +ɚ: + max_states: 3 +ɪ: + max_states: 3 +e: + max_states: 3 +eː: + max_states: 3 +ɛ: + max_states: 3 +ɛː: + max_states: 3 +ɐ: + max_states: 3 +i: + max_states: 3 +iː: + max_states: 3 +o: + max_states: 3 +oː: + max_states: 3 +u: + max_states: 3 +uː: + max_states: 3 +ɝ: + max_states: 3 +j: + max_states: 3 +w: + max_states: 3 diff --git a/tests/test_commandline_train.py b/tests/test_commandline_train.py index f42c44e1..9e63b749 100644 --- a/tests/test_commandline_train.py +++ b/tests/test_commandline_train.py @@ -4,6 +4,7 @@ import pytest from montreal_forced_aligner.command_line.mfa import mfa_cli +from montreal_forced_aligner.exceptions import PhoneGroupTopologyMismatchError @pytest.mark.skip("Inconsistent failing on CI") @@ -56,11 +57,45 @@ def test_train_and_align_basic_speaker_dict( temp_dir, basic_train_config_path, textgrid_output_model_path, + english_mfa_phone_groups_path, + english_mfa_rules_path, + english_mfa_topology_path, + bad_topology_path, db_setup, ): if os.path.exists(textgrid_output_model_path): os.remove(textgrid_output_model_path) output_directory = generated_dir.joinpath("ipa speaker output") + with pytest.raises(PhoneGroupTopologyMismatchError): + command = [ + "train", + multilingual_ipa_tg_corpus_dir, + mfa_speaker_dict_path, + textgrid_output_model_path, + "--config_path", + basic_train_config_path, + "-q", + "--clean", + "--no_debug", + "--output_directory", + output_directory, + "--single_speaker", + "--phone_groups_path", + english_mfa_phone_groups_path, + "--rules_path", + english_mfa_rules_path, + "--topology_path", + bad_topology_path, + ] + command = [str(x) for x in command] + result = click.testing.CliRunner(mix_stderr=False).invoke( + mfa_cli, command, catch_exceptions=True + ) + print(result.stdout) + print(result.stderr) + if result.exception: + print(result.exc_info) + raise result.exception command = [ "train", multilingual_ipa_tg_corpus_dir, @@ -74,6 +109,12 @@ def test_train_and_align_basic_speaker_dict( "--output_directory", output_directory, "--single_speaker", + "--phone_groups_path", + english_mfa_phone_groups_path, + "--rules_path", + english_mfa_rules_path, + "--topology_path", + english_mfa_topology_path, ] command = [str(x) for x in command] result = click.testing.CliRunner(mix_stderr=False).invoke(