Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add TensorDict storage #826

Closed
wants to merge 18 commits into from
Closed

Conversation

mjlaali
Copy link

@mjlaali mjlaali commented Jun 24, 2024

Description

TensorDict storage is a utility function that allows to use external storage (e.g. storage in cpu memory) with in other modules
in torch.

Motivation and Context

Why is this change required? What problem does it solve?
This helps to implement algorithms such MCTS (see https://github.com/mjlaali/torchrl_mcts) or RAG modules.

cc @shagunsodhani @dtsaras

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2024
@vmoens vmoens added the enhancement New feature or request label Jun 24, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a very quick initial review but overall it looks marvellous!

Do we need all of these class attributes to be public?

Can we have a short example in each docstring?

tensordict/nn/storage.py Outdated Show resolved Hide resolved
tensordict/nn/storage.py Outdated Show resolved Hide resolved
expands as necessary.
"""

def __init__(self, default_tensor: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default could be a TensorDict no?

Copy link
Author

@mjlaali mjlaali Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can switch to TensorDict, no strong feeling.

My thought process that torch.Tensor is more granular than TensorDict. In particular, we may have different storage for different keys of TensorDict.

Note that we can have DynamicStorage of TensorDict with TensorDictStorage.

"""

def __init__(self, default_tensor: torch.Tensor):
self.tensor_dict: Dict[int, torch.Tensor] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we could use a TensorDict?

def clear(self) -> None:
self.tensor_dict.clear()

def __getitem__(self, indices: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we put some checks, eg, are the indices tensors? Are they 1d? Are they long type?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the 1d check, on the long type (or int type), I am not sure what is the right decision. Float can be supported, but is not a safe and can be error prone.

tensordict/nn/storage.py Outdated Show resolved Hide resolved
tensordict/nn/storage.py Outdated Show resolved Hide resolved
@mjlaali
Copy link
Author

mjlaali commented Jun 25, 2024

Do we need all of these class attributes to be public?

I am not sure if I understood the question. If you mean abstraction over the storage, yes, the list of methods is extracted from what I need in MCTS implementation.

Can we have a short example in each docstring?

Sure

@vmoens
Copy link
Contributor

vmoens commented Jun 25, 2024

@mjlaali

I am not sure if I understood the question. If you mean abstraction over the storage, yes, the list of methods is extracted from what I need in MCTS implementation.

I just mean that in all these classes, any attrbute we think the public will / should not care about should be named with a leading underscore to let us change / remove them at will.
Making all attributes public is stringent as it means that we will need deprecation warnings every time we want to change the internals of a class.

@mjlaali
Copy link
Author

mjlaali commented Jun 25, 2024

Only methods defined in TensorStorage interface should be public, the rest is private.

Regarding classes, I have developed two types of storage (Dynamic/Fixed Storage). For MCTS, we only need DynamicStorage.

raise NotImplementedError


class DynamicStorage(TensorStorage[torch.Tensor, torch.Tensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about making this or the parent class a nn.Module?
I'm suggesting this because there is a tensor default_tensor in it, and it may be useful to register it as a buffer such that calling dyn_storage.to("cuda") sends the tensor to the device.

Examples:
>>> storage = DynamicStorage(default_tensor=torch.zeros((1,)))
>>> index = torch.randn((3,))
>>> value = torch.rand((2, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to broadcast this?
I changed the first dim such that it matches the index's but if we want broadcasting we should make sure it works that way.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this.
I think there is typo here, I checked the unit test as well and it has to be
value = torch.rand((3, 1))

@vmoens
Copy link
Contributor

vmoens commented Jun 26, 2024

I made some improvements in da61af8, LMK if you think some should be reverted.

I wonder if TensorDictStorage could not be simplified if there is only one index key in the query_module:

        >>> import torch
        >>> from tensordict import TensorDict
        >>> from typing import cast
        >>> query_module = QueryModule(
        ...     in_keys=["key1", "key2"],
        ...     index_key="index",
        ...     hash_module=SipHash(),
        ... )
        >>> embedding_storage = DynamicStorage(
        ...     default_tensor=torch.zeros((1,)),
        ... )
        >>> tensor_dict_storage = TensorDictStorage(
        ...     query_module=query_module,
        ...     # key_to_storage={"index": embedding_storage},  # We know it's "index" since there is only one index key
        ...     key_to_storage= embedding_storage,
        ... )
        >>> index = TensorDict(
        ...     {
        ...         "key1": torch.Tensor([[-1], [1], [3], [-3]]),
        ...         "key2": torch.Tensor([[0], [2], [4], [-4]]),
        ...     },
        ...     batch_size=(4,),
        ... )
        >>> # value = TensorDict({"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)) # We know it's "index" since there is only one index key
        >>>  # Use this instead
        >>> value = torch.Tensor([[10], [20], [30], [40]])
        >>> tensor_dict_storage[index] = value
        >>> # indexing returns a single tensor / tensordict since there is only one key
		>>> tensor_dict_storage[index]
        Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},

but maybe I'm missing the point and in most cases we'll have way more than one index key?

@mjlaali
Copy link
Author

mjlaali commented Jun 27, 2024

I made some improvements in da61af8, LMK if you think some should be reverted.

Great work, thanks for the edits.

I wonder if TensorDictStorage could not be simplified if there is only one index key in the query_module:

I think there is a confusion because of my bad naming, I should've not used index in the example.

In particular, key_to_storage is not related to query_module. key_to_storage is about output tensordict (the storage of each output tensordict key) and query_module is about input tensordict (how to create an indices from input tensordict).

Let me know if this makes sense?

@vmoens
Copy link
Contributor

vmoens commented Jun 27, 2024

In particular, key_to_storage is not related to query_module. key_to_storage is about output tensordict (the storage of each output tensordict key) and query_module is about input tensordict (how to create an indices from input tensordict).

Let me know if this makes sense?

Not sure I'm following sorry. Can you maybe give an example where "index" is used properly (as you say it's a bad naming) and where there are more keys in the key_to_storage such that it's apparent why we need a dict?

@mjlaali
Copy link
Author

mjlaali commented Jun 28, 2024

This is an example that how key_to_storage allows us to have different storage with different output key.

query_module = QueryModule(
  in_keys=["index_key1", "index_key2"],
  index_key="index",
  hash_module=SipHash(),
)

emd_1d = FixedStorage(
    torch.nn.Embedding(num_embeddings=50, embedding_dim=1),
)
emd_2d = FixedStorage(
    torch.nn.Embedding(num_embeddings=100, embedding_dim=2),
)

tensor_dict_storage = TensorDictStorage(
    query_module=query_module,
    key_to_storage={"out_key1": emd_1d, "out_key2": emd_2d}, 
)
index = TensorDict(
   {
        "index_key1": torch.Tensor([[-1], [1], [3], [-3]]),
        "index_key2": torch.Tensor([[0], [2], [4], [-4]]),
    },
   batch_size=(4,),
)
value = TensorDict(
   {
        "out_key1": torch.Tensor([[-1], [1], [3], [-3]]),
        "out_key2": torch.Tensor([[0, 1], [1, 2], [2, 3], [3, 4]]),
    },
   batch_size=(4,),
)
tensor_dict_storage[index] = value

In this case, we have to use different storage because the dim is forced to be the same for every index when we use FixedStorage.
DynamicStorage would not have this issue, thought, we have to internally keep different storage per output key (like what key_to_storage provides).

@vmoens
Copy link
Contributor

vmoens commented Jul 1, 2024

Ok got it, I'll play a bit with it to make sure I fully grasp that feature and come back to you!

@vmoens
Copy link
Contributor

vmoens commented Jul 4, 2024

Ok @mjlaali I think that I would pitch this feature as this:
TensorDictStorage allows you to use a tensordict as an index, to build a bijective function

TensorDict -> TensorDict

The bijective is important: we want that for a specific tensordict (e.g. a value of "observation" and a value of "action") we get one tensor out (with entries "a", "b" and "c" always having corresponding values).
In principle the opposite should hold: as every {"a", "b", "c"} combination there should be one {"observation", "action" associated.
I think that the example should have 2 entries on one side and 3 entries on the other to make it clear that it's not that input0 points to output0 and input1 to output1, but (input0, input1) point to (output0, output1).

For this feature to work we need to parse the tensordict used as index into a "real" long tensor index. We provide a SipHash class to do that along with a query module. This is our TensorDict2Index pipeline.

Then we need to find a way to store values and associate them with an long-tensor index. We can do that via DynamicStorage which writes a tensor for each index value (we assume that the index is a single integer).
A FixedStorage relies on nn.embedding and will write on the weights of that module.
We can have multiple storages associated with our module, and if we have a nested key associated with each of these we can combine the results in a TensorDict.
This is the Index2TensorDict part. Joining the TensorDict2Index and Index2TensorDict gives us a TensorDict2TensorDict.

The most crucial thing I can think of right now is what are the limitations and assumptions, and in which situations will this fail to be efficient.
Are we limited in size, ie can we use billions of tensordicts as indices or will this break? Which class will work best in large dimensional setting? These info should be put in the docstrings (like: "this module is particularly fit for settings where there are less than X observations" or "this module is particularly fit for settings where the observations have fewer than X elements").
What are the constrains in term of input dimension? Can I use an image as in_keys in my tensordict storage? Should we just flatten it? If we have images, maybe we should also give the opportunity to people to write their own embedding (maybe it's trivial using TensorDictModules...)

After giving a thorough look at all this I think that TensorDictMap is more appropriate:

  1. There seems to be a confusion about storage because we use Storage in the Index2Tensor part but the main module is called TensorDictStorage
  2. The term storage in pytorch usually refers to the content of a tensor. Since we now have tensordict.consolidate() which plays a lot with the storage of its content, using the term storage here may be confusing.

About that, I also think that the term is a bit overloaded here as DynamicStorage is about the Index2TensorDict part whereas TensorDictStorage refers to the whole TensorDict2TensorDict pipeline. I think that using Map could clarify that.

Comment on lines +585 to +608
Examples:
>>> # The following example requires torchrl and gymnasium to be installed
>>> from tensordict.nn.storage import TensorDictStorage, RandomProjectionHash
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("CartPole-v1")
>>> rollout = env.rollout(100)
>>> source, dest = r.exclude("next"), r.get("next")
>>> storage = TensorDictStorage.from_tensordict_pair(
... source, dest,
... in_keys=["observation", "action"],
... )
>>> # maps the (obs, action) tuple to a corresponding next state
>>> storage[source] = dest
>>> storage[source]
TensorDict(
fields={
done: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([35, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([35]),
device=None,
is_shared=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to have a better, standalone example, but this one is nice bc it's a bit "real-life" and makes the point clear about the storage

cc @dtsaras for info

@mjlaali
Copy link
Author

mjlaali commented Jul 6, 2024

TensorDictStorage allows you to use a tensordict as an index, to build a bijective function

The bijective function is an interesting view of TensorDictStorage, though originally the reason I needed it in MCTS was to have a way to manipulate the output of a function (not necessarily the bijective function).

TensorDictModule is also a function, but it is difficult to change its output. I needed something like a memory/storage to save, retrieve, and update a value.

For this feature to work we need to parse the tensordict used as index into a "real" long tensor index.

This is true for this implementation, but I hope we can extend this in the future so that we can support KNN search with the same abstraction, such as retrieving the K-closest items to the input. This can be quite useful for RAG LLM.

TensorDict2Index and Index2TensorDict

I like these names, they are more intuitive 👍

The most crucial thing I can think of right now is what are the limitations and assumptions, and in which situations will this fail to be efficient.

To me, the biggest limitation is that it does not allow the gradient to flow to memory, at least in this implementation. The second issue could be performance-related, in particular, DynamicStorage and SipHash have both been implemented in CPU memory, so tensors have to move from GPU. (SipHash could have been implemented with PyTorch operations, so this can be alleviated in the future.)

Finally, concurrency is another issue, especially in distributed training jobs.

Are we limited in size, ie can we use billions of tensordicts as indices or will this break?

This a good question, it is quite interesting to see if we can wrap FAISS in this api.

After giving a thorough look at all this I think that TensorDictMap is more appropriate

While I originally thought about this API as a memory/storage API (the main added value is the ease in updating a storage/memory value), I have no strong opinion here. Either name works.

):
super().__init__()
from sklearn.random_projection import (
GaussianRandomProjection,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as MLP where the weight are initialized from gaussian distribution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I believe there is some fitting to the data involved, where it's trying to find the best separation for your examples using the projection.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick check and GaussianRandomProjection does not seem to use X in the fit function, it only uses it to get input shape:
https://github.com/scikit-learn/scikit-learn/blob/70fdc843a4b8182d97a3508c1a426acc5e87e980/sklearn/random_projection.py#L364

return self.flag[index]


class LazyFixedStorage(FixedStorage):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make to convert FixedStorage implementation such that by default it is lazy-initialized?

Is there any use case that we don't want lazy-initialization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same thought, and actually that train of thought went even further... let me elaborate.

I refactored a bit FixedStorage such that we don't use modulo to get the index transform but a dict that maps the hash value to an int from 0 to M-1 where M is the max number of items we are going to store.

Also I think we don't need to rely on the Embedding class, we could just have a single tensor (akin to the weight that we're modifying) and assign values to it. That would make one less "code dependence" and clarify things IMO.

Now let's put these things together: (1) we need to convert any sparse index from 0 to N-1 with N >> M to an index from 0 to M-1 and (2) once we have that "dense" index we need to write on a storage (possibly lazily).
That looks a lot like torchrl's TensorStorage. The DynamicStorage looks a lot like torchrl's ListStorage (provided that you add the index-to-index transform explained in (1)).

So I'm starting to wonder: we are coding something with the purpose of applying it to MCTS, and recoding something very similar to what we have in torchrl (at list the second block, the Index2TensorDict part). What's missing in torchrl is

  1. The TensorDict2Index (how to map a tensordict to a hash value)
  2. How to map the hash value to an index from 0 to M-1
    but we have the storage part, and it's pretty neat (works with prioritized values, N-dimension, arbitrary pytrees...)

Not sure what the way forward is, but here are various versions of what we could do

  1. Make sure that whatever API we propose here matches TorchRL's storages one, such that we can have a minimal amount of storages coded here and just use TorchRL's when we need to
  2. move things to torchrl and make an API that leverages the storages but provides more flexibility in terms of indexing. Something like:
num_elts = 100
t2t_storage = TD2TDStorage(
    storage=LazyTensorStorage(num_elts), 
    in_keys=["obs", "action"], 
    out_keys=None, 
    map=TensorToIndex()) # out_keys dynamically determined

(1) would be a bit weird if we don't make things well bc there will be some implicit dependency on TorchRL's API that will make little sense here, but it could make sense if we have other usages of this kind of storages than MCTS (can you name any?)
(2) would make more sense if we are only thinking about MCTS

Copy link
Author

@mjlaali mjlaali Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ListStorage is pretty neat, I missed it in TorchRL.

My typical strategy is to optimize for simplicity, in this case having the code in torchrl make many decisions more simple. So my vote would be moving this PR to TorchRL. (Let me know if you want to close this PR here and re-open it in TorchRL)

At this moment, I don't have any other use case than MCTS. If future, if the same abstraction can be used to implement K-NN search, then we may have other use cases (e.g. RAG), thought I would not optimize for this potential use case.

@vmoens vmoens mentioned this pull request Jul 9, 2024
4 tasks
@vmoens
Copy link
Contributor

vmoens commented Jul 22, 2024

Closing in favor of pytorch/rl#2283

@vmoens vmoens closed this Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants