You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Model summary does not display submodules in execution order. I believe it's sorted in alphebetical order.
Reproduce
class VAE(eg.Module):
@eg.compact
def __call__(self, img):
z = EncoderConv()(img)
y = DecoderConv()(z)
return dict(img=y, latents=z)
Expected behavior
Have the rows list in execution order, i.e encoder -> decoder would be great. I might recalled this wrong but I don't have this issue in the past.
Additional context
I found the rows are output by jax.tree_flatten, so it might not be trivial to sort the rows again based on execution order.
flat, _ = jax.tree_flatten(self)
tree_part_types: tp.Tuple[tp.Type[types.TreePart], ...] = tuple(
{
field_info.kind
for field_info in flat
if utils._generic_issubclass(field_info.kind, types.TreePart)
}
The text was updated successfully, but these errors were encountered:
Hey @lkhphuc, thanks for bringing this up. Just guessing but I think this has to do with the fact that JAX sort dict keys alphabetically. I'll look into it. Should we move this to Treex?
Describe the bug
Model summary does not display submodules in execution order. I believe it's sorted in alphebetical order.
Reproduce
Expected behavior
Have the rows list in execution order, i.e encoder -> decoder would be great. I might recalled this wrong but I don't have this issue in the past.
Additional context
I found the rows are output by
jax.tree_flatten
, so it might not be trivial to sort the rows again based on execution order.The text was updated successfully, but these errors were encountered: