You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
AxoNN and pytorch lightning together do not behave well in slurm interactive sessions. reason - PL wants to grab ranks and world sizes from MPI (this is also a bug on their end as it fails to detect slurm env variables in interactive jobs), whereas AxoNN disables automatic init of MPI as soon as it is imported (https://github.com/axonn-ai/axonn/blob/develop/axonn/axonn.py#L17-L23). A potential solution could be to initialize MPI manually in the Strategy constructor (https://github.com/axonn-ai/axonn/blob/develop/axonn/axonn.py#L17-L23 - these lines should be moved inside the AxoNNStrategy constructor.)
Per device batch sizes - Pytorch dataloaders expect the user to pass the per device batch sizes instead of the global batch sizes (bad design imo). Now the per device batch size is easy to calculate in pure data parallel/FSDP/HSDP setting because it is simply the global batch size divided by the number of GPUs. It is a bit more complicated with AxoNN, since row/column parallel GPUs see the same shard of the batch. So, we should have a utility function to calculate the per device batch size for the user in fabric.
Communication optimizations - As of now, if the user wants to employ our overlap optimizations, they need to do their forward and backward passes under the contextmanager returned by fabric._strategy.optimize_communication(model). This is not a standard pytorch lightning API call. Maybe we should have an argument in the constructor called overlap_communication and if that is set to True we should wrap the forward and the backward passes inside the overlap_communication context manger internally.
Checkpoint combiner - have an AxoNN script callable as a python module python -m axonn.intra_layer.combine_checkpoints
Support pretraining, finetuning, and inference in our fork of litgpt. Efficient inference would require us to create a parallel GPT implementation in the expert mode.
In the documentation, we should say that the only way to auto-parallelize a model is through fabric.init_module. fabric.setup... does nothing. The other way is to manually replace nn.Linear with axonn.intra_layer.Linear in your model definitions.
The text was updated successfully, but these errors were encountered:
AxoNNStrategy
constructor.)fabric._strategy.optimize_communication(model)
. This is not a standard pytorch lightning API call. Maybe we should have an argument in the constructor calledoverlap_communication
and if that is set to True we should wrap the forward and the backward passes inside the overlap_communication context manger internally.python -m axonn.intra_layer.combine_checkpoints
fabric.init_module
.fabric.setup...
does nothing. The other way is to manually replacenn.Linear
withaxonn.intra_layer.Linear
in your model definitions.The text was updated successfully, but these errors were encountered: