diff --git a/pdm.lock b/pdm.lock index 8278b7b..467e81b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:ba436b1fde50163f18ee08183fcedc49fa36b2c6dc2d4ac1a005056c052128dd" +content_hash = "sha256:c9d1751b45e83e3705d56d669dcaa4a911db6fdf591aa53f15f7bb5aec672e5d" [[package]] name = "appdirs" diff --git a/src/word_prediction/__init__.py b/src/word_prediction/__init__.py index 9a1b7b2..3af85f9 100644 --- a/src/word_prediction/__init__.py +++ b/src/word_prediction/__init__.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass +from typing import Optional, Tuple from sparv.api import ( # type: ignore [import-untyped] annotator, Output, @@ -10,26 +12,16 @@ from transformers import ( # type: ignore [import-untyped] BertTokenizer, BertForMaskedLM, - FillMaskPipeline, ) +from word_prediction.predictor import HuggingFaceTopKPredictor -__description__ = "Calculating word neighbours by mask a word in a BERT model." +__description__ = "Calculating word predictions by mask a word in a BERT model." __config__ = [ Config( - "word_prediction.model", - description="Huggingface pretrained model name", - default="KBLab/bert-base-swedish-cased", - ), - Config( - "word_prediction.tokenizer", - description="HuggingFace pretrained tokenizer name", - default="KBLab/bert-base-swedish-cased", - ), - Config( - "word_prediction.num_neighbours", - description="The number of neighbours to list", + "word_prediction.num_predictions", + description="The number of predictions to list", default=5, ), ] @@ -41,36 +33,83 @@ TOK_SEP = " " +@dataclass +class HuggingfaceModel: + model_name: str + model_revision: str + tokenizer_name: Optional[str] = None + tokenizer_revision: Optional[str] = None + + def tokenizer_name_and_revision(self) -> Tuple[str, str]: + if tokenizer_name := self.tokenizer_name: + return tokenizer_name, self.tokenizer_revision or "main" + else: + return self.model_name, self.model_revision + + +MODELS = { + "kb-bert": HuggingfaceModel( + model_name="KBLab/bert-base-swedish-cased", + model_revision="c710fb8dff81abb11d704cd46a8a1e010b2b022c", + ) +} + + @annotator( - "Word neighbour tagging with a masked Bert model", + "Word prediction tagging with a masked Bert model", ) -def annotate_masked_bert( - out_neighbour: Output = Output( - ":word_prediction.transformer-neighbour", - cls="transformer_neighbour", - description="Transformer neighbours from masked BERT (format: '|:|...|)", +def predict_words__kb_bert( + out_prediction: Output = Output( + ":word_prediction.word-prediction--kb-bert", + cls="word_prediction", + description="Word predictions from masked BERT (format: '|:|...|)", ), word: Annotation = Annotation(""), sentence: Annotation = Annotation(""), - model_name: str = Config("word_prediction.model"), - tokenizer_name: str = Config("word_prediction.tokenizer"), - num_neighbours_str: str = Config("word_prediction.num_neighbours"), + num_predictions_str: str = Config("word_prediction.num_predictions"), ) -> None: - logger.info("annotate_masked_bert") + logger.info("predict_words") try: - num_neighbours = int(num_neighbours_str) + num_predictions = int(num_predictions_str) except ValueError as exc: raise SparvErrorMessage( - f"'word_prediction.num_neighbours' must contain an 'int' got: '{num_neighbours_str}'" + f"'word_prediction.num_predictions' must contain an 'int' got: '{num_predictions_str}'" ) from exc - tokenizer = BertTokenizer.from_pretrained(tokenizer_name) - model = BertForMaskedLM.from_pretrained(model_name) + tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision() + + tokenizer = BertTokenizer.from_pretrained( + tokenizer_name, revision=tokenizer_revision + ) + model = BertForMaskedLM.from_pretrained( + MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision + ) - hf_top_k_predictor = HuggingFaceTopKPredictor(model=model, tokenizer=tokenizer) + predictor = HuggingFaceTopKPredictor(model=model, tokenizer=tokenizer) sentences, _orphans = sentence.get_children(word) token_word = list(word.read()) - out_neighbour_annotation = word.create_empty_attribute() + out_prediction_annotation = word.create_empty_attribute() + + run_word_prediction( + predictor=predictor, + num_predictions=num_predictions, + sentences=sentences, + token_word=token_word, + out_prediction_annotations=out_prediction_annotation, + ) + + logger.info("writing annotations") + out_prediction.write(out_prediction_annotation) + + +def run_word_prediction( + predictor: HuggingFaceTopKPredictor, + num_predictions: int, + sentences, + token_word: list, + out_prediction_annotations, +) -> None: + logger.info("run_word_prediction") logger.progress(total=len(sentences)) # type: ignore for sent in sentences: @@ -86,27 +125,7 @@ def annotate_masked_bert( for token_index in sent ) - neighbours_scores = hf_top_k_predictor.get_top_k_predictions( - sent_to_tag, k=num_neighbours - ) - out_neighbour_annotation[token_index_to_mask] = neighbours_scores - - logger.info("writing annotations") - out_neighbour.write(out_neighbour_annotation) - - -class HuggingFaceTopKPredictor: - def __init__(self, *, tokenizer, model) -> None: - self.tokenizer = tokenizer - self.model = model - self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer) - - def get_top_k_predictions(self, text: str, k=5) -> str: - if predictions := self.pipeline(text, top_k=k): - predictions_str = "|".join( - f"{pred['token_str']}:{pred['score']}" # type: ignore - for pred in predictions + predictions_scores = predictor.get_top_k_predictions( + sent_to_tag, k=num_predictions ) - return f"|{predictions_str}|" - else: - return "|" + out_prediction_annotations[token_index_to_mask] = predictions_scores diff --git a/src/word_prediction/predictor.py b/src/word_prediction/predictor.py new file mode 100644 index 0000000..b046d41 --- /dev/null +++ b/src/word_prediction/predictor.py @@ -0,0 +1,20 @@ +from transformers import ( # type: ignore [import-untyped] + FillMaskPipeline, +) + + +class HuggingFaceTopKPredictor: + def __init__(self, *, tokenizer, model) -> None: + self.tokenizer = tokenizer + self.model = model + self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer) + + def get_top_k_predictions(self, text: str, k=5) -> str: + if predictions := self.pipeline(text, top_k=k): + predictions_str = "|".join( + f"{pred['token_str']}:{pred['score']}" # type: ignore + for pred in predictions + ) + return f"|{predictions_str}|" + else: + return "|"