Skip to content

Commit

Permalink
feat: enable rounding by num decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Mar 19, 2024
1 parent 625ae6b commit a2e4427
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
18 changes: 17 additions & 1 deletion word-prediction-kb-bert/src/word_prediction_kb_bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
description="The number of predictions to list",
default=5,
),
Config(
"work_prediction_kb_bert.num_decimals",
description="The number of decimals to round the score to",
default=3,
),
]

__version__ = "0.4.0"
Expand Down Expand Up @@ -65,6 +70,7 @@ def predict_words__kb_bert(
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
num_predictions_str: str = Config("word_prediction_kb_bert.num_predictions"),
num_decimals_str: str = Config("word_prediction_kb_bert.num_deciamals"),
) -> None:
logger.info("predict_words")
try:
Expand All @@ -73,6 +79,12 @@ def predict_words__kb_bert(
raise SparvErrorMessage(
f"'word_prediction_kb_bert.num_predictions' must contain an 'int' got: '{num_predictions_str}'"
) from exc
try:
num_decimals = int(num_decimals_str)
except ValueError as exc:
raise SparvErrorMessage(
f"'word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'"
) from exc
tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision()

tokenizer = BertTokenizer.from_pretrained(
Expand All @@ -82,7 +94,11 @@ def predict_words__kb_bert(
MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision
)

predictor = TopKPredictor(model=model, tokenizer=tokenizer)
predictor = TopKPredictor(
model=model,
tokenizer=tokenizer,
num_decimals=num_decimals,
)

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
Expand Down
35 changes: 31 additions & 4 deletions word-prediction-kb-bert/src/word_prediction_kb_bert/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@
FillMaskPipeline,
)

SCORE_FORMATS = {
1: ("{:.1f}", lambda s: s.endswith(".0")),
2: ("{:.2f}", lambda s: s.endswith(".00")),
3: ("{:.3f}", lambda s: s.endswith(".000")),
4: ("{:.4f}", lambda s: s.endswith(".0000")),
5: ("{:.5f}", lambda s: s.endswith(".00000")),
6: ("{:.6f}", lambda s: s.endswith(".000000")),
7: ("{:.7f}", lambda s: s.endswith(".0000000")),
8: ("{:.8f}", lambda s: s.endswith(".00000000")),
9: ("{:.9f}", lambda s: s.endswith(".000000000")),
10: ("{:.10f}", lambda s: s.endswith(".0000000000")),
}


class TopKPredictor:
def __init__(self, *, tokenizer, model) -> None:
def __init__(self, *, tokenizer, model, num_decimals: int = 3) -> None:
self.tokenizer = tokenizer
self.model = model
self.num_decimals = num_decimals
self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer)

def get_top_k_predictions(self, text: str, k=5) -> str:
def get_top_k_predictions(self, text: str, k: int = 5) -> str:
tokenized_inputs = self.tokenizer(text)
if len(tokenized_inputs["input_ids"]) <= 512:
return self._run_pipeline(text, k)
Expand All @@ -34,10 +48,23 @@ def compute_context(self, text):

def _run_pipeline(self, text, k) -> str:
if predictions := self.pipeline(text, top_k=k):
predictions_str = "|".join(
f"{pred['token_str']}:{pred['score']}" # type: ignore
collect_token_and_score = (
(pred["token_str"], pred["score"]) # type: ignore
for pred in predictions
)
score_format, score_pred = SCORE_FORMATS[self.num_decimals]
format_scores = (
(token, score_format.format(score))
for token, score in collect_token_and_score
)
filter_out_zero_scores = (
(token, score)
for token, score in format_scores
if not score_pred(score)
)
predictions_str = "|".join(
f"{token}:{score}" for token, score in filter_out_zero_scores
)
return f"|{predictions_str}|"
else:
return "|"
13 changes: 13 additions & 0 deletions word-prediction-kb-bert/tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import islice
from typing import Tuple
import pytest
from word_prediction_kb_bert.predictor import TopKPredictor
Expand Down Expand Up @@ -30,6 +31,18 @@ def test_short_text(kb_bert_predictor: TopKPredictor) -> None:
assert actual == expected


def test_rounding(kb_bert_predictor: TopKPredictor) -> None:
text = "namnet på det hus där historien börjar och slutar [MASK] men annars pratas det mest om huset . "
kb_bert_predictor.num_decimals = 2
actual = kb_bert_predictor.get_top_k_predictions(text)
kb_bert_predictor.num_decimals = 3
print(f"{actual=}")
expected = [",:0.92", "-:0.06"]

num_bars = actual.count("|")
assert list(islice(actual.split("|"), 1, num_bars)) == expected


def remove_scores(actual):
return "|".join(map(lambda x: x.split(":")[0], actual.split("|")))

Expand Down

0 comments on commit a2e4427

Please sign in to comment.