diff --git a/benchmarks/distributed/dataloading.py b/benchmarks/distributed/dataloading.py index 5d1777b57..0c8b03cdf 100644 --- a/benchmarks/distributed/dataloading.py +++ b/benchmarks/distributed/dataloading.py @@ -28,8 +28,7 @@ import torch import tqdm -from tensordict import MemoryMappedTensor -from tensordict.prototype import tensorclass +from tensordict import MemoryMappedTensor, tensorclass from torch import multiprocessing as mp, nn from torch.distributed import rpc from torch.utils.data import DataLoader diff --git a/benchmarks/tensorclass/test_tensorclass_speed.py b/benchmarks/tensorclass/test_tensorclass_speed.py index 0c8bf8c7d..452f6b86a 100644 --- a/benchmarks/tensorclass/test_tensorclass_speed.py +++ b/benchmarks/tensorclass/test_tensorclass_speed.py @@ -9,7 +9,7 @@ import pytest import torch -from tensordict.prototype import tensorclass +from tensordict import tensorclass @tensorclass diff --git a/benchmarks/tensorclass/test_torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py index 0dc3c560c..7ed717872 100644 --- a/benchmarks/tensorclass/test_torch_functions.py +++ b/benchmarks/tensorclass/test_torch_functions.py @@ -6,7 +6,7 @@ import pytest import torch -from tensordict.prototype import tensorclass +from tensordict import tensorclass @tensorclass