Skip to content

Commit

Permalink
[π˜€π—½π—Ώ] initial version
Browse files Browse the repository at this point in the history
Created using spr 1.3.5
  • Loading branch information
kod-kristoff committed Feb 8, 2024
1 parent b675519 commit d041d9a
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 26 deletions.
6 changes: 3 additions & 3 deletions .bumpversion.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ allow_dirty = False
search = version = "{current_version}"
replace = version = "{new_version}"

[bumpversion:file:src/sparv_bert_neighbour/__init__.py]
[bumpversion:file:src/bert_neighbour/__init__.py]
search = __version__ = "{current_version}"
replace = __version__ = "{new_version}"

[bumpversion:file:tests/test_version.py]
search = assert sparv_bert_neighbour.__version__ == "{current_version}"
replace = assert sparv_bert_neighbour.__version__ == "{new_version}"
search = assert bert_neighbour.__version__ == "{current_version}"
replace = assert bert_neighbour.__version__ == "{new_version}"
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ jobs:
name: pypi_files
path: dist

- run: rm -r src/sparv_bert_neighbour
- run: rm -r src/bert_neighbour
- run: pip install typing-extensions
- run: pip install -r tests/requirements-testing.txt
- run: pip install sparv-bert-neighbour-plugin --no-index --no-deps --find-links dist --force-reinstall
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ help:

PLATFORM := `uname -o`
REPO := "sparv-bert-neighbour-plugin"
PROJECT_SRC := "src/sparv_bert_neighbour"
PROJECT_SRC := "src/bert_neighbour"

ifeq (${VIRTUAL_ENV},)
VENV_NAME = .venv
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ annotation exclusively by adding it as the only annotation to export under `xml_
```yaml
xml_export:
annotations:
- <token>:sparv_bert_neighbour.transformer-neighbour
- <token>:bert_neighbour.transformer-neighbour
```
To use it together with other annotations you might add it under `export`:

```yaml
export:
annotations:
- <token>:sparv_bert_neighbour.transformer-neighbour
- <token>:bert_neighbour.transformer-neighbour
...
```

Expand All @@ -50,7 +50,7 @@ You can configure this plugin by choosing a huggingface model, huggingface trans
The model defaults to [`KBLab/bert-base-swedish-cased`](https://huggingface.co/KBLab/bert-base-swedish-cased) but can be configured in `config.yaml`:

```yaml
sparv_bert_neighbour:
bert_neighbour:
model: "KBLab/bert-base-swedish-cased"
```

Expand All @@ -59,7 +59,7 @@ sparv_bert_neighbour:
The tokenizer defaults to [`KBLab/bert-base-swedish-cased`](https://huggingface.co/KBLab/bert-base-swedish-cased) but can be configured in `config.yaml`:

```yaml
sparv_bert_neighbour:
bert_neighbour:
tokenizer: "KBLab/bert-base-swedish-cased"
```

Expand All @@ -68,6 +68,6 @@ sparv_bert_neighbour:
The number of neighbours defaults to `5` but can be configured in `config.yaml`:

```yaml
sparv_bert_neighbour:
bert_neighbour:
num_neighbours: 5
```
2 changes: 1 addition & 1 deletion examples/hello-bert-mask/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export:
- <sentence>
- <token:word>
- <token>:stanza.pos
- <token>:sparv_bert_neighbour.transformer-neighbour
- <token>:bert_neighbour.transformer-neighbour

sparv:
compression: none
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project.entry-points."sparv.plugin"]
sparv_bert_neighbour = "sparv_bert_neighbour"
bert_neighbour = "bert_neighbour"

[project.urls]
Homepage = "https://github.com/spraakbanken/sparv-bert-neighbour-plugin"
Expand All @@ -55,7 +55,7 @@ dev-dependencies = [
exclude = ["/.github", "/docs"]

[tool.hatch.build.targets.wheel]
packages = ["src/sparv_bert_neighbour"]
packages = ["src/bert_neighbour"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@

__config__ = [
Config(
"sparv_bert_neighbour.model",
"bert_neighbour.model",
description="Huggingface pretrained model name",
default="KBLab/bert-base-swedish-cased",
),
Config(
"sparv_bert_neighbour.tokenizer",
"bert_neighbour.tokenizer",
description="HuggingFace pretrained tokenizer name",
default="KBLab/bert-base-swedish-cased",
),
Config(
"sparv_bert_neighbour.num_neighbours",
"bert_neighbour.num_neighbours",
description="The number of neighbours to list",
default=5,
),
Expand All @@ -46,22 +46,22 @@
)
def annotate_masked_bert(
out_neighbour: Output = Output(
"<token>:sparv_bert_neighbour.transformer-neighbour",
"<token>:bert_neighbour.transformer-neighbour",
cls="transformer_neighbour",
description="Transformer neighbours from masked BERT (format: '|<word>:<score>|...|)",
),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
model_name: str = Config("sparv_bert_neighbour.model"),
tokenizer_name: str = Config("sparv_bert_neighbour.tokenizer"),
num_neighbours_str: str = Config("sparv_bert_neighbour.num_neighbours"),
model_name: str = Config("bert_neighbour.model"),
tokenizer_name: str = Config("bert_neighbour.tokenizer"),
num_neighbours_str: str = Config("bert_neighbour.num_neighbours"),
) -> None:
logger.info("annotate_masked_bert")
try:
num_neighbours = int(num_neighbours_str)
except ValueError as exc:
raise SparvErrorMessage(
f"'sparv_bert_neighbour.num_neighbours' must contain an 'int' got: '{num_neighbours_str}'"
f"'bert_neighbour.num_neighbours' must contain an 'int' got: '{num_neighbours_str}'"
) from exc
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
model = BertForMaskedLM.from_pretrained(model_name)
Expand All @@ -78,9 +78,11 @@ def annotate_masked_bert(
token_indices = list(sent)
for token_index_to_mask in token_indices:
sent_to_tag = TOK_SEP.join(
"[MASK]"
if token_index == token_index_to_mask
else token_word[token_index]
(
"[MASK]"
if token_index == token_index_to_mask
else token_word[token_index]
)
for token_index in sent
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sparv_bert_neighbour
import bert_neighbour


def test_version() -> None:
assert sparv_bert_neighbour.__version__ == "0.2.1"
assert bert_neighbour.__version__ == "0.2.1"

0 comments on commit d041d9a

Please sign in to comment.