Skip to content

Commit

Permalink
refactor: solve some typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed May 6, 2024
1 parent 70e2c34 commit 472caf5
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 68 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ namespace_packages = True
explicit_package_bases = True
show_error_codes = True
ignore_missing_imports = True
python_version = "3.8"
python_version = 3.8
; plugins = adt.mypy_plugin
12 changes: 11 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"diff-match-patch>=20230430",
"more-itertools>=10.2.0",
"typing-extensions>=4.11.0",
"strenum>=0.4.15", # For StrEnum i Python < 3.10
]
requires-python = ">=3.8"
readme = "README.md"
Expand Down
37 changes: 21 additions & 16 deletions src/parallel_corpus/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import logging
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, TypeVar

from typing_extensions import Self
from typing import Dict, Iterable, List, Optional, TypeVar

import parallel_corpus.shared.ranges
import parallel_corpus.shared.str_map
Expand Down Expand Up @@ -38,7 +36,7 @@ class Edge:
comment: Optional[str] = None


Edges = dict[str, Edge]
Edges = Dict[str, Edge]


@dataclass
Expand All @@ -48,12 +46,12 @@ class Graph(SourceTarget[List[Token]]):

def copy_with_updated_side_and_edges(
self, side: Side, new_tokens: List[Token], edges: Edges
) -> Self:
) -> "Graph":
source = self.source if side == Side.target else new_tokens
target = new_tokens if side == Side.target else self.target
return Graph(source=source, target=target, edges=edges, comment=self.comment)

def copy_with_edges(self, edges: Edges) -> Self:
def copy_with_edges(self, edges: Edges) -> "Graph":
return Graph(source=self.source, target=self.target, edges=edges, comment=self.comment)


Expand All @@ -79,7 +77,7 @@ def edge(
)


def edge_record(es: List[Edge]) -> Dict[str, Edge]:
def edge_record(es: Iterable[Edge]) -> Dict[str, Edge]:
return {e.id: e for e in es}


Expand Down Expand Up @@ -145,10 +143,12 @@ def align(g: Graph) -> Graph:
for c in char_diff:
# these undefined makes the alignment skip spaces.
# they originate from to_char_ids
if c.change == diffs.ChangeType.CONSTANT and (c.a.id is not None and c.b.id is not None):
if c.change == diffs.ChangeType.CONSTANT and (
c.a is not None and c.b is not None and c.a.id is not None and c.b.id is not None
):
uf.union(c.a.id, c.b.id)
proto_edges = {k: e for k, e in g.edges.items() if e.manual}
first = UniqueCheck()
first: UniqueCheck[str] = UniqueCheck()

def update_edges(tokens, _side):
for tok in tokens:
Expand All @@ -157,7 +157,10 @@ def update_edges(tokens, _side):
labels = e_repr.labels if first(e_repr.id) else []
e_token = edge([tok.id], labels, manual=False, comment=e_repr.comment)
dicts.modify(
proto_edges, uf.find(tok.id), zero_edge, lambda e: merge_edges(e, e_token)
proto_edges,
uf.find(tok.id),
zero_edge,
lambda e: merge_edges(e, e_token), # noqa: B023
)

map_sides(g, update_edges)
Expand Down Expand Up @@ -203,7 +206,9 @@ def unaligned_set_side(g: Graph, side: Side, text: str) -> Graph:
return unaligned_modify(g, from_, to, new_text, side)


