Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 8, 2024
1 parent ef27f40 commit dd522fa
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_extensions():

ext_modules = [
extension(
"tensordict._C._tensordict",
"tensordict._C",
sources,
include_dirs=[this_dir],
extra_compile_args=extra_compile_args,
Expand Down
15 changes: 4 additions & 11 deletions tensordict/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

def unravel_key_list(list_of_keys):
...

def unravel_keys(keys):
...

def unravel_key(key):
...

def _unravel_key_to_tuple(key):
...
def unravel_key_list(list_of_keys): ...
def unravel_keys(keys): ...
def unravel_key(key): ...
def _unravel_key_to_tuple(key): ...
2 changes: 1 addition & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False
from tensordict._C import _unravel_key_to_tuple, unravel_key_list
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.base import (
_is_leaf_nontensor,
_is_tensor_collection,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace py = pybind11;

PYBIND11_MODULE(_tensordict, m) {
PYBIND11_MODULE(_C, m) {
m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat
m.def("unravel_key", &unravel_key, py::arg("key"));
m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key"));
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import torch
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
from tensordict._C import _unravel_key_to_tuple, unravel_key_list

from tensordict._td import is_tensor_collection, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.functional import make_tensordict
from tensordict.nn.functional_modules import (
_swap_state,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."


from tensordict._tensordict import unravel_key_list
from tensordict._C import unravel_key_list
from tensordict.nn.common import dispatch, TensorDictModule
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import NestedKey
Expand Down
2 changes: 1 addition & 1 deletion tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
import tensordict as tensordict_lib

import torch
from tensordict._C import _unravel_key_to_tuple
from tensordict._lazy import LazyStackedTensorDict
from tensordict._pytree import _register_td_node
from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import (
_ACCEPTED_CLASSES,
Expand Down
4 changes: 2 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
except ImportError:
_has_funcdim = False
from packaging.version import parse
from tensordict._contextlib import _DecoratorContextManager
from tensordict._tensordict import ( # noqa: F401
from tensordict._C import ( # noqa: F401
_unravel_key_to_tuple,
unravel_key,
unravel_key_list,
unravel_keys,
)
from tensordict._contextlib import _DecoratorContextManager

from torch import Tensor
from torch._C import _disabled_torch_function_impl
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
tensorclass,
TensorDict,
)
from tensordict._tensordict import unravel_key_list
from tensordict._C import unravel_key_list
from tensordict.nn import (
dispatch,
probabilistic as nn_probabilistic,
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from _utils_internal import get_available_devices
from tensordict import TensorDict, unravel_key, unravel_key_list
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._C import _unravel_key_to_tuple
from tensordict.utils import (
_getitem_batch_size,
_make_cache_key,
Expand Down

0 comments on commit dd522fa

Please sign in to comment.