From b493178d09b0107e02119ff4eacb72be2645e712 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 18:18:04 +0000 Subject: [PATCH] [BE] Check ordering and exclusivity of tensorclass registers ghstack-source-id: 3dc907f4dd3047238adb0bb309d9ae75d24c5085 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1176 --- test/test_tensorclass.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 48711851c..01a5aec37 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -125,7 +125,6 @@ def test_tensorclass_stub_methods(): ] if missing_methods: - raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") raise Exception( f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" ) @@ -156,6 +155,34 @@ class X: ) +def test_sorted_methods(): + from tensordict.tensorclass import ( + _FALLBACK_METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + ) + + lists_to_check = [ + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD, + ] + # Check that each list is sorted and has unique elements + for lst in lists_to_check: + assert lst == sorted(lst), f"List {lst} is not sorted" + assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements" + # Check that no two lists share any elements + for i, lst1 in enumerate(lists_to_check): + for j, lst2 in enumerate(lists_to_check): + if i != j: + shared_elements = set(lst1) & set(lst2) + assert ( + not shared_elements + ), f"Lists {lst1} and {lst2} share elements: {shared_elements}" + + def _make_data(shape): return MyData( X=torch.rand(*shape),