Skip to content

Commit

Permalink
[Refactor] Refactor c++ binaries location (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 9, 2024
1 parent 222e126 commit 612fbbc
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_extensions():

ext_modules = [
extension(
"tensordict._tensordict",
"tensordict._C",
sources,
include_dirs=[this_dir],
extra_compile_args=extra_compile_args,
Expand Down
9 changes: 9 additions & 0 deletions tensordict/_C/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

def unravel_key_list(list_of_keys): ...
def unravel_keys(keys): ...
def unravel_key(key): ...
def _unravel_key_to_tuple(key): ...
4 changes: 2 additions & 2 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
set_lazy_legacy,
)
from tensordict._pytree import *
from tensordict._tensordict import unravel_key, unravel_key_list
from tensordict._C import unravel_key, unravel_key_list # @manual=//tensordict:_C
from tensordict.nn import TensorDictParams

try:
from tensordict.version import __version__
from tensordict.version import __version__ # @manual=//tensordict:version
except ImportError:
__version__ = None

Expand Down
9 changes: 6 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False
from tensordict._C import ( # @manual=//tensordict:_C
_unravel_key_to_tuple,
unravel_key_list,
)
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.base import (
_is_leaf_nontensor,
_is_tensor_collection,
Expand Down Expand Up @@ -89,13 +92,13 @@
_has_functorch = False
try:
try:
from torch._C._functorch import (
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_remove_batch_dim,
is_batchedtensor,
)
except ImportError:
from functorch._C import is_batchedtensor
from functorch._C import is_batchedtensor # @manual=fbcode//caffe2/functorch:_C

_has_functorch = True
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@
_has_functorch = False
try:
try:
from torch._C._functorch import (
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_remove_batch_dim,
is_batchedtensor,
)
except ImportError:
from functorch._C import is_batchedtensor
from functorch._C import is_batchedtensor # @manual=fbcode//functorch:_C

_has_functorch = True
except ImportError:
Expand Down
17 changes: 17 additions & 0 deletions tensordict/_tensordict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

from tensordict._C import ( # noqa
_unravel_key_to_tuple,
unravel_key,
unravel_key_list,
unravel_keys,
)

warnings.warn(
"tensordict._tensordict will soon be removed in favour of tensordict._C.",
category=DeprecationWarning,
)
2 changes: 1 addition & 1 deletion tensordict/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace py = pybind11;

PYBIND11_MODULE(_tensordict, m) {
PYBIND11_MODULE(_C, m) {
m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat
m.def("unravel_key", &unravel_key, py::arg("key"));
m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key"));
Expand Down
5 changes: 4 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

import torch
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
from tensordict._C import ( # @manual=//tensordict:_C
_unravel_key_to_tuple,
unravel_key_list,
)

from tensordict._td import is_tensor_collection, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.functional import make_tensordict
from tensordict.nn.functional_modules import (
_swap_state,
Expand Down
8 changes: 4 additions & 4 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def set_tensor_dict( # noqa: F811

_RESET_OLD_TENSORDICT = True
try:
import torch._functorch.vmap as vmap_src
from torch._functorch.vmap import (
import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch
from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
Expand All @@ -124,7 +124,7 @@ def set_tensor_dict( # noqa: F811
_has_functorch = True
except ImportError:
try:
from functorch._src.vmap import (
from functorch._src.vmap import ( # @manual=fbcode//caffe2/functorch:functorch_src
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
Expand All @@ -136,7 +136,7 @@ def set_tensor_dict( # noqa: F811
)

_has_functorch = True
import functorch._src.vmap as vmap_src
import functorch._src.vmap as vmap_src # @manual=fbcode//caffe2/functorch:functorch_src
except ImportError:
_has_functorch = False

Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."


from tensordict._tensordict import unravel_key_list
from tensordict._C import unravel_key_list # @manual=//tensordict:_C
from tensordict.nn.common import dispatch, TensorDictModule
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import NestedKey
Expand Down
2 changes: 1 addition & 1 deletion tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
import tensordict as tensordict_lib

import torch
from tensordict._C import _unravel_key_to_tuple # @manual=//tensordict:_C
from tensordict._lazy import LazyStackedTensorDict
from tensordict._pytree import _register_td_node
from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import (
_ACCEPTED_CLASSES,
Expand Down
16 changes: 11 additions & 5 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
except ImportError:
_has_funcdim = False
from packaging.version import parse
from tensordict._contextlib import _DecoratorContextManager
from tensordict._tensordict import ( # noqa: F401
from tensordict._C import ( # noqa: F401 # @manual=//tensordict:_C
_unravel_key_to_tuple,
unravel_key,
unravel_key_list,
unravel_keys,
)
from tensordict._contextlib import _DecoratorContextManager

from torch import Tensor
from torch._C import _disabled_torch_function_impl
Expand All @@ -69,9 +69,15 @@

try:
try:
from functorch._C import get_unwrapped, is_batchedtensor
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
get_unwrapped,
is_batchedtensor,
)
except ImportError:
from torch._C._functorch import get_unwrapped, is_batchedtensor
from functorch._C import ( # @manual=fbcode//functorch:_C # noqa
get_unwrapped,
is_batchedtensor,
)
except ImportError:
pass

Expand Down Expand Up @@ -2315,7 +2321,7 @@ class _add_batch_dim_pre_hook:
def __call__(self, mod: torch.nn.Module, args, kwargs):
for name, param in list(mod.named_parameters(recurse=False)):
if hasattr(param, "in_dim") and hasattr(param, "vmap_level"):
from torch._C._functorch import _add_batch_dim
from torch._C._functorch import _add_batch_dim # @manual=//caffe2:_C

param = _add_batch_dim(param, param.in_dim, param.vmap_level)
delattr(mod, name)
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
tensorclass,
TensorDict,
)
from tensordict._tensordict import unravel_key_list
from tensordict._C import unravel_key_list
from tensordict.nn import (
dispatch,
probabilistic as nn_probabilistic,
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from _utils_internal import get_available_devices
from tensordict import TensorDict, unravel_key, unravel_key_list
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._C import _unravel_key_to_tuple
from tensordict.utils import (
_getitem_batch_size,
_make_cache_key,
Expand Down

0 comments on commit 612fbbc

Please sign in to comment.