Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] First class dim compatibility #525

Merged
merged 3 commits into from
Oct 4, 2023
Merged

[Feature] First class dim compatibility #525

merged 3 commits into from
Oct 4, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Sep 12, 2023

Description

WIP to make TensorDict compatible with FCD.

Example usage:

import torch
from functorch import dim
from tensordict import TensorDict

t = torch.randn(3, 4, 5)
td = TensorDict({"t": t}, [3, 4])
d0, d1 = dim.dims(2)

std = td[d0]
print('[d0]', std)

# [d0] TensorDict(
#     fields={
#         t: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=None)},
#     batch_size=torch.Size([4]),
#     device=None,
#     is_shared=False)

std = td[d0, d1]
print('[d0, d1]', std)

# [d0, d1] TensorDict(
#     fields={
#         t: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=None)},
#     batch_size=torch.Size([]),
#     device=None,
#     is_shared=False)

std = td[:, d1]
print('[:, d1]', std)

# [:, d1] TensorDict(
#     fields={
#         t: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=None)},
#     batch_size=torch.Size([3]),
#     device=None,
#     is_shared=False)

This can be used with modules to batch operations across sets of parameters seamlessly. We extract the parameters,


from torch import nn

net = nn.Sequential(
    nn.Linear(3, 4),
    nn.Tanh(),
    nn.Linear(4, 4),
)
params = TensorDict.from_module(net)
params = params.expand(10).clone()  # a TensorDict of shape [10]
d0 = dim.dims(1)
params = params[d0]

params.to_module(net)  # replace the params

y = net(torch.randn(11, 3))
print(y)

to_module is a draft of what we could use.

The idea of this example is to avoid functional calls through tensordict and functorch.dim when working with model ensembles.

cc @ezyang @zou3519 @zdevito
@matteobettini for MARL MLP and ConvNets
@smorad for model ensembles

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2023
@github-actions
Copy link

