diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4519900ae8b..70fdf03c0ff 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments. PendulumEnv TicTacToeEnv + LLMHashingEnv + Multi-agent environments ------------------------ diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 8f6be633743..264534a725c 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward - logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the + logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. @@ -174,7 +174,7 @@ Trainer and hooks BatchSubSampler ClearCudaCache CountFramesLog - LogScaler + LogScalar OptimizerHook LogValidationReward ReplayBufferTrainer diff --git a/test/test_env.py b/test/test_env.py index b48b1a1cf8f..cef7a507f2a 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -8,6 +8,7 @@ import functools import gc import os.path +import random import re from collections import defaultdict from functools import partial @@ -114,6 +115,7 @@ DoubleToFloat, EnvBase, EnvCreator, + LLMHashingEnv, ParallelEnv, PendulumEnv, SerialEnv, @@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device): r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device)) assert r.shape == torch.Size((5, 10)) + def test_llm_hashing_env(self): + vocab_size = 5 + + class Tokenizer: + def __call__(self, obj): + return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist() + + def decode(self, obj): + words = ["apple", "banana", "cherry", "date", "elderberry"] + return " ".join(random.choice(words) for _ in obj) + + def batch_decode(self, obj): + return [self.decode(_obj) for _obj in obj] + + def encode(self, obj): + return self(obj) + + tokenizer = Tokenizer() + env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size) + td = env.make_tensordict("some sentence") + assert isinstance(td, TensorDict) + env.check_env_specs(tensordict=td) + @pytest.mark.parametrize("device", [None, *get_default_devices()]) @pytest.mark.parametrize("env_device", [None, *get_default_devices()]) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 513a7b94e58..c09db75aa5b 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -135,35 +135,40 @@ def make_node( def full_observation_spec(self): """The observation spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`. + """ return self.specs["output_spec", "full_observation_spec"] @property def full_reward_spec(self): """The reward spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`. + """ return self.specs["output_spec", "full_reward_spec"] @property def full_done_spec(self): """The done spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`. + """ return self.specs["output_spec", "full_done_spec"] @property def full_state_spec(self): """The state spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`. + """ return self.specs["input_spec", "full_state_spec"] @property def full_action_spec(self): """The action spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`. + """ return self.specs["input_spec", "full_action_spec"] @property diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 36e4ec1a908..f3dec221ce0 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import PendulumEnv, TicTacToeEnv +from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index bafe88b639a..4f6002dedd3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,14 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.utils import NestedKey +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) +from tensordict.base import _is_leaf_nontensor +from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, _make_ordinal_device, @@ -25,7 +31,13 @@ seed_generator, ) -from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import ( + Categorical, + Composite, + NonTensor, + TensorSpec, + Unbounded, +) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -430,7 +442,6 @@ def auto_specs_( done_key: NestedKey | List[NestedKey] | None = None, observation_key: NestedKey | List[NestedKey] = "observation", reward_key: NestedKey | List[NestedKey] = "reward", - batch_size: torch.Size | None = None, ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -484,6 +495,7 @@ def auto_specs_( tensordict2, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) input_spec = Composite(input_spec_stack, batch_size=batch_size) if not self.batch_locked and batch_size != self.batch_size: @@ -501,6 +513,7 @@ def auto_specs_( nexts_1, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) output_spec = Composite(output_spec_stack, batch_size=batch_size) @@ -523,7 +536,8 @@ def auto_specs_( full_observation_spec = output_spec.separates(*observation_key, default=None) if not output_spec.is_empty(recurse=True): raise RuntimeError( - f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + f"Keys {list(output_spec.keys(True, True))} are unaccounted for. " + f"Make sure you have passed all the leaf names to the auto_specs_ method." ) if full_action_spec is not None: @@ -541,6 +555,8 @@ def auto_specs_( @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): + return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs) + kwargs["return_contiguous"] = return_contiguous return check_env_specs_func(self, *args, **kwargs) check_env_specs.__doc__ = check_env_specs_func.__doc__ @@ -3206,7 +3222,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - any_done = done.any() + if done is not None: + any_done = done.any() + else: + any_done = False if any_done: tensordict._set_str( "_reset", @@ -3572,6 +3591,12 @@ def _has_dynamic_specs(spec: Composite): def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)): + stack[name] = NonTensor(shape=()) + return + elif is_non_tensor(leaf): + stack[name] = NonTensor(shape=leaf.shape) + return shape = leaf.shape if leaf_compare is not None: shape_compare = leaf_compare.shape diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 8649d3d3e97..375a0e23a57 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -3,5 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py new file mode 100644 index 00000000000..2f456482147 --- /dev/null +++ b/torchrl/envs/custom/llm.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable, List, Union + +import torch +from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from torchrl.data import ( + Categorical as CategoricalSpec, + Composite, + NonTensor, + SipHash, + Unbounded, +) +from torchrl.envs import EnvBase +from torchrl.envs.utils import _StepMDP + + +class LLMHashingEnv(EnvBase): + """A text generation environment that uses a hashing module to identify unique observations. + + The primary goal of this environment is to identify token chains using a hashing function. + This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node + identifiers, or easily prune repeated token chains in a data structure. + The following figure gives an overview of this workflow: + + .. figure:: /_static/img/rollout-llm.png + :alt: Data collection loop with our LLM environment. + + .. seealso:: the :ref:`Beam Search ` tutorial gives a practical example of how this env can be used. + + Args: + vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed. + + Keyword Args: + hashing_module (Callable[[torch.Tensor], torch.Tensor], optional): + A hashing function that takes a tensor as input and returns a hashed tensor. + Defaults to :class:`~torchrl.data.SipHash` if not provided. + observation_key (NestedKey, optional): The key for the observation in the TensorDict. + Defaults to "observation". + text_output (bool, optional): Whether to include the text output in the observation. + Defaults to True. + tokenizer (transformers.Tokenizer | None, optional): + A tokenizer function that converts text to tensors. + Only used when `text_output` is `True`. + Must implement the following methods: `decode` and `batch_decode`. + Defaults to ``None``. + text_key (NestedKey | None, optional): The key for the text output in the TensorDict. + Defaults to "text". + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.envs import LLMHashingEnv + >>> from transformers import GPT2Tokenizer + >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + >>> x = tokenizer(["Check out TorchRL!"])["input_ids"] + >>> env = LLMHashingEnv(tokenizer=tokenizer) + >>> td = TensorDict(observation=x, batch_size=[1]) + >>> td = env.reset(td) + >>> print(td) + TensorDict( + fields={ + done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + text: NonTensorStack( + ['Check out TorchRL!'], + batch_size=torch.Size([1]), + device=None)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) + + """ + + def __init__( + self, + vocab_size: int | None = None, + *, + hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, + observation_key: NestedKey = "observation", + text_output: bool = True, + tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None, + text_key: NestedKey | None = "text", + ): + super().__init__() + if vocab_size is None: + if tokenizer is None: + raise TypeError( + "You must provide a vocab_size integer if tokenizer is `None`." + ) + vocab_size = tokenizer.vocab_size + self._batch_locked = False + if hashing_module is None: + hashing_module = SipHash() + + self._hashing_module = hashing_module + self._tokenizer = tokenizer + self.observation_key = observation_key + observation_spec = { + observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)), + "hashing": Unbounded(shape=(1,), dtype=torch.int64), + } + self.text_output = text_output + if not text_output: + text_key = None + elif text_key is None: + text_key = "text" + if text_key is not None: + observation_spec[text_key] = NonTensor(shape=()) + self.text_key = text_key + self.observation_spec = Composite(observation_spec) + self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) + _StepMDP(self) + + def make_tensordict(self, input: str | List[str]) -> TensorDict: + """Converts a string or list of strings in a TensorDict with appropriate shape and device.""" + list_len = len(input) if isinstance(input, list) else 0 + tensordict = TensorDict( + {self.observation_key: self._tokenizer(input)}, device=self.device + ) + if list_len: + tensordict.batch_size = [list_len] + return self.reset(tensordict) + + def _reset(self, tensordict: TensorDictBase): + """Initializes the environment with a given observation. + + Args: + tensordict (TensorDictBase): A TensorDict containing the initial observation. + + Returns: + A TensorDict containing the initial observation, its hash, and other relevant information. + + """ + out = tensordict.empty() + obs = tensordict.get(self.observation_key, None) + if obs is None: + raise RuntimeError( + f"Resetting the {type(self).__name__} environment requires a prompt." + ) + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + out.set(self.text_key, text) + + if obs.ndim > 1: + out.set("hashing", self._hashing_module(obs).unsqueeze(-1)) + else: + out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) + + if not self.full_done_spec.is_empty(): + out.update(self.full_done_spec.zero(tensordict.shape)) + else: + out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)) + out.set( + "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool) + ) + return out + + def _step(self, tensordict): + """Takes an action (i.e., the next token to generate) and returns the next observation and reward. + + Args: + tensordict: A TensorDict containing the current observation and action. + + Returns: + A TensorDict containing the next observation, its hash, and other relevant information. + """ + out = tensordict.empty() + action = tensordict.get("action") + obs = torch.cat([tensordict.get(self.observation_key), action], -1) + kwargs = {self.observation_key: obs} + + catval = torch.cat([tensordict.get("hashing"), action], -1) + if obs.ndim > 1: + new_hash = self._hashing_module(catval).unsqueeze(-1) + else: + new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1) + + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + kwargs[self.text_key] = text + kwargs.update( + { + "hashing": new_hash, + "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), + "terminated": torch.zeros( + (*tensordict.batch_size, 1), dtype=torch.bool + ), + } + ) + return out.update(kwargs) + + def _set_seed(self, *args): + """Sets the seed for the environment's randomness. + + .. note:: This environment has no randomness, so this method does nothing. + """ + pass diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 209349878ec..d2ec66475ab 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -76,7 +76,7 @@ def __get__(self, cls, owner): class _StepMDP: - """Stateful version of step_mdp. + """Stateful version of :func:`~torchrl.envs.step_mdp`. Precomputes the list of keys to include and exclude during a call to step_mdp to reduce runtime. @@ -778,12 +778,15 @@ def check_env_specs( ) zeroing_err_msg = ( "zeroing the two tensordicts did not make them identical. " - "Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" + f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" ) from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict.filter_non_tensor_data().unbind(-1), + fake_tensordict.filter_non_tensor_data().unbind(-1), + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1367,6 +1370,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key