From 2d7becafa649e869e56cff7935b5ba45019878b7 Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 11 Jun 2024 08:45:26 +0200 Subject: [PATCH 1/5] test: remove unneeded parenthesis --- word-prediction-kb-bert/tests/test_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/word-prediction-kb-bert/tests/test_predictor.py b/word-prediction-kb-bert/tests/test_predictor.py index 568b057..b6624f4 100644 --- a/word-prediction-kb-bert/tests/test_predictor.py +++ b/word-prediction-kb-bert/tests/test_predictor.py @@ -44,7 +44,7 @@ def test_rounding(kb_bert_predictor: TopKPredictor) -> None: def remove_scores(actual): - return "|".join((x.split(":")[0] for x in actual.split("|"))) + return "|".join(x.split(":")[0] for x in actual.split("|")) def test_long_text(kb_bert_predictor: TopKPredictor) -> None: From 690bf94c387326d8a4788cd7016617e3ee63db96 Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 11 Jun 2024 11:14:08 +0200 Subject: [PATCH 2/5] chore: add py.typed marker --- word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/py.typed diff --git a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/py.typed b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/py.typed new file mode 100644 index 0000000..e69de29 From 578031372b83ec77dc5928fbf53a4d64713ac7d4 Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 11 Jun 2024 11:15:44 +0200 Subject: [PATCH 3/5] refactor: load model and tokenizer as default --- .../sbx_word_prediction_kb_bert/__init__.py | 39 ------------- .../sbx_word_prediction_kb_bert/predictor.py | 57 +++++++++++++++++-- word-prediction-kb-bert/tests/conftest.py | 14 +---- 3 files changed, 54 insertions(+), 56 deletions(-) diff --git a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py index 3f00039..6dfccfb 100644 --- a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py +++ b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/__init__.py @@ -1,6 +1,3 @@ -from dataclasses import dataclass -from typing import Optional, Tuple - from sparv.api import ( # type: ignore [import-untyped] Annotation, Config, @@ -9,10 +6,6 @@ annotator, get_logger, ) -from transformers import ( # type: ignore [import-untyped] - BertForMaskedLM, - BertTokenizer, -) from sbx_word_prediction_kb_bert.predictor import TopKPredictor @@ -39,28 +32,6 @@ TOK_SEP = " " -@dataclass -class HuggingfaceModel: - model_name: str - model_revision: str - tokenizer_name: Optional[str] = None - tokenizer_revision: Optional[str] = None - - def tokenizer_name_and_revision(self) -> Tuple[str, str]: - if tokenizer_name := self.tokenizer_name: - return tokenizer_name, self.tokenizer_revision or "main" - else: - return self.model_name, self.model_revision - - -MODELS = { - "kb-bert": HuggingfaceModel( - model_name="KBLab/bert-base-swedish-cased", - model_revision="c710fb8dff81abb11d704cd46a8a1e010b2b022c", - ) -} - - @annotator("Word prediction tagging with a masked Bert model", language=["swe"]) def predict_words__kb_bert( out_prediction: Output = Output( @@ -86,18 +57,8 @@ def predict_words__kb_bert( raise SparvErrorMessage( f"'sbx_word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501 ) from exc - tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision() - - tokenizer = BertTokenizer.from_pretrained( - tokenizer_name, revision=tokenizer_revision - ) - model = BertForMaskedLM.from_pretrained( - MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision - ) predictor = TopKPredictor( - model=model, - tokenizer=tokenizer, num_decimals=num_decimals, ) diff --git a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py index 845bceb..455bf50 100644 --- a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py +++ b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py @@ -1,4 +1,9 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + from transformers import ( # type: ignore [import-untyped] + BertForMaskedLM, + BertTokenizer, FillMaskPipeline, ) @@ -17,11 +22,33 @@ class TopKPredictor: - def __init__(self, *, tokenizer, model, num_decimals: int = 3) -> None: - self.tokenizer = tokenizer - self.model = model + def __init__( + self, + *, + tokenizer: Optional[BertTokenizer] = None, + model: Optional[BertForMaskedLM] = None, + num_decimals: int = 3, + ) -> None: + self.tokenizer = tokenizer or self._default_tokenizer() + self.model = model or self._default_model() self.num_decimals = num_decimals - self.pipeline = FillMaskPipeline(model=model, tokenizer=tokenizer) + self.pipeline = FillMaskPipeline(model=self.model, tokenizer=self.tokenizer) + + @classmethod + def _default_model(cls) -> BertForMaskedLM: + return BertForMaskedLM.from_pretrained( + MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision + ) + + @classmethod + def _default_tokenizer(cls) -> BertTokenizer: + tokenizer_name, tokenizer_revision = MODELS[ + "kb-bert" + ].tokenizer_name_and_revision() + + return BertTokenizer.from_pretrained( + tokenizer_name, revision=tokenizer_revision + ) def get_top_k_predictions(self, text: str, k: int = 5) -> str: tokenized_inputs = self.tokenizer(text) @@ -69,3 +96,25 @@ def _run_pipeline(self, text, k) -> str: return f"|{predictions_str}|" if predictions_str else "|" else: return "|" + + +@dataclass +class HuggingfaceModel: + model_name: str + model_revision: str + tokenizer_name: Optional[str] = None + tokenizer_revision: Optional[str] = None + + def tokenizer_name_and_revision(self) -> Tuple[str, str]: + if tokenizer_name := self.tokenizer_name: + return tokenizer_name, self.tokenizer_revision or "main" + else: + return self.model_name, self.model_revision + + +MODELS = { + "kb-bert": HuggingfaceModel( + model_name="KBLab/bert-base-swedish-cased", + model_revision="c710fb8dff81abb11d704cd46a8a1e010b2b022c", + ) +} diff --git a/word-prediction-kb-bert/tests/conftest.py b/word-prediction-kb-bert/tests/conftest.py index c943124..aad8a3a 100644 --- a/word-prediction-kb-bert/tests/conftest.py +++ b/word-prediction-kb-bert/tests/conftest.py @@ -1,21 +1,9 @@ import pytest from sbx_word_prediction_kb_bert import ( - MODELS, TopKPredictor, ) -from transformers import ( # type: ignore [import-untyped] - BertForMaskedLM, - BertTokenizer, -) @pytest.fixture(scope="session") def kb_bert_predictor() -> TopKPredictor: - tokenizer_name, tokenizer_revision = MODELS["kb-bert"].tokenizer_name_and_revision() - tokenizer = BertTokenizer.from_pretrained( - tokenizer_name, revision=tokenizer_revision - ) - model = BertForMaskedLM.from_pretrained( - MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision - ) - return TopKPredictor(model=model, tokenizer=tokenizer) + return TopKPredictor() From a93e93a9bdf22feb72138adfcd1f75c94ec759fd Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 11 Jun 2024 11:15:59 +0200 Subject: [PATCH 4/5] chore(deps): add torch --- word-prediction-kb-bert/pdm.lock | 48 +++++++++++++------------- word-prediction-kb-bert/pyproject.toml | 6 +++- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/word-prediction-kb-bert/pdm.lock b/word-prediction-kb-bert/pdm.lock index 61ba54f..2904228 100644 --- a/word-prediction-kb-bert/pdm.lock +++ b/word-prediction-kb-bert/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:ae5c879ee52fe6a0ae43754676bbb64a7960e079d7948483ac996fef40aab011" +content_hash = "sha256:0d176a27e2c00139bc4c403890fb86ecd8a9236e5a4ce6102c860b5bbeaa2956" [[package]] name = "appdirs" @@ -1954,7 +1954,7 @@ files = [ [[package]] name = "torch" -version = "2.3.0" +version = "2.3.1" requires_python = ">=3.8.0" summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" groups = ["default"] @@ -1976,26 +1976,26 @@ dependencies = [ "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "sympy", - "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"", + "triton==2.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"", "typing-extensions>=4.8.0", ] files = [ - {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, - {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, - {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, - {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, - {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, - {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, - {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, - {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, - {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, - {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, - {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, - {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, - {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, - {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, - {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, - {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [[package]] @@ -2048,7 +2048,7 @@ files = [ [[package]] name = "triton" -version = "2.3.0" +version = "2.3.1" summary = "A language and compiler for custom Deep Learning operations" groups = ["default"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"" @@ -2056,10 +2056,10 @@ dependencies = [ "filelock", ] files = [ - {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, - {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, - {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, - {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, ] [[package]] diff --git a/word-prediction-kb-bert/pyproject.toml b/word-prediction-kb-bert/pyproject.toml index 017455c..39bec2f 100644 --- a/word-prediction-kb-bert/pyproject.toml +++ b/word-prediction-kb-bert/pyproject.toml @@ -5,7 +5,11 @@ description = "A sparv plugin for computing word neighbours using a BERT model." authors = [ { name = "Kristoffer Andersson", email = "kristoffer.andersson@gu.se" }, ] -dependencies = ["sparv-pipeline >=5.2.0", "transformers>=4.34.1"] +dependencies = [ + "sparv-pipeline >=5.2.0", + "transformers>=4.34.1", + "torch>=2.3.1", +] license = "MIT" readme = "README.md" requires-python = ">= 3.8.1,<3.12" From 40b9ff6a3730249b5c31750f70fa1e832637eacf Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 11 Jun 2024 14:49:02 +0200 Subject: [PATCH 5/5] feat: use gpu if available --- .../sbx_word_prediction_kb_bert/predictor.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py index 455bf50..729eb04 100644 --- a/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py +++ b/word-prediction-kb-bert/src/sbx_word_prediction_kb_bert/predictor.py @@ -1,12 +1,16 @@ from dataclasses import dataclass from typing import Optional, Tuple +import torch +from sparv import api as sparv_api # type: ignore [import-untyped] from transformers import ( # type: ignore [import-untyped] BertForMaskedLM, BertTokenizer, FillMaskPipeline, ) +logger = sparv_api.get_logger(__name__) + SCORE_FORMATS = { 1: ("{:.1f}", lambda s: s.endswith(".0")), 2: ("{:.2f}", lambda s: s.endswith(".00")), @@ -36,9 +40,25 @@ def __init__( @classmethod def _default_model(cls) -> BertForMaskedLM: - return BertForMaskedLM.from_pretrained( - MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision + if torch.cuda.is_available(): + logger.info("Using GPU (cuda)") + dtype = torch.float16 + else: + logger.warning("Using CPU, is cuda available?") + dtype = torch.float32 + model = BertForMaskedLM.from_pretrained( + MODELS["kb-bert"].model_name, + revision=MODELS["kb-bert"].model_revision, + torch_dtype=dtype, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + return model @classmethod def _default_tokenizer(cls) -> BertTokenizer: