Skip to content

Commit

Permalink
refactor: use new name
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Feb 23, 2024
1 parent 37ff5cb commit 3ab4d79
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 72 additions & 53 deletions src/word_prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple
from sparv.api import ( # type: ignore [import-untyped]
annotator,
Output,
Expand All @@ -10,26 +12,16 @@
from transformers import ( # type: ignore [import-untyped]
BertTokenizer,
BertForMaskedLM,
FillMaskPipeline,
)
from word_prediction.predictor import HuggingFaceTopKPredictor

__description__ = "Calculating word neighbours by mask a word in a BERT model."
__description__ = "Calculating word predictions by mask a word in a BERT model."


__config__ = [
Config(
"word_prediction.model",
description="Huggingface pretrained model name",
default="KBLab/bert-base-swedish-cased",
),
Config(
"word_prediction.tokenizer",
description="HuggingFace pretrained tokenizer name",
default="KBLab/bert-base-swedish-cased",
),
Config(
"word_prediction.num_neighbours",
description="The number of neighbours to list",
"word_prediction.num_predictions",
description="The number of predictions to list",
default=5,
),
]
Expand All @@ -41,36 +33,83 @@
TOK_SEP = " "


@dataclass
class HuggingfaceModel:
model_name: str
model_revision: str
tokenizer_name: Optional[str] = None
tokenizer_revision: Optional[str] = None

def tokenizer_name_and_revision(self) -> Tuple[str, str]:
if tokenizer_name := self.tokenizer_name:
return tokenizer_name, self.tokenizer_revision or "main"
else:
return self.model_name, self.model_revision


MODELS = {
"kb-bert": HuggingfaceModel(
model_name="KBLab/bert-base-swedish-cased",
model_revision="c710fb8dff81abb11d704cd46a8a1e010b2b022c",
)
}


@annotator(
"Word neighbour tagging with a masked Bert model",
"Word prediction tagging with a masked Bert model",
)
def annotate_masked_bert(
out_neighbour: Output = Output(
"<token>:word_prediction.transformer-neighbour",
cls="transformer_neighbour",
description="Transformer neighbours from masked BERT (format: '|<word>:<score>|...|)",
def predict_words__kb_bert(
out_prediction: Output = Output(
"<token>:word_prediction.word-prediction--kb-bert",
cls="word_prediction",
description="Word predictions from masked BERT (format: '|<word>:<score>|...|)",
),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
model_name: str = Config("word_prediction.model"),
tokenizer_name: str = Config("word_prediction.tokenizer"),
num_neighbours_str: str = Config("word_prediction.num_neighbours"),
num_predictions_str: str = Config("word_prediction.num_predictions"),
) -> None:
logger.info("annotate_masked_bert")
logger.info("predict_words")
try:
num_neighbours = int(num_neighbours_str)
num_predictions = int(num_predictions_str)
except ValueError as exc:
raise SparvErrorMessage(
f"'word_prediction.num_neighbours' must contain an 'int' got: '{num_neighbours_str}'"
f"'word_prediction.num_predictions' must contain an 'int' got: '{num_predictions_str}'"
) from exc
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
model = BertForMaskedLM.from_pretrained(model_name)
tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision()

tokenizer = BertTokenizer.from_pretrained(
tokenizer_name, revision=tokenizer_revision
)
model = BertForMaskedLM.from_pretrained(
MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision
)

hf_top_k_predictor = HuggingFaceTopKPredictor(model=model, tokenizer=tokenizer)
predictor = HuggingFaceTopKPredictor(model=model, tokenizer=tokenizer)

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
out_neighbour_annotation = word.create_empty_attribute()
out_prediction_annotation = word.create_empty_attribute()

run_word_prediction(
predictor=predictor,
num_predictions=num_predictions,
sentences=sentences,
token_word=token_word,
out_prediction_annotations=out_prediction_annotation,
)

logger.info("writing annotations")
out_prediction.write(out_prediction_annotation)


def run_word_prediction(
predictor: HuggingFaceTopKPredictor,
num_predictions: int,
sentences,
token_word: list,
out_prediction_annotations,
) -> None:
logger.info("run_word_prediction")

logger.progress(total=len(sentences)) # type: ignore
for sent in sentences:
Expand All @@ -86,27 +125,7 @@ def annotate_masked_bert(
for token_index in sent
)

neighbours_scores = hf_top_k_predictor.get_top_k_predictions(
sent_to_tag, k=num_neighbours
)
out_neighbour_annotation[token_index_to_mask] = neighbours_scores

logger.info("writing annotations")
out_neighbour.write(out_neighbour_annotation)


class HuggingFaceTopKPredictor:
def __init__(self, *, tokenizer, model) -> None:
self.tokenizer = tokenizer
self.model = model
self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer)

def get_top_k_predictions(self, text: str, k=5) -> str:
if predictions := self.pipeline(text, top_k=k):
predictions_str = "|".join(
f"{pred['token_str']}:{pred['score']}" # type: ignore
for pred in predictions
predictions_scores = predictor.get_top_k_predictions(
sent_to_tag, k=num_predictions
)
return f"|{predictions_str}|"
else:
return "|"
out_prediction_annotations[token_index_to_mask] = predictions_scores
20 changes: 20 additions & 0 deletions src/word_prediction/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from transformers import ( # type: ignore [import-untyped]
FillMaskPipeline,
)


class HuggingFaceTopKPredictor:
def __init__(self, *, tokenizer, model) -> None:
self.tokenizer = tokenizer
self.model = model
self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer)

def get_top_k_predictions(self, text: str, k=5) -> str:
if predictions := self.pipeline(text, top_k=k):
predictions_str = "|".join(
f"{pred['token_str']}:{pred['score']}" # type: ignore
for pred in predictions
)
return f"|{predictions_str}|"
else:
return "|"

0 comments on commit 3ab4d79

Please sign in to comment.