diff --git a/tensordict/nn/storage.py b/tensordict/nn/storage.py index e7d58cead..6bad8870b 100644 --- a/tensordict/nn/storage.py +++ b/tensordict/nn/storage.py @@ -97,7 +97,7 @@ def __init__( def clear(self): self.init_fm(self.embedding.weight) - self.flag = torch.zeros(size=(self.embedding.num_embeddings, 1)).to(torch.int64) + self.flag = torch.zeros((self.embedding.num_embeddings, 1), dtype=torch.int64) def to_index(self, item: torch.Tensor) -> torch.Tensor: return torch.remainder(item.to(torch.int64), self.num_embedding).to(torch.int64)