diff --git a/word-prediction-kb-bert/src/word_prediction_kb_bert/__init__.py b/word-prediction-kb-bert/src/word_prediction_kb_bert/__init__.py index 5c68f93..0689128 100644 --- a/word-prediction-kb-bert/src/word_prediction_kb_bert/__init__.py +++ b/word-prediction-kb-bert/src/word_prediction_kb_bert/__init__.py @@ -24,6 +24,11 @@ description="The number of predictions to list", default=5, ), + Config( + "work_prediction_kb_bert.num_decimals", + description="The number of decimals to round the score to", + default=3, + ), ] __version__ = "0.4.0" @@ -65,6 +70,7 @@ def predict_words__kb_bert( word: Annotation = Annotation(""), sentence: Annotation = Annotation(""), num_predictions_str: str = Config("word_prediction_kb_bert.num_predictions"), + num_decimals_str: str = Config("word_prediction_kb_bert.num_deciamals"), ) -> None: logger.info("predict_words") try: @@ -73,6 +79,12 @@ def predict_words__kb_bert( raise SparvErrorMessage( f"'word_prediction_kb_bert.num_predictions' must contain an 'int' got: '{num_predictions_str}'" ) from exc + try: + num_decimals = int(num_decimals_str) + except ValueError as exc: + raise SparvErrorMessage( + f"'word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'" + ) from exc tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision() tokenizer = BertTokenizer.from_pretrained( @@ -82,7 +94,11 @@ def predict_words__kb_bert( MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision ) - predictor = TopKPredictor(model=model, tokenizer=tokenizer) + predictor = TopKPredictor( + model=model, + tokenizer=tokenizer, + num_decimals=num_decimals, + ) sentences, _orphans = sentence.get_children(word) token_word = list(word.read()) diff --git a/word-prediction-kb-bert/src/word_prediction_kb_bert/predictor.py b/word-prediction-kb-bert/src/word_prediction_kb_bert/predictor.py index 6d9e348..8e2e999 100644 --- a/word-prediction-kb-bert/src/word_prediction_kb_bert/predictor.py +++ b/word-prediction-kb-bert/src/word_prediction_kb_bert/predictor.py @@ -2,14 +2,28 @@ FillMaskPipeline, ) +SCORE_FORMATS = { + 1: ("{:.1f}", lambda s: s.endswith(".0")), + 2: ("{:.2f}", lambda s: s.endswith(".00")), + 3: ("{:.3f}", lambda s: s.endswith(".000")), + 4: ("{:.4f}", lambda s: s.endswith(".0000")), + 5: ("{:.5f}", lambda s: s.endswith(".00000")), + 6: ("{:.6f}", lambda s: s.endswith(".000000")), + 7: ("{:.7f}", lambda s: s.endswith(".0000000")), + 8: ("{:.8f}", lambda s: s.endswith(".00000000")), + 9: ("{:.9f}", lambda s: s.endswith(".000000000")), + 10: ("{:.10f}", lambda s: s.endswith(".0000000000")), +} + class TopKPredictor: - def __init__(self, *, tokenizer, model) -> None: + def __init__(self, *, tokenizer, model, num_decimals: int = 3) -> None: self.tokenizer = tokenizer self.model = model + self.num_decimals = num_decimals self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer) - def get_top_k_predictions(self, text: str, k=5) -> str: + def get_top_k_predictions(self, text: str, k: int = 5) -> str: tokenized_inputs = self.tokenizer(text) if len(tokenized_inputs["input_ids"]) <= 512: return self._run_pipeline(text, k) @@ -34,10 +48,23 @@ def compute_context(self, text): def _run_pipeline(self, text, k) -> str: if predictions := self.pipeline(text, top_k=k): - predictions_str = "|".join( - f"{pred['token_str']}:{pred['score']}" # type: ignore + collect_token_and_score = ( + (pred["token_str"], pred["score"]) # type: ignore for pred in predictions ) + score_format, score_pred = SCORE_FORMATS[self.num_decimals] + format_scores = ( + (token, score_format.format(score)) + for token, score in collect_token_and_score + ) + filter_out_zero_scores = ( + (token, score) + for token, score in format_scores + if not score_pred(score) + ) + predictions_str = "|".join( + f"{token}:{score}" for token, score in filter_out_zero_scores + ) return f"|{predictions_str}|" else: return "|" diff --git a/word-prediction-kb-bert/tests/test_predictor.py b/word-prediction-kb-bert/tests/test_predictor.py index b28b42d..8be2e56 100644 --- a/word-prediction-kb-bert/tests/test_predictor.py +++ b/word-prediction-kb-bert/tests/test_predictor.py @@ -1,3 +1,4 @@ +from itertools import islice from typing import Tuple import pytest from word_prediction_kb_bert.predictor import TopKPredictor @@ -30,6 +31,18 @@ def test_short_text(kb_bert_predictor: TopKPredictor) -> None: assert actual == expected +def test_rounding(kb_bert_predictor: TopKPredictor) -> None: + text = "namnet på det hus där historien börjar och slutar [MASK] men annars pratas det mest om huset . " + kb_bert_predictor.num_decimals = 2 + actual = kb_bert_predictor.get_top_k_predictions(text) + kb_bert_predictor.num_decimals = 3 + print(f"{actual=}") + expected = [",:0.92", "-:0.06"] + + num_bars = actual.count("|") + assert list(islice(actual.split("|"), 1, num_bars)) == expected + + def remove_scores(actual): return "|".join(map(lambda x: x.split(":")[0], actual.split("|")))