Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking all issues related to pytorch lightning support here #89

Open
2 of 6 tasks
siddharth9820 opened this issue Jul 1, 2024 · 0 comments
Open
2 of 6 tasks
Milestone

Comments

@siddharth9820
Copy link
Collaborator

siddharth9820 commented Jul 1, 2024

  • AxoNN and pytorch lightning together do not behave well in slurm interactive sessions. reason - PL wants to grab ranks and world sizes from MPI (this is also a bug on their end as it fails to detect slurm env variables in interactive jobs), whereas AxoNN disables automatic init of MPI as soon as it is imported (https://github.com/axonn-ai/axonn/blob/develop/axonn/axonn.py#L17-L23). A potential solution could be to initialize MPI manually in the Strategy constructor (https://github.com/axonn-ai/axonn/blob/develop/axonn/axonn.py#L17-L23 - these lines should be moved inside the AxoNNStrategy constructor.)
  • Per device batch sizes - Pytorch dataloaders expect the user to pass the per device batch sizes instead of the global batch sizes (bad design imo). Now the per device batch size is easy to calculate in pure data parallel/FSDP/HSDP setting because it is simply the global batch size divided by the number of GPUs. It is a bit more complicated with AxoNN, since row/column parallel GPUs see the same shard of the batch. So, we should have a utility function to calculate the per device batch size for the user in fabric.
  • Communication optimizations - As of now, if the user wants to employ our overlap optimizations, they need to do their forward and backward passes under the contextmanager returned by fabric._strategy.optimize_communication(model). This is not a standard pytorch lightning API call. Maybe we should have an argument in the constructor called overlap_communication and if that is set to True we should wrap the forward and the backward passes inside the overlap_communication context manger internally.
  • Checkpoint combiner - have an AxoNN script callable as a python module python -m axonn.intra_layer.combine_checkpoints
  • Support pretraining, finetuning, and inference in our fork of litgpt. Efficient inference would require us to create a parallel GPT implementation in the expert mode.
  • In the documentation, we should say that the only way to auto-parallelize a model is through fabric.init_module. fabric.setup... does nothing. The other way is to manually replace nn.Linear with axonn.intra_layer.Linear in your model definitions.
@siddharth9820 siddharth9820 added this to the v0.2.0 milestone Jul 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant