Managing Large Datasets#
Distributed Data Parallel is a data loading technique useful for multi-host training on large datasets.
When to use?#
When you are running multihost training, and your dataset is too large to fit in memory on a single host.
When large batch processing is the bottleneck in your training pipeline.
How it works?#
- Split your dataset into N shards, where N is the number of hosts. Each shard having the same number of samples.
Tip
Use
ray.data.split()on your Ray dataset, orray.data.streaming_split()for streamed datasets
Each host loads and processes data only from its assigned shard
For example, with a 32,000-sample dataset split across 4 hosts, each host manages 8,000 samples instead of the full dataset
Example DDP Workflow#
Hereβs an example of how to implement DDP in Kithara:
@ray.remote(resources={"TPU": num_chips_per_host})
def main(train_ds):
...
dataloader = Kithara.Dataloader(
train_ds,
per_device_batch_size=1, # or any other batch size
dataset_is_sharded_per_host=True, # Enable DDP
)
...
trainer.train()
train_ds: List[Any] = split_dataset(ray_dataset, num_hosts=num_tpu_hosts)
ray.get(
[
main.remote(train_ds[i]) for i in range(num_tpu_hosts)
]
)