def unaligned_modify(g: Graph, from_: int, to: int, text: str, side: Side = "target") -> Graph:
def unaligned_modify(
g: Graph, from_: int, to: int, text: str, side: Side = Side.target
) -> Graph:
"""Replace the text at some position, merging the spans it touches upon.
>>> show = lambda g: [t.text for t in g.target]
Expand Down Expand Up @@ -242,7 +247,7 @@ def unaligned_modify(g: Graph, from_: int, to: int, text: str, side: Side = "tar
Indexes are character offsets (use CodeMirror's doc.posFromIndex and doc.indexFromPos to convert)
"""
""" # noqa: E501

tokens = get_side_texts(g, side)
token_at = token.token_at(tokens, from_)
Expand All @@ -264,7 +269,7 @@ def get_side_texts(g: Graph, side: Side) -> List[str]:
return token.texts(g.get_side(side))


def unaligned_modify_tokens(
def unaligned_modify_tokens( # noqa: C901
g: Graph, from_: int, to: int, text: str, side: Side = Side.target
) -> Graph:
"""# /** Replace the text at some position, merging the spans it touches upon.
Expand Down Expand Up @@ -295,7 +300,7 @@ def unaligned_modify_tokens(
# idsS(unaligned_modify_tokens(g, 0, 0, 'this ', 'source')) // => 's3 s0 s1 s2'
# Indexes are token offsets
"""
""" # noqa: E501

if (
from_ < 0
Expand Down Expand Up @@ -370,13 +375,13 @@ def unaligned_rearrange(g: Graph, begin: int, end: int, dest: int) -> Graph:
target_text(unaligned_rearrange(init('apa bepa cepa depa'), 1, 2, 0)) // => 'bepa cepa apa depa '
Indexes are token offsets"""
Indexes are token offsets""" # noqa: E501
em = edge_map(g)
edge_ids_to_update = {em[t.id].id for t in g.target[begin : (end + 1)]}
new_edges = {}
new_edges.update(g.edges)
for id_ in edge_ids_to_update:
new_edges[id_] = merge_edges(g.edges[id_], edge([], [], manual=True))
return g.copy_with_updated_side_and_edges(
"target", lists.rearrange(g.target, begin, end, dest), new_edges
Side.target, lists.rearrange(g.target, begin, end, dest), new_edges
)
4 changes: 1 addition & 3 deletions src/parallel_corpus/shared/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import re
from typing import List, TypeVar

from typing_extensions import Self

from . import diffs

__all__ = ["diffs"]
Expand All @@ -19,7 +17,7 @@ def end_with_space(s: str) -> str:

def uniq(xs: List[str]) -> List[str]:
used = set()
return [x for x in xs if x not in used and (used.add(x) or True)]
return [x for x in xs if x not in used and (used.add(x) or True)] # type: ignore [func-returns-value]


A = TypeVar("A")
Expand Down
10 changes: 8 additions & 2 deletions src/parallel_corpus/shared/dicts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Callable, Dict, List, TypeVar
from typing import TYPE_CHECKING, Callable, Dict, List, TypeVar

if TYPE_CHECKING:
from _typeshed import SupportsRichComparison

K = TypeVar("K", bound=SupportsRichComparison)
else:
K = TypeVar("K")

A = TypeVar("A")
B = TypeVar("B")
K = TypeVar("K")
V = TypeVar("V")


Expand Down
15 changes: 7 additions & 8 deletions src/parallel_corpus/shared/diffs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import enum
import itertools
from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import diff_match_patch as dmp_module
from typing_extensions import Self

from parallel_corpus.shared.str_map import str_map
from parallel_corpus.source_target import Side

dmp = dmp_module.diff_match_patch()

Expand Down Expand Up @@ -45,7 +43,7 @@ def deleted(cls, a: A) -> Self:
def inserted(cls, b: B) -> Self:
return cls(ChangeType.INSERTED, b=b)

def model_dump(self) -> dict[str, Union[int, A, B]]:
def model_dump(self) -> Dict[str, Union[int, A, B]]:
out: Dict[str, Union[int, A, B]] = {
"change": int(self.change),
}
Expand All @@ -55,7 +53,9 @@ def model_dump(self) -> dict[str, Union[int, A, B]]:
out["b"] = self.b
return out

def __eq__(self, other: Self) -> bool:
def __eq__(self, other) -> bool:
if not isinstance(other, Change):
return NotImplemented
return self.change == other.change and self.a == other.a and self.b == other.b

def __repr__(self) -> str:
Expand Down Expand Up @@ -87,7 +87,7 @@ def char_stream():
i += 1


def hdiff(
def hdiff( # noqa: C901
xs: List[A],
ys: List[B],
a_cmp: Callable[[A], str] = str,
Expand Down Expand Up @@ -115,8 +115,8 @@ def assign(c: C, c_cmp: Callable[[C], str], c_from: Dict[str, List[C]]) -> str:
s2 = "".join((assign(b, b_cmp, b_from) for b in ys))
d = dmp.diff_main(s1, s2)

def str_map_change(change: int) -> Callable[[str, Side], Change]:
def inner(c: str, _side: Side) -> Change:
def str_map_change(change: int) -> Callable[[str, int], Change]:
def inner(c: str, _: int) -> Change:
if change == 0:
a = a_from.get(c, []).pop(0)
b = b_from.get(c, []).pop(0)
Expand All @@ -139,7 +139,6 @@ def map_change(change: int, cs):
# print(f"{changes=}")
out.extend(changes)
return out
return list(itertools.chain(*(map_change(change, cs) for change, cs in d)))


def token_diff(s1: str, s2: str) -> List[Tuple[int, str]]:
Expand Down
8 changes: 5 additions & 3 deletions src/parallel_corpus/shared/functional.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List
from typing import Callable, Sequence, TypeVar

A = TypeVar("A")

def take_last_while(predicate, xs: List) -> List:

def take_last_while(predicate: Callable[[A], bool], xs: Sequence[A]) -> Sequence[A]:
start = 0
for e in reversed(xs):
if not predicate(e):
break
start -= 1
return xs[start:] if start < 0 else []
return xs[start:] if start < 0 else xs[:0]
3 changes: 1 addition & 2 deletions src/parallel_corpus/shared/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def edit_range(s0: str, s: str) -> EditRange:
{'from': 0, 'to': 0, 'insert': '01'}
"""
patches = token_diff(s0, s)
pre = itertools.takewhile(lambda i: i[0] == 0, patches)
pre = list(itertools.takewhile(lambda i: i[0] == 0, patches))
post = take_last_while(lambda i: i[0] == 0, patches)
pre = list(pre)
from_ = len("".join((i[1] for i in pre)))
postlen = len("".join((i[1] for i in post)))
to = len(s0) - postlen
Expand Down
18 changes: 10 additions & 8 deletions src/parallel_corpus/shared/union_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def unions(self, xs: List[A]) -> None:

class UnionFind(UnionFindOperations[int]):
def __init__(self, *, rev: Optional[List[int]] = None) -> None:
self._rev: List[Optional[int]] = rev or []
self._rev: List[int] = rev or []

def find(self, x: int) -> int:
while x >= len(self._rev):
self._rev.append(None)
self._rev.append(None) # type: ignore [arg-type]
if self._rev[x] is None:
self._rev[x] = x
elif self._rev[x] != x:
self._rev[x] = self.find(self._rev[x])
return self._rev[x]
self._rev[x] = self.find(self._rev[x]) # type: ignore [arg-type]
return self._rev[x] # type: ignore [return-value]

def union(self, x: int, y: int) -> int:
find_x = self.find(x)
Expand All @@ -52,7 +52,7 @@ def unions(self, xs: List[int]) -> None:
@dataclass
class Renumber(Generic[A]):
bw: Dict[str, int]
fw: Dict[str, A]
fw: Dict[int, A]
i = 0
serialize: Callable[[A], str]

Expand All @@ -74,7 +74,7 @@ def init(cls, serialize: Callable[[A], str] = json.dumps) -> Self:

def renumber(
serialize: Callable[[A], str] = json.dumps,
) -> Tuple[Callable[[int], A], Callable[[A], int]]:
) -> Tuple[Callable[[int], Optional[A]], Callable[[A], int]]:
"""
Assign unique numbers to each distinct element
Expand All @@ -91,7 +91,7 @@ def renumber(
num('FOO') // => 0
un(0) // => 'foo'
"""
renum = Renumber(bw={}, fw={}, serialize=serialize)
renum: Renumber[A] = Renumber(bw={}, fw={}, serialize=serialize)

return renum.un, renum.num

Expand All @@ -111,7 +111,9 @@ def union(self, x: A, y: A) -> Optional[A]:
return self._renum.un(self._uf.union(self._renum.num(x), self._renum.num(y)))

def unions(self, xs: List[A]) -> None:
self._uf.unions(map(self._renum.num, xs))
num_xs_0 = self._renum.num(xs[0])
for x in xs[1:]:
self._uf.union(num_xs_0, self._renum.num(x))


def poly_union_find(serialize: Callable[[str], str]) -> PolyUnionFind:
Expand Down
2 changes: 1 addition & 1 deletion src/parallel_corpus/shared/unique_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class UniqueCheck(Generic[S]):
"""

def __init__(self) -> None:
self.c = Count()
self.c: Count[S] = Count()

def __call__(self, s: S) -> bool:
return self.c.inc(s) == 1
Expand Down
12 changes: 6 additions & 6 deletions src/parallel_corpus/source_target.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import enum
from dataclasses import dataclass
from typing import Callable, Generic, TypeVar

# Used to support StrEnum in python 3.8 and 3.9
# Not drop-in of StrEnum in python 3.11
import strenum

A = TypeVar("A")
B = TypeVar("B")


class Side(enum.StrEnum):
class Side(strenum.StrEnum):
source = "source"
target = "target"

Expand All @@ -17,10 +20,7 @@ class SourceTarget(Generic[A]):
target: A

def get_side(self, side: Side) -> A:
if side == Side.source:
return self.source
if side == Side.target:
return self.target
return self.source if side == Side.source else self.target


def map_sides(g: SourceTarget[A], f: Callable[[A, Side], B]) -> SourceTarget[B]:
Expand Down
Loading

0 comments on commit 472caf5

Please sign in to comment.