From 472caf51b3554b9da5f46e82e086f64a5ceb4b72 Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Mon, 6 May 2024 10:44:44 +0200 Subject: [PATCH] refactor: solve some typing issues --- mypy.ini | 2 +- pdm.lock | 12 ++++++- pyproject.toml | 1 + src/parallel_corpus/graph.py | 37 ++++++++++++---------- src/parallel_corpus/shared/__init__.py | 4 +-- src/parallel_corpus/shared/dicts.py | 10 ++++-- src/parallel_corpus/shared/diffs.py | 15 ++++----- src/parallel_corpus/shared/functional.py | 8 +++-- src/parallel_corpus/shared/ranges.py | 3 +- src/parallel_corpus/shared/union_find.py | 18 ++++++----- src/parallel_corpus/shared/unique_check.py | 2 +- src/parallel_corpus/source_target.py | 12 +++---- src/parallel_corpus/token.py | 11 +++---- tests/test_graph.py | 26 +++++++++++---- tests/test_shared/test_diffs.py | 6 ++-- tests/test_shared/test_union_find.py | 4 +-- 16 files changed, 103 insertions(+), 68 deletions(-) diff --git a/mypy.ini b/mypy.ini index 3a8eefc..18a6155 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,5 +4,5 @@ namespace_packages = True explicit_package_bases = True show_error_codes = True ignore_missing_imports = True -python_version = "3.8" +python_version = 3.8 ; plugins = adt.mypy_plugin diff --git a/pdm.lock b/pdm.lock index 578d0a9..876864e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:adaa82ef7accbac4507c9429d6896eb5271db46c371408c1d7cfda2680f917a3" +content_hash = "sha256:45a31179520f4206be41a3c63086952a5a2fb833ae1a85e98262d71ed7988196" [[package]] name = "colorama" @@ -334,6 +334,16 @@ files = [ {file = "ruff-0.4.2.tar.gz", hash = "sha256:33bcc160aee2520664bc0859cfeaebc84bb7323becff3f303b8f1f2d81cb4edc"}, ] +[[package]] +name = "strenum" +version = "0.4.15" +summary = "An Enum that inherits from str." +groups = ["default"] +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + [[package]] name = "syrupy" version = "3.0.6" diff --git a/pyproject.toml b/pyproject.toml index 5cf1458..50d1d34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "diff-match-patch>=20230430", "more-itertools>=10.2.0", "typing-extensions>=4.11.0", + "strenum>=0.4.15", # For StrEnum i Python < 3.10 ] requires-python = ">=3.8" readme = "README.md" diff --git a/src/parallel_corpus/graph.py b/src/parallel_corpus/graph.py index b9c8009..3f74f9e 100644 --- a/src/parallel_corpus/graph.py +++ b/src/parallel_corpus/graph.py @@ -2,9 +2,7 @@ import logging import re from dataclasses import dataclass -from typing import Dict, List, Optional, TypeVar - -from typing_extensions import Self +from typing import Dict, Iterable, List, Optional, TypeVar import parallel_corpus.shared.ranges import parallel_corpus.shared.str_map @@ -38,7 +36,7 @@ class Edge: comment: Optional[str] = None -Edges = dict[str, Edge] +Edges = Dict[str, Edge] @dataclass @@ -48,12 +46,12 @@ class Graph(SourceTarget[List[Token]]): def copy_with_updated_side_and_edges( self, side: Side, new_tokens: List[Token], edges: Edges - ) -> Self: + ) -> "Graph": source = self.source if side == Side.target else new_tokens target = new_tokens if side == Side.target else self.target return Graph(source=source, target=target, edges=edges, comment=self.comment) - def copy_with_edges(self, edges: Edges) -> Self: + def copy_with_edges(self, edges: Edges) -> "Graph": return Graph(source=self.source, target=self.target, edges=edges, comment=self.comment) @@ -79,7 +77,7 @@ def edge( ) -def edge_record(es: List[Edge]) -> Dict[str, Edge]: +def edge_record(es: Iterable[Edge]) -> Dict[str, Edge]: return {e.id: e for e in es} @@ -145,10 +143,12 @@ def align(g: Graph) -> Graph: for c in char_diff: # these undefined makes the alignment skip spaces. # they originate from to_char_ids - if c.change == diffs.ChangeType.CONSTANT and (c.a.id is not None and c.b.id is not None): + if c.change == diffs.ChangeType.CONSTANT and ( + c.a is not None and c.b is not None and c.a.id is not None and c.b.id is not None + ): uf.union(c.a.id, c.b.id) proto_edges = {k: e for k, e in g.edges.items() if e.manual} - first = UniqueCheck() + first: UniqueCheck[str] = UniqueCheck() def update_edges(tokens, _side): for tok in tokens: @@ -157,7 +157,10 @@ def update_edges(tokens, _side): labels = e_repr.labels if first(e_repr.id) else [] e_token = edge([tok.id], labels, manual=False, comment=e_repr.comment) dicts.modify( - proto_edges, uf.find(tok.id), zero_edge, lambda e: merge_edges(e, e_token) + proto_edges, + uf.find(tok.id), + zero_edge, + lambda e: merge_edges(e, e_token), # noqa: B023 ) map_sides(g, update_edges) @@ -203,7 +206,9 @@ def unaligned_set_side(g: Graph, side: Side, text: str) -> Graph: return unaligned_modify(g, from_, to, new_text, side) -def unaligned_modify(g: Graph, from_: int, to: int, text: str, side: Side = "target") -> Graph: +def unaligned_modify( + g: Graph, from_: int, to: int, text: str, side: Side = Side.target +) -> Graph: """Replace the text at some position, merging the spans it touches upon. >>> show = lambda g: [t.text for t in g.target] @@ -242,7 +247,7 @@ def unaligned_modify(g: Graph, from_: int, to: int, text: str, side: Side = "tar Indexes are character offsets (use CodeMirror's doc.posFromIndex and doc.indexFromPos to convert) - """ + """ # noqa: E501 tokens = get_side_texts(g, side) token_at = token.token_at(tokens, from_) @@ -264,7 +269,7 @@ def get_side_texts(g: Graph, side: Side) -> List[str]: return token.texts(g.get_side(side)) -def unaligned_modify_tokens( +def unaligned_modify_tokens( # noqa: C901 g: Graph, from_: int, to: int, text: str, side: Side = Side.target ) -> Graph: """# /** Replace the text at some position, merging the spans it touches upon. @@ -295,7 +300,7 @@ def unaligned_modify_tokens( # idsS(unaligned_modify_tokens(g, 0, 0, 'this ', 'source')) // => 's3 s0 s1 s2' # Indexes are token offsets - """ + """ # noqa: E501 if ( from_ < 0 @@ -370,7 +375,7 @@ def unaligned_rearrange(g: Graph, begin: int, end: int, dest: int) -> Graph: target_text(unaligned_rearrange(init('apa bepa cepa depa'), 1, 2, 0)) // => 'bepa cepa apa depa ' - Indexes are token offsets""" + Indexes are token offsets""" # noqa: E501 em = edge_map(g) edge_ids_to_update = {em[t.id].id for t in g.target[begin : (end + 1)]} new_edges = {} @@ -378,5 +383,5 @@ def unaligned_rearrange(g: Graph, begin: int, end: int, dest: int) -> Graph: for id_ in edge_ids_to_update: new_edges[id_] = merge_edges(g.edges[id_], edge([], [], manual=True)) return g.copy_with_updated_side_and_edges( - "target", lists.rearrange(g.target, begin, end, dest), new_edges + Side.target, lists.rearrange(g.target, begin, end, dest), new_edges ) diff --git a/src/parallel_corpus/shared/__init__.py b/src/parallel_corpus/shared/__init__.py index bf8a293..04e5599 100644 --- a/src/parallel_corpus/shared/__init__.py +++ b/src/parallel_corpus/shared/__init__.py @@ -1,8 +1,6 @@ import re from typing import List, TypeVar -from typing_extensions import Self - from . import diffs __all__ = ["diffs"] @@ -19,7 +17,7 @@ def end_with_space(s: str) -> str: def uniq(xs: List[str]) -> List[str]: used = set() - return [x for x in xs if x not in used and (used.add(x) or True)] + return [x for x in xs if x not in used and (used.add(x) or True)] # type: ignore [func-returns-value] A = TypeVar("A") diff --git a/src/parallel_corpus/shared/dicts.py b/src/parallel_corpus/shared/dicts.py index 59822b9..3176a48 100644 --- a/src/parallel_corpus/shared/dicts.py +++ b/src/parallel_corpus/shared/dicts.py @@ -1,8 +1,14 @@ -from typing import Callable, Dict, List, TypeVar +from typing import TYPE_CHECKING, Callable, Dict, List, TypeVar + +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + + K = TypeVar("K", bound=SupportsRichComparison) +else: + K = TypeVar("K") A = TypeVar("A") B = TypeVar("B") -K = TypeVar("K") V = TypeVar("V") diff --git a/src/parallel_corpus/shared/diffs.py b/src/parallel_corpus/shared/diffs.py index 40e588a..56d55d0 100644 --- a/src/parallel_corpus/shared/diffs.py +++ b/src/parallel_corpus/shared/diffs.py @@ -1,12 +1,10 @@ import enum -import itertools from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import diff_match_patch as dmp_module from typing_extensions import Self from parallel_corpus.shared.str_map import str_map -from parallel_corpus.source_target import Side dmp = dmp_module.diff_match_patch() @@ -45,7 +43,7 @@ def deleted(cls, a: A) -> Self: def inserted(cls, b: B) -> Self: return cls(ChangeType.INSERTED, b=b) - def model_dump(self) -> dict[str, Union[int, A, B]]: + def model_dump(self) -> Dict[str, Union[int, A, B]]: out: Dict[str, Union[int, A, B]] = { "change": int(self.change), } @@ -55,7 +53,9 @@ def model_dump(self) -> dict[str, Union[int, A, B]]: out["b"] = self.b return out - def __eq__(self, other: Self) -> bool: + def __eq__(self, other) -> bool: + if not isinstance(other, Change): + return NotImplemented return self.change == other.change and self.a == other.a and self.b == other.b def __repr__(self) -> str: @@ -87,7 +87,7 @@ def char_stream(): i += 1 -def hdiff( +def hdiff( # noqa: C901 xs: List[A], ys: List[B], a_cmp: Callable[[A], str] = str, @@ -115,8 +115,8 @@ def assign(c: C, c_cmp: Callable[[C], str], c_from: Dict[str, List[C]]) -> str: s2 = "".join((assign(b, b_cmp, b_from) for b in ys)) d = dmp.diff_main(s1, s2) - def str_map_change(change: int) -> Callable[[str, Side], Change]: - def inner(c: str, _side: Side) -> Change: + def str_map_change(change: int) -> Callable[[str, int], Change]: + def inner(c: str, _: int) -> Change: if change == 0: a = a_from.get(c, []).pop(0) b = b_from.get(c, []).pop(0) @@ -139,7 +139,6 @@ def map_change(change: int, cs): # print(f"{changes=}") out.extend(changes) return out - return list(itertools.chain(*(map_change(change, cs) for change, cs in d))) def token_diff(s1: str, s2: str) -> List[Tuple[int, str]]: diff --git a/src/parallel_corpus/shared/functional.py b/src/parallel_corpus/shared/functional.py index 555a10b..50a9d94 100644 --- a/src/parallel_corpus/shared/functional.py +++ b/src/parallel_corpus/shared/functional.py @@ -1,10 +1,12 @@ -from typing import List +from typing import Callable, Sequence, TypeVar +A = TypeVar("A") -def take_last_while(predicate, xs: List) -> List: + +def take_last_while(predicate: Callable[[A], bool], xs: Sequence[A]) -> Sequence[A]: start = 0 for e in reversed(xs): if not predicate(e): break start -= 1 - return xs[start:] if start < 0 else [] + return xs[start:] if start < 0 else xs[:0] diff --git a/src/parallel_corpus/shared/ranges.py b/src/parallel_corpus/shared/ranges.py index 6569a6e..6945fb6 100644 --- a/src/parallel_corpus/shared/ranges.py +++ b/src/parallel_corpus/shared/ranges.py @@ -37,9 +37,8 @@ def edit_range(s0: str, s: str) -> EditRange: {'from': 0, 'to': 0, 'insert': '01'} """ patches = token_diff(s0, s) - pre = itertools.takewhile(lambda i: i[0] == 0, patches) + pre = list(itertools.takewhile(lambda i: i[0] == 0, patches)) post = take_last_while(lambda i: i[0] == 0, patches) - pre = list(pre) from_ = len("".join((i[1] for i in pre))) postlen = len("".join((i[1] for i in post))) to = len(s0) - postlen diff --git a/src/parallel_corpus/shared/union_find.py b/src/parallel_corpus/shared/union_find.py index d5dea7c..e201c5d 100644 --- a/src/parallel_corpus/shared/union_find.py +++ b/src/parallel_corpus/shared/union_find.py @@ -27,16 +27,16 @@ def unions(self, xs: List[A]) -> None: class UnionFind(UnionFindOperations[int]): def __init__(self, *, rev: Optional[List[int]] = None) -> None: - self._rev: List[Optional[int]] = rev or [] + self._rev: List[int] = rev or [] def find(self, x: int) -> int: while x >= len(self._rev): - self._rev.append(None) + self._rev.append(None) # type: ignore [arg-type] if self._rev[x] is None: self._rev[x] = x elif self._rev[x] != x: - self._rev[x] = self.find(self._rev[x]) - return self._rev[x] + self._rev[x] = self.find(self._rev[x]) # type: ignore [arg-type] + return self._rev[x] # type: ignore [return-value] def union(self, x: int, y: int) -> int: find_x = self.find(x) @@ -52,7 +52,7 @@ def unions(self, xs: List[int]) -> None: @dataclass class Renumber(Generic[A]): bw: Dict[str, int] - fw: Dict[str, A] + fw: Dict[int, A] i = 0 serialize: Callable[[A], str] @@ -74,7 +74,7 @@ def init(cls, serialize: Callable[[A], str] = json.dumps) -> Self: def renumber( serialize: Callable[[A], str] = json.dumps, -) -> Tuple[Callable[[int], A], Callable[[A], int]]: +) -> Tuple[Callable[[int], Optional[A]], Callable[[A], int]]: """ Assign unique numbers to each distinct element @@ -91,7 +91,7 @@ def renumber( num('FOO') // => 0 un(0) // => 'foo' """ - renum = Renumber(bw={}, fw={}, serialize=serialize) + renum: Renumber[A] = Renumber(bw={}, fw={}, serialize=serialize) return renum.un, renum.num @@ -111,7 +111,9 @@ def union(self, x: A, y: A) -> Optional[A]: return self._renum.un(self._uf.union(self._renum.num(x), self._renum.num(y))) def unions(self, xs: List[A]) -> None: - self._uf.unions(map(self._renum.num, xs)) + num_xs_0 = self._renum.num(xs[0]) + for x in xs[1:]: + self._uf.union(num_xs_0, self._renum.num(x)) def poly_union_find(serialize: Callable[[str], str]) -> PolyUnionFind: diff --git a/src/parallel_corpus/shared/unique_check.py b/src/parallel_corpus/shared/unique_check.py index 9a14576..be6b0d2 100644 --- a/src/parallel_corpus/shared/unique_check.py +++ b/src/parallel_corpus/shared/unique_check.py @@ -21,7 +21,7 @@ class UniqueCheck(Generic[S]): """ def __init__(self) -> None: - self.c = Count() + self.c: Count[S] = Count() def __call__(self, s: S) -> bool: return self.c.inc(s) == 1 diff --git a/src/parallel_corpus/source_target.py b/src/parallel_corpus/source_target.py index b329387..f8c2dd2 100644 --- a/src/parallel_corpus/source_target.py +++ b/src/parallel_corpus/source_target.py @@ -1,12 +1,15 @@ -import enum from dataclasses import dataclass from typing import Callable, Generic, TypeVar +# Used to support StrEnum in python 3.8 and 3.9 +# Not drop-in of StrEnum in python 3.11 +import strenum + A = TypeVar("A") B = TypeVar("B") -class Side(enum.StrEnum): +class Side(strenum.StrEnum): source = "source" target = "target" @@ -17,10 +20,7 @@ class SourceTarget(Generic[A]): target: A def get_side(self, side: Side) -> A: - if side == Side.source: - return self.source - if side == Side.target: - return self.target + return self.source if side == Side.source else self.target def map_sides(g: SourceTarget[A], f: Callable[[A, Side], B]) -> SourceTarget[B]: diff --git a/src/parallel_corpus/token.py b/src/parallel_corpus/token.py index 376cd50..ff6e32c 100644 --- a/src/parallel_corpus/token.py +++ b/src/parallel_corpus/token.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass import re -from typing import List, TypedDict - +from dataclasses import dataclass +from typing import List, Sequence, TypedDict from parallel_corpus import shared @@ -22,7 +21,7 @@ class Span: end: int -def text(ts: List[Text]) -> str: +def text(ts: Sequence[Text]) -> str: """The text in some tokens >>> text(identify(tokenize('apa bepa cepa '), '#')) @@ -32,13 +31,13 @@ def text(ts: List[Text]) -> str: return "".join(texts(ts)) -def texts(ts: List[Text]) -> List[str]: +def texts(ts: Sequence[Text]) -> List[str]: """The texts in some tokens >>> texts(identify(tokenize('apa bepa cepa '), '#')) ['apa ', 'bepa ', 'cepa '] """ - return list(map(lambda t: t.text, ts)) + return [t.text for t in ts] def tokenize(s: str) -> List[str]: diff --git a/tests/test_graph.py b/tests/test_graph.py index 6cbf967..8726e4b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,6 +2,7 @@ import pytest from parallel_corpus import graph, token +from parallel_corpus.source_target import Side def test_graph_init() -> None: @@ -26,10 +27,21 @@ def test_graph_case1() -> None: assert "e-s0-t19-t20" in gm.edges +def test_graph_case2() -> None: + first = "Jonat han saknades , emedan han , med sin vapendragare , redan på annat håll sökt och anträffat fienden ." # noqa: E501 + second = "Jonathan saknaes , emedan han , med sin vapendragare , redan på annat håll sökt och anträffat fienden ." # noqa: E501 + + g = graph.init(first) + + gm = graph.set_target(g, second) + print(f"{gm=}") + assert "e-s0-s1-t20" in gm.edges + + def test_unaligned_set_side() -> None: g0 = graph.init("a bc d") print(">>> test_unaligned_set_side") - g = graph.unaligned_set_side(g0, "target", "ab c d") + g = graph.unaligned_set_side(g0, Side.target, "ab c d") print("<<< test_unaligned_set_side") expected_source = [ @@ -66,7 +78,7 @@ def test_unaligned_set_side() -> None: def test_graph_align() -> None: g0 = graph.init("a bc d") - g = graph.unaligned_set_side(g0, "target", "ab c d") + g = graph.unaligned_set_side(g0, Side.target, "ab c d") expected_source = [ token.Token(id="s0", text="a "), @@ -209,7 +221,9 @@ def test_unaligned_modify_tokens_ids(from_: int, to: int, text: str, snapshot) - ) def test_unaligned_modify_tokens_show_source(from_: int, to: int, text: str, snapshot) -> None: g = graph.init("test graph hello") - assert show_source(graph.unaligned_modify_tokens(g, from_, to, text, "source")) == snapshot + assert ( + show_source(graph.unaligned_modify_tokens(g, from_, to, text, Side.source)) == snapshot + ) @pytest.mark.parametrize( @@ -220,7 +234,7 @@ def test_unaligned_modify_tokens_show_source(from_: int, to: int, text: str, sna ) def test_unaligned_modify_tokens_ids_source(from_: int, to: int, text: str, snapshot) -> None: g = graph.init("test graph hello") - assert ids_source(graph.unaligned_modify_tokens(g, from_, to, text, "source")) == snapshot + assert ids_source(graph.unaligned_modify_tokens(g, from_, to, text, Side.source)) == snapshot # show(unaligned_modify_tokens(init('a '), 0, 1, ' ')) // => [' '] @@ -230,14 +244,14 @@ def test_unaligned_modify_tokens_ids_source(from_: int, to: int, text: str, snap # ids(unaligned_modify_tokens(g, 0, 1, 'this')) // => 't3 t2' # const showS = (g: Graph) => g.source.map(t => t.text) # const idsS = (g: Graph) => g.source.map(t => t.id).join(' ') -# showS(unaligned_modify_tokens(g, 0, 0, 'this ', 'source')) // => ['this ', 'test ', 'graph ', 'hello '] +# showS(unaligned_modify_tokens(g, 0, 0, 'this ', 'source')) // => ['this ', 'test ', 'graph ', 'hello '] # noqa: E501 # idsS(unaligned_modify_tokens(g, 0, 0, 'this ', 'source')) // => 's3 s0 s1 s2' def test_unaligned_rearrange() -> None: g = graph.init("apa bepa cepa depa") gr = graph.unaligned_rearrange(g, 1, 2, 0) - assert graph.target_text(gr) == "bepa cepa apa depa " + assert graph.target_text(gr) == "bepa cepa apa depa " # type: ignore [arg-type] # target_text(unaligned_rearrange(init(), 1, 2, 0)) // => diff --git a/tests/test_shared/test_diffs.py b/tests/test_shared/test_diffs.py index 065de39..754b8f0 100644 --- a/tests/test_shared/test_diffs.py +++ b/tests/test_shared/test_diffs.py @@ -2,8 +2,8 @@ def test_hdiff() -> None: - (*abcca,) = "abcca" - (*BACC,) = "BACC" + (*abcca,) = "abcca" # type: ignore + (*BACC,) = "BACC" # type: ignore expected = [ Change.deleted("a"), @@ -14,4 +14,4 @@ def test_hdiff() -> None: Change.deleted("a"), ] - assert hdiff(abcca, BACC, str.lower, str.lower) == expected + assert hdiff(abcca, BACC, str.lower, str.lower) == expected # type: ignore [has-type] diff --git a/tests/test_shared/test_union_find.py b/tests/test_shared/test_union_find.py index c723559..696eb44 100644 --- a/tests/test_shared/test_union_find.py +++ b/tests/test_shared/test_union_find.py @@ -14,7 +14,7 @@ def test_union_find() -> None: def test_renumber_default() -> None: - un, num = renumber() + un, num = renumber() # type: ignore [var-annotated] assert num("foo") == 0 assert num("bar") == 1 assert num("foo") == 0 @@ -24,7 +24,7 @@ def test_renumber_default() -> None: def test_renumber_lowercase() -> None: - un, num = renumber(lambda a: a.lower()) + un, num = renumber(str.lower) # type: ignore [var-annotated] assert num("foo") == 0 assert num("FOO") == 0