-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix two checkpoint bugs #93
base: master
Are you sure you want to change the base?
Conversation
2ea928d
to
a9487c2
Compare
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Signed-off-by: Kale Kundert <[email protected]>
Fixes QUVA-Lab#103 Signed-off-by: Kale Kundert <[email protected]>
This pattern is replaced with `for x in unique_ever_seen(y)`, which also removes duplicate elements, but is guaranteed to always produce elements in the same order as in the input list. Signed-off-by: Kale Kundert <[email protected]>
a084cfc
to
a4188a2
Compare
Signed-off-by: Kale Kundert <[email protected]>
After giving this PR another look, I decided that converting the basis expansion indices from buffers into normal attributes wasn't quite the right approach. The problem is that |
This PR fixes two errors I encountered when trying to resume training a model from a checkpoint. The first involves basis expansion, and the second involves batch normalization.
Basis expansion
Here's a simplified example that reproduces the first error I encountered:
Here's the resulting stack trace:
The error ultimately happens because
BlocksBasisExpansion
usesmeshgrid()
to create itsin_indices_*
andout_indices_*
buffers. The significance ofmeshgrid()
is that it uses stride tricks to save space. For example, to make a tensor where every row is the same,meshgrid()
will only allocate a single row, then set the stride such that that same memory gets reused for every row. This ends up causing problems becauseModule.load_state_dict()
usesTensor.copy_()
to copy the checkpointed parameters/buffers back into the module. This operation fails if "more than one element of the written-to tensor refers to a single memory location", i.e. if the destination tensor is using stride tricks.Here's the actual code in question:
escnn/escnn/nn/modules/basismanager/basisexpansion_blocks.py
Lines 128 to 134 in fec08a3
I didn't mention the$\psi_1 \oplus \psi_0 \oplus \psi_1$ representation in the above example specifically so that the stride tricks would be kept, and thus trigger the error.
reshape()
calls above, but they're significant in that they usually—but don't always—get rid of the stride tricks. I constructed theThe obvious solution is to do what the error message suggests, and call
clone()
afterreshape()
. I can confirm that this works, but after looking at the code more closely, I think the best solution is to simply not store these indices in a buffer at all. My understanding is that buffers are for things like the running averages in batch normalization layers: data that aren't subject to optimization (i.e. not parameters), but that still change over the course of a training run and need to be restored from checkpoints. These indices don't ever change, so there's no reason for them to be buffers. They can just be normal object attributes, and the whole problem of loading them from checkpoints goes away.Batch normalization
Here's a simplified example that reproduces the second error I encountered:
To get the error, you first need to run this script with the
-k
option to create a checkpoint. You then need to run the script without the-k
option a couple of times, because the crash isn't deterministic. Eventually, you'll get the following stack trace:This error happens because
_IIDBatchNorm
registers its buffers in a random order. More specifically,_IIDBatchNorm
uses a set to get rid of duplicate representations, and then registers its buffers within a loop over that set. But set iteration order actually changes each time python runs, because python chooses a different random value to incorporate into the hash values of some built-in types each time it starts. This apparently helps protect servers written in python from DOS attacks.The actual crash happens when the optimizer tries to update the parameters after having been restored from a checkpoint. The checkpoint contains some metadata on each parameter, stored in whatever order the parameters were originally generated in. When the optimizer is reconstituted with the parameters in a different order, the result is that checkpointed metadata will get applied to the wrong parameters. The best case scenario at this point is for the program to crash, which happens when the parameters have incompatible dimensions. The worst case scenario is that the program doesn't crash, and instead effectively shuffles the metadata. I believe this will happen if each different representation has the same multiplicity. Models such as the example SE(3) CNN might exhibit this behavior.
The solution is to iterate over the full list of representations, and to manually remove duplicates. This guarantees that the parameters will be generated in the same order every time. While I was making this fix, I noticed the same bug in the
GnormBatchNorm
module, so I fixed it there, too.