Skip to content

Commit

Permalink
[BugFix] Fix tensorclass register (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 15, 2024
1 parent 1de6fb6 commit a989ee6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 88 deletions.
5 changes: 4 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,11 +2045,14 @@ def _stack_onto_(self, list_item: list[CompatibleType], dim: int) -> TensorDict:
if all(v is None for v in vals):
continue
dest = self._get_str(key, NO_DEFAULT)
torch.stack(
new_dest = torch.stack(
vals,
dim=dim,
out=dest,
)
if new_dest is not dest:
# This can happen with non-tensor data
self._set_str(key, new_dest, inplace=False, validated=True)
return self

def entry_class(self, key: NestedKey) -> type:
Expand Down
181 changes: 94 additions & 87 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def __subclasscheck__(self, subclass):
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"_get_str",
"_get_tuple",
"gather",
"is_memmap",
"is_shared",
"ndimension",
"numel",
"replace",
]
# Methods to be executed from tensordict, any ref to self means 'self._tensordict'
Expand All @@ -108,68 +114,111 @@ def __subclasscheck__(self, subclass):
"__sub__",
"__truediv__",
"_add_batch_dim",
"apply",
"_apply_nest",
"_fast_apply",
"apply_",
"named_apply",
"_check_unlock",
"unsqueeze",
"squeeze",
"_erase_names", # TODO: must be specialized
"_exclude", # TODO: must be specialized
"_get_str",
"_get_tuple",
"_set_at_tuple",
"_fast_apply",
"_has_names",
"_propagate_lock",
"_propagate_unlock",
"_remove_batch_dim",
"is_memmap",
"is_shared",
"_select", # TODO: must be specialized
"_set_at_tuple",
"_set_str",
"_set_tuple",
"abs",
"abs_",
"acos",
"acos_",
"add",
"add_",
"addcdiv",
"addcdiv_",
"addcmul",
"addcmul_",
"all",
"any",
"apply",
"apply_",
"asin",
"asin_",
"atan",
"atan_",
"ceil",
"ceil_",
"clamp_max",
"clamp_max_",
"clamp_min",
"clamp_min_",
"copy_",
"cos",
"cos_",
"cosh",
"cosh_",
"cpu",
"cuda",
"div",
"div_",
"empty",
"erf",
"erf_",
"erfc",
"erfc_",
"exclude",
"exp",
"exp_",
"expand",
"expand_as",
"expm1",
"expm1_",
"flatten",
"floor",
"floor_",
"frac",
"frac_",
"is_empty",
"is_memmap",
"is_shared",
"is_shared",
"items",
"keys",
"lerp",
"lerp_",
"lgamma",
"lgamma_",
"lock_",
"log",
"log10",
"log10_",
"log1p",
"log1p_",
"log2",
"log2_",
"log_",
"masked_fill",
"masked_fill_",
"permute",
"flatten",
"unflatten",
"ndimension",
"rename_", # TODO: must be specialized
"reshape",
"select",
"to",
"transpose",
"unlock_",
"values",
"view",
"zero_",
"add",
"add_",
"maximum",
"maximum_",
"minimum",
"minimum_",
"mul",
"mul_",
"abs",
"abs_",
"acos",
"acos_",
"exp",
"exp_",
"named_apply",
"ndimension",
"neg",
"neg_",
"norm",
"permute",
"pow",
"pow_",
"reciprocal",
"reciprocal_",
"rename_", # TODO: must be specialized
"reshape",
"round",
"round_",
"select",
"sigmoid",
"sigmoid_",
"sign",
Expand All @@ -178,67 +227,25 @@ def __subclasscheck__(self, subclass):
"sin_",
"sinh",
"sinh_",
"sqrt",
"sqrt_",
"squeeze",
"sub",
"sub_",
"tan",
"tan_",
"tanh",
"tanh_",
"to",
"transpose",
"trunc",
"trunc_",
"norm",
"lgamma",
"lgamma_",
"frac",
"frac_",
"expm1",
"expm1_",
"log",
"log_",
"log10",
"log10_",
"log1p",
"log1p_",
"log2",
"log2_",
"ceil",
"ceil_",
"floor",
"floor_",
"round",
"round_",
"erf",
"erf_",
"erfc",
"erfc_",
"asin",
"asin_",
"atan",
"atan_",
"cos",
"cos_",
"cosh",
"cosh_",
"lerp",
"lerp_",
"addcdiv",
"addcdiv_",
"addcmul",
"addcmul_",
"sub",
"sub_",
"maximum_",
"maximum",
"minimum_",
"minimum",
"clamp_max_",
"clamp_max",
"clamp_min_",
"clamp_min",
"pow",
"pow_",
"div",
"div_",
"sqrt",
"sqrt_",
"unflatten",
"unlock_",
"unsqueeze",
"values",
"view",
"zero_",
]
_FALLBACK_METHOD_FROM_TD_COPY = [
"_clone", # TODO: must be specialized
Expand Down

0 comments on commit a989ee6

Please sign in to comment.