Skip to content

Commit

Permalink
Remove all instances of for x in set(y)
Browse files Browse the repository at this point in the history
This pattern is replaced with `for x in unique_ever_seen(y)`, which also
removes duplicate elements, but is guaranteed to always produce elements
in the same order as in the input list.

Signed-off-by: Kale Kundert <[email protected]>
  • Loading branch information
kalekundert committed Oct 31, 2024
1 parent 1ef554a commit a4188a2
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 27 deletions.
5 changes: 3 additions & 2 deletions escnn/kernels/steerable_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from escnn.group import Group
from escnn.group import IrreducibleRepresentation
from escnn.group import Representation
from escnn.utils import unique_ever_seen

import torch

Expand Down Expand Up @@ -238,9 +239,9 @@ def __init__(self,
js = set()

# loop over all input irreps
for i_irrep_id in set(in_repr.irreps):
for i_irrep_id in unique_ever_seen(in_repr.irreps):
# loop over all output irreps
for o_irrep_id in set(out_repr.irreps):
for o_irrep_id in unique_ever_seen(out_repr.irreps):
try:
# retrieve the irrep intertwiner basis
intertwiner_basis = irreps_basis._generator(basis, i_irrep_id, o_irrep_id, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion escnn/kernels/wignereckart_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .steerable_filters_basis import SteerableFiltersBasis

from escnn.group import *
from escnn.utils import unique_ever_seen

import torch

Expand Down Expand Up @@ -302,7 +303,7 @@ def __init__(self,
_js_restriction = defaultdict(list)

# for each harmonic j' to consider
for _j in set(_j for _j, _ in basis.js):
for _j in unique_ever_seen(_j for _j, _ in basis.js):
if basis.multiplicity(_j) == 0:
continue

Expand Down
2 changes: 1 addition & 1 deletion escnn/nn/modules/basismanager/basisexpansion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from escnn.kernels import KernelBasis, EmptyBasisException
from escnn.group import Representation
from escnn.nn.modules import utils
from escnn.nn.modules.utils import unique_ever_seen
from escnn.utils import unique_ever_seen

from .basismanager import BasisManager
from .basisexpansion_singleblock import block_basisexpansion
Expand Down
2 changes: 1 addition & 1 deletion escnn/nn/modules/basismanager/basissampler_blocks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

from escnn.group import Representation
from escnn.kernels import KernelBasis, EmptyBasisException
from escnn.utils import unique_ever_seen


from escnn.nn.modules.basismanager import retrieve_indices
from .basismanager import BasisManager

from escnn.nn.modules.basismanager.basissampler_singleblock import block_basissampler
from escnn.nn.modules.utils import unique_ever_seen

from typing import Callable, Tuple, Dict, List, Iterable, Union
from collections import defaultdict
Expand Down
2 changes: 1 addition & 1 deletion escnn/nn/modules/batchnormalization/gnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from escnn.gspaces import *
from escnn.nn import FieldType
from escnn.nn import GeometricTensor
from escnn.nn.modules.utils import unique_ever_seen
from escnn.utils import unique_ever_seen

from ..equivariant_module import EquivariantModule

Expand Down
2 changes: 1 addition & 1 deletion escnn/nn/modules/batchnormalization/iid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from escnn.gspaces import *
from escnn.nn import FieldType
from escnn.nn import GeometricTensor
from escnn.nn.modules.utils import unique_ever_seen
from escnn.utils import unique_ever_seen

from ..equivariant_module import EquivariantModule

Expand Down
13 changes: 1 addition & 12 deletions escnn/nn/modules/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from escnn.nn import FieldType
from typing import List, Dict, Tuple, Iterable
from typing import List, Dict, Tuple
from collections import defaultdict


Expand Down Expand Up @@ -55,14 +55,3 @@ def indexes_from_labels(in_type: FieldType, labels: List[str]) -> Dict[str, Tupl

return groups


def unique_ever_seen(iterable: Iterable) -> Iterable:
already_seen = set()

for item in iterable:
if item in already_seen:
continue
else:
already_seen.add(item)
yield item

13 changes: 13 additions & 0 deletions escnn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Iterable

def unique_ever_seen(iterable: Iterable) -> Iterable:
already_seen = set()

for item in iterable:
if item in already_seen:
continue
else:
already_seen.add(item)
yield item


8 changes: 4 additions & 4 deletions test/nn/test_basisexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ def compare(self, basis: BlocksBasisExpansion):

for i, attr1 in enumerate(basis.get_basis_info()):
attr2 = basis.get_element_info(i)
self.assertEquals(attr1, attr2)
self.assertEquals(attr1['id'], i)
self.assertEqual(attr1, attr2)
self.assertEqual(attr1['id'], i)

for _ in range(5):
w = torch.randn(basis.dimension())

f1 = basis(w)
f2 = basis(w)
assert torch.allclose(f1, f2)
self.assertEquals(f1.shape[1], basis._input_size)
self.assertEquals(f1.shape[0], basis._output_size)
self.assertEqual(f1.shape[1], basis._input_size)
self.assertEqual(f1.shape[0], basis._output_size)


def test_checkpoint_meshgrid(self):
Expand Down
8 changes: 4 additions & 4 deletions test/nn/test_basissampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def compare(self, basis: BlocksBasisSampler, d: int):

for i, attr1 in enumerate(basis.get_basis_info()):
attr2 = basis.get_element_info(i)
self.assertEquals(attr1, attr2)
self.assertEquals(attr1['id'], i)
self.assertEqual(attr1, attr2)
self.assertEqual(attr1['id'], i)

for _ in range(5):
P = 20
Expand All @@ -150,8 +150,8 @@ def compare(self, basis: BlocksBasisSampler, d: int):
f1 = basis(w, edge_delta)
f2 = basis(w, edge_delta)
self.assertTrue(torch.allclose(f1, f2))
self.assertEquals(f1.shape[2], basis._input_size)
self.assertEquals(f1.shape[1], basis._output_size)
self.assertEqual(f1.shape[2], basis._input_size)
self.assertEqual(f1.shape[1], basis._output_size)

y1 = basis.compute_messages(w, x_j, edge_delta, conv_first=False)
y2 = basis.compute_messages(w, x_j, edge_delta, conv_first=True)
Expand Down

0 comments on commit a4188a2

Please sign in to comment.