Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] torch.func.functional_call compatibility #526

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@
]

# from tensordict._pytree import *

# execute patch
from ._functional_patch import *
100 changes: 100 additions & 0 deletions tensordict/_functional_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch

from tensordict import TensorDictBase
from torch import Tensor
from torch._functorch.utils import exposed_in
from torch.nn.utils._named_member_accessor import _MISSING

from torch.func import functional_call
from functorch import dim as funcdim

__all__ = []


@exposed_in("torch.func")
def functional_call_patched(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
if isinstance(parameter_and_buffer_dicts, TensorDictBase):
# Ideally, we would like to make the whole stack compatible with tensordict,
# from functional_call to NamedMemberAccessor.
# Withing tensordict library, this would an awful lot of monkey patching
# so we'll do the quickest hack instead.
# It's quite inefficient as we flatten the parameter structure to conventinal names
# to then reconstruct the structure from scratch, but it works.
parameter_and_buffer_dicts = dict(parameter_and_buffer_dicts.flatten_keys("."))
return functional_call(
module,
parameter_and_buffer_dicts,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)

def swap_tensor(
module: "torch.nn.Module",
name: str,
tensor: torch.Tensor,
allow_missing: bool = False,
) -> torch.Tensor:
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if (
tensor is not _MISSING
and not isinstance(tensor, (torch.Tensor, funcdim.Tensor)) # <= adding funcdim.Tensor
and tensor is not None
):
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
if "." in name:
raise KeyError('tensor name can\'t contain "."')
if name == "":
raise KeyError('tensor name can\'t be empty string ""')

orig_tensor: torch.Tensor
if name in module._parameters:
orig_tensor = module._parameters[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._parameters[name] = tensor # type: ignore[assignment]
else:
del module._parameters[name]
elif name in module._buffers:
orig_tensor = module._buffers[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._buffers[name] = tensor
else:
del module._buffers[name]
else:
try:
orig_tensor = getattr(module, name)
except AttributeError as ex:
if not allow_missing:
raise AttributeError(
f"{module._get_name()} has no attribute `{name}`"
) from ex
orig_tensor = _MISSING
if (
orig_tensor is not _MISSING
and not isinstance(orig_tensor, torch.Tensor)
and orig_tensor is not None
):
raise TypeError(
f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
)
if tensor is not _MISSING:
setattr(module, name, tensor)
elif hasattr(module, name):
delattr(module, name)
return orig_tensor

torch.nn.utils._named_member_accessor.swap_tensor = swap_tensor
torch._functorch.functional_call.functional_call = functional_call_patched
torch.func.functional_call = functional_call_patched
7 changes: 4 additions & 3 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, Iterator, Sequence

import torch
from functorch import dim as ftdim

from tensordict import TensorDictBase
from tensordict.nn.utils import Buffer
Expand Down Expand Up @@ -72,7 +73,7 @@ def _get_args_dict(func, args, kwargs):

def _maybe_make_param(tensor):
if (
isinstance(tensor, Tensor)
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
Expand All @@ -82,7 +83,7 @@ def _maybe_make_param(tensor):

def _maybe_make_param_or_buffer(tensor):
if (
isinstance(tensor, Tensor)
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
Expand Down Expand Up @@ -319,7 +320,7 @@ def __torch_function__(
if kwargs is None:
kwargs = {}
if func not in TDPARAM_HANDLED_FUNCTIONS or not all(
issubclass(t, (Tensor, TensorDictBase)) for t in types
issubclass(t, (Tensor, ftdim.Tensor, TensorDictBase)) for t in types
):
return NotImplemented
return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs)
Expand Down
89 changes: 33 additions & 56 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
TypeVar,
Union,
)

from warnings import warn

import numpy as np

import torch
from functorch import dim as ftdim
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor
from tensordict.utils import (
Expand Down Expand Up @@ -69,7 +71,7 @@
NestedKey,
prod,
)
from torch import distributed as dist, multiprocessing as mp, Tensor
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.utils._pytree import tree_map

try:
Expand Down Expand Up @@ -408,6 +410,24 @@ def from_module(module, as_module: bool = False):
return TensorDictParams(td, no_convert=True)
return td

def to_module(self, module):
from tensordict.nn.functional_modules import set_tensor_dict

__base__setattr__ = nn.Module.__setattr__
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__

for key, value in self.items():
cls = value.__class__
if _is_tensor_collection(cls) or issubclass(cls, dict):
value.to_module(__dict__["_modules"][key])
else:
if module.__class__.__setattr__ is __base__setattr__:
set_tensor_dict(__dict__, module, key, value)
else:
# use specialized __setattr__ if needed
setattr(module, key, value)

@property
def shape(self) -> torch.Size:
"""See :obj:`TensorDictBase.batch_size`."""
Expand Down Expand Up @@ -3394,6 +3414,8 @@ def _get_names_idx(self, idx):
else:

def is_boolean(idx):
if isinstance(idx, ftdim.Dim):
return None
if isinstance(idx, tuple) and len(idx) == 1:
return is_boolean(idx[0])
if hasattr(idx, "dtype") and idx.dtype is torch.bool:
Expand Down Expand Up @@ -3765,6 +3787,7 @@ def type(self, dst_type):
Tensor,
MemmapTensor,
TensorDictBase,
ftdim.Tensor,
]
if _has_torchrec:
_ACCEPTED_CLASSES += [KeyedJaggedTensor]
Expand Down Expand Up @@ -4452,11 +4475,6 @@ def memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
# if not self._tensordict.keys():
# raise Exception(
# "memmap_() must be called when the TensorDict is (partially) "
# "populated. Set a tensor first."
# )
for key, value in self.items():
if value.requires_grad:
raise Exception(
Expand Down Expand Up @@ -6393,7 +6411,13 @@ def _split_index(self, index):
continue
if cursor == self.stack_dim:
# we need to check which tds need to be indexed
if isinstance(idx, slice) or _is_number(idx):
if isinstance(idx, ftdim.Dim):
raise ValueError(
"Cannot index a lazy stacked tensordict along the stack dimension with "
"a first-class dimension index. Consider consolidating the tensordict first "
"using `tensordict.contiguous()`."
)
elif isinstance(idx, slice) or _is_number(idx):
selected_td_idx = range(len(self.tensordicts))[idx]
if not isinstance(selected_td_idx, range):
isinteger = True
Expand Down Expand Up @@ -6425,6 +6449,7 @@ def _split_index(self, index):
idx,
(
int,
ftdim.Dim,
slice,
list,
range,
Expand Down Expand Up @@ -7238,54 +7263,6 @@ def __getitem__(self, index: IndexType) -> T:
out._td_dim_name = self._td_dim_name
return out

# index_dict = _convert_index_lazystack(index, self.stack_dim, self.batch_size)
# if index_dict is None:
# # then we use a sub-tensordict
# return self.get_sub_tensordict(index)
# td_index = index_dict["remaining_index"]
# stack_index = index_dict["stack_index"]
# new_stack_dim = index_dict["new_stack_dim"]
# if new_stack_dim is not None:
# if isinstance(stack_index, slice):
# # we can't iterate but we can index the list directly
# out = LazyStackedTensorDict(
# *[td[td_index] for td in self.tensordicts[stack_index]],
# stack_dim=new_stack_dim,
# )
# elif isinstance(stack_index, (list, range)):
# # then we can iterate
# out = LazyStackedTensorDict(
# *[self.tensordicts[idx][td_index] for idx in stack_index],
# stack_dim=new_stack_dim,
# )
# elif isinstance(stack_index, Tensor):
# # td_index is a nested tuple that mimics the shape of stack_index
# def _nested_stack(t: list, stack_idx: Tensor, td_index):
# if stack_idx.ndim:
# out = LazyStackedTensorDict(
# *[
# _nested_stack(t, _idx, td_index[i])
# for i, _idx in enumerate(stack_idx.unbind(0))
# ],
# stack_dim=new_stack_dim,
# )
# return out
# return t[stack_idx][td_index]
#
# # print(index, td_index, stack_index)
# out = _nested_stack(self.tensordicts, stack_index, td_index)
# else:
# raise TypeError("Invalid index used for stack dimension.")
# out._td_dim_name = self._td_dim_name
# return out
# out = self.tensordicts[stack_index]
# if td_index:
# return out[td_index]
# return out

# def __hash__(self):
# return hash(self.tensordicts)

def __eq__(self, other):
if is_tensorclass(other):
return other == self
Expand Down Expand Up @@ -8948,7 +8925,7 @@ def _clone_value(value: CompatibleType, recurse: bool) -> CompatibleType:


def _is_number(item):
if isinstance(item, Number):
if isinstance(item, (Number, ftdim.Dim)):
return True
if isinstance(item, Tensor) and item.ndim == 0:
return True
Expand Down
6 changes: 5 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
from functorch import dim as ftdim

from packaging.version import parse
from tensordict._tensordict import ( # noqa: F401
Expand Down Expand Up @@ -150,7 +151,7 @@ def _getitem_batch_size(batch_size, index):
out.extend(bs_shape)
bs_shape = None
continue
elif isinstance(idx, int):
elif isinstance(idx, (int, ftdim.Dim)):
# could be spared for efficiency
continue
elif isinstance(idx, slice):
Expand Down Expand Up @@ -761,9 +762,12 @@ def _is_shared(tensor: torch.Tensor) -> bool:
if torch._C._functorch.is_batchedtensor(tensor):
return None
return tensor.is_shared()
if isinstance(tensor, ftdim.Tensor):
return None
elif isinstance(tensor, KeyedJaggedTensor):
return False
else:
print(type(tensor))
return tensor.is_shared()


Expand Down
Loading