Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A revisit on improving the performance of Data Loader #70

Open
lchu6 opened this issue Apr 5, 2024 · 2 comments
Open

A revisit on improving the performance of Data Loader #70

lchu6 opened this issue Apr 5, 2024 · 2 comments

Comments

@lchu6
Copy link
Contributor

lchu6 commented Apr 5, 2024

We have been noticing a slowdown on training that was introduced by our dataloader. Upon further checking, we identified the issue coming from the fact that our dataset class is maintaining a bunch of very large lists.

Background

Each logical shard maintains a list of (dataset_id, shard_id, doc_id) in order to track the document. e.g. ("c4", 3, 110) refers to the 110th document inside the file dataset_root_folder/dataset=c4/xxx.part3.arrrow. When we distribute billions of documents over the thousands of logical shard workers, each logical shard worker gets such a list of millions of (dataset_id, shard_id, doc_id) tuples. So in total we are maintaining hundreds of GBs worth of lists internally.

And why we did this at first place? datasets are assumed not shuffled and thus we need to shuffle our billions of (dataset_id, shard_id, doc_id), so each logical shards maintains a shuffled list that contains millions of such tuples. Such kind of list has to be materialized at certain point (even we do lazy init or something similar) in order to have our dataloader stateful - we need to know and checkpoint exactly which documents are visited and which are to be visited and in what order, so that we can recover a training flawlessly in a deterministic fashion.

Solution

If we peel the onion here completely, the question actually boils down to:
how can we maintain a list that: is truly stateful, provides random reading, and provides easy checkpointing and recovery.
This leads us to leverage LCG (Linear congruential generator) and utilize the "stateful-ness" of LCG to achieve the stateful-ness of the list.

A quick overview of the LCG we built for an arbitrary sized list:

# ZX81, cc65, Knuth and H. W. Lewis
LCG_PARAMS = [(2 ** 16 + 1, 75, 74), (2 ** 23, 65793, 4282663), (2 ** 32, 1664525, 1013904223)]

class LCG:
    def __init__(self, size, seed=42):
        self.size = size
        self.state = seed
        for params in LCG_PARAMS:
            if size <= params[0]:
                self.m, self.a, self.c = params
                break

    def _next(self):
        self.state = (self.a * self.state + self.c) % self.m
        return self.state

    def next(self):
        while True:
            res = self._next()
            if res < self.size:
                return res

and validation:

selector = LCG(1000000)
res = [selector.next() for _ in range(1000000)]
expected = list(range(1000000))
assert sorted(res) == expected
@lchu6 lchu6 mentioned this issue Apr 5, 2024
@lchu6
Copy link
Contributor Author

lchu6 commented Apr 5, 2024

@daviswer

This comment was marked as outdated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants