Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix two checkpoint bugs #93

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

kalekundert
Copy link
Contributor

This PR fixes two errors I encountered when trying to resume training a model from a checkpoint. The first involves basis expansion, and the second involves batch normalization.

Basis expansion

Here's a simplified example that reproduces the first error I encountered:

import torch
from escnn.gspaces import rot3dOnR3
from escnn.nn import FieldType, R3Conv

gs = rot3dOnR3()
so3 = gs.fibergroup
ft = FieldType(gs, [so3.irrep(1), so3.irrep(0), so3.irrep(1)])

conv = R3Conv(ft, ft, kernel_size=3)

torch.save(conv.state_dict(), 'demo_conv.ckpt')
ckpt = torch.load('demo_conv.ckpt')

conv.load_state_dict(ckpt)

Here's the resulting stack trace:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/escnn_checkpoint/demo_conv.py", line 14, in <module>
    conv.load_state_dict(ckpt)
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for R3Conv:
	While copying the parameter named "_basisexpansion.in_indices_('irrep_0', 'irrep_1')", whose dimensions in the model are torch.Size([6]) and whose dimensions in the checkpoint are torch.Size([6]), an exception occurred : ('unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.',).
	While copying the parameter named "_basisexpansion.out_indices_('irrep_1', 'irrep_0')", whose dimensions in the model are torch.Size([6]) and whose dimensions in the checkpoint are torch.Size([6]), an exception occurred : ('unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.',).

The error ultimately happens because BlocksBasisExpansion uses meshgrid() to create its in_indices_* and out_indices_* buffers. The significance of meshgrid() is that it uses stride tricks to save space. For example, to make a tensor where every row is the same, meshgrid() will only allocate a single row, then set the stride such that that same memory gets reused for every row. This ends up causing problems because Module.load_state_dict() uses Tensor.copy_() to copy the checkpointed parameters/buffers back into the module. This operation fails if "more than one element of the written-to tensor refers to a single memory location", i.e. if the destination tensor is using stride tricks.

Here's the actual code in question:

out_indices, in_indices = torch.meshgrid([_out_indices[io_pair[1]], _in_indices[io_pair[0]]], indexing='ij')
in_indices = in_indices.reshape(-1)
out_indices = out_indices.reshape(-1)
# register the indices tensors and the bases tensors as parameters of this module
self.register_buffer('in_indices_{}'.format(self._escape_pair(io_pair)), in_indices)
self.register_buffer('out_indices_{}'.format(self._escape_pair(io_pair)), out_indices)

I didn't mention the reshape() calls above, but they're significant in that they usually—but don't always—get rid of the stride tricks. I constructed the $\psi_1 \oplus \psi_0 \oplus \psi_1$ representation in the above example specifically so that the stride tricks would be kept, and thus trigger the error.

The obvious solution is to do what the error message suggests, and call clone() after reshape(). I can confirm that this works, but after looking at the code more closely, I think the best solution is to simply not store these indices in a buffer at all. My understanding is that buffers are for things like the running averages in batch normalization layers: data that aren't subject to optimization (i.e. not parameters), but that still change over the course of a training run and need to be restored from checkpoints. These indices don't ever change, so there's no reason for them to be buffers. They can just be normal object attributes, and the whole problem of loading them from checkpoints goes away.

Batch normalization

Here's a simplified example that reproduces the second error I encountered:

import torch
import sys

from escnn.gspaces import rot3dOnR3
from escnn.nn import FieldType, GeometricTensor, IIDBatchNorm3d
from torch.optim import Adam

gs = rot3dOnR3()
so3 = gs.fibergroup
ft = FieldType(gs, [so3.irrep(0), so3.irrep(1), so3.irrep(1)] * 2)

bn = IIDBatchNorm3d(ft)
opt = Adam(bn.parameters())

x = GeometricTensor(
        torch.randn(2, 14, 3, 3, 3, requires_grad=True),
        ft,
)

def step(opt, bn, x):
    opt.zero_grad()
    y = bn(x)
    y = torch.sum(y.tensor)  # arbitrary function to get scalar
    y.backward()
    opt.step()

if '-k' in sys.argv:
    step(opt, bn, x)

    ckpt = dict(
            bn=bn.state_dict(),
            opt=opt.state_dict(),
    )
    torch.save(ckpt, 'demo_bn.ckpt')

else:
    ckpt = torch.load('demo_bn.ckpt')

    bn.load_state_dict(ckpt['bn'])
    opt.load_state_dict(ckpt['opt'])

    step(opt, bn, x)

To get the error, you first need to run this script with the -k option to create a checkpoint. You then need to run the script without the -k option a couple of times, because the crash isn't deterministic. Eventually, you'll get the following stack trace:

Traceback (most recent call last):
  File "/home/kale/hacking/bugs/escnn_checkpoint/demo_bn.py", line 42, in <module>
    step(opt, bn, x)
  File "/home/kale/hacking/bugs/escnn_checkpoint/demo_bn.py", line 25, in step
    opt.step()
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/optim/optimizer.py", line 33, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/optim/adam.py", line 141, in step
    adam(
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/optim/adam.py", line 281, in adam
    func(params,
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/torch/optim/adam.py", line 344, in _single_tensor_adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

This error happens because _IIDBatchNorm registers its buffers in a random order. More specifically, _IIDBatchNorm uses a set to get rid of duplicate representations, and then registers its buffers within a loop over that set. But set iteration order actually changes each time python runs, because python chooses a different random value to incorporate into the hash values of some built-in types each time it starts. This apparently helps protect servers written in python from DOS attacks.

The actual crash happens when the optimizer tries to update the parameters after having been restored from a checkpoint. The checkpoint contains some metadata on each parameter, stored in whatever order the parameters were originally generated in. When the optimizer is reconstituted with the parameters in a different order, the result is that checkpointed metadata will get applied to the wrong parameters. The best case scenario at this point is for the program to crash, which happens when the parameters have incompatible dimensions. The worst case scenario is that the program doesn't crash, and instead effectively shuffles the metadata. I believe this will happen if each different representation has the same multiplicity. Models such as the example SE(3) CNN might exhibit this behavior.

The solution is to iterate over the full list of representations, and to manually remove duplicates. This guarantees that the parameters will be generated in the same order every time. While I was making this fix, I noticed the same bug in the GnormBatchNorm module, so I fixed it there, too.

Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
This pattern is replaced with `for x in unique_ever_seen(y)`, which also
removes duplicate elements, but is guaranteed to always produce elements
in the same order as in the input list.

Signed-off-by: Kale Kundert <[email protected]>
@kalekundert
Copy link
Contributor Author

After giving this PR another look, I decided that converting the basis expansion indices from buffers into normal attributes wasn't quite the right approach. The problem is that nn.Module.to() affects buffers, but not normal attributes. In light of this, I just changed the indices back into buffers. However, I made the buffers non-persistent. This way they don't get recorded to the checkpoint file, which still fixes the bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant