diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst
index 438ecc11..8025ac35 100644
--- a/docs/source/changelog/changelog_3.0.rst
+++ b/docs/source/changelog/changelog_3.0.rst
@@ -5,6 +5,24 @@
3.0 Changelog
*************
+3.1.0
+-----
+
+- Fixed a bug where cutoffs were not properly modelled
+- Added additional filter on create subset to not include utterances with cutoffs in smaller subsets
+- Added the ability to specify HMM topologies for phones
+- Fixed issues caused by validators not cleaning up temporary files and databases
+- Added support for default and nonnative dictionaries generated from other dictionaries
+- Restricted initial training rounds to exclude default and nonnative dictionaries
+- Changed clustering of phones to not mix silence and non-silence phones
+- Optimized textgrid export
+- Added better memory management for collecting alignments
+
+3.0.8
+-----
+
+- Fixed a compatibility issue with models trained under version 1.0 and earlier
+
3.0.7
-----
diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py
index f814ebb0..96fcc4df 100644
--- a/montreal_forced_aligner/abc.py
+++ b/montreal_forced_aligner/abc.py
@@ -316,7 +316,7 @@ def db_engine(self) -> sqlalchemy.engine.Engine:
self._db_engine = self.construct_engine()
return self._db_engine
- def get_next_primary_key(self, database_table: MfaSqlBase):
+ def get_next_primary_key(self, database_table):
with self.session() as session:
pk = session.query(sqlalchemy.func.max(database_table.id)).scalar()
if not pk:
@@ -634,7 +634,8 @@ def parse_args(
unknown_dict[name] = val
for name, param_type in param_types.items():
if (name.endswith("_directory") and name != "audio_directory") or (
- name.endswith("_path") and name not in {"rules_path", "phone_groups_path"}
+ name.endswith("_path")
+ and name not in {"rules_path", "phone_groups_path", "topology_path"}
):
continue
if args is not None and name in args and args[name] is not None:
diff --git a/montreal_forced_aligner/acoustic_modeling/lda.py b/montreal_forced_aligner/acoustic_modeling/lda.py
index 279b1c60..fa9c8c09 100644
--- a/montreal_forced_aligner/acoustic_modeling/lda.py
+++ b/montreal_forced_aligner/acoustic_modeling/lda.py
@@ -97,6 +97,8 @@ def _run(self):
]
for dict_id in job.dictionary_ids:
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
lda_logger.debug(f"Processing {ali_path}")
feat_path = job.construct_path(
job.corpus.current_subset_directory, "feats", "scp", dictionary_id=dict_id
@@ -164,6 +166,8 @@ def _run(self) -> typing.Generator[int]:
]
for dict_id in job.dictionary_ids:
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
lda_logger.debug(f"Processing {ali_path}")
feature_archive = job.construct_feature_archive(self.working_directory, dict_id)
alignment_archive = AlignmentArchive(ali_path)
diff --git a/montreal_forced_aligner/acoustic_modeling/monophone.py b/montreal_forced_aligner/acoustic_modeling/monophone.py
index 84fff727..06d68b72 100644
--- a/montreal_forced_aligner/acoustic_modeling/monophone.py
+++ b/montreal_forced_aligner/acoustic_modeling/monophone.py
@@ -75,7 +75,7 @@ def _run(self):
num_error = 0
tot_like = 0.0
tot_t = 0.0
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
dict_id = d.id
train_logger.debug(f"Aligning for dictionary {d.name} ({d.id})")
train_logger.debug(f"Aligning with model: {self.model_path}")
@@ -302,14 +302,22 @@ def _trainer_initialization(self) -> None:
tree_path = self.working_directory.joinpath("tree")
init_log_path = self.working_log_directory.joinpath("init.log")
job = self.jobs[0]
- dict_id = job.dictionary_ids[0]
- feature_archive = job.construct_feature_archive(self.working_directory, dict_id)
feats = []
with kalpy_logger("kalpy.train", init_log_path) as train_logger:
- for i, (_, mat) in enumerate(feature_archive):
- if i > 10:
+ dict_index = 0
+ while len(feats) < 10:
+ try:
+ dict_id = job.dictionary_ids[dict_index]
+ except IndexError:
break
- feats.append(mat)
+ feature_archive = job.construct_feature_archive(self.working_directory, dict_id)
+ for i, (_, mat) in enumerate(feature_archive):
+ if i > 10:
+ break
+ feats.append(mat)
+ dict_index += 1
+ if not feats:
+ raise Exception("Could not initialize monophone model due to lack of features")
shared_phones = self.worker.shared_phones_set_symbols()
topo = read_topology(self.worker.topo_path)
gmm_init_mono(topo, feats, shared_phones, str(self.model_path), str(tree_path))
diff --git a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py
index 3650c4f0..0903b3cd 100644
--- a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py
+++ b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py
@@ -312,8 +312,12 @@ def setup(self):
previous_directory = self.previous_aligner.working_directory
for j in self.jobs:
for p in j.construct_path_dictionary(previous_directory, "ali", "ark").values():
+ if not p.exists():
+ continue
shutil.copy(p, wf.working_directory.joinpath(p.name))
for p in j.construct_path_dictionary(previous_directory, "words", "ark").values():
+ if not p.exists():
+ continue
shutil.copy(p, wf.working_directory.joinpath(p.name))
for f in ["final.mdl", "final.alimdl", "lda.mat", "tree"]:
p = previous_directory.joinpath(f)
@@ -384,6 +388,12 @@ def train_pronunciation_probabilities(self) -> None:
)
with mfa_open(silence_info_path, "r") as f:
data = json.load(f)
+ for k, v in data.items():
+ if v is None:
+ if "correction" in k:
+ data[k] = 1.0
+ else:
+ data[k] = 0.5
if self.silence_probabilities:
d.silence_probability = data["silence_probability"]
d.initial_silence_probability = data["initial_silence_probability"]
diff --git a/montreal_forced_aligner/acoustic_modeling/sat.py b/montreal_forced_aligner/acoustic_modeling/sat.py
index 73cb97eb..3da3d2b0 100644
--- a/montreal_forced_aligner/acoustic_modeling/sat.py
+++ b/montreal_forced_aligner/acoustic_modeling/sat.py
@@ -81,13 +81,15 @@ def _run(self):
.filter(Job.id == self.job_name)
.first()
)
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})")
train_logger.debug(f"Accumulating stats for model: {self.model_path}")
dict_id = d.id
accumulator = TwoFeatsStatsAccumulator(self.model_path)
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
fmllr_path = job.construct_path(
job.corpus.current_subset_directory, "trans", "scp", dict_id
)
diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py
index 6808738b..c663e0d3 100644
--- a/montreal_forced_aligner/acoustic_modeling/trainer.py
+++ b/montreal_forced_aligner/acoustic_modeling/trainer.py
@@ -100,6 +100,8 @@ def _run(self) -> typing.Generator[typing.Tuple[int, str]]:
transition_model, acoustic_model = read_gmm_model(self.model_path)
for dict_id in job.dictionary_ids:
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
transition_accs = DoubleVector(transition_model.NumTransitionIds() + 1)
alignment_archive = AlignmentArchive(ali_path)
for alignment in alignment_archive:
@@ -523,6 +525,8 @@ def quality_check_subset(self):
self.working_directory, "temp_ali", "ark"
)
for dict_id, ali_path in ali_paths.items():
+ if not ali_path.exists():
+ continue
new_path = temp_ali_paths[dict_id]
write_specifier = generate_write_specifier(new_path)
writer = Int32VectorWriter(write_specifier)
@@ -577,15 +581,20 @@ def train(self) -> None:
self.current_acoustic_model = AcousticModel(
previous.exported_model_path, self.working_directory
)
- self.align()
- with self.session() as session:
- session.query(WordInterval).delete()
- session.query(PhoneInterval).delete()
- session.commit()
- self.collect_alignments()
- self.analyze_alignments()
- if self.current_subset != 0:
- self.quality_check_subset()
+ if (
+ not self.current_workflow.done
+ or not self.current_workflow.working_directory.exists()
+ ):
+ logger.debug(f"Skipping {self.current_aligner.identifier} alignments")
+ self.align()
+ with self.session() as session:
+ session.query(WordInterval).delete()
+ session.query(PhoneInterval).delete()
+ session.commit()
+ self.collect_alignments()
+ self.analyze_alignments()
+ if self.current_subset != 0:
+ self.quality_check_subset()
self.set_current_workflow(trainer.identifier)
if trainer.identifier.startswith("pronunciation_probabilities"):
@@ -721,7 +730,6 @@ def align_options(self) -> MetaDict:
options = self.current_aligner.align_options
else:
options = super().align_options
- options["boost_silence"] = max(1.25, options["boost_silence"])
return options
def align(self) -> None:
diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py
index 8c146554..fce20bfb 100644
--- a/montreal_forced_aligner/acoustic_modeling/triphone.py
+++ b/montreal_forced_aligner/acoustic_modeling/triphone.py
@@ -94,10 +94,12 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
train_logger.debug(f"Previous model path: {self.align_model_path}")
train_logger.debug(f"Model path: {self.model_path}")
train_logger.debug(f"Tree path: {self.tree_path}")
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
dict_id = d.id
train_logger.debug(f"Converting alignments for {d.name}")
ali_path = self.ali_paths[dict_id]
+ if not ali_path.exists():
+ continue
new_ali_path = self.new_ali_paths[dict_id]
train_logger.debug(f"Old alignments: {ali_path}")
train_logger.debug(f"New alignments: {new_ali_path}")
@@ -159,12 +161,14 @@ def _run(self):
.filter(Phone.phone_type.in_([PhoneType.silence, PhoneType.oov]))
.order_by(Phone.mapping_id)
]
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})")
train_logger.debug(f"Accumulating stats for model: {self.model_path}")
dict_id = d.id
feature_archive = job.construct_feature_archive(self.working_directory, dict_id)
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
train_logger.debug("Feature Archive information:")
train_logger.debug(f"File: {feature_archive.file_name}")
train_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}")
@@ -397,8 +401,29 @@ def _setup_tree(self, init_from_previous=False, initial_mix_up=True) -> None:
train_logger.debug(f"Phone sets: {phone_sets}")
questions = automatically_obtain_questions(tree_stats, phone_sets, [1], 1)
train_logger.debug(f"Automatically obtained {len(questions)} questions")
- for v in self.worker.extra_questions_mapping.values():
- questions.append(sorted([self.phone_mapping[x] for x in v]))
+ train_logger.debug("Automatic questions:")
+ for q_set in questions:
+ train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set]))
+
+ # Remove questions containing silence and other phones
+ train_logger.debug("Filtering the following sets for containing silence phone:")
+ silence_phone_id = self.phone_mapping[self.optional_silence_phone]
+ silence_sets = [
+ x for x in questions if silence_phone_id in x and x != [silence_phone_id]
+ ]
+ for q_set in silence_sets:
+ train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set]))
+ questions = [
+ x for x in questions if silence_phone_id not in x or x == [silence_phone_id]
+ ]
+
+ extra_questions = self.worker.extra_questions_mapping
+ if extra_questions:
+ train_logger.debug(f"Adding {len(extra_questions)} questions")
+ train_logger.debug("Extra questions:")
+ for v in self.worker.extra_questions_mapping.values():
+ questions.append(sorted([self.phone_mapping[x] for x in v]))
+ train_logger.debug(", ".join(v))
train_logger.debug(f"{len(questions)} total questions")
build_tree(
diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py
index a17baf5f..f060bc3f 100644
--- a/montreal_forced_aligner/alignment/base.py
+++ b/montreal_forced_aligner/alignment/base.py
@@ -6,6 +6,7 @@
import functools
import io
import logging
+import math
import multiprocessing as mp
import os
import re
@@ -77,9 +78,7 @@
mfa_open,
)
from montreal_forced_aligner.textgrid import (
- construct_output_path,
- construct_output_tiers,
- export_textgrid,
+ construct_textgrid_output,
output_textgrid_writing_errors,
)
from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function
@@ -424,18 +423,26 @@ def compute_pronunciation_probabilities(self):
"""
begin = time.time()
- dictionary_counters = {
- dict_id: PronunciationProbabilityCounter()
- for dict_id in self.dictionary_lookup.values()
- }
+ with self.session() as session:
+ dictionary_counters = {
+ dict_id: PronunciationProbabilityCounter()
+ for dict_id, in session.query(Dictionary.id).filter(Dictionary.name != "default")
+ }
logger.info("Generating pronunciations...")
arguments = self.generate_pronunciations_arguments()
for result in run_kaldi_function(
GeneratePronunciationsFunction, arguments, total_count=self.num_current_utterances
):
- dict_id, utterance_counter = result
- dictionary_counters[dict_id].add_counts(utterance_counter)
-
+ try:
+ dict_id, utterance_counter = result
+ dictionary_counters[dict_id].add_counts(utterance_counter)
+ except Exception:
+ import sys
+ import traceback
+
+ exc_type, exc_value, exc_traceback = sys.exc_info()
+ print("\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)))
+ raise
initial_key = ("", "")
final_key = ("", "")
lambda_2 = 2
@@ -450,16 +457,27 @@ def compute_pronunciation_probabilities(self):
) as log_file, self.session() as session:
session.query(Pronunciation).update({"count": 0})
session.commit()
- dictionaries = session.query(Dictionary.id)
+ dictionaries = session.query(Dictionary)
dictionary_mappings = []
- for (d_id,) in dictionaries:
+ for d in dictionaries:
+ d_id = d.id
+ if d_id not in dictionary_counters:
+ continue
counter = dictionary_counters[d_id]
- log_file.write(f"For {d_id}:\n")
+ log_file.write(f"For {d.name}:\n")
words = (
session.query(Word.word)
.filter(Word.dictionary_id == d_id)
- .filter(Word.word_type != WordType.silence)
- .filter(Word.count > 0)
+ .filter(
+ sqlalchemy.or_(
+ sqlalchemy.and_(
+ Word.word_type.in_(WordType.speech_types()), Word.count > 0
+ ),
+ Word.word.in_(
+ [d.cutoff_word, d.oov_word, d.laughter_word, d.bracketed_word]
+ ),
+ )
+ )
)
pronunciations = (
session.query(
@@ -470,8 +488,16 @@ def compute_pronunciation_probabilities(self):
)
.join(Pronunciation.word)
.filter(Word.dictionary_id == d_id)
- .filter(Word.word_type != WordType.silence)
- .filter(Word.count > 0)
+ .filter(
+ sqlalchemy.or_(
+ sqlalchemy.and_(
+ Word.word_type.in_(WordType.speech_types()), Word.count > 0
+ ),
+ Word.word.in_(
+ [d.cutoff_word, d.oov_word, d.laughter_word, d.bracketed_word]
+ ),
+ )
+ )
)
pron_mapping = {}
pronunciations = [
@@ -479,6 +505,7 @@ def compute_pronunciation_probabilities(self):
for w, p, p_id, generated in pronunciations
if not generated or p in counter.word_pronunciation_counts[w]
]
+ pronunciations.append((d.cutoff_word, "cutoff_model", None))
for w, p, p_id in pronunciations:
pron_mapping[(w, p)] = {"id": p_id}
if w in {initial_key[0], final_key[0], self.silence_word}:
@@ -546,8 +573,28 @@ def compute_pronunciation_probabilities(self):
(non_silence_count + lambda_3)
/ (bar_count_non_silence_wp[(w, p)] + lambda_3)
)
- session.bulk_update_mappings(Pronunciation, pron_mapping.values())
+ cutoff_model = pron_mapping.pop((d.cutoff_word, "cutoff_model"))
+ bulk_update(session, Pronunciation, list(pron_mapping.values()))
session.flush()
+ cutoff_not_model = pron_mapping[(d.cutoff_word, "spn")]
+ cutoff_query = (
+ session.query(Pronunciation.id, Pronunciation.pronunciation)
+ .join(Pronunciation.word)
+ .filter(Word.word != d.cutoff_word)
+ .filter(Word.word.like(f"{d.cutoff_word[:-1]}%"))
+ )
+ cutoff_mappings = []
+ for pron_id, pron in cutoff_query:
+ if pron == d.cutoff_word:
+ data = dict(cutoff_not_model)
+ else:
+ data = dict(cutoff_model)
+ data["id"] = pron_id
+ cutoff_mappings.append(data)
+ if cutoff_mappings:
+ bulk_update(session, Pronunciation, cutoff_mappings)
+ session.flush()
+
initial_silence_count = counter.silence_before_counts[initial_key] + (
silence_probability * lambda_2
)
@@ -915,6 +962,11 @@ def collect_alignments(self) -> None:
cursor.copy_from(phone_buf, PhoneInterval.__tablename__, sep=",", null="")
phone_buf.truncate(0)
phone_buf.seek(0)
+ conn.commit()
+ cursor.close()
+ conn.close()
+ conn = self.db_engine.raw_connection()
+ cursor = conn.cursor()
if config.USE_POSTGRES:
if word_buf.tell() != 0:
@@ -1129,21 +1181,31 @@ def export_textgrids(
self.collect_alignments()
begin = time.time()
error_dict = {}
-
- with self.session() as session:
- files = (
- session.query(
- File.id,
- File.name,
- File.relative_path,
- SoundFile.duration,
- TextFile.text_file_path,
- )
- .join(File.sound_file)
- .join(File.text_file)
- ).all()
with tqdm(total=self.num_files, disable=config.QUIET) as pbar:
if config.USE_MP and config.NUM_JOBS > 1:
+ with self.session() as session:
+ files_per_job = math.ceil(self.num_files / len(self.jobs))
+ file_batches = [{}]
+ query = (
+ session.query(
+ File.id,
+ File.name,
+ File.relative_path,
+ SoundFile.duration,
+ TextFile.text_file_path,
+ )
+ .join(File.sound_file)
+ .join(File.text_file)
+ )
+ for file_id, file_name, relative_path, file_duration, text_file_path in query:
+ if len(file_batches[-1]) >= files_per_job:
+ file_batches.append({})
+ file_batches[-1][file_id] = (
+ file_name,
+ relative_path,
+ file_duration,
+ text_file_path,
+ )
stopped = mp.Event()
finished_adding = mp.Event()
@@ -1167,8 +1229,8 @@ def export_textgrids(
export_proc.start()
export_procs.append(export_proc)
try:
- for args in files:
- for_write_queue.put(args)
+ for batch in file_batches:
+ for_write_queue.put(batch)
time.sleep(1)
finished_adding.set()
while True:
@@ -1197,30 +1259,37 @@ def export_textgrids(
else:
logger.debug("Not using multiprocessing for TextGrid export")
- for file_id, name, relative_path, duration, text_file_path in files:
- output_path = construct_output_path(
- name,
- relative_path,
- self.export_output_directory,
- text_file_path,
- output_format,
- )
-
- data = construct_output_tiers(
- session,
- file_id,
- workflow,
- config.CLEANUP_TEXTGRIDS,
- self.clitic_marker,
- include_original_text,
- )
- export_textgrid(
- data,
- output_path,
- duration,
- self.export_frame_shift,
- output_format=output_format,
+ with self.session() as session:
+ file_batch = {}
+ query = (
+ session.query(
+ File.id,
+ File.name,
+ File.relative_path,
+ SoundFile.duration,
+ TextFile.text_file_path,
+ )
+ .join(File.sound_file)
+ .join(File.text_file)
)
+ for file_id, file_name, relative_path, file_duration, text_file_path in query:
+ file_batch[file_id] = (
+ file_name,
+ relative_path,
+ file_duration,
+ text_file_path,
+ )
+ for _ in construct_textgrid_output(
+ session,
+ file_batch,
+ workflow,
+ config.CLEANUP_TEXTGRIDS,
+ self.clitic_marker,
+ self.export_output_directory,
+ self.export_frame_shift,
+ output_format,
+ include_original_text,
+ ):
pbar.update(1)
if error_dict:
@@ -1321,7 +1390,7 @@ def evaluate_alignments(
Directory to save results, if not specified, it will be saved in the log directory
comparison_source: :class:`~montreal_forced_aligner.data.WorkflowType`
Workflow to compare to the reference intervals, defaults to :attr:`~montreal_forced_aligner.data.WorkflowType.alignment`
- comparison_source: :class:`~montreal_forced_aligner.data.WorkflowType`
+ reference_source: :class:`~montreal_forced_aligner.data.WorkflowType`
Workflow to use as the reference intervals, defaults to :attr:`~montreal_forced_aligner.data.WorkflowType.reference`
"""
diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py
index 22f4a6be..2b963385 100644
--- a/montreal_forced_aligner/alignment/multiprocessing.py
+++ b/montreal_forced_aligner/alignment/multiprocessing.py
@@ -63,11 +63,7 @@
)
from montreal_forced_aligner.exceptions import AlignmentCollectionError, AlignmentExportError
from montreal_forced_aligner.helper import mfa_open, split_phone_position
-from montreal_forced_aligner.textgrid import (
- construct_output_path,
- construct_output_tiers,
- export_textgrid,
-)
+from montreal_forced_aligner.textgrid import construct_textgrid_output
from montreal_forced_aligner.utils import thread_logger
if TYPE_CHECKING:
@@ -418,7 +414,7 @@ def _run(self):
text_column = Utterance.normalized_character_text
else:
text_column = Utterance.normalized_text
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
begin = time.time()
if self.lexicon_compilers and d.id in self.lexicon_compilers:
lexicon = self.lexicon_compilers[d.id]
@@ -477,7 +473,7 @@ def __init__(self, args: AccStatsArguments):
self.working_directory = args.working_directory
self.model_path = args.model_path
- def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
+ def _run(self) -> None:
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.train", self.log_path, job_name=self.job_name
@@ -488,7 +484,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
.filter(Job.id == self.job_name)
.first()
)
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})")
train_logger.debug(f"Accumulating stats for model: {self.model_path}")
dict_id = d.id
@@ -496,6 +492,8 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
feature_archive = job.construct_feature_archive(self.working_directory, dict_id)
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
alignment_archive = AlignmentArchive(ali_path)
train_logger.debug("Feature Archive information:")
train_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}")
@@ -565,7 +563,7 @@ def _run(self) -> None:
**align_options,
)
aligner.boost_silence(boost_silence, silence_phones)
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
align_logger.debug(f"Aligning for dictionary {d.name} ({d.id})")
align_logger.debug(f"Aligning with model: {aligner.acoustic_model_path}")
dict_id = d.id
@@ -1114,21 +1112,10 @@ def _run(self) -> None:
silence_words = session.query(Word.word).filter(Word.word_type == WordType.silence)
self.silence_words.update(x for x, in silence_words)
- for d in job.dictionaries:
+ for d in job.training_dictionaries:
ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id)
if not os.path.exists(ali_path):
continue
-
- utts = (
- session.query(Utterance.id, Utterance.normalized_text)
- .join(Utterance.speaker)
- .filter(Utterance.job_id == self.job_name)
- .filter(Speaker.dictionary_id == d.id)
- )
- utterance_texts = {}
- for u_id, text in utts:
- utterance_texts[u_id] = text
-
if self.lexicon_compilers and d.id in self.lexicon_compilers:
lexicon_compiler = self.lexicon_compilers[d.id]
else:
@@ -1142,11 +1129,15 @@ def _run(self) -> None:
)
utterance = int(alignment.utterance_id.split("-")[-1])
ctm = lexicon_compiler.phones_to_pronunciations(alignment.words, intervals)
- word_pronunciations = [(x.label, x.pronunciation) for x in ctm.word_intervals]
- # word_pronunciations = [
- # x if x[1] != OOV_PHONE else (OOV_WORD, OOV_PHONE)
- # for x in word_pronunciations
- # ]
+ word_pronunciations = []
+ for wi in ctm.word_intervals:
+ label = wi.label
+ pronunciation = wi.pronunciation
+ if label.startswith(d.cutoff_word[:-1]):
+ label = d.cutoff_word
+ if pronunciation != d.oov_phone:
+ pronunciation = "cutoff_model"
+ word_pronunciations.append((label, pronunciation))
if self.for_g2p:
phones = []
for i, x in enumerate(word_pronunciations):
@@ -1256,6 +1247,8 @@ def _run(self) -> None:
lexicon_compiler = d.lexicon_compiler
if self.transcription:
lat_path = job.construct_path(workflow.working_directory, "lat", "ark", d.id)
+ if not lat_path.exists():
+ continue
transcription_archive = TranscriptionArchive(
lat_path, acoustic_scale=self.score_options["acoustic_scale"]
@@ -1304,6 +1297,8 @@ def _run(self) -> None:
self.callback((utterance, d.id, ctm))
else:
ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id)
+ if not ali_path.exists():
+ continue
words_path = job.construct_path(
workflow.working_directory, "words", "ark", d.id
)
@@ -1505,13 +1500,7 @@ def run(self) -> None:
)
while True:
try:
- (
- file_id,
- name,
- relative_path,
- duration,
- text_file_path,
- ) = self.for_write_queue.get(timeout=1)
+ (file_batch) = self.for_write_queue.get(timeout=1)
except Empty:
if self.finished_adding.is_set():
self.finished_processing.set()
@@ -1521,25 +1510,18 @@ def run(self) -> None:
if self.stopped.is_set():
continue
try:
- output_path = construct_output_path(
- name,
- relative_path,
- self.output_directory,
- text_file_path,
- self.output_format,
- )
- data = construct_output_tiers(
+ for output_path in construct_textgrid_output(
session,
- file_id,
+ file_batch,
workflow,
self.cleanup_textgrids,
self.clitic_marker,
+ self.output_directory,
+ self.export_frame_shift,
+ self.output_format,
self.include_original_text,
- )
- export_textgrid(
- data, output_path, duration, self.export_frame_shift, self.output_format
- )
- self.return_queue.put(1)
+ ):
+ self.return_queue.put(1)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
self.return_queue.put(
diff --git a/montreal_forced_aligner/command_line/train_acoustic_model.py b/montreal_forced_aligner/command_line/train_acoustic_model.py
index e040a2c8..f16930ca 100644
--- a/montreal_forced_aligner/command_line/train_acoustic_model.py
+++ b/montreal_forced_aligner/command_line/train_acoustic_model.py
@@ -81,6 +81,13 @@
"https://github.com/MontrealCorpusTools/mfa-models/tree/main/config/acoustic/rules for examples.",
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
)
+@click.option(
+ "--topology_path",
+ "topology_path",
+ help="Path to yaml file defining topologies. See "
+ "https://github.com/MontrealCorpusTools/mfa-models/tree/main/config/acoustic/topologies for examples.",
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
+)
@click.option(
"--output_format",
help="Format for aligned output files (default is long_textgrid).",
diff --git a/montreal_forced_aligner/command_line/utils.py b/montreal_forced_aligner/command_line/utils.py
index b02024f9..01d684ad 100644
--- a/montreal_forced_aligner/command_line/utils.py
+++ b/montreal_forced_aligner/command_line/utils.py
@@ -21,7 +21,6 @@
FileArgumentNotFoundError,
ModelExtensionError,
ModelTypeNotSupportedError,
- NoDefaultSpeakerDictionaryError,
PretrainedModelNotFoundError,
)
from montreal_forced_aligner.helper import mfa_open
@@ -145,7 +144,7 @@ def common_options(f: typing.Callable) -> typing.Callable:
return functools.reduce(lambda x, opt: opt(x), options, f)
-def validate_model_arg(name: str, model_type: str) -> Path:
+def validate_model_arg(name: str, model_type: str) -> typing.Union[Path, str]:
"""
Validate pretrained model name argument
@@ -181,6 +180,8 @@ def validate_model_arg(name: str, model_type: str) -> Path:
model_class = MODEL_TYPES[model_type]
if name in available_models:
name = model_class.get_pretrained_path(name)
+ elif model_type == "dictionary" and str(name).lower() in {"default", "nonnative"}:
+ return name
else:
if isinstance(name, str):
name = Path(name)
@@ -193,8 +194,6 @@ def validate_model_arg(name: str, model_type: str) -> Path:
paths = sorted(set(data.values()))
for path in paths:
validate_model_arg(path, "dictionary")
- if "default" not in data:
- raise click.BadParameter(str(NoDefaultSpeakerDictionaryError()))
else:
if name.exists():
if name.suffix:
@@ -410,7 +409,7 @@ def check_server() -> None:
logger.debug(f"pg_ctl stdout: {stdout}")
logger.debug(f"pg_ctl stderr: {stderr}")
if "no server running" in stdout:
- raise DatabaseError()
+ raise DatabaseError(f"stdout: {stdout}\nstderr: {stderr}")
def start_server() -> None:
diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py
index 43eb3e8c..b65fdc4e 100644
--- a/montreal_forced_aligner/command_line/validate.py
+++ b/montreal_forced_aligner/command_line/validate.py
@@ -139,6 +139,7 @@ def validate_corpus_cli(context, **kwargs) -> None:
validator.dirty = True
raise
finally:
+ config.FINAL_CLEAN = True
validator.cleanup()
@@ -196,4 +197,5 @@ def validate_dictionary_cli(*args, **kwargs) -> None:
validator.dirty = True
raise
finally:
+ config.FINAL_CLEAN = True
validator.cleanup()
diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py
index 30a9e573..0044d449 100644
--- a/montreal_forced_aligner/corpus/acoustic_corpus.py
+++ b/montreal_forced_aligner/corpus/acoustic_corpus.py
@@ -1111,6 +1111,7 @@ def load_corpus(self) -> None:
self.write_lexicon_information()
logger.debug(f"Wrote lexicon information in {time.time() - begin:.3f} seconds")
else:
+ self.load_phone_topologies()
self.load_phone_groups()
self.load_lexicon_compilers()
diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py
index 48198d76..33ae6aa9 100644
--- a/montreal_forced_aligner/corpus/base.py
+++ b/montreal_forced_aligner/corpus/base.py
@@ -4,6 +4,7 @@
import collections
import logging
import os
+import re
import threading
import time
import typing
@@ -1182,38 +1183,68 @@ def create_subset(self, subset: int) -> None:
Number of utterances to include in subset
"""
logger.info(f"Creating subset directory with {subset} utterances...")
- multiword_pattern = r"\s\S+\s"
+ if hasattr(self, "cutoff_word") and hasattr(self, "brackets"):
+ initial_brackets = re.escape("".join(x[0] for x in self.brackets))
+ final_brackets = re.escape("".join(x[1] for x in self.brackets))
+ cutoff_identifier = re.sub(
+ rf"[{initial_brackets}{final_brackets}]", "", self.cutoff_word
+ )
+ cutoff_pattern = f"[{initial_brackets}]({cutoff_identifier}|hes)"
+ else:
+ cutoff_pattern = "<(cutoff|hes)"
+
+ def add_filters(query):
+ multiword_pattern = r"\s\S+\s"
+ filtered = (
+ query.filter(
+ Utterance.normalized_text.op("~")(multiword_pattern)
+ if config.USE_POSTGRES
+ else Utterance.normalized_text.regexp_match(multiword_pattern)
+ )
+ .filter(Utterance.ignored == False) # noqa
+ .filter(
+ sqlalchemy.or_(
+ Utterance.duration_deviation == None, # noqa
+ Utterance.duration_deviation < 10,
+ )
+ )
+ )
+ if subset <= 25000:
+ filtered = filtered.filter(
+ sqlalchemy.not_(
+ Utterance.normalized_text.op("~")(cutoff_pattern)
+ if config.USE_POSTGRES
+ else Utterance.normalized_text.regexp_match(cutoff_pattern)
+ )
+ )
+
+ return filtered
+
with self.session() as session:
begin = time.time()
session.query(Utterance).filter(Utterance.in_subset == True).update( # noqa
{Utterance.in_subset: False}
)
session.commit()
- dictionary_lookup = {k: v for k, v in session.query(Dictionary.name, Dictionary.id)}
+ dictionary_query = session.query(Dictionary.name, Dictionary.id).filter(
+ Dictionary.name != "default"
+ )
+ if subset <= 25000:
+ dictionary_query = dictionary_query.filter(Dictionary.name != "nonnative")
+ dictionary_lookup = {k: v for k, v in dictionary_query}
num_dictionaries = len(dictionary_lookup)
if num_dictionaries > 1:
subsets_per_dictionary = {}
utts_per_dictionary = {}
subsetted = 0
for dict_name, dict_id in dictionary_lookup.items():
- num_utts = (
+ base_query = (
session.query(Utterance)
.join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
- )
- ) # noqa
- .count()
+ .filter(Speaker.dictionary_id == dict_id) # noqa
)
+ base_query = add_filters(base_query)
+ num_utts = base_query.count()
utts_per_dictionary[dict_name] = num_utts
if num_utts < int(subset / num_dictionaries):
subsets_per_dictionary[dict_name] = num_utts
@@ -1240,42 +1271,22 @@ def create_subset(self, subset: int) -> None:
larger_subset_num = int(subset_per_dictionary * 10)
speaker_ids = None
average_duration = (
- session.query(sqlalchemy.func.avg(Utterance.duration))
- .join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
+ add_filters(
+ session.query(sqlalchemy.func.avg(Utterance.duration))
+ .join(Utterance.speaker)
+ .filter(Speaker.dictionary_id == dict_id)
)
- .filter(Utterance.ignored == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
- )
- ) # noqa
).first()[0]
for utt_count_cutoff in [30, 15, 5]:
sq = (
- session.query(
- Speaker.id.label("speaker_id"),
- sqlalchemy.func.count(Utterance.id).label("utt_count"),
- )
- .join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
+ add_filters(
+ session.query(
+ Speaker.id.label("speaker_id"),
+ sqlalchemy.func.count(Utterance.id).label("utt_count"),
)
- ) # noqa
+ .join(Utterance.speaker)
+ .filter(Speaker.dictionary_id == dict_id)
+ )
.filter(Utterance.duration <= average_duration)
.group_by(Speaker.id.label("speaker_id"))
.subquery()
@@ -1286,32 +1297,21 @@ def create_subset(self, subset: int) -> None:
)
).first()[0]
if total_speaker_utterances >= subset_per_dictionary:
- speaker_ids = (
+ speaker_ids = [
x
for x, in session.query(sq.c.speaker_id).filter(
sq.c.utt_count >= utt_count_cutoff
)
- )
+ ]
break
if num_utts > larger_subset_num:
larger_subset_query = (
session.query(Utterance.id)
.join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
- )
- ) # noqa
+ .filter(Speaker.dictionary_id == dict_id) # noqa
)
- if speaker_ids is not None:
+ larger_subset_query = add_filters(larger_subset_query)
+ if speaker_ids:
larger_subset_query = larger_subset_query.filter(
Speaker.id.in_(speaker_ids)
)
@@ -1333,7 +1333,8 @@ def create_subset(self, subset: int) -> None:
)
session.execute(query)
- # Remove speakers with less than 5 utterances from subset, can't estimate speaker transforms well for low utterance counts
+ # Remove speakers with less than 5 utterances from subset,
+ # can't estimate speaker transforms well for low utterance counts
sq = (
session.query(
Utterance.speaker_id.label("speaker_id"),
@@ -1343,9 +1344,9 @@ def create_subset(self, subset: int) -> None:
.group_by(Utterance.speaker_id.label("speaker_id"))
.subquery()
)
- speaker_ids = (
+ speaker_ids = [
x for x, in session.query(sq.c.speaker_id).filter(sq.c.utt_count < 5)
- )
+ ]
session.query(Utterance).filter(
Utterance.speaker_id.in_(speaker_ids)
).update({Utterance.in_subset: False})
@@ -1355,21 +1356,10 @@ def create_subset(self, subset: int) -> None:
larger_subset_query = (
session.query(Utterance.id)
.join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
- )
- ) # noqa
+ .filter(Speaker.dictionary_id == dict_id) # noqa
)
- if speaker_ids is not None:
+ larger_subset_query = add_filters(larger_subset_query)
+ if speaker_ids:
larger_subset_query = larger_subset_query.filter(
Speaker.id.in_(speaker_ids)
)
@@ -1431,27 +1421,15 @@ def create_subset(self, subset: int) -> None:
).first()[0]
remaining = subset_per_dictionary - total_speaker_utterances
if remaining > 0:
- speaker_ids = (x for x, in session.query(sq.c.speaker_id))
+ speaker_ids = [x for x, in session.query(sq.c.speaker_id)]
larger_subset_query = (
session.query(Utterance.id)
.join(Utterance.speaker)
- .filter(Speaker.dictionary_id == dict_id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
- .filter(Utterance.in_subset == False) # noqa
- .filter(
- sqlalchemy.or_(
- Utterance.duration_deviation == None, # noqa
- Utterance.duration_deviation < 10,
- )
- ) # noqa
+ .filter(Speaker.dictionary_id == dict_id) # noqa
)
- if speaker_ids is not None:
+ larger_subset_query = add_filters(larger_subset_query)
+ if speaker_ids:
larger_subset_query = larger_subset_query.filter(
Speaker.id.in_(speaker_ids)
)
@@ -1478,13 +1456,7 @@ def create_subset(self, subset: int) -> None:
if subset < self.num_utterances:
# Get all shorter utterances that are not one word long
larger_subset_query = (
- session.query(Utterance.id)
- .filter(
- Utterance.normalized_text.op("~")(multiword_pattern)
- if config.USE_POSTGRES
- else Utterance.normalized_text.regexp_match(multiword_pattern)
- )
- .filter(Utterance.ignored == False) # noqa
+ add_filters(session.query(Utterance.id))
.order_by(Utterance.duration)
.limit(larger_subset_num)
)
diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py
index 8bcd6c5c..335d809a 100644
--- a/montreal_forced_aligner/corpus/features.py
+++ b/montreal_forced_aligner/corpus/features.py
@@ -511,6 +511,8 @@ def _run(self) -> None:
**self.fmllr_options,
)
ali_path = job.construct_path(self.working_directory, "ali", "ark", dict_id)
+ if not ali_path.exists():
+ continue
fmllr_logger.debug(f"Alignment path: {ali_path}")
alignment_archive = AlignmentArchive(ali_path)
temp_trans_path = job.construct_path(
@@ -594,6 +596,7 @@ def __init__(
self,
feature_type: str = "mfcc",
use_energy: bool = True,
+ raw_energy: bool = False,
frame_shift: int = 10,
frame_length: int = 25,
snip_edges: bool = False,
@@ -644,6 +647,7 @@ def __init__(
# MFCC options
self.use_energy = use_energy
+ self.raw_energy = raw_energy
self.low_frequency = low_frequency
self.high_frequency = high_frequency
self.sample_frequency = sample_frequency
@@ -783,6 +787,7 @@ def mfcc_options(self) -> MetaDict:
else:
options = {
"use_energy": self.use_energy,
+ "raw_energy": self.raw_energy,
"dither": self.dither,
"energy_floor": self.energy_floor,
"num_coefficients": self.num_coefficients,
diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py
index 9e1c2675..374d03bf 100644
--- a/montreal_forced_aligner/db.py
+++ b/montreal_forced_aligner/db.py
@@ -112,7 +112,7 @@ def full_load_utterance(session: sqlalchemy.orm.Session, utterance_id: int):
def bulk_update(
session: sqlalchemy.orm.Session,
table: MfaSqlBase,
- values: typing.List[typing.Dict[str, typing.Any]],
+ values: typing.Collection[typing.Dict[str, typing.Any]],
id_field=None,
) -> None:
"""
@@ -353,7 +353,7 @@ class Dictionary(MfaSqlBase):
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), nullable=False)
- path = Column(PathType, unique=True)
+ path = Column(PathType)
rules_applied = Column(Boolean, default=False)
phone_set_type = Column(Enum(PhoneSetType), nullable=True)
root_temp_directory = Column(PathType, nullable=True)
@@ -1114,7 +1114,7 @@ def save(
save_transcription: bool
Flag for whether the hypothesized transcription text should be saved instead of the default text
"""
- from montreal_forced_aligner.alignment.multiprocessing import construct_output_path
+ from montreal_forced_aligner.textgrid import construct_output_path
utterance_count = len(self.utterances)
if output_format is None: # Saving directly
@@ -2024,6 +2024,14 @@ def __str__(self):
def has_dictionaries(self) -> bool:
return len(self.dictionaries) > 0
+ @property
+ def training_dictionaries(self) -> typing.List[int]:
+ if self.corpus.current_subset == 0:
+ return self.dictionaries
+ if self.corpus.current_subset <= 25000:
+ return [x for x in self.dictionaries if x.name not in {"default", "nonnative"}]
+ return [x for x in self.dictionaries if x.name not in {"default"}]
+
@property
def dictionary_ids(self) -> typing.List[int]:
return [x.id for x in self.dictionaries]
diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py
index 65f61ded..3d13bc1a 100644
--- a/montreal_forced_aligner/dictionary/mixins.py
+++ b/montreal_forced_aligner/dictionary/mixins.py
@@ -6,7 +6,7 @@
import os
import re
import typing
-from collections import Counter
+from collections import Counter, defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
@@ -27,7 +27,7 @@
DEFAULT_QUOTE_MARKERS = list("“„\"”〝〟″「」『』‚ʻʿ‘′'")
DEFAULT_CLITIC_MARKERS = list("'’‘")
-DEFAULT_COMPOUND_MARKERS = list("-/")
+DEFAULT_COMPOUND_MARKERS = list("-‑/")
DEFAULT_BRACKETS = [("<", ">"), ("[", "]"), ("{", "}"), ("(", ")"), ("<", ">")]
__all__ = ["DictionaryMixin", "TemporaryDictionaryMixin"]
@@ -190,6 +190,7 @@ def __init__(
self.bracket_sanitize_regex = None
self.use_cutoff_model = use_cutoff_model
self._phone_groups = {}
+ self._topologies = {}
@property
def tokenizer(self):
@@ -491,6 +492,12 @@ def kaldi_non_silence_phones(self) -> List[str]:
@property
def phone_groups(self) -> typing.Dict[str, typing.List[str]]:
+ if (
+ not self._phone_groups
+ and getattr(self, "phone_group_path", None)
+ and hasattr(self, "load_phone_groups")
+ ):
+ self.load_phone_groups()
if not self._phone_groups:
for p in sorted(self.non_silence_phones):
base_phone = self.get_base_phone(p)
@@ -665,8 +672,14 @@ def _write_topo(self) -> None:
"""
Write the topo file to the temporary directory
"""
-
+ if (
+ not self._topologies
+ and getattr(self, "topology_path", None)
+ and hasattr(self, "load_phone_topologies")
+ ):
+ self.load_phone_topologies()
sil_transp = 1 / (self.num_silence_states - 1)
+ topo_groups = defaultdict(set)
silence_lines = [
"",
@@ -694,9 +707,16 @@ def _write_topo(self) -> None:
silence_topo_string = "\n".join(silence_lines)
topo_sections = [silence_topo_string]
- topo_phones = self._get_grouped_phones()
- for phone_list in topo_phones.values():
+ for k, v in self._topologies.items():
+ min_states = v.get("min_states", 1)
+ max_states = v.get("max_states", self.num_non_silence_states)
+ topo_groups[(min_states, max_states)].add(k)
+ for phone in self.non_silence_phones:
+ if phone not in self._topologies:
+ topo_groups[(1, self.num_non_silence_states)].add(phone)
+
+ for (min_states, max_states), phone_list in topo_groups.items():
if not phone_list:
continue
non_silence_lines = [
@@ -707,16 +727,18 @@ def _write_topo(self) -> None:
),
"",
]
- # num_states = state_mapping[phone_type]
- num_states = self.num_non_silence_states
+ num_states = max_states
for i in range(num_states):
if i == 0: # Initial non_silence state
- transition_probability = 1 / self.num_non_silence_states
- transition_string = " ".join(
- f" {x} {transition_probability}"
- for x in range(1, self.num_non_silence_states + 1)
- )
+ if min_states == max_states:
+ transition_string = f" {i} 0.5 {i+1} 0.5"
+ else:
+ transition_probability = 1 / max_states
+ transition_string = " ".join(
+ f" {x} {transition_probability}"
+ for x in range(min_states, max_states + 1)
+ )
non_silence_lines.append(
f" {i} {i} {transition_string} "
)
diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py
index 7dd9addb..23b3b14b 100644
--- a/montreal_forced_aligner/dictionary/multispeaker.py
+++ b/montreal_forced_aligner/dictionary/multispeaker.py
@@ -38,7 +38,11 @@
bulk_update,
)
from montreal_forced_aligner.dictionary.mixins import TemporaryDictionaryMixin
-from montreal_forced_aligner.exceptions import DictionaryError, DictionaryFileError
+from montreal_forced_aligner.exceptions import (
+ DictionaryError,
+ DictionaryFileError,
+ PhoneGroupTopologyMismatchError,
+)
from montreal_forced_aligner.helper import comma_join, mfa_open, split_phone_position
from montreal_forced_aligner.models import DictionaryModel
from montreal_forced_aligner.utils import parse_dictionary_file
@@ -86,6 +90,7 @@ def __init__(
dictionary_path: typing.Union[str, Path] = None,
rules_path: typing.Union[str, Path] = None,
phone_groups_path: typing.Union[str, Path] = None,
+ topology_path: typing.Union[str, Path] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -109,8 +114,11 @@ def __init__(
rules_path = Path(rules_path)
if isinstance(phone_groups_path, str):
phone_groups_path = Path(phone_groups_path)
+ if isinstance(topology_path, str):
+ topology_path = Path(topology_path)
self.rules_path = rules_path
self.phone_groups_path = phone_groups_path
+ self.topology_path = topology_path
self.lexicon_compilers: typing.Dict[int, typing.Union[LexiconCompiler, G2PCompiler]] = {}
self._tokenizers = {}
@@ -165,6 +173,25 @@ def load_phone_groups(self) -> None:
self._phone_groups[k] = sorted(
set(x for x in v if x in self.non_silence_phones)
)
+ errors = []
+ for phone_group in self._phone_groups.values():
+ topos = set()
+ for phone in phone_group:
+ min_states = 1
+ max_states = self.num_non_silence_states
+ if phone in self._topologies:
+ min_states = self._topologies[phone].get("min_states", 1)
+ max_states = self._topologies[phone].get(
+ "max_states", self.num_non_silence_states
+ )
+ topos.add((min_states, max_states))
+ if len(topos) > 1:
+ errors.append((phone_group, topos))
+ if errors:
+ raise PhoneGroupTopologyMismatchError(
+ errors, self.phone_groups_path, self.topology_path
+ )
+
found_phones = set()
for phones in self._phone_groups.values():
found_phones.update(phones)
@@ -177,6 +204,33 @@ def load_phone_groups(self) -> None:
else:
logger.debug("All phones were included in phone groups")
+ def load_phone_topologies(self) -> None:
+ """
+ Load phone topologies from the dictionary's groups file path
+ """
+ if self.topology_path is not None and self.topology_path.exists():
+ with mfa_open(self.topology_path) as f:
+ self._topologies = {
+ k: v
+ for k, v in yaml.load(f, Loader=yaml.Loader).items()
+ if k in self.non_silence_phones
+ }
+ found_phones = set(self._topologies.keys())
+ missing_phones = self.non_silence_phones - found_phones
+ if missing_phones:
+ logger.debug(
+ f"The following phones will use the default topology (min states = 1, max states = {self.num_non_silence_states}): "
+ f"{comma_join(sorted(missing_phones))}"
+ )
+ else:
+ logger.debug("All phones were included in topology config")
+ logger.debug("The following phones will use custom topologies: ")
+ for k, v in self._topologies.items():
+ min_states = self._topologies[k].get("min_states", 1)
+ max_states = self._topologies[k].get("max_states", self.num_non_silence_states)
+ logger.debug(f"{k}: min states = {min_states}, max states = {max_states}")
+ assert min_states <= max_states
+
@property
def speaker_mapping(self) -> typing.Dict[str, int]:
"""Mapping of speakers to dictionaries"""
@@ -349,6 +403,7 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]:
)
if default:
self._default_dictionary_id = dictionary_id
+
word_primary_key = 1
pronunciation_primary_key = 1
word_objs = []
@@ -357,6 +412,7 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]:
phone_counts = collections.Counter()
graphemes = set(self.clitic_markers + self.compound_markers)
clitic_cleanup_regex = None
+ has_nonnative_speakers = False
if len(self.clitic_markers) >= 1:
other_clitic_markers = self.clitic_markers[1:]
if other_clitic_markers:
@@ -370,10 +426,54 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]:
dictionary_model,
speakers,
) in self.dictionary_model.load_dictionary_paths().values():
+ if dictionary_model == "nonnative":
+ if not has_nonnative_speakers:
+ dialect = "nonnative"
+ if dialect not in dialect_id_cache:
+ dialect_obj = Dialect(name=dialect)
+ session.add(dialect_obj)
+ session.flush()
+ dialect_id_cache[dialect] = dialect_obj.id
+ dictionary = Dictionary(
+ name=dialect,
+ dialect_id=dialect_id_cache[dialect],
+ path="",
+ phone_set_type=self.phone_set_type,
+ root_temp_directory=self.dictionary_output_directory,
+ position_dependent_phones=self.position_dependent_phones,
+ clitic_marker=self.clitic_marker
+ if self.clitic_marker is not None
+ else "",
+ default=False,
+ use_g2p=False,
+ max_disambiguation_symbol=0,
+ silence_word=self.silence_word,
+ oov_word=self.oov_word,
+ bracketed_word=self.bracketed_word,
+ cutoff_word=self.cutoff_word,
+ laughter_word=self.laughter_word,
+ optional_silence_phone=self.optional_silence_phone,
+ oov_phone=self.oov_phone,
+ )
+ session.add(dictionary)
+ session.flush()
+ dictionary_id_cache[dialect] = dictionary.id
+ for speaker in speakers:
+ if speaker not in self._speaker_ids:
+ speaker_objs.append(
+ {
+ "id": self._current_speaker_index,
+ "name": speaker,
+ "dictionary_id": dictionary_id_cache[dialect],
+ }
+ )
+ self._speaker_ids[speaker] = self._current_speaker_index
+ self._current_speaker_index += 1
+ has_nonnative_speakers = True
+ continue
if dictionary_model.path not in dictionary_id_cache and not self.use_g2p:
word_cache = {}
pronunciation_cache = set()
- subsequences = set()
if self.phone_set_type not in auto_set:
if (
self.phone_set_type != dictionary_model.phone_set_type
@@ -543,10 +643,6 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]:
pronunciation_primary_key += 1
pronunciation_cache.add((word, pron_string))
phone_counts.update(pron)
- pron = pron[:-1]
- while pron:
- subsequences.add(tuple(pron))
- pron = pron[:-1]
for w, wt in special_words.items():
if w in specials_found:
@@ -650,13 +746,219 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]:
conn.execute(sqlalchemy.insert(Grapheme.__table__), grapheme_objs)
session.commit()
+ self.create_default_dictionary()
+ if has_nonnative_speakers:
+ self.create_nonnative_dictionary()
if pron_objs:
self.calculate_disambiguation()
self.load_phonological_rules()
self.calculate_phone_mapping()
+ self.load_phone_topologies()
self.load_phone_groups()
return graphemes, phone_counts
+ def create_default_dictionary(self):
+ if self._default_dictionary_id is not None: # Already a default dictionary
+ return
+ with self.session() as session:
+ dictionary = session.query(Dictionary).filter_by(default=True).first()
+ if dictionary is not None: # Already a default dictionary
+ self._default_dictionary_id = dictionary.id
+ return
+ word_objs = []
+ pron_objs = []
+ dialect = session.query(Dialect).filter_by(name="unknown").first()
+ if dialect is None:
+ dialect = Dialect(name="unknown")
+ session.add(dialect)
+ session.flush()
+ dictionary = Dictionary(
+ name="default",
+ dialect_id=dialect.id,
+ path="",
+ phone_set_type=self.phone_set_type,
+ root_temp_directory=self.dictionary_output_directory,
+ position_dependent_phones=self.position_dependent_phones,
+ clitic_marker=self.clitic_marker if self.clitic_marker is not None else "",
+ default=True,
+ use_g2p=False,
+ max_disambiguation_symbol=0,
+ silence_word=self.silence_word,
+ oov_word=self.oov_word,
+ bracketed_word=self.bracketed_word,
+ cutoff_word=self.cutoff_word,
+ laughter_word=self.laughter_word,
+ optional_silence_phone=self.optional_silence_phone,
+ oov_phone=self.oov_phone,
+ )
+ session.add(dictionary)
+ session.commit()
+
+ special_words = {
+ self.silence_word: WordType.silence,
+ self.oov_word: WordType.oov,
+ self.bracketed_word: WordType.bracketed,
+ self.cutoff_word: WordType.cutoff,
+ self.laughter_word: WordType.laughter,
+ }
+ self._default_dictionary_id = dictionary.id
+ self.dictionary_lookup[dictionary.name] = dictionary.id
+ self._words_mappings[dictionary.id] = {}
+ word_primary_key = self.get_next_primary_key(Word)
+ pronunciation_primary_key = self.get_next_primary_key(Pronunciation)
+ current_index = 0
+ word_cache = {}
+ for w, w_type in special_words.items():
+ if w_type is WordType.silence:
+ pron = self.optional_silence_phone
+ else:
+ pron = self.oov_phone
+ word_objs.append(
+ {
+ "id": word_primary_key,
+ "mapping_id": current_index,
+ "word": w,
+ "word_type": w_type,
+ "dictionary_id": dictionary.id,
+ }
+ )
+ self._words_mappings[dictionary.id][w] = current_index
+ current_index += 1
+
+ pron_objs.append(
+ {
+ "id": pronunciation_primary_key,
+ "pronunciation": pron,
+ "probability": 1.0,
+ "disambiguation": None,
+ "silence_after_probability": None,
+ "silence_before_correction": None,
+ "non_silence_before_correction": None,
+ "word_id": word_primary_key,
+ }
+ )
+ word_primary_key += 1
+ pronunciation_primary_key += 1
+
+ query = (
+ session.query(
+ Word.word,
+ Word.word_type,
+ Pronunciation.pronunciation,
+ )
+ .join(Pronunciation.word)
+ .filter(~Word.word.in_(special_words.keys()))
+ .distinct()
+ .order_by(Word.word, Pronunciation.pronunciation)
+ )
+ for word, word_type, pronunciation in query:
+ if word not in word_cache:
+ word_objs.append(
+ {
+ "id": word_primary_key,
+ "mapping_id": current_index,
+ "word": word,
+ "word_type": word_type,
+ "dictionary_id": dictionary.id,
+ }
+ )
+ self._words_mappings[dictionary.id][word] = current_index
+ current_index += 1
+ word_cache[word] = word_primary_key
+ word_primary_key += 1
+ pron_objs.append(
+ {
+ "id": pronunciation_primary_key,
+ "pronunciation": pronunciation,
+ "probability": None,
+ "disambiguation": None,
+ "silence_after_probability": None,
+ "silence_before_correction": None,
+ "non_silence_before_correction": None,
+ "word_id": word_cache[word],
+ }
+ )
+ pronunciation_primary_key += 1
+ for s in ["#0", "", ""]:
+ word_objs.append(
+ {
+ "id": word_primary_key,
+ "word": s,
+ "dictionary_id": dictionary.id,
+ "mapping_id": current_index,
+ "word_type": WordType.disambiguation,
+ }
+ )
+ self._words_mappings[dictionary.id][s] = current_index
+ word_primary_key += 1
+ current_index += 1
+ with self.session() as session:
+ with session.bind.begin() as conn:
+ if word_objs:
+ conn.execute(sqlalchemy.insert(Word.__table__), word_objs)
+ if pron_objs:
+ conn.execute(sqlalchemy.insert(Pronunciation.__table__), pron_objs)
+ session.commit()
+
+ def create_nonnative_dictionary(self):
+ with self.session() as session:
+ dictionary = session.query(Dictionary).filter_by(name="nonnative").first()
+ word_objs = []
+ pron_objs = []
+
+ self.dictionary_lookup[dictionary.name] = dictionary.id
+ self._words_mappings[dictionary.id] = {}
+ word_primary_key = self.get_next_primary_key(Word)
+ pronunciation_primary_key = self.get_next_primary_key(Pronunciation)
+ word_cache = {}
+ query = (
+ session.query(
+ Word.word,
+ Word.mapping_id,
+ Word.word_type,
+ Pronunciation.pronunciation,
+ )
+ .join(Word.pronunciations, isouter=True)
+ .filter(Word.dictionary_id == self._default_dictionary_id)
+ .distinct()
+ .order_by(Word.mapping_id)
+ )
+ for word, mapping_id, word_type, pronunciation in query:
+ if word not in word_cache:
+ word_objs.append(
+ {
+ "id": word_primary_key,
+ "mapping_id": mapping_id,
+ "word": word,
+ "word_type": word_type,
+ "dictionary_id": dictionary.id,
+ }
+ )
+ self._words_mappings[dictionary.id][word] = mapping_id
+ word_cache[word] = word_primary_key
+ word_primary_key += 1
+ if pronunciation is not None:
+ pron_objs.append(
+ {
+ "id": pronunciation_primary_key,
+ "pronunciation": pronunciation,
+ "probability": None,
+ "disambiguation": None,
+ "silence_after_probability": None,
+ "silence_before_correction": None,
+ "non_silence_before_correction": None,
+ "word_id": word_cache[word],
+ }
+ )
+ pronunciation_primary_key += 1
+ with self.session() as session:
+ with session.bind.begin() as conn:
+ if word_objs:
+ conn.execute(sqlalchemy.insert(Word.__table__), word_objs)
+ if pron_objs:
+ conn.execute(sqlalchemy.insert(Pronunciation.__table__), pron_objs)
+ session.commit()
+
def calculate_disambiguation(self) -> None:
"""Calculate the number of disambiguation symbols necessary for the dictionary"""
with self.session() as session:
diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py
index aab12922..f82dea66 100644
--- a/montreal_forced_aligner/exceptions.py
+++ b/montreal_forced_aligner/exceptions.py
@@ -306,6 +306,34 @@ def __init__(self, missing_phones: Collection[str]):
self.message_lines.append(comma_join(missing_phones))
+class PhoneGroupTopologyMismatchError(DictionaryError):
+ """
+ Exception class for when a dictionary receives a new phone
+
+ Parameters
+ ----------
+ error_topologies: List[Tuple[List[str], List[Tuple[int, int]]]]
+ Phones that are not in the acoustic model
+ """
+
+ def __init__(
+ self,
+ error_topologies: typing.List[
+ typing.Tuple[typing.List[str], typing.List[typing.Tuple[int, int]]]
+ ],
+ phone_groups_path: typing.Union[Path, str],
+ topologies_path: typing.Union[Path, str],
+ ):
+ super().__init__("There were multiple topologies found for phones in the same group: ")
+ for k, v in error_topologies:
+ v = [f"(min_states = {x[0]}, max_states = {x[1]})" for x in v]
+ self.message_lines.append(f"{comma_join(k)}: {comma_join(v)}")
+ self.message_lines.append(
+ f"Please update {phone_groups_path} or {topologies_path} to "
+ f"ensure that phone groups have a single set of minimum and maximum states"
+ )
+
+
class NoDefaultSpeakerDictionaryError(DictionaryError):
"""
Exception class for errors in creating MultispeakerDictionary objects
diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py
index 124d3659..087a7bf9 100644
--- a/montreal_forced_aligner/models.py
+++ b/montreal_forced_aligner/models.py
@@ -647,7 +647,7 @@ def meta(self) -> MetaDict:
self._meta["other_noise_phone"] = "sp"
if "phone_set_type" not in self._meta:
self._meta["phone_set_type"] = "UNKNOWN"
- if "language" not in self._meta:
+ if "language" not in self._meta or self._meta["version"] <= "3.0":
self._meta["language"] = "unknown"
self._meta["phones"] = set(self._meta.get("phones", []))
if (
@@ -1589,7 +1589,9 @@ def load_dictionary_paths(self) -> Dict[str, Tuple[DictionaryModel, typing.Set[s
data = yaml.load(f, Loader=yaml.Loader)
for speaker, path in data.items():
if path not in mapping:
- mapping[path] = (DictionaryModel(path), set())
+ if path != "nonnative":
+ path = DictionaryModel(path)
+ mapping[path] = (path, set())
mapping[path][1].add(speaker)
else:
mapping[str(self.path)] = (self, {"default"})
diff --git a/montreal_forced_aligner/textgrid.py b/montreal_forced_aligner/textgrid.py
index 140d63f6..f804c5bb 100644
--- a/montreal_forced_aligner/textgrid.py
+++ b/montreal_forced_aligner/textgrid.py
@@ -13,6 +13,7 @@
from pathlib import Path
from typing import Dict, List
+import sqlalchemy
from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval
from praatio.utilities import utils as tgio_utils
@@ -40,7 +41,7 @@
__all__ = [
"process_ctm_line",
"export_textgrid",
- "construct_output_tiers",
+ "construct_textgrid_output",
"construct_output_path",
"output_textgrid_writing_errors",
]
@@ -275,87 +276,150 @@ def parse_aligned_textgrid(
return data
-def construct_output_tiers(
+def construct_textgrid_output(
session: Session,
- file_id: int,
+ file_batch: typing.Dict[int, typing.Tuple],
workflow: CorpusWorkflow,
cleanup_textgrids: bool,
clitic_marker: str,
- include_original_text: bool,
-) -> Dict[str, Dict[str, List[CtmInterval]]]:
- """
- Construct aligned output tiers for a file
-
- Parameters
- ----------
- session: Session
- SqlAlchemy session
- file_id: int
- Integer ID for the file
-
- Returns
- -------
- Dict[str, Dict[str,List[CtmInterval]]]
- Aligned tiers
- """
- data = {}
- phone_intervals = (
- session.query(PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name)
+ output_directory: Path,
+ frame_shift: float,
+ output_format: str = TextgridFormats.SHORT_TEXTGRID,
+ include_original_text: bool = False,
+):
+ phone_interval_query = (
+ sqlalchemy.select(
+ PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name, Utterance.file_id
+ )
+ .execution_options(yield_per=1000)
.join(PhoneInterval.phone)
.join(PhoneInterval.utterance)
.join(Utterance.speaker)
- .filter(Utterance.file_id == file_id)
.filter(PhoneInterval.workflow_id == workflow.id)
.filter(PhoneInterval.duration > 0)
- .order_by(PhoneInterval.begin)
+ .filter(Utterance.file_id.in_(list(file_batch.keys())))
)
- word_intervals = (
- session.query(WordInterval.begin, WordInterval.end, Word.word, Speaker.name)
+ word_interval_query = (
+ sqlalchemy.select(
+ WordInterval.begin, WordInterval.end, Word.word, Speaker.name, Utterance.file_id
+ )
+ .execution_options(yield_per=1000)
.join(WordInterval.word)
.join(WordInterval.utterance)
.join(Utterance.speaker)
- .filter(Utterance.file_id == file_id)
.filter(WordInterval.workflow_id == workflow.id)
.filter(WordInterval.duration > 0)
- .order_by(WordInterval.begin)
+ .filter(Utterance.file_id.in_(list(file_batch.keys())))
)
if cleanup_textgrids:
- phone_intervals = phone_intervals.filter(Phone.phone_type != PhoneType.silence)
- word_intervals = word_intervals.filter(Word.word_type != WordType.silence)
-
- for w_begin, w_end, w, speaker_name in word_intervals:
- if speaker_name not in data:
- data[speaker_name] = {"words": [], "phones": []}
- if include_original_text:
- data[speaker_name]["utterances"] = []
- if (
- cleanup_textgrids
- and data[speaker_name]["words"]
- and w_begin - data[speaker_name]["words"][-1].end < 0.02
- and clitic_marker
- and (
- data[speaker_name]["words"][-1].label.endswith(clitic_marker)
- or w.startswith(clitic_marker)
- )
- ):
- data[speaker_name]["words"][-1].end = w_end
- data[speaker_name]["words"][-1].label += w
-
- else:
- data[speaker_name]["words"].append(CtmInterval(w_begin, w_end, w))
-
+ phone_interval_query = phone_interval_query.filter(Phone.phone_type != PhoneType.silence)
+ word_interval_query = word_interval_query.filter(Word.word_type != WordType.silence)
+ phone_intervals = session.execute(
+ phone_interval_query.order_by(Utterance.file_id, PhoneInterval.begin)
+ )
+ word_intervals = session.execute(
+ word_interval_query.order_by(Utterance.file_id, WordInterval.begin)
+ )
+ utterances = None
if include_original_text:
- utterances = (
- session.query(Utterance.begin, Utterance.end, Utterance.text, Speaker.name)
+ utterances = session.execute(
+ sqlalchemy.select(
+ Utterance.begin, Utterance.end, Utterance.text, Speaker.name, Utterance.file_id
+ )
+ .execution_options(yield_per=1000)
.join(Utterance.speaker)
- .filter(Utterance.file_id == file_id)
+ .filter(Utterance.file_id.in_(list(file_batch.keys())))
+ .order_by(Utterance.file_id)
)
- for utt_begin, utt_end, utt_text, speaker_name in utterances:
- data[speaker_name]["utterances"].append(CtmInterval(utt_begin, utt_end, utt_text))
-
- for p_begin, p_end, phone, speaker_name in phone_intervals:
- data[speaker_name]["phones"].append(CtmInterval(p_begin, p_end, phone))
- return data
+ pi_current_file_id = None
+ wi_current_file_id = None
+ u_current_file_id = None
+ word_data = []
+ phone_data = []
+ utterance_data = []
+
+ def process_phone_data():
+ for beg, end, p, speaker_name in phone_data:
+ if speaker_name not in data:
+ data[speaker_name] = {"words": [], "phones": []}
+ if include_original_text:
+ data[speaker_name]["utterances"] = []
+ data[speaker_name]["phones"].append(CtmInterval(beg, end, p))
+
+ def process_word_data():
+ for beg, end, w, speaker_name in word_data:
+ if (
+ cleanup_textgrids
+ and data[speaker_name]["words"]
+ and beg - data[speaker_name]["words"][-1].end < 0.02
+ and clitic_marker
+ and (
+ data[speaker_name]["words"][-1].label.endswith(clitic_marker)
+ or w.startswith(clitic_marker)
+ )
+ ):
+ data[speaker_name]["words"][-1].end = end
+ data[speaker_name]["words"][-1].label += w
+ else:
+ data[speaker_name]["words"].append(CtmInterval(beg, end, w))
+
+ def process_utterance_data():
+ for beg, end, u, speaker_name in utterance_data:
+ data[speaker_name]["utterances"].append(CtmInterval(beg, end, u))
+
+ while True:
+ data = {}
+ for pi_begin, pi_end, phone, pi_speaker_name, pi_file_id in phone_intervals:
+ if pi_current_file_id is None:
+ pi_current_file_id = pi_file_id
+ if pi_file_id != pi_current_file_id:
+ process_phone_data()
+ phone_data = [(pi_begin, pi_end, phone, pi_speaker_name)]
+ current_file_id = pi_current_file_id
+ pi_current_file_id = pi_file_id
+ break
+ phone_data.append((pi_begin, pi_end, phone, pi_speaker_name))
+ else:
+ if phone_data:
+ process_phone_data()
+ current_file_id = pi_current_file_id
+ else:
+ break
+ for wi_begin, wi_end, word, wi_speaker_name, wi_file_id in word_intervals:
+ if wi_current_file_id is None:
+ wi_current_file_id = wi_file_id
+ if wi_file_id != wi_current_file_id:
+ process_word_data()
+ word_data = [(wi_begin, wi_end, word, wi_speaker_name)]
+ wi_current_file_id = wi_file_id
+ break
+ word_data.append((wi_begin, wi_end, word, wi_speaker_name))
+ else:
+ if word_data:
+ process_word_data()
+ if include_original_text:
+ for u_begin, u_end, text, u_speaker_name, u_file_id in utterances:
+ if u_current_file_id is None:
+ u_current_file_id = u_file_id
+ if u_file_id != u_current_file_id:
+ process_utterance_data()
+ utterance_data = [(u_begin, u_end, text, u_speaker_name)]
+ u_current_file_id = u_file_id
+ break
+ utterance_data.append((u_begin, u_end, text, u_speaker_name))
+ else:
+ if utterance_data:
+ process_utterance_data()
+
+ file_name, relative_path, file_duration, text_file_path = file_batch[current_file_id]
+ output_path = construct_output_path(
+ file_name, relative_path, output_directory, text_file_path, output_format
+ )
+ export_textgrid(data, output_path, file_duration, frame_shift, output_format)
+ yield output_path
+ word_data = []
+ phone_data = []
+ utterance_data = []
def construct_output_path(
diff --git a/montreal_forced_aligner/tokenization/japanese.py b/montreal_forced_aligner/tokenization/japanese.py
index beae7db2..551e9393 100644
--- a/montreal_forced_aligner/tokenization/japanese.py
+++ b/montreal_forced_aligner/tokenization/japanese.py
@@ -16,7 +16,7 @@ class JapaneseTokenizer:
def __init__(self, ignore_case: bool = True):
self.ignore_case = ignore_case
resource_dir = pathlib.Path(__file__).parent.joinpath("resources")
- config_path = resource_dir.joinpath("japanese", "sudachi_config.json")
+ config_path = str(resource_dir.joinpath("japanese", "sudachi_config.json"))
try:
self.tokenizer = sudachipy.Dictionary(dict="full", config_path=config_path).create(
mode=sudachipy.SplitMode.B
diff --git a/montreal_forced_aligner/tokenization/simple.py b/montreal_forced_aligner/tokenization/simple.py
index a7bd355d..74abd9a6 100644
--- a/montreal_forced_aligner/tokenization/simple.py
+++ b/montreal_forced_aligner/tokenization/simple.py
@@ -129,6 +129,7 @@ def __init__(
initial_clitic_regex: typing.Optional[re.Pattern],
final_clitic_regex: typing.Optional[re.Pattern],
compound_regex: typing.Optional[re.Pattern],
+ cutoff_regex: typing.Optional[re.Pattern],
non_speech_regexes: typing.Dict[str, re.Pattern],
oov_word: typing.Optional[str] = None,
grapheme_set: typing.Optional[typing.Collection[str]] = None,
@@ -136,6 +137,7 @@ def __init__(
self.word_table = word_table
self.clitic_marker = clitic_marker
self.compound_regex = compound_regex
+ self.cutoff_regex = cutoff_regex
self.oov_word = oov_word
self.specials_set = {self.oov_word, "", ""}
if not grapheme_set:
@@ -172,6 +174,8 @@ def to_str(self, normalized_text: str) -> str:
if self.word_table and self.word_table.member(normalized_text):
return normalized_text
for word, regex in self.non_speech_regexes.items():
+ if self.cutoff_regex.match(normalized_text):
+ return normalized_text
if regex.match(normalized_text):
return word
return normalized_text
@@ -195,7 +199,8 @@ def split_clitics(
"""
split = []
if self.compound_regex is not None:
- s = self.compound_regex.split(item)
+ s = [x for x in self.compound_regex.split(item) if x]
+
else:
s = [item]
if self.word_table is None:
@@ -292,6 +297,8 @@ def __call__(
"""
if self.word_table and self.word_table.member(item):
return [item]
+ if self.cutoff_regex.match(item):
+ return item
for regex in self.non_speech_regexes.values():
if regex.match(item):
return [item]
@@ -325,7 +332,10 @@ def __init__(
self.laughter_word = laughter_word
self.oov_word = oov_word
self.bracketed_word = bracketed_word
- self.cutoff_word = cutoff_word
+
+ initial_brackets = re.escape("".join(x[0] for x in self.brackets))
+ final_brackets = re.escape("".join(x[1] for x in self.brackets))
+ self.cutoff_identifier = re.sub(rf"[{initial_brackets}{final_brackets}]", "", cutoff_word)
self.ignore_case = ignore_case
self.use_g2p = use_g2p
self.clitic_set = set()
@@ -372,6 +382,7 @@ def __init__(
self.initial_clitic_regex,
self.final_clitic_regex,
self.compound_regex,
+ self.cutoff_regex,
self.non_speech_regexes,
self.oov_word,
self.grapheme_set,
@@ -396,15 +407,19 @@ def _compile_regexes(self) -> None:
if "-" in self.compound_markers:
extra = "-"
compound_markers = [x for x in compound_markers if x != "-"]
- self.compound_regex = re.compile(rf"(?<=\w)[{extra}{''.join(compound_markers)}](?=\w)")
+ self.compound_regex = re.compile(
+ rf"(?<=\w)[{extra}{''.join(compound_markers)}](?:$|(?=\w))"
+ )
if self.brackets:
- left_brackets = [x[0] for x in self.brackets]
- right_brackets = [x[1] for x in self.brackets]
- self.bracket_regex = re.compile(
- rf"[{re.escape(''.join(left_brackets))}].*?[{re.escape(''.join(right_brackets))}]+"
+ left_brackets = re.escape("".join(x[0] for x in self.brackets))
+ right_brackets = re.escape("".join(x[1] for x in self.brackets))
+ self.cutoff_regex = re.compile(
+ rf"^[{left_brackets}]({self.cutoff_identifier}|hes(itation)?)([-_](?P[^{right_brackets}]+))?[{right_brackets}]$",
+ flags=re.IGNORECASE,
)
+ self.bracket_regex = re.compile(rf"[{left_brackets}].*?[{right_brackets}]+")
self.laughter_regex = re.compile(
- rf"[{re.escape(''.join(left_brackets))}](laugh(ing|ter)?|lachen|lg)[{re.escape(''.join(right_brackets))}]+",
+ rf"[{left_brackets}](laugh(ing|ter)?|lachen|lg)[{right_brackets}]+",
flags=re.IGNORECASE,
)
all_punctuation = set()
diff --git a/tests/conftest.py b/tests/conftest.py
index e1d0d53c..9e4a62f4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -996,11 +996,11 @@ def multispeaker_dictionary_config_path(generated_dir, basic_dict_path, english_
@pytest.fixture(scope="session")
-def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_mfa_dictionary):
+def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_mfa_reduced_dict):
path = generated_dir.joinpath("test_multispeaker_mfa_dictionary.yaml")
with mfa_open(path, "w") as f:
yaml.dump(
- {"default": english_us_mfa_dictionary, "speaker": english_uk_mfa_dictionary},
+ {"default": english_us_mfa_reduced_dict, "speaker": english_us_mfa_reduced_dict},
f,
Dumper=yaml.Dumper,
allow_unicode=True,
@@ -1008,6 +1008,30 @@ def mfa_speaker_dict_path(generated_dir, english_uk_mfa_dictionary, english_us_m
return path
+@pytest.fixture(scope="session")
+def english_mfa_phone_groups_path(config_directory):
+ path = config_directory.joinpath("acoustic", "english_mfa_phone_groups.yaml")
+ return path
+
+
+@pytest.fixture(scope="session")
+def english_mfa_rules_path(config_directory):
+ path = config_directory.joinpath("acoustic", "english_mfa_rules.yaml")
+ return path
+
+
+@pytest.fixture(scope="session")
+def english_mfa_topology_path(config_directory):
+ path = config_directory.joinpath("acoustic", "english_mfa_topology.yaml")
+ return path
+
+
+@pytest.fixture(scope="session")
+def bad_topology_path(config_directory):
+ path = config_directory.joinpath("acoustic", "bad_topology.yaml")
+ return path
+
+
@pytest.fixture(scope="session")
def test_align_config():
return {"beam": 100, "retry_beam": 400}
diff --git a/tests/data/configs/acoustic/bad_topology.yaml b/tests/data/configs/acoustic/bad_topology.yaml
new file mode 100644
index 00000000..048869a9
--- /dev/null
+++ b/tests/data/configs/acoustic/bad_topology.yaml
@@ -0,0 +1,4 @@
+b:
+ max_states: 3
+bʲ:
+ max_states: 1
diff --git a/tests/data/configs/acoustic/english_mfa_phone_groups.yaml b/tests/data/configs/acoustic/english_mfa_phone_groups.yaml
new file mode 100644
index 00000000..52a5e88b
--- /dev/null
+++ b/tests/data/configs/acoustic/english_mfa_phone_groups.yaml
@@ -0,0 +1,170 @@
+-
+ - p
+ - pʷ
+ - pʰ
+ - pʲ
+-
+ - kp
+-
+ - b
+ - bʲ
+-
+ - ɡb
+-
+ - f
+ - fʷ
+ - fʲ
+-
+ - v
+ - vʷ
+ - vʲ
+-
+ - θ
+-
+ - t̪
+-
+ - ð
+-
+ - d̪
+-
+ - t
+ - tʷ
+ - tʰ
+ - tʲ
+-
+ - ʈ
+ - ʈʲ
+ - ʈʷ
+-
+ - ʔ
+-
+ - d
+ - dʲ
+-
+ - ɖ
+ - ɖʲ
+-
+ - ɾ
+ - ɾʲ
+-
+ - tʃ
+-
+ - dʒ
+-
+ - ʃ
+-
+ - ʒ
+-
+ - s
+-
+ - z
+-
+ - ɹ
+-
+ - m
+-
+ - mʲ
+-
+ - m̩
+-
+ - ɱ
+-
+ - n
+-
+ - n̩
+-
+ - ɲ
+-
+ - ɾ̃
+-
+ - ŋ
+-
+ - l
+-
+ - ɫ
+-
+ - ɫ̩
+-
+ - ʎ
+-
+ - ɟ
+ - ɟʷ
+-
+ - ɡ
+ - ɡʷ
+-
+ - c
+ - cʷ
+ - cʰ
+-
+ - k
+ - kʷ
+ - kʰ
+-
+ - ç
+-
+ - h
+-
+ - ɐ
+-
+ - ə
+-
+ - ɜː
+ - ɜ
+-
+ - ɝ
+-
+ - ɚ
+-
+ - ʊ
+-
+ - ɪ
+-
+ - ɑ
+ - ɑː
+-
+ - ɒ
+ - ɒː
+-
+ - ɔ
+-
+ - aː
+ - a
+-
+ - æ
+-
+ - aj
+-
+ - aw
+-
+ - i
+ - iː
+-
+ - j
+-
+ - ɛː
+ - ɛ
+-
+ - e
+ - eː
+-
+ - ej
+-
+ - ʉ
+ - ʉː
+-
+ - uː
+ - u
+-
+ - w
+-
+ - ʋ
+-
+ - ɔj
+-
+ - ow
+-
+ - əw
+-
+ - o
+ - oː
diff --git a/tests/data/configs/acoustic/english_mfa_rules.yaml b/tests/data/configs/acoustic/english_mfa_rules.yaml
new file mode 100644
index 00000000..9b8379f0
--- /dev/null
+++ b/tests/data/configs/acoustic/english_mfa_rules.yaml
@@ -0,0 +1,383 @@
+dialects:
+ nonnative:
+ - following_context: ''
+ preceding_context: ''
+ replacement: z
+ segment: 'ð'
+ - following_context: ''
+ preceding_context: ''
+ replacement: s
+ segment: 'θ'
+ nigeria:
+ - following_context: $
+ preceding_context: ''
+ replacement: p s
+ segment: 'b z'
+ - following_context: $
+ preceding_context: ''
+ replacement: t s
+ segment: 'd z'
+ - following_context: $
+ preceding_context: ''
+ replacement: k s
+ segment: 'ɡ z'
+ - following_context: $
+ preceding_context: ''
+ replacement: s
+ segment: 'z'
+ - following_context: ''
+ preceding_context: '^'
+ replacement: ''
+ segment: 'ç'
+ - following_context: ''
+ preceding_context: '^'
+ replacement: ''
+ segment: 'h'
+ - following_context: '$'
+ preceding_context: 'ŋ'
+ replacement: ''
+ segment: 'ɡ'
+ uk:
+ - following_context: ''
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ʔ
+ segment: 't[ʲʷ]?'
+ - following_context: ʉː?
+ preceding_context: ''
+ replacement: tʃ
+ segment: 'tʲ'
+ - following_context: ʉː?
+ preceding_context: ''
+ replacement: dʒ
+ segment: 'dʲ'
+ - following_context: ''
+ preceding_context: '^'
+ replacement: ''
+ segment: 'ç'
+ - following_context: '' # mitten glottalized syllabic
+ preceding_context: ''
+ replacement: ʔ n̩
+ segment: t ə n
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic n
+ preceding_context: ''
+ replacement: n̩
+ segment: ə n
+ - following_context: $ # syllabic n
+ preceding_context: ''
+ replacement: n̩
+ segment: ə n
+ - following_context: $ # syllabic m
+ preceding_context: ''
+ replacement: m̩
+ segment: ə m
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic m
+ preceding_context: ''
+ replacement: m̩
+ segment: ə m
+ - following_context: $ # syllabic l
+ preceding_context: ''
+ replacement: ɫ̩
+ segment: ə ɫ
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic l
+ preceding_context: ''
+ replacement: ɫ̩
+ segment: ə ɫ
+ india:
+ - following_context: ''
+ preceding_context: ''
+ replacement: ɾ
+ segment: 'ɹ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: a
+ segment: 'ɒ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: a
+ segment: 'ɑ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: aː
+ segment: 'ɒː'
+ - following_context: ''
+ preceding_context: ''
+ replacement: aː
+ segment: 'ɑː'
+ - following_context: ''
+ preceding_context: ''
+ replacement: f
+ segment: 'pʰ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: dʒ
+ segment: 'z'
+ - following_context: ''
+ preceding_context: ''
+ replacement: dʒ
+ segment: 'ʒ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: z
+ segment: 'ʒ'
+ - following_context: ''
+ preceding_context: ''
+ replacement: ʃ
+ segment: 'ʒ'
+ us:
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ɾ
+ segment: '[td]'
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: ɹ
+ replacement: ɾ
+ segment: '[td]'
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: ɫ
+ replacement: ɾ
+ segment: '[td]'
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ɾʲ
+ segment: '[td]ʲ'
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: ɹ
+ replacement: ɾʲ
+ segment: '[td]ʲ'
+ - following_context: (ɪ|ə|ɚ|i) # t/d flapping
+ preceding_context: '(ɫ|ɫ̩)'
+ replacement: ɾʲ
+ segment: '[td]ʲ'
+ - following_context: (ɪ|ə|ɚ|i) # nasal flapping
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ɾ̃
+ segment: (ɲ|n)
+ - following_context: (ɪ|ə|ɚ|i) # nasal flapping
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ɾ̃
+ segment: (ɲ|n) [td][ʲʷ]?
+ - following_context: $ # caught-cot merger
+ preceding_context: ''
+ replacement: ɑː
+ segment: ɒː
+ - following_context: '[^ɹ]' # caught-cot merger
+ preceding_context: ''
+ replacement: ɑː
+ segment: ɒː
+ - following_context: $ # caught-cot merger
+ preceding_context: ''
+ replacement: ɑ
+ segment: ɒ
+ - following_context: '[^ɹ]' # caught-cot merger
+ preceding_context: ''
+ replacement: ɑ
+ segment: ɒ
+ - following_context: $ # t/d flapping
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ɾ
+ segment: d
+ - following_context: '' # mitten glottalized syllabic
+ preceding_context: ''
+ replacement: ʔ n̩
+ segment: t ə n
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic n
+ preceding_context: ''
+ replacement: n̩
+ segment: ə n
+ - following_context: $ # syllabic n
+ preceding_context: ''
+ replacement: n̩
+ segment: ə n
+ - following_context: $ # syllabic m
+ preceding_context: ''
+ replacement: m̩
+ segment: ə m
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic m
+ preceding_context: ''
+ replacement: m̩
+ segment: ə m
+ - following_context: $ # syllabic l
+ preceding_context: ''
+ replacement: ɫ̩
+ segment: ə ɫ
+ - following_context: '[^ʊɔɝaɔɛɜeuoæɐɪəɚɑʉɒi].*' # syllabic l
+ preceding_context: ''
+ replacement: ɫ̩
+ segment: ə ɫ
+rules:
+- following_context: .*(ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ preceding_context: ^
+ replacement: ''
+ segment: ə
+- following_context: $
+ preceding_context: '[sʃn]'
+ replacement: ''
+ segment: '[tʈ]'
+- following_context: $
+ preceding_context: '[zʒn]'
+ replacement: ''
+ segment: '[dɖ]'
+- following_context: ''
+ preceding_context: 'n'
+ replacement: ''
+ segment: '[dɖ]'
+- following_context: ə|ɚ
+ preceding_context: ''
+ replacement: j
+ segment: i
+- following_context: ə|ɚ
+ preceding_context: ''
+ replacement: w
+ segment: '[ʉu]'
+- following_context: '$'
+ preceding_context: ''
+ replacement: t̪
+ segment: '[tʈ] θ'
+- following_context: ''
+ preceding_context: ''
+ replacement: t̪
+ segment: 'θ [tʈ]'
+- following_context: '$'
+ preceding_context: ''
+ replacement: d̪
+ segment: 'ð [ɖd]'
+- following_context: $ # t deletion
+ preceding_context: (ʊ|ɔj|ɝ|ɛ|ej|ɜ|a|u|o|ow|æ|aw|əw|aj|ɐ|ɪ|ə|ɔ|e|ɚ|ɑ|ʉ|ɒ|i)ː?
+ replacement: ''
+ segment: '[tʈ]'
+- following_context: ''
+ preceding_context: ''
+ replacement: d̪
+ segment: '[ɖd] ð'
+- following_context: ''
+ preceding_context: ''
+ replacement: d̪
+ segment: ð
+- following_context: ''
+ preceding_context: ''
+ replacement: t̪
+ segment: θ
+- following_context: '[tʈpkcsʃf][ʲʷ]?'
+ preceding_context: ''
+ replacement: s
+ segment: z
+- following_context: '[tʈpkcsʃf][ʲʷ]?'
+ preceding_context: ''
+ replacement: t
+ segment: d
+- following_context: '[tʈpkcsʃf][ʲʷ]?'
+ preceding_context: ''
+ replacement: p
+ segment: b
+- following_context: '[tpkcsʃf][ʲʷ]?'
+ preceding_context: ''
+ replacement: k
+ segment: ɡ
+- following_context: '[tʈpkcsʃf][ʲʷ]?'
+ preceding_context: ''
+ replacement: tʃ
+ segment: dʒ
+- following_context: '[vʋf]ʲ?'
+ preceding_context: ''
+ replacement: ɱ
+ segment: m
+- following_context: '[vʋf]ʲ?'
+ preceding_context: ''
+ replacement: ɱ
+ segment: 'n'
+- following_context: '[tʈdɖ]$'
+ preceding_context: m
+ replacement: ''
+ segment: '[pb]'
+- following_context: '[sz]$'
+ preceding_context: 'n'
+ replacement: ''
+ segment: '[tʈdɖ]'
+- following_context: '[s]$'
+ preceding_context: ''
+ replacement: k
+ segment: '[sʃ] k'
+- following_context: '[s]$'
+ preceding_context: ''
+ replacement: ''
+ segment: '[sʃ] t'
+- following_context: $ # ask metathesis
+ preceding_context: ''
+ replacement: k s
+ segment: s k
+- following_context: '[dɖʈtcɟɡk][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: p
+- following_context: '[pbcɟɡk][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: '[tʈ]'
+- following_context: '[pbcɟɡk][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: d
+- following_context: '[dɖtʈcɟɡk][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: b
+- following_context: '[dɖtʈpb][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: k
+- following_context: '[dɖtʈpb][ʲʷ]?'
+ preceding_context: ''
+ replacement: ''
+ segment: ɡ
+- following_context: ([tʈpkc][ʲʷ]?)? ɹ
+ preceding_context: ''
+ replacement: ʃ
+ segment: s
+- following_context: 'ɹ'
+ preceding_context: ''
+ replacement: tʃ
+ segment: '[tʈ][ʲʷ]?'
+- following_context: ''
+ preceding_context: ''
+ replacement: tʃ
+ segment: '[tʈ][ʲʷ]? ɹ'
+- following_context: ''
+ preceding_context: ''
+ replacement: dʒ
+ segment: '[dɖ][ʲʷ]? ɹ'
+- following_context: 'ɹ'
+ preceding_context: ''
+ replacement: dʒ
+ segment: '[dɖ][ʲʷ]?'
+- following_context: ə n
+ preceding_context: ''
+ replacement: ʔ
+ segment: '[tʈ]'
+- following_context: $ # ing/in' variation
+ preceding_context: ɪ
+ replacement: 'n'
+ segment: ŋ
+- following_context: z$ # ing/in' variation
+ preceding_context: ɪ
+ replacement: 'n'
+ segment: ŋ
+- following_context: '' # schwa deletion
+ preceding_context: ''
+ replacement: l ə
+ segment: 'ə l ə'
+- following_context: '' # schwa deletion
+ preceding_context: ''
+ replacement: ʎ ə
+ segment: 'ə ʎ ə'
+- following_context: '' # schwa deletion
+ preceding_context: ''
+ replacement: n ə
+ segment: 'ə n ə'
+- following_context: '' # schwa deletion
+ preceding_context: ''
+ replacement: m ə
+ segment: 'ə m ə'
+- following_context: '' # schwa deletion
+ preceding_context: ''
+ replacement: ɹ ə
+ segment: 'ə ɹ ə'
diff --git a/tests/data/configs/acoustic/english_mfa_topology.yaml b/tests/data/configs/acoustic/english_mfa_topology.yaml
new file mode 100644
index 00000000..18af2846
--- /dev/null
+++ b/tests/data/configs/acoustic/english_mfa_topology.yaml
@@ -0,0 +1,46 @@
+ɾ:
+ max_states: 1
+ min_states: 1
+ɾʲ:
+ max_states: 1
+ min_states: 1
+ɾ̃:
+ max_states: 1
+ min_states: 1
+ʔ:
+ max_states: 1
+ min_states: 1
+ə:
+ max_states: 3
+ɚ:
+ max_states: 3
+ɪ:
+ max_states: 3
+e:
+ max_states: 3
+eː:
+ max_states: 3
+ɛ:
+ max_states: 3
+ɛː:
+ max_states: 3
+ɐ:
+ max_states: 3
+i:
+ max_states: 3
+iː:
+ max_states: 3
+o:
+ max_states: 3
+oː:
+ max_states: 3
+u:
+ max_states: 3
+uː:
+ max_states: 3
+ɝ:
+ max_states: 3
+j:
+ max_states: 3
+w:
+ max_states: 3
diff --git a/tests/test_commandline_train.py b/tests/test_commandline_train.py
index f42c44e1..9e63b749 100644
--- a/tests/test_commandline_train.py
+++ b/tests/test_commandline_train.py
@@ -4,6 +4,7 @@
import pytest
from montreal_forced_aligner.command_line.mfa import mfa_cli
+from montreal_forced_aligner.exceptions import PhoneGroupTopologyMismatchError
@pytest.mark.skip("Inconsistent failing on CI")
@@ -56,11 +57,45 @@ def test_train_and_align_basic_speaker_dict(
temp_dir,
basic_train_config_path,
textgrid_output_model_path,
+ english_mfa_phone_groups_path,
+ english_mfa_rules_path,
+ english_mfa_topology_path,
+ bad_topology_path,
db_setup,
):
if os.path.exists(textgrid_output_model_path):
os.remove(textgrid_output_model_path)
output_directory = generated_dir.joinpath("ipa speaker output")
+ with pytest.raises(PhoneGroupTopologyMismatchError):
+ command = [
+ "train",
+ multilingual_ipa_tg_corpus_dir,
+ mfa_speaker_dict_path,
+ textgrid_output_model_path,
+ "--config_path",
+ basic_train_config_path,
+ "-q",
+ "--clean",
+ "--no_debug",
+ "--output_directory",
+ output_directory,
+ "--single_speaker",
+ "--phone_groups_path",
+ english_mfa_phone_groups_path,
+ "--rules_path",
+ english_mfa_rules_path,
+ "--topology_path",
+ bad_topology_path,
+ ]
+ command = [str(x) for x in command]
+ result = click.testing.CliRunner(mix_stderr=False).invoke(
+ mfa_cli, command, catch_exceptions=True
+ )
+ print(result.stdout)
+ print(result.stderr)
+ if result.exception:
+ print(result.exc_info)
+ raise result.exception
command = [
"train",
multilingual_ipa_tg_corpus_dir,
@@ -74,6 +109,12 @@ def test_train_and_align_basic_speaker_dict(
"--output_directory",
output_directory,
"--single_speaker",
+ "--phone_groups_path",
+ english_mfa_phone_groups_path,
+ "--rules_path",
+ english_mfa_rules_path,
+ "--topology_path",
+ english_mfa_topology_path,
]
command = [str(x) for x in command]
result = click.testing.CliRunner(mix_stderr=False).invoke(