-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add TensorDict storage #826
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a very quick initial review but overall it looks marvellous!
Do we need all of these class attributes to be public?
Can we have a short example in each docstring?
expands as necessary. | ||
""" | ||
|
||
def __init__(self, default_tensor: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default could be a TensorDict no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can switch to TensorDict, no strong feeling.
My thought process that torch.Tensor is more granular than TensorDict. In particular, we may have different storage for different keys of TensorDict.
Note that we can have DynamicStorage of TensorDict with TensorDictStorage.
""" | ||
|
||
def __init__(self, default_tensor: torch.Tensor): | ||
self.tensor_dict: Dict[int, torch.Tensor] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, we could use a TensorDict?
def clear(self) -> None: | ||
self.tensor_dict.clear() | ||
|
||
def __getitem__(self, indices: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we put some checks, eg, are the indices tensors? Are they 1d? Are they long type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added the 1d check, on the long type (or int type), I am not sure what is the right decision. Float can be supported, but is not a safe and can be error prone.
I am not sure if I understood the question. If you mean abstraction over the storage, yes, the list of methods is extracted from what I need in MCTS implementation.
Sure |
I just mean that in all these classes, any attrbute we think the public will / should not care about should be named with a leading underscore to let us change / remove them at will. |
Only methods defined in TensorStorage interface should be public, the rest is private. Regarding classes, I have developed two types of storage (Dynamic/Fixed Storage). For MCTS, we only need DynamicStorage. |
raise NotImplementedError | ||
|
||
|
||
class DynamicStorage(TensorStorage[torch.Tensor, torch.Tensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about making this or the parent class a nn.Module?
I'm suggesting this because there is a tensor default_tensor
in it, and it may be useful to register it as a buffer such that calling dyn_storage.to("cuda")
sends the tensor to the device.
Examples: | ||
>>> storage = DynamicStorage(default_tensor=torch.zeros((1,))) | ||
>>> index = torch.randn((3,)) | ||
>>> value = torch.rand((2, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to broadcast this?
I changed the first dim such that it matches the index's but if we want broadcasting we should make sure it works that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this.
I think there is typo here, I checked the unit test as well and it has to be
value = torch.rand((3, 1))
I made some improvements in da61af8, LMK if you think some should be reverted. I wonder if >>> import torch
>>> from tensordict import TensorDict
>>> from typing import cast
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
... index_key="index",
... hash_module=SipHash(),
... )
>>> embedding_storage = DynamicStorage(
... default_tensor=torch.zeros((1,)),
... )
>>> tensor_dict_storage = TensorDictStorage(
... query_module=query_module,
... # key_to_storage={"index": embedding_storage}, # We know it's "index" since there is only one index key
... key_to_storage= embedding_storage,
... )
>>> index = TensorDict(
... {
... "key1": torch.Tensor([[-1], [1], [3], [-3]]),
... "key2": torch.Tensor([[0], [2], [4], [-4]]),
... },
... batch_size=(4,),
... )
>>> # value = TensorDict({"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)) # We know it's "index" since there is only one index key
>>> # Use this instead
>>> value = torch.Tensor([[10], [20], [30], [40]])
>>> tensor_dict_storage[index] = value
>>> # indexing returns a single tensor / tensordict since there is only one key
>>> tensor_dict_storage[index]
Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, but maybe I'm missing the point and in most cases we'll have way more than one index key? |
Great work, thanks for the edits.
I think there is a confusion because of my bad naming, I should've not used In particular, Let me know if this makes sense? |
Not sure I'm following sorry. Can you maybe give an example where "index" is used properly (as you say it's a bad naming) and where there are more keys in the |
This is an example that how key_to_storage allows us to have different storage with different output key. query_module = QueryModule(
in_keys=["index_key1", "index_key2"],
index_key="index",
hash_module=SipHash(),
)
emd_1d = FixedStorage(
torch.nn.Embedding(num_embeddings=50, embedding_dim=1),
)
emd_2d = FixedStorage(
torch.nn.Embedding(num_embeddings=100, embedding_dim=2),
)
tensor_dict_storage = TensorDictStorage(
query_module=query_module,
key_to_storage={"out_key1": emd_1d, "out_key2": emd_2d},
)
index = TensorDict(
{
"index_key1": torch.Tensor([[-1], [1], [3], [-3]]),
"index_key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)
value = TensorDict(
{
"out_key1": torch.Tensor([[-1], [1], [3], [-3]]),
"out_key2": torch.Tensor([[0, 1], [1, 2], [2, 3], [3, 4]]),
},
batch_size=(4,),
)
tensor_dict_storage[index] = value In this case, we have to use different storage because the dim is forced to be the same for every index when we use FixedStorage. |
Ok got it, I'll play a bit with it to make sure I fully grasp that feature and come back to you! |
Ok @mjlaali I think that I would pitch this feature as this:
The bijective is important: we want that for a specific tensordict (e.g. a value of "observation" and a value of "action") we get one tensor out (with entries "a", "b" and "c" always having corresponding values). For this feature to work we need to parse the tensordict used as index into a "real" long tensor index. We provide a SipHash class to do that along with a query module. This is our TensorDict2Index pipeline. Then we need to find a way to store values and associate them with an long-tensor index. We can do that via DynamicStorage which writes a tensor for each index value (we assume that the index is a single integer). The most crucial thing I can think of right now is what are the limitations and assumptions, and in which situations will this fail to be efficient. After giving a thorough look at all this I think that TensorDictMap is more appropriate:
About that, I also think that the term is a bit overloaded here as DynamicStorage is about the Index2TensorDict part whereas TensorDictStorage refers to the whole |
Examples: | ||
>>> # The following example requires torchrl and gymnasium to be installed | ||
>>> from tensordict.nn.storage import TensorDictStorage, RandomProjectionHash | ||
>>> from torchrl.envs import GymEnv | ||
>>> env = GymEnv("CartPole-v1") | ||
>>> rollout = env.rollout(100) | ||
>>> source, dest = r.exclude("next"), r.get("next") | ||
>>> storage = TensorDictStorage.from_tensordict_pair( | ||
... source, dest, | ||
... in_keys=["observation", "action"], | ||
... ) | ||
>>> # maps the (obs, action) tuple to a corresponding next state | ||
>>> storage[source] = dest | ||
>>> storage[source] | ||
TensorDict( | ||
fields={ | ||
done: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
observation: Tensor(shape=torch.Size([35, 4]), device=cpu, dtype=torch.float32, is_shared=False), | ||
reward: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.float32, is_shared=False), | ||
terminated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
truncated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, | ||
batch_size=torch.Size([35]), | ||
device=None, | ||
is_shared=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to have a better, standalone example, but this one is nice bc it's a bit "real-life" and makes the point clear about the storage
cc @dtsaras for info
The bijective function is an interesting view of TensorDictStorage, though originally the reason I needed it in MCTS was to have a way to manipulate the output of a function (not necessarily the bijective function). TensorDictModule is also a function, but it is difficult to change its output. I needed something like a memory/storage to save, retrieve, and update a value.
This is true for this implementation, but I hope we can extend this in the future so that we can support KNN search with the same abstraction, such as retrieving the K-closest items to the input. This can be quite useful for RAG LLM.
I like these names, they are more intuitive 👍
To me, the biggest limitation is that it does not allow the gradient to flow to memory, at least in this implementation. The second issue could be performance-related, in particular, DynamicStorage and SipHash have both been implemented in CPU memory, so tensors have to move from GPU. (SipHash could have been implemented with PyTorch operations, so this can be alleviated in the future.) Finally, concurrency is another issue, especially in distributed training jobs.
This a good question, it is quite interesting to see if we can wrap FAISS in this api.
While I originally thought about this API as a memory/storage API (the main added value is the ease in updating a storage/memory value), I have no strong opinion here. Either name works. |
): | ||
super().__init__() | ||
from sklearn.random_projection import ( | ||
GaussianRandomProjection, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same as MLP where the weight are initialized from gaussian distribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I believe there is some fitting to the data involved, where it's trying to find the best separation for your examples using the projection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a quick check and GaussianRandomProjection does not seem to use X in the fit function, it only uses it to get input shape:
https://github.com/scikit-learn/scikit-learn/blob/70fdc843a4b8182d97a3508c1a426acc5e87e980/sklearn/random_projection.py#L364
return self.flag[index] | ||
|
||
|
||
class LazyFixedStorage(FixedStorage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make to convert FixedStorage implementation such that by default it is lazy-initialized?
Is there any use case that we don't want lazy-initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same thought, and actually that train of thought went even further... let me elaborate.
I refactored a bit FixedStorage such that we don't use modulo to get the index transform but a dict that maps the hash value to an int from 0 to M-1 where M is the max number of items we are going to store.
Also I think we don't need to rely on the Embedding class, we could just have a single tensor (akin to the weight that we're modifying) and assign values to it. That would make one less "code dependence" and clarify things IMO.
Now let's put these things together: (1) we need to convert any sparse index from 0 to N-1 with N >> M to an index from 0 to M-1 and (2) once we have that "dense" index we need to write on a storage (possibly lazily).
That looks a lot like torchrl's TensorStorage. The DynamicStorage looks a lot like torchrl's ListStorage (provided that you add the index-to-index transform explained in (1)).
So I'm starting to wonder: we are coding something with the purpose of applying it to MCTS, and recoding something very similar to what we have in torchrl (at list the second block, the Index2TensorDict part). What's missing in torchrl is
- The TensorDict2Index (how to map a tensordict to a hash value)
- How to map the hash value to an index from 0 to M-1
but we have the storage part, and it's pretty neat (works with prioritized values, N-dimension, arbitrary pytrees...)
Not sure what the way forward is, but here are various versions of what we could do
- Make sure that whatever API we propose here matches TorchRL's storages one, such that we can have a minimal amount of storages coded here and just use TorchRL's when we need to
- move things to torchrl and make an API that leverages the storages but provides more flexibility in terms of indexing. Something like:
num_elts = 100
t2t_storage = TD2TDStorage(
storage=LazyTensorStorage(num_elts),
in_keys=["obs", "action"],
out_keys=None,
map=TensorToIndex()) # out_keys dynamically determined
(1) would be a bit weird if we don't make things well bc there will be some implicit dependency on TorchRL's API that will make little sense here, but it could make sense if we have other usages of this kind of storages than MCTS (can you name any?)
(2) would make more sense if we are only thinking about MCTS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ListStorage is pretty neat, I missed it in TorchRL.
My typical strategy is to optimize for simplicity, in this case having the code in torchrl make many decisions more simple. So my vote would be moving this PR to TorchRL. (Let me know if you want to close this PR here and re-open it in TorchRL)
At this moment, I don't have any other use case than MCTS. If future, if the same abstraction can be used to implement K-NN search, then we may have other use cases (e.g. RAG), thought I would not optimize for this potential use case.
Closing in favor of pytorch/rl#2283 |
Description
TensorDict storage is a utility function that allows to use external storage (e.g. storage in cpu memory) with in other modules
in torch.
Motivation and Context
Why is this change required? What problem does it solve?
This helps to implement algorithms such MCTS (see https://github.com/mjlaali/torchrl_mcts) or RAG modules.
cc @shagunsodhani @dtsaras