-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rescalability via IBM dataset layers #1372
base: main
Are you sure you want to change the base?
Conversation
""" | ||
|
||
|
||
def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are tail elements just truncated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this code will distribute extra elements as evenly as possible, even if it's not perfect. Technically nothing breaks if you load into a worldsize that doesn't divide logical_shards evenly, you just end up with some shards progressing faster than others (since some devices now have an extra logical shard)
# Setup / loading flags | ||
self.is_setup = False | ||
|
||
def setup(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be mapped pretty easily to BaseNode.reset()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I thought so too!
[setattr(self, flag, state_dict[self.statename(flag)]) for flag in self.state_params] | ||
|
||
|
||
class _WrapperDataset(_StatefulDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking out loud: could we do this with mixins instead of extending the type hierarchy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, what's the benefit of having two subclasses?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base reader is a _StatefulDataset but not a _WrapperDataset, so the distinction is meaningful, but yeah the only reason it's not mixins is because of my lack of familiarity with building mixins!
while True: | ||
ind = self.current_reader | ||
# Read doc | ||
out = next(data[ind]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is StopIteration handled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not, in this framework we just assume each iterator loops forever. Converting to a next()
based framework would make this pretty easy to handle though.
# Convert to tensor form | ||
out = {} | ||
for k, v in state_dict.items(): | ||
v = torch.tensor(v) | ||
if len(v.shape) == 0: | ||
k = k + ".scalar" | ||
v = v.unsqueeze(0) | ||
out[k] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done to satisfy DCP requirements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, dtensors have to have at least one dimension
#### ------------------------- CHECKPOINT FUNCTIONS ------------------------- #### | ||
|
||
|
||
def __pop_dstate(state, device_mesh, placements): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should create standard utilities to get these in torchdata #1337
self.current_reader = (self.current_reader + 1) % self.n_logicals | ||
yield out | ||
|
||
def state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{
my_children: [c.state_dict() for c in self.children],
scalar_state: self.scalar, # "my_string"
my_reshardale_state: tensor.array([1, 2, 3, 4, 5]), # 2d tensor
}
question: what happens if above state_dict gets passed to DCP?
Answer: it will fail because torch.tensor gets called on everything?
Andrew to follow up with @pradeepfn on this
assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???" | ||
# Flip list[dict[Any]] to dict[list[Any]] | ||
state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()} | ||
state_dict.update(_StatefulDataset.state_dict(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does self.current_reader
need to be stored too? ie for determinism in the case where resharding doesn't happen at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, self.current_reader
is stored here so that when we don't rescale, we'll maintain the same data ordering
self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) | ||
|
||
|
||
class ScalableShardDataset(_WrapperDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ScalableShardDataset(_WrapperDataset): | |
Protocol based instead | |
class ScalableShardDataset(BaseNode[T], Reshardable): |
data = [iter(d) for d in self.data] | ||
while True: | ||
ind = self.current_reader | ||
# Read doc | ||
out = next(data[ind]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we bound the number of open iterators/filepointers/etc in some way here while still maintaining re-shardability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly run into end-of-epoch problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove the assumption of indexable files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need indexable files in order to support fractional file ownership between logical shards. I don't see a way around that one if we want to maintain rough epoch boundaries. We can bound the number of open files though. We'd discussed how the reshard layer makes sense to either go at the very top of the pipeline (as in here), or at the very bottom just above the base reader (as in the older version). If we move the reshard layer to the very bottom - and maybe merge it with the base reader(s) - we can allow base readers to share file pointers. Then instead of partitioning all data across shards based on file size, we can instead split every data file over all the workers. That way all logical shards on the same device will share a single file pointer, and assuming they all finish at roughly the same time, we'll end up maintaining no more than two open files per dataset per physical worker (instead of one open file per dataset per logical worker, as here). Total pulls per file goes up if we do this, but I think this is still a good tradeoff.
logical_shard_states = [d.state_dict() for d in self.data] | ||
assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???" | ||
# Flip list[dict[Any]] to dict[list[Any]] | ||
state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: len(logical_shard_state) is always 1 ? Looking at the list comprehension, it seems so. But I do not understand why. thanks.
Update:
Think I got it. The keys are same across sub-datasets, therefore, we use the logical_shard_state[0].keys() as anchor.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly!
writer, | ||
) | ||
# Write nondistributed state dict | ||
torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also make this state part of the main checkpoint? We can use the torch.save serialization (output is bytestream), to store the state as part of the DCP checkpoint.
buff = io.BytesIO()
torch.save(state, buff) # or we can serialize individual keys in the state-dict. But no strong need.
buff.seek(0)
assume the unique key is 'trainer_dataloader_state_rank_k' -> "
update the dstate with new key -> value.
checkpoint.save(dstate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some context on how DCP handles regular ( non-dtensor) tensor and non-tensor values found in state dict.
1\ regular tensors ( not distributed tensors) are saved against a unique fqn. The state dict can have { 'foo.bar.1' : torch.ones((2,3), dtype=float). As long as the fqns do not collide between different ranks we are good.
2\ DCP can save/load non-tensor values found in state dict as well. That means the value can be, a string, list of ints or a ByteIO object. Under the hood DCP uses torch.save/torch.load to serialize/deserialize these non-tensor (we call them blobs) values during checkpoint save/load. E.g. UT -> https://github.com/pytorch/pytorch/blob/11bb94b7eaf272da8e1f1dfd94c3d8872247e895/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py#L367
Therefore, we can use the dataloader state-dict values as it is, to save them as part of the DCP checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good to know! How are non-dtensor values like these handled during rescaling? I guess if all fqns are unique, they can all be broadcast over all ranks?
i.e. with worldsize 2, if rank 0 saves "field_0" in its checkpoint shard and rank 1 saves "field_1" in its checkpoint shard, what happens when we load into worldsize 4?
ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x]) | ||
# Check that number of loaders matches | ||
if ckp_ws == loader.dataset.worldsize: | ||
state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q: what are we missing out, if we just set the;
data_loader_state= base
without considering the rescaling property of the training run ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean if we always set state=base
, and just drop the non-reshardable state every time? Nothing breaks, we just lose the ability to maintain the original data ordering when we're not rescaling because every local worker index resets to zero.
Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers (must divide partition number evenly), stateful guarantees are maintained: seen data is not revisited until the next epoch.
Based off of the datasets in the corresponding IBM torchtitan PR, but uses StatefulDataLoader and DCP to manage checkpointing from the master process. Sampling and Dummy datasets are included for demo purposes. It is possible that the IBM datasets can be merged into the existing node structure.
Changes
torchdata/stateful_dataloader/ibm_rescalable.py
examples/ibm_rescaling/rescaling_demo.py