github-actions bot commented Sep 12, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}26$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 40.8030μs 22.8378μs 43.7870 KOps/s 42.1301 KOps/s $\color{#35bf28}+3.93\%$
test_plain_set_stack_nested 0.2576ms 0.2135ms 4.6837 KOps/s 4.5755 KOps/s $\color{#35bf28}+2.37\%$
test_plain_set_nested_inplace 0.1127ms 27.1638μs 36.8137 KOps/s 35.4453 KOps/s $\color{#35bf28}+3.86\%$
test_plain_set_stack_nested_inplace 1.8214ms 0.2543ms 3.9322 KOps/s 3.7469 KOps/s $\color{#35bf28}+4.95\%$
test_items 58.0050μs 4.1970μs 238.2634 KOps/s 229.5948 KOps/s $\color{#35bf28}+3.78\%$
test_items_nested 0.5103ms 0.4152ms 2.4087 KOps/s 2.3127 KOps/s $\color{#35bf28}+4.15\%$
test_items_nested_locked 0.5175ms 0.4156ms 2.4060 KOps/s 2.3016 KOps/s $\color{#35bf28}+4.53\%$
test_items_nested_leaf 2.0650ms 0.2752ms 3.6337 KOps/s 3.7637 KOps/s $\color{#d91a1a}-3.45\%$
test_items_stack_nested 2.4735ms 2.3262ms 429.8859 Ops/s 404.0665 Ops/s $\textbf{\color{#35bf28}+6.39\%}$
test_items_stack_nested_leaf 2.3999ms 2.1087ms 474.2295 Ops/s 444.0169 Ops/s $\textbf{\color{#35bf28}+6.80\%}$
test_items_stack_nested_locked 3.1510ms 1.2269ms 815.0581 Ops/s 819.9756 Ops/s $\color{#d91a1a}-0.60\%$
test_keys 23.0020μs 6.1172μs 163.4735 KOps/s 153.8631 KOps/s $\textbf{\color{#35bf28}+6.25\%}$
test_keys_nested 2.2918ms 0.2147ms 4.6579 KOps/s 4.4352 KOps/s $\textbf{\color{#35bf28}+5.02\%}$
test_keys_nested_locked 0.2919ms 0.2124ms 4.7091 KOps/s 4.5592 KOps/s $\color{#35bf28}+3.29\%$
test_keys_nested_leaf 0.3513ms 0.2049ms 4.8794 KOps/s 4.3601 KOps/s $\textbf{\color{#35bf28}+11.91\%}$
test_keys_stack_nested 3.1491ms 2.2531ms 443.8280 Ops/s 436.3837 Ops/s $\color{#35bf28}+1.71\%$
test_keys_stack_nested_leaf 2.2429ms 2.1161ms 472.5734 Ops/s 443.0617 Ops/s $\textbf{\color{#35bf28}+6.66\%}$
test_keys_stack_nested_locked 1.0903ms 0.9436ms 1.0598 KOps/s 938.1109 Ops/s $\textbf{\color{#35bf28}+12.97\%}$
test_values 36.3020μs 1.8935μs 528.1175 KOps/s 504.4815 KOps/s $\color{#35bf28}+4.69\%$
test_values_nested 0.1407ms 73.5073μs 13.6041 KOps/s 13.6382 KOps/s $\color{#d91a1a}-0.25\%$
test_values_nested_locked 0.4006ms 73.4395μs 13.6166 KOps/s 13.7244 KOps/s $\color{#d91a1a}-0.78\%$
test_values_nested_leaf 0.1528ms 65.7384μs 15.2118 KOps/s 15.2686 KOps/s $\color{#d91a1a}-0.37\%$
test_values_stack_nested 2.4459ms 1.8560ms 538.8000 Ops/s 534.7590 Ops/s $\color{#35bf28}+0.76\%$
test_values_stack_nested_leaf 2.0091ms 1.8486ms 540.9581 Ops/s 537.0398 Ops/s $\color{#35bf28}+0.73\%$
test_values_stack_nested_locked 0.9060ms 0.7596ms 1.3164 KOps/s 1.2777 KOps/s $\color{#35bf28}+3.03\%$
test_membership 19.8010μs 2.1541μs 464.2206 KOps/s 451.5150 KOps/s $\color{#35bf28}+2.81\%$
test_membership_nested 38.4020μs 4.2442μs 235.6141 KOps/s 232.5990 KOps/s $\color{#35bf28}+1.30\%$
test_membership_nested_leaf 76.3060μs 4.1557μs 240.6342 KOps/s 230.1685 KOps/s $\color{#35bf28}+4.55\%$
test_membership_stacked_nested 50.6030μs 16.9154μs 59.1179 KOps/s 55.5477 KOps/s $\textbf{\color{#35bf28}+6.43\%}$
test_membership_stacked_nested_leaf 96.7070μs 16.8607μs 59.3096 KOps/s 55.3671 KOps/s $\textbf{\color{#35bf28}+7.12\%}$
test_membership_nested_last 24.6010μs 8.7516μs 114.2644 KOps/s 107.4573 KOps/s $\textbf{\color{#35bf28}+6.33\%}$
test_membership_nested_leaf_last 32.1020μs 8.7086μs 114.8292 KOps/s 107.4175 KOps/s $\textbf{\color{#35bf28}+6.90\%}$
test_membership_stacked_nested_last 0.3429ms 0.2594ms 3.8547 KOps/s 3.6090 KOps/s $\textbf{\color{#35bf28}+6.81\%}$
test_membership_stacked_nested_leaf_last 0.1070ms 19.6227μs 50.9614 KOps/s 47.2588 KOps/s $\textbf{\color{#35bf28}+7.83\%}$
test_nested_getleaf 93.4070μs 17.7802μs 56.2423 KOps/s 51.5113 KOps/s $\textbf{\color{#35bf28}+9.18\%}$
test_nested_get 50.5030μs 16.8795μs 59.2436 KOps/s 55.7017 KOps/s $\textbf{\color{#35bf28}+6.36\%}$
test_stacked_getleaf 1.1854ms 1.0236ms 976.9444 Ops/s 911.3819 Ops/s $\textbf{\color{#35bf28}+7.19\%}$
test_stacked_get 1.1296ms 0.9775ms 1.0230 KOps/s 988.3704 Ops/s $\color{#35bf28}+3.50\%$
test_nested_getitemleaf 96.4070μs 17.7327μs 56.3930 KOps/s 52.8133 KOps/s $\textbf{\color{#35bf28}+6.78\%}$
test_nested_getitem 86.4060μs 16.9241μs 59.0872 KOps/s 55.8261 KOps/s $\textbf{\color{#35bf28}+5.84\%}$
test_stacked_getitemleaf 2.9759ms 1.0283ms 972.4494 Ops/s 912.9131 Ops/s $\textbf{\color{#35bf28}+6.52\%}$
test_stacked_getitem 1.1051ms 0.9763ms 1.0242 KOps/s 968.1876 Ops/s $\textbf{\color{#35bf28}+5.79\%}$
test_lock_nested 73.2636ms 1.8029ms 554.6577 Ops/s 594.2184 Ops/s $\textbf{\color{#d91a1a}-6.66\%}$
test_lock_stack_nested 0.1016s 23.0279ms 43.4257 Ops/s 45.8036 Ops/s $\textbf{\color{#d91a1a}-5.19\%}$
test_unlock_nested 71.0487ms 1.8132ms 551.5093 Ops/s 565.3265 Ops/s $\color{#d91a1a}-2.44\%$
test_unlock_stack_nested 0.1032s 23.7151ms 42.1672 Ops/s 45.1116 Ops/s $\textbf{\color{#d91a1a}-6.53\%}$
test_flatten_speed 1.2824ms 1.1718ms 853.4100 Ops/s 849.3135 Ops/s $\color{#35bf28}+0.48\%$
test_unflatten_speed 2.2633ms 2.1058ms 474.8786 Ops/s 476.0059 Ops/s $\color{#d91a1a}-0.24\%$
test_common_ops 5.8099ms 1.2812ms 780.4994 Ops/s 783.2521 Ops/s $\color{#d91a1a}-0.35\%$
test_creation 28.7020μs 7.2687μs 137.5767 KOps/s 141.4394 KOps/s $\color{#d91a1a}-2.73\%$
test_creation_empty 58.1040μs 15.9956μs 62.5173 KOps/s 64.2385 KOps/s $\color{#d91a1a}-2.68\%$
test_creation_nested_1 0.1380ms 29.1215μs 34.3389 KOps/s 35.6341 KOps/s $\color{#d91a1a}-3.63\%$
test_creation_nested_2 0.1228ms 31.6025μs 31.6431 KOps/s 32.6988 KOps/s $\color{#d91a1a}-3.23\%$
test_clone 0.1615ms 28.3325μs 35.2951 KOps/s 35.3838 KOps/s $\color{#d91a1a}-0.25\%$
test_getitem[int] 0.1232ms 32.3697μs 30.8931 KOps/s 31.5161 KOps/s $\color{#d91a1a}-1.98\%$
test_getitem[slice_int] 0.1468ms 63.8013μs 15.6737 KOps/s 15.8951 KOps/s $\color{#d91a1a}-1.39\%$
test_getitem[range] 0.1302ms 96.5474μs 10.3576 KOps/s 10.5466 KOps/s $\color{#d91a1a}-1.79\%$
test_getitem[tuple] 0.1160ms 53.5406μs 18.6774 KOps/s 19.3390 KOps/s $\color{#d91a1a}-3.42\%$
test_getitem[list] 0.3439ms 92.5050μs 10.8102 KOps/s 11.1285 KOps/s $\color{#d91a1a}-2.86\%$
test_setitem_dim[int] 55.8040μs 38.8641μs 25.7307 KOps/s 26.2328 KOps/s $\color{#d91a1a}-1.91\%$
test_setitem_dim[slice_int] 0.1135ms 69.2660μs 14.4371 KOps/s 14.6669 KOps/s $\color{#d91a1a}-1.57\%$
test_setitem_dim[range] 0.2208ms 94.9757μs 10.5290 KOps/s 10.8262 KOps/s $\color{#d91a1a}-2.74\%$
test_setitem_dim[tuple] 82.4050μs 57.3171μs 17.4468 KOps/s 17.9345 KOps/s $\color{#d91a1a}-2.72\%$
test_setitem 0.1728ms 36.9297μs 27.0785 KOps/s 27.3960 KOps/s $\color{#d91a1a}-1.16\%$
test_set 0.1562ms 35.8199μs 27.9175 KOps/s 28.1877 KOps/s $\color{#d91a1a}-0.96\%$
test_set_shared 2.9139ms 0.2100ms 4.7626 KOps/s 4.8148 KOps/s $\color{#d91a1a}-1.08\%$
test_update 0.2191ms 40.2059μs 24.8719 KOps/s 24.9126 KOps/s $\color{#d91a1a}-0.16\%$
test_update_nested 0.2248ms 60.2317μs 16.6026 KOps/s 16.8783 KOps/s $\color{#d91a1a}-1.63\%$
test_set_nested 0.1428ms 39.3426μs 25.4177 KOps/s 25.7714 KOps/s $\color{#d91a1a}-1.37\%$
test_set_nested_new 0.2376ms 61.3428μs 16.3018 KOps/s 16.6495 KOps/s $\color{#d91a1a}-2.09\%$
test_select 0.2717ms 0.1139ms 8.7824 KOps/s 8.9222 KOps/s $\color{#d91a1a}-1.57\%$
test_unbind_speed 0.8707ms 0.7594ms 1.3169 KOps/s 1.3290 KOps/s $\color{#d91a1a}-0.91\%$
test_unbind_speed_stack0 8.5867ms 8.3202ms 120.1899 Ops/s 94.7012 Ops/s $\textbf{\color{#35bf28}+26.91\%}$
test_unbind_speed_stack1 17.5010μs 1.3187μs 758.3332 KOps/s 777.0370 KOps/s $\color{#d91a1a}-2.41\%$
test_creation[device0] 3.4052ms 0.5331ms 1.8759 KOps/s 1.9014 KOps/s $\color{#d91a1a}-1.34\%$
test_creation_from_tensor 0.6736ms 0.5903ms 1.6941 KOps/s 1.6526 KOps/s $\color{#35bf28}+2.51\%$
test_add_one[memmap_tensor0] 1.9809ms 38.7326μs 25.8181 KOps/s 26.5326 KOps/s $\color{#d91a1a}-2.69\%$
test_contiguous[memmap_tensor0] 61.4040μs 9.9932μs 100.0676 KOps/s 100.4496 KOps/s $\color{#d91a1a}-0.38\%$
test_stack[memmap_tensor0] 85.4050μs 31.3238μs 31.9246 KOps/s 32.0624 KOps/s $\color{#d91a1a}-0.43\%$
test_memmaptd_index 0.4830ms 0.3580ms 2.7934 KOps/s 2.7074 KOps/s $\color{#35bf28}+3.18\%$
test_memmaptd_index_astensor 1.6289ms 1.5756ms 634.6893 Ops/s 624.1487 Ops/s $\color{#35bf28}+1.69\%$
test_memmaptd_index_op 3.2206ms 3.1206ms 320.4559 Ops/s 330.7082 Ops/s $\color{#d91a1a}-3.10\%$
test_reshape_pytree 0.1294ms 44.0568μs 22.6980 KOps/s 22.8842 KOps/s $\color{#d91a1a}-0.81\%$
test_reshape_td 81.9060μs 53.9422μs 18.5384 KOps/s 19.3242 KOps/s $\color{#d91a1a}-4.07\%$
test_view_pytree 0.1114ms 41.6940μs 23.9843 KOps/s 24.4412 KOps/s $\color{#d91a1a}-1.87\%$
test_view_td 77.4050μs 10.3294μs 96.8112 KOps/s 95.7883 KOps/s $\color{#35bf28}+1.07\%$
test_unbind_pytree 93.8060μs 44.6648μs 22.3890 KOps/s 21.7918 KOps/s $\color{#35bf28}+2.74\%$
test_unbind_td 0.2131ms 0.1123ms 8.9046 KOps/s 9.0284 KOps/s $\color{#d91a1a}-1.37\%$
test_split_pytree 68.2050μs 46.3313μs 21.5837 KOps/s 19.4295 KOps/s $\textbf{\color{#35bf28}+11.09\%}$
test_split_td 0.8416ms 0.1264ms 7.9119 KOps/s 7.4986 KOps/s $\textbf{\color{#35bf28}+5.51\%}$
test_add_pytree 0.1256ms 54.7281μs 18.2722 KOps/s 18.1519 KOps/s $\color{#35bf28}+0.66\%$
test_add_td 0.2196ms 88.8607μs 11.2536 KOps/s 11.5918 KOps/s $\color{#d91a1a}-2.92\%$
test_distributed 33.3020μs 10.8269μs 92.3625 KOps/s 94.0733 KOps/s $\color{#d91a1a}-1.82\%$
test_tdmodule 0.2245ms 33.4279μs 29.9151 KOps/s 29.7587 KOps/s $\color{#35bf28}+0.53\%$
test_tdmodule_dispatch 0.3167ms 63.6643μs 15.7074 KOps/s 15.3457 KOps/s $\color{#35bf28}+2.36\%$
test_tdseq 0.6137ms 36.9631μs 27.0540 KOps/s 25.8168 KOps/s $\color{#35bf28}+4.79\%$
test_tdseq_dispatch 0.2286ms 76.7709μs 13.0258 KOps/s 12.6993 KOps/s $\color{#35bf28}+2.57\%$
test_instantiation_functorch 2.0226ms 1.8947ms 527.7792 Ops/s 518.8319 Ops/s $\color{#35bf28}+1.72\%$
test_instantiation_td 2.3753ms 1.5682ms 637.6577 Ops/s 626.4042 Ops/s $\color{#35bf28}+1.80\%$
test_exec_functorch 0.3001ms 0.2196ms 4.5528 KOps/s 4.5480 KOps/s $\color{#35bf28}+0.10\%$
test_exec_td 0.2601ms 0.2057ms 4.8623 KOps/s 4.7988 KOps/s $\color{#35bf28}+1.32\%$
test_vmap_mlp_speed[True-True] 8.9455ms 1.3783ms 725.5226 Ops/s 710.4665 Ops/s $\color{#35bf28}+2.12\%$
test_vmap_mlp_speed[True-False] 4.4109ms 0.7194ms 1.3901 KOps/s 1.3543 KOps/s $\color{#35bf28}+2.64\%$
test_vmap_mlp_speed[False-True] 9.4309ms 1.1802ms 847.3028 Ops/s 836.5257 Ops/s $\color{#35bf28}+1.29\%$
test_vmap_mlp_speed[False-False] 1.1355ms 0.5364ms 1.8643 KOps/s 1.7111 KOps/s $\textbf{\color{#35bf28}+8.95\%}$
test_vmap_transformer_speed[True-True] 25.3063ms 16.0569ms 62.2785 Ops/s 32.4655 Ops/s $\textbf{\color{#35bf28}+91.83\%}$
test_vmap_transformer_speed[True-False] 14.0185ms 10.9114ms 91.6474 Ops/s 84.7445 Ops/s $\textbf{\color{#35bf28}+8.15\%}$
test_vmap_transformer_speed[False-True] 24.9099ms 16.8639ms 59.2984 Ops/s 61.2021 Ops/s $\color{#d91a1a}-3.11\%$
test_vmap_transformer_speed[False-False] 19.7356ms 10.6362ms 94.0187 Ops/s 91.4549 Ops/s $\color{#35bf28}+2.80\%$

@vmoens vmoens added the enhancement New feature or request label Sep 12, 2023
@ezyang
Copy link

ezyang commented Sep 12, 2023

Cc @zdevito

@ezyang
Copy link

ezyang commented Sep 13, 2023

Nice! What kind of feedback are you looking for?

@vmoens
Copy link
Contributor Author

vmoens commented Sep 13, 2023

For now nothing, it's just an FYI.
There seem to be some issue regarding compatibility with nn.Parameter and also it's a bit surprising that it doesn't subclass Tensor but I'll make a more comprehensive list of advantages and drawbacks when the PR is ready.

@vmoens vmoens marked this pull request as ready for review September 13, 2023 13:07
@vmoens
Copy link
Contributor Author

vmoens commented Sep 13, 2023

This is ready for review if anyone wants to give it a shot.

One of my goals with this is to do what we did in the example:

  • Extract params from a model (or several models) to create a mixture of experts
  • replace this stack of params within the module to avoid functional calls.
    Basically, I was hoping I could have a regular nn.Module with batched params that would act as a mixture of expect without loops, functional calls or direct call to vmap. I love these features but I hoped I could package them in through tensordict in a way that made the end user less reliant on these primitives.

The issues I see currently are:

  • It is not very clear how that should interact with nn.Parameter. nn.Parameter(tensor[dim]) fails because tensor[dim] is not a tensor anymore. Maybe we should do nn.Parameter(tensor)[dim] but then we replace in the module something that is not a parameter anymore. Some module may fail with that (eg, RNNs but they will fail for other reasons too, like CuDNN compatibility etc).
  • For instance, this currently works:
from functorch import dim as ftdim

module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
# create batched params
params = TensorDict.from_module(module)
params = params.expand(3).clone().apply(lambda x: x.requires_grad_())
# FCD indexing
d0 = ftdim.dims(1)
# execute grad
x = torch.randn(3)
def func(params):
    params_batched = params[d0]
    params_batched.to_module(module)
    return module(x)._tensor.sum()
g = torch.func.grad(func)(params)
print(g)
print("params", params['0', 'weight'], "\ngrads", g['0', 'weight'])

which is cool. But we must set the parameters within the call to func which introduces some overhead. Also it isn't very pytorch-y. What I understood from my conversations with @ezyang was that FCD promise was to hide away some dimensions of a tensor while keeping its behaviour as natural as possible. Hence, my (unreasonable) expectation was that I could simply do

module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
params = TensorDict.from_module(module)
# create a batch of params
params = params.expand(3).clone().apply(lambda x: nn.Parameter(x))
# reduce the first dim
d0 = dim.dims(1)
params = params[0]
# repopulate module
params.to_module(module)
# run module
loss = loss_fn(module(x), target)
# backprop
loss.backward()

But because my tensordict is not full of tensors but dim.Tensor and because backward does not work with tensors that are build within functorch, this cannot currently work.

Having to modify that logic makes FCD less attractive -- I might as well use plain vmap + functional calls if I have to repopulate my module at every call, which I believe will be faster to execute.

Do you guys have any thought on that?

@zdevito
Copy link

zdevito commented Sep 18, 2023

The semantics for adding first-class dimensions to tensordict make sense to me. It also makes sense to be able to install parameters that have first-class dimensions. I am less clear on how we would accomplish it with how parameters in modules currently exist. First-class dims are implemented as their own objects that are not tensor subclasses in order to run fast enough in eager mode without overhead. I think parameters need to be real tensors in a lot of apis. In particular trying to set .grad properties on first-class tensors probably doesn't work. I wonder if there is a way to install the underlying raw tensor that is wrapped by the first-class dim tensor as the parameter, and then overload the property access for that tensor with something that constructs the tensor with first-class dims on the fly. This could be installed in the to_module call. I don't thing this is possible today but maybe @ezyang knows more.

@ezyang
Copy link

ezyang commented Sep 19, 2023

So... we could relax the constraint here:

# Metaclass to combine _TensorMeta and the instance check override for Parameter.
class _ParameterMeta(torch._C._TensorMeta):
    # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
    def __instancecheck__(self, instance):
        return super().__instancecheck__(instance) or (
            isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))

and say that it has to either be a torch.Tensor, or it can be one of the first-class dim objects. Is that enough? Might be!

@ezyang
Copy link

ezyang commented Sep 19, 2023

Actually, I read the example more carefully (and fixed up some stuff) and I see:

import torch
import torch.nn as nn
from tensordict import TensorDict
import functorch.dim as dim

module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
params = TensorDict.from_module(module)
# create a batch of params
params = params.expand(3).clone().apply(lambda x: nn.Parameter(x))
# reduce the first dim
d0 = dim.dims(1)
params = params[d0]
# repopulate module
params.to_module(module)
# run module
loss = module(torch.randn(3, 3)).sum()
# backprop
loss.backward()

fails with

RuntimeError: backward() called inside a functorch transform. This is not supported, please use functorch.grad or functorch.vjp instead or call backward() outside of functorch transforms.

Back in the day, @zou3519 talked a lot about the potential of non-lexical functorch transforms, one of the use cases being this kind of backward() call. I think we concluded that in principle it would be possible, but we have to implement it (which we never found time to do.)

So drat, I guess you can't actually replace the auto-batching behavior from tensordict with first class dims, what a bummer.

@vmoens
Copy link
Contributor Author

vmoens commented Sep 19, 2023

Thanks @ezyang and @zdevito for looking into this!
I don't think "wrapping" the dim.Tensor within an nn.Parameter is crucial for a first prototype.
Based on #526 I think that the model ensemble API will initially look like this:

import torch
import torch.nn as nn
from tensordict import TensorDict
import functorch.dim as dim

module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
params = TensorDict.from_module(module)
# create a batch of params
num_models = 5
params = params.expand(num_models).clone().apply(lambda x: nn.Parameter(x))
x = torch.randn(3)

def compute(params_vals):
    # reduce the first dim
    d0 = dim.dims(1)
    for key, val_source in zip(params.keys(include_nested=True, leaves_only=True), params_vals):
        params.set(key, val_source)
    params_dims = params[d0]
    # no vmap
    y = torch.func.functional_call(module, params_dims, (x,))
    return y._tensor.sum() # or any other loss

grads = torch.func.grad(compute)(list(params.values(True, True)))
# put grads in a tensordict for clarity
grads = TensorDict(
    {key: grad for key, grad in zip(params.keys(include_nested=True, leaves_only=True), grads)}, batch_size=params.batch_size
)
print("grads of our 5 models", grads)

which gives you

grads of our 5 models TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([5, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([5, 1, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

ie, you can get your gradients using fcd and no apparent vmap

Note that this requires some modif in functorch and torch nn (see #526 which is based on this PR).

I think that this is pretty cool, what is happening is pretty apparent. Not as easy as substituting the params with their fcd counterpart but with some woodwork this could look pretty nice already. For instance, I already made vmap compatible with TensorDict, making the same work for grad and similar functorch functions would bring some value. Ideally we'd like to simplify the above to

# pass directly the params
def compute(params):
    d0 = dim.dims(1)
    params_dims = params[d0]
    y = torch.func.functional_call(module, params_dims, (x,))
    return y._tensor.sum() # or any other loss

# pass directly the params
grads = torch.func.grad(compute)(params)
# grads are already in a tensordict
print("grads of our 5 models", grads)

@ezyang
Copy link

ezyang commented Sep 25, 2023

Oh this is the thing where you wanted to override how functorch detects batching on inputs passed to grad. @zou3519 I don't remember what your objection to this was?

@zou3519
Copy link

zou3519 commented Sep 25, 2023

I don't think I have an objection? The code in #525 (comment) looks reasonable to me

@vmoens
Copy link
Contributor Author

vmoens commented Oct 3, 2023

@zou3519 Specifically referring to https://github.com/pytorch-labs/tensordict/pull/526/files#diff-b60cca91a0a60483cc294cdb1ae495bac71dbdb25d06e008443a41873e2a1891

That and this too, which I think is more contentious.
The pitch of this feature is this:
We could just register TensorDict in PyTree and call it a day.
However, if we just "pytree" over a tensordict (ie we register it in pytree, deconstruct it and reconstruct it) we lose the ability to do this:

def func(tensordict_or_tensor):
    assert tensordict_or_tensor.shape == (1, 2, 4)
    return tensordict_or_tensor
tensor = torch.randn(1, 2, 3, 4)
vmap(func, (2,))(tensor)
tensordict = TensorDict({}, batch_size=[1, 2, 3, 4])
vmap(func, (2,))(tensordict)

Currently, with the monkey patch, this code runs. If we just register tensordict within pytree but do not patch, it will fail.
This is because params will be deconstructed, a batch-dimension will be added and then the tensordict will be reconstructed but then the batch-size will either mismatch or be lost.
For this, we need some ad-hoc mechanism to keep the batch-size of the tensordict consistent through vmap, hence the ad-hoc monkey-patch in tensordict.nn.

Now, I totally get that this has ramifications that go beyond this (I guess we don't want to overcharge vmap with every new class that comes into sight). To me, any solution that makes vmap possible with tensordict (ie, a tensordict passed through vmap sees its shape changed to a consistent shape within the vmap call) is great.

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 4c0eb1d into main Oct 4, 2023
@vmoens vmoens deleted the first_class_dim branch October 4, 2023 13:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants