Training modern ML models at scale always looks simple on slides: add more GPUs, increase batch size, call it “distributed,” and enjoy faster results. In reality, scaling introduces as many problems as it solves—data bottlenecks, process crashes, manual orchestration scripts, and infrastructure that becomes harder to reason about over time.
I joined the High-Performance and Robust Model Training With PyTorch and Ray Train workshop because I wanted to see whether Ray’s training abstractions could simplify the parts of distributed training that typically turn into glue code: cluster setup, DDP orchestration, checkpoint handling, and fault recovery. I already use Ray in other parts of my workflow (for data processing and multi-agent tasks), so testing it as a training layer felt like a natural next step.
This is the core pipeline we worked with during the session—Ray handles orchestration, worker lifecycle, dataset sharding, and distributed execution, while the PyTorch Lightning training loop stays mostly intact.
The session was built around one theme: keeping the training code nearly identical, while letting Ray manage everything that normally requires custom scripts or vendor-specific tooling. Topics included:
Running the same training loop on a laptop or a multi-GPU cluster
Ray Data for efficient ingestion and preprocessing across workers
How Ray Train integrates with PyTorch Lightning without “framework lock-in”
Fault-tolerant execution (workers can die and resume mid-epoch)
Where Ray replaces vs. reuses your existing code
Instead of rewriting your training logic, you introduce Ray in a few structured places: scaling config, trainer strategy, run config, and the per-worker loop.
Setting up Ray Train
Distributed configuration is handled through ScalingConfig, which controls how many workers are launched and whether they use GPUs:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
That one object determines how far the same training script scales—local CPU, single GPU, or multi-GPU cluster—without changing the rest of the code.
Integrating Ray with an existing Lightning trainer
Here’s the same Lightning trainer I had before, but wrapped with Ray’s distributed strategy and environment:
trainer = pl.Trainer(
max_steps=config["max_steps"],
max_epochs=config["max_epochs"],
accelerator="gpu",
precision="bf16-mixed",
devices="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_dataloader)
No extra torch.distributed.launch scripts, no environment variable setup, no explicit init-process-group calls. Ray handles the backend wiring, but Lightning stays in control of the training logic.
Defining where training state and artifacts live
Ray uses a RunConfig to define experiment metadata, logging, and checkpoints:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining"
run_config = ray.train.RunConfig(
name=experiment_name,
storage_path=storage_path
)
This removes the usual “log folder sprawl” that happens when running multiple distributed experiments manually.
Worker-level training loop and dataset sharding
Ray exposes dataset shards directly in the worker loop, avoiding manual slicing or DistributedSampler setup:
def train_loop_per_worker(config: dict):
train_ds = ray.train.get_dataset_shard("train")
Each worker receives only the data it needs, trains independently, and synchronizes gradients through Ray’s DDP implementation.
One of the best parts of the workshop was realizing what didn’t change:
Before Ray
Custom multi-GPU launcher scripts
Manual distributed sampler setup
Restarts required re-running entire job
Infra knowledge embedded inside training loop
After Ray
Scaling defined in one line (num_workers)
Sharded data handled by Ray Data
Fault tolerance built in automatically
Trainer code remains almost identical
Ray didn’t replace Lightning—it only replaced the parts of the system I didn’t want to maintain myself.
The session demonstrated how Ray Train allows a single-GPU workload to scale horizontally without altering the training loop. That applies whether you move from 1 → 2 GPUs, 2 → 8, or from a single machine to a cluster. The important shift is not “more GPUs = faster,” but that the scaling boundary moves out of the model code and into configuration.
The takeaway wasn't about chasing perfect linear speedups—it was about reducing the engineering cost of scaling, so performance tuning becomes iterative rather than architectural.
Most of my existing PyTorch Lightning code migrated with small, targeted changes.
Ray Train handles orchestration, not modeling—so you’re not tied to Ray-specific APIs.
The debugging experience improves because Ray gives you visibility into workers, datasets, system load, and retries.
Dataset ingestion and training are now part of a single pipeline instead of two separate systems.
Scaling strategy is now config-driven, not script-driven.
How I expect to use this going forward?
The most immediate fit is integrating Ray Train into workflows that already rely on Ray for data preparation or task orchestration. Instead of handing off preprocessed outputs to a separate training system, data and training can share the same runtime, hardware pool, and failure model.
I also like that it doesn’t trap you in a “Ray-only” world—if the day comes when training needs to run outside Ray, the underlying Lightning code remains portable.
If you already have a working single-GPU Lightning training setup and want to scale it without rebuilding the surrounding infrastructure, Ray Train offers a pragmatic upgrade path. It doesn’t hide distributed training, but it removes the boilerplate that traditionally surrounds it.
The workshop was a useful way to see these abstractions in practice, not in marketing diagrams. I left with a clearer sense of where Ray fits into the ML stack—and how to adopt it incrementally instead of all at once.