diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 48711851c..01a5aec37 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -125,7 +125,6 @@ def test_tensorclass_stub_methods(): ] if missing_methods: - raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") raise Exception( f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" ) @@ -156,6 +155,34 @@ class X: ) +def test_sorted_methods(): + from tensordict.tensorclass import ( + _FALLBACK_METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + ) + + lists_to_check = [ + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD, + ] + # Check that each list is sorted and has unique elements + for lst in lists_to_check: + assert lst == sorted(lst), f"List {lst} is not sorted" + assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements" + # Check that no two lists share any elements + for i, lst1 in enumerate(lists_to_check): + for j, lst2 in enumerate(lists_to_check): + if i != j: + shared_elements = set(lst1) & set(lst2) + assert ( + not shared_elements + ), f"Lists {lst1} and {lst2} share elements: {shared_elements}" + + def _make_data(shape): return MyData( X=torch.rand(*shape),