Skip to content

Commit

Permalink
[BE] Check ordering and exclusivity of tensorclass registers
Browse files Browse the repository at this point in the history
ghstack-source-id: 3dc907f4dd3047238adb0bb309d9ae75d24c5085
Pull Request resolved: #1176
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent dedec04 commit b493178
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_tensorclass_stub_methods():
]

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)
Expand Down Expand Up @@ -156,6 +155,34 @@ class X:
)


def test_sorted_methods():
from tensordict.tensorclass import (
_FALLBACK_METHOD_FROM_TD,
_FALLBACK_METHOD_FROM_TD_FORCE,
_FALLBACK_METHOD_FROM_TD_NOWRAP,
_METHOD_FROM_TD,
)

lists_to_check = [
_FALLBACK_METHOD_FROM_TD_NOWRAP,
_METHOD_FROM_TD,
_FALLBACK_METHOD_FROM_TD_FORCE,
_FALLBACK_METHOD_FROM_TD,
]
# Check that each list is sorted and has unique elements
for lst in lists_to_check:
assert lst == sorted(lst), f"List {lst} is not sorted"
assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements"
# Check that no two lists share any elements
for i, lst1 in enumerate(lists_to_check):
for j, lst2 in enumerate(lists_to_check):
if i != j:
shared_elements = set(lst1) & set(lst2)
assert (
not shared_elements
), f"Lists {lst1} and {lst2} share elements: {shared_elements}"


def _make_data(shape):
return MyData(
X=torch.rand(*shape),
Expand Down

0 comments on commit b493178

Please sign in to comment.