Skip to content

Commit

Permalink
Slightly Optimize Trie.get_all
Browse files Browse the repository at this point in the history
  • Loading branch information
graphemecluster committed Jul 25, 2024
1 parent 44c44dc commit 362a2b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/ToJyutping/Trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def get_all(self, s: str, attr: None = None) -> List[Tuple[str, List[Union[Jyutp
def get_all(self, s: str, attr: Optional[Literal['jyutping', 'ipa']] = None) -> List[Tuple[str, List[Union[str, Jyutping.Jyutping, Jyutping.JyutpingList]]]]:
t = self.t
def initialize(c: str):
d = defaultdict(list)
d = utils.EdgeLengthToItems()
u = t.get(c)
if u is not None and u.v:
d[0] = [getattr(p, attr, None) for p in u.v] if attr else u.v
return d
r: List[Tuple[str, DefaultDict[int, List[Union[str, Jyutping.Jyutping, Jyutping.JyutpingList]]]]] = [(c, initialize(c)) for c in s]
r: List[Tuple[str, utils.EdgeLengthToItems[Union[str, Jyutping.Jyutping, Jyutping.JyutpingList]]]] = [(c, initialize(c)) for c in s]
for i in range(len(r)):
u = t.get(r[i][0])
if u is None:
Expand All @@ -94,4 +94,4 @@ def initialize(c: str):
for p in u.v:
for k in range(i, j + 1):
r[k][1][l].append(getattr(p[k - i], attr, None) if attr else p[k - i])
return [(c, utils.flat_dedupe(map(itemgetter(1), sorted(s.items(), key=itemgetter(0), reverse=True)))) for c, s in r]
return [(c, utils.flat_dedupe(s)) for c, s in r]
17 changes: 17 additions & 0 deletions src/ToJyutping/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import DefaultDict, List, TypeVar
import re

punct_dict = dict(
Expand Down Expand Up @@ -173,3 +174,19 @@ def is_iterable(o):
except TypeError:
return False
return True

T = TypeVar('T')

class EdgeLengthToItems(DefaultDict[int, List[T]]):
def __init__(self):
super().__init__(list)
self.max = 0

def __missing__(self, index: int):
result = super().__missing__(index)
self[index] = result
if index > self.max: self.max = index
return result

def __iter__(self):
return map(super().__getitem__, range(self.max, -1, -1))

0 comments on commit 362a2b9

Please sign in to comment.