You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have been noticing a slowdown on training that was introduced by our dataloader. Upon further checking, we identified the issue coming from the fact that our dataset class is maintaining a bunch of very large lists.
Background
Each logical shard maintains a list of (dataset_id, shard_id, doc_id) in order to track the document. e.g. ("c4", 3, 110) refers to the 110th document inside the file dataset_root_folder/dataset=c4/xxx.part3.arrrow. When we distribute billions of documents over the thousands of logical shard workers, each logical shard worker gets such a list of millions of (dataset_id, shard_id, doc_id) tuples. So in total we are maintaining hundreds of GBs worth of lists internally.
And why we did this at first place? datasets are assumed not shuffled and thus we need to shuffle our billions of (dataset_id, shard_id, doc_id), so each logical shards maintains a shuffled list that contains millions of such tuples. Such kind of list has to be materialized at certain point (even we do lazy init or something similar) in order to have our dataloader stateful - we need to know and checkpoint exactly which documents are visited and which are to be visited and in what order, so that we can recover a training flawlessly in a deterministic fashion.
Solution
If we peel the onion here completely, the question actually boils down to:
how can we maintain a list that: is truly stateful, provides random reading, and provides easy checkpointing and recovery.
This leads us to leverage LCG (Linear congruential generator) and utilize the "stateful-ness" of LCG to achieve the stateful-ness of the list.
A quick overview of the LCG we built for an arbitrary sized list:
We have been noticing a slowdown on training that was introduced by our dataloader. Upon further checking, we identified the issue coming from the fact that our dataset class is maintaining a bunch of very large lists.
Background
Each logical shard maintains a list of
(dataset_id, shard_id, doc_id)
in order to track the document. e.g. ("c4", 3, 110) refers to the 110th document inside the filedataset_root_folder/dataset=c4/xxx.part3.arrrow
. When we distribute billions of documents over the thousands of logical shard workers, each logical shard worker gets such a list of millions of(dataset_id, shard_id, doc_id)
tuples. So in total we are maintaining hundreds of GBs worth of lists internally.And why we did this at first place? datasets are assumed not shuffled and thus we need to shuffle our billions of
(dataset_id, shard_id, doc_id)
, so each logical shards maintains a shuffled list that contains millions of such tuples. Such kind of list has to be materialized at certain point (even we do lazy init or something similar) in order to have our dataloader stateful - we need to know and checkpoint exactly which documents are visited and which are to be visited and in what order, so that we can recover a training flawlessly in a deterministic fashion.Solution
If we peel the onion here completely, the question actually boils down to:
how can we maintain a list that: is truly stateful, provides random reading, and provides easy checkpointing and recovery.
This leads us to leverage LCG (Linear congruential generator) and utilize the "stateful-ness" of LCG to achieve the stateful-ness of the list.
A quick overview of the LCG we built for an arbitrary sized list:
and validation:
The text was updated successfully, but these errors were encountered: