Deep learning at scale: lesson from experiments from Perlmutter
Nov 17, 2025
| System | JUWELS (JSC) | Perlmutter (NERSC) | JUPITER (JSC, EuroHPC) | Colossus (xAI) |
|---|---|---|---|---|
| Year | 2018 / 2020 | 2021 | 2024–2025 | 2024–2025 |
| Purpose | Petascale HPC | HPC + GPU acceleration | Exascale HPC + AI | Large-scale AI training |
| Compute | Xeon + EPYC/A100 | EPYC + 4× A100 (GPU nodes) | ~6000 GH200 nodes (4× GH200) | 100k–200k H100/H200 GPUs |
| GPU Count | 3,744 A100 GPUs | ~7,168 A100 GPUs | ~24,000 GH200 GPUs | 100k+ Hopper-class GPUs |
| CPU | Xeon / EPYC | AMD EPYC 7763 | Grace (4×72 cores/node) | Minimal (GPU-centric) |
| Peak Perf | ~70 PFLOP/s | ~70 PFLOP/s (GPU partition) | ~1 EFLOP/s (FP64) | AI-precision PFLOPs (not FP64) |
| Network | HDR InfiniBand | Slingshot 11 | NDR200 InfiniBand | Spectrum-X Ethernet |
| Power | ~10–20 MW | ~20 MW | ~18 MW | 150–300 MW |
| Use-case | Scientific HPC | HPC + ML/AI | HPC + AI hybrid | LLM training |
| Cost | ~€145M | ~$600M | ~€500M | ~$10–15B |
The problem
Figure 2: (a) The multi-layer transformer architecture that utilizes the Adaptive Fourier Neural Operator with shared MLP and frequency soft-thresholding for spatial token mixing. The input frame is first divided into a h × w grid of patches, where each patch has a small size p × p × c. Each patch is then embedded in a higher dimensional space with high number of latent channels and position embedding is added to form a sequence of tokens. Tokens are then mixed spatially using AFNO, and subsequently for each token the latent channels are mixed. This process is repeated for L layers, and finally a linear decoder reconstructs the patches for the next frame from the final embedding. The right-hand panels describe the FourCastNet model’s additional training and inference modes: (b) two-step fine-tuning, (c) backbone model that forecasts the 20 variables in Table 1 with secondary precipitation diagnostic model (note that p(k + 1) denotes the 6 hour accumulated total precipitation that falls between k + 1 and k + 2 time steps) (d) forecast model in free-running autoregressive inference mode. For details, see Pathak et al. (2022), FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators, arXiv:2202.11214. Dataset
ERA5 is more than just a climate dataset—it’s a cornerstone for large-scale climate research and AI. At approximately 0.25° spatial resolution, every global surface field comprises over a million grid points. Factoring in 100+ vertical levels, that’s hundreds of millions of data values per hour. Spanning 24 hours a day, 365 days a year, and more than 80 years (starting from 1940), ERA5 collects trillions of data points, yielding multiple petabytes of compressed information. This data volume enables researchers to investigate global and regional weather patterns, train high-capacity machine learning models, or reconstruct the hour-by-hour evolution of Earth’s climate from a uniquely data-rich perspective.Special thanks and credit to the following contributors for the slides and code underpinning this research: Hashank Subramanian, Steven Farrell, Josh Romero, Thorsten Kurth, and Corneel Casert. Resources are available in the NERSC repo.
Tue Nov 18 19:28:52 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15 Driver Version: 570.86.15 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 NVL On | 00000001:00:00.0 Off | 0 |
| N/A 39C P0 64W / 400W | 1MiB / 95830MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 NVL On | 00000002:00:00.0 Off | 0 |
| N/A 39C P0 67W / 400W | 1MiB / 95830MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
import torch.nn.functional as F
import torch
import torch.nn as nn
from functools import partial
from networks.helpers import DropPath, trunc_normal_
# mp stuff
from utils import comm
from distributed.layers import (
DistributedMatmul,
DistributedMLP,
DistributedAttention,
DistributedLayerNorm,
)
from distributed.helpers import compute_split_shapes
from distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region
class MLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = True
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
cp_shapes=None,
):
super().__init__()
mlp_hidden_dim = int(dim * mlp_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if (comm.get_size("tp-cp")) > 1:
# model parallelism is on, distribute the layers
# tp: tensor parallel shards the weights
# cp: context parallel shards the sequence
self.norm1 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp")
self.norm2 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp")
self.attn = DistributedAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
comm_tp_name="tp",
comm_cp_name="cp",
cp_shapes=cp_shapes,
)
self.mlp = DistributedMLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
comm_tp_name="tp",
comm_cp_name="cp",
)
else:
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.mlp = MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
y = self.attn(self.norm1(x))
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
# grid of patches
self.h = img_size[0] // patch_size
self.w = img_size[1] // patch_size
num_patches = self.h * self.w
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class VisionTransformer(nn.Module):
"""
Vision Transformer Architecture — Structure Diagram (Left-to-Right, Detailed)
+-----------------+ +-----------------+ +-------------------------------+ +-----------------+ +-----------------+
| Input Image |-->| PatchEmbed |-->| + Positional Embedding |-->| Dropout |-->| Scatter |
| (B, C, H, W) | | Conv2d splits | | (learned, per token) | | (pos_drop) | | to context |
| | | into patches | | | | | | parallel |
+-----------------+ +-----------------+ +-------------------------------+ +-----------------+ +-------+---------+
|
v
----------------------------------------------- Transformer Encoder Blocks ---------------------------------------+-------------
Repeats "depth" times: (for i = 1 to depth)
+--------------+ +------------+ +--------------+ +------------+ +-------------+ +------------+ +-------------+
| Input |-->| Norm1 |-->| Multi-Head |-->| DropPath |-->| Residual |-->| Norm2 |-->| MLP |
| (B, N, D) | |(Layer/Dist)| | Attention | |(stoch. | | Connection | | | | (2 layers, |
| | | | | (LN -> MHSA | | depth, | | (+Input) | | | | GELU, |
| | | | | QKV-Proj) | | optional) | | | | | | Dropout) |
+--------------+ +------------+ +--------------+ +------------+ +------+------+ +------------+ +------|------+
^ |
| v
| +--------------+------+
|------------------>| DropPath |
+---------------------+
--------------+------------------------------------------------------------------------------------------------------------------
|
V
+-------------+-------------+ +------------------------------+ +-------------------------------+ +------------------------------+
| LayerNorm / |-->| Gather from context |-->| Head: Linear projection |-->| Reshape & Rearrangement |
| DistributedLayerNorm | | parallel (if needed) | | to patch×patch×channels | | (to output image shape) |
+---------------------------+ +------------------------------+ +-------------------------------+ +------------------------------+
Key Properties:
- patch_embed: Converts image to sequence of patch tokens
- pos_embed: Learnable position embeddings per token
- blocks: N repeated Transformer blocks with LayerNorm, MHSA, MLP, DropPath and residual
- norm: Final normalization; distributed option if needed
- head: Linear projection from embedding space to reconstructed output
- All computation left-to-right (data flow shown above)
Notation:
B = batch size, C = in channels, H,W = input dims, N = num patches, D = embed_dim
"""
def __init__(
self,
img_size=[224, 224],
patch_size=16,
in_chans=3,
out_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
**kwargs
):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
self.img_size = img_size
self.out_ch = out_chans
self.drop_rate = drop_rate
# ─── Patch Embedding ───────────────────────────────────────────────────────────
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.embed_dim,
)
num_patches = self.patch_embed.num_patches
# ─── Parallel Shapes for Sequence/Context ──────────────────────────────────────
self.cp_shapes = compute_split_shapes(num_patches, comm.get_size("cp"))
# ─── Learnable Positional Embedding ────────────────────────────────────────────
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# ─── Stochastic Depth/DropPath Scheduling ──────────────────────────────────────
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
]
# ─── Stacked Transformer Blocks ────────────────────────────────────────────────
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
cp_shapes=self.cp_shapes,
)
for i in range(depth)
]
)
# ─── Final LayerNorm (possibly distributed) ────────────────────────────────────
if (comm.get_size("tp-cp")) > 1:
self.norm = DistributedLayerNorm(embed_dim, comm_tp_name="tp", comm_cp_name="cp")
else:
self.norm = norm_layer(embed_dim)
# ─── Head: Linear map to patch × patch × channel output ────────────────────────
self.out_size = self.out_ch * self.patch_size * self.patch_size
self.head = nn.Linear(embed_dim, self.out_size, bias=False)
# ─── Initialization ────────────────────────────────────────────────────────────
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add positional encoding to each token
x = x + self.pos_embed
return self.pos_drop(x)
def forward_head(self, x):
B, _, _ = x.shape # B x N x embed_dim
x = x.reshape(B, self.patch_embed.h, self.patch_embed.w, self.embed_dim)
B, h, w, _ = x.shape
# apply head
x = self.head(x)
x = x.reshape(shape=(B, h, w, self.patch_size, self.patch_size, self.out_ch))
x = torch.einsum("nhwpqc->nchpwq", x)
x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1]))
return x
def forward(self, x):
x = self.prepare_tokens(x)
# split sequence if cp is on (shape of x is (batch, seq, embed))
x = scatter_to_parallel_region(x, dim=1, comm_name="cp")
# if cp is on, each block operates on a sequence shard
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# gather sequence if cp is on
x = gather_from_parallel_region(x, dim=1, shapes=self.cp_shapes, comm_name="cp")
x = self.forward_head(x)
return x
def ViT(params, **kwargs):
model = VisionTransformer(
img_size=tuple(params.img_size),
in_chans=params.n_in_channels,
out_chans=params.n_out_channels,
patch_size=params.patch_size,
embed_dim=params.embed_dim,
depth=params.depth,
num_heads=params.num_heads,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop_path_rate=float(params.dropout),
drop_rate=float(params.dropout),
attn_drop_rate=float(params.dropout),
**kwargs
)
return model
Data loader
root@86ffe15cc233:/dli/task# \
ENABLE_PROFILING=1 \
PROFILE_OUTPUT=baseline-dli_dw8 \
./submit_dli.sh \
--config=short \
--num_data_workers 8 \
--run_num=nw8
Enabling profiling...
+ mpirun --allow-run-as-root -np 1 bash -c '
source export_DDP_vars.sh
nsys profile \
--trace=cuda,cublas,nvtx \
--kill none \
-c cudaProfilerApi \
-f true \
-o /dli/task/logs/baseline-dli_dw8 \
python train.py \
--config=short \
--num_data_workers 8 \
--run_num=nw8
'
2025-11-17 21:32:08,311 - root - INFO - ------------------ Configuration ------------------
2025-11-17 21:32:08,311 - root - INFO - Configuration file: /dli/task/config/ViT.yaml
2025-11-17 21:32:08,311 - root - INFO - Configuration name: short
2025-11-17 21:32:08,311 - root - INFO - limit_nsamples 512
2025-11-17 21:32:08,311 - root - INFO - limit_nsamples_val 128
2025-11-17 21:32:08,311 - root - INFO - num_iters 128
2025-11-17 21:32:08,311 - root - INFO - embed_dim 384
2025-11-17 21:32:08,311 - root - INFO - depth 12
2025-11-17 21:32:08,311 - root - INFO - dropout 0.0
2025-11-17 21:32:08,311 - root - INFO - patch_size 8
2025-11-17 21:32:08,311 - root - INFO - num_heads 8
2025-11-17 21:32:08,311 - root - INFO - model_backend pytorch
2025-11-17 21:32:08,311 - root - INFO - img_size [360, 720]
2025-11-17 21:32:08,311 - root - INFO - dt 1
2025-11-17 21:32:08,311 - root - INFO - global_batch_size 16
2025-11-17 21:32:08,312 - root - INFO - amp_mode none
2025-11-17 21:32:08,312 - root - INFO - enable_fused False
2025-11-17 21:32:08,312 - root - INFO - enable_jit False
2025-11-17 21:32:08,312 - root - INFO - expdir /dli/task/logs
2025-11-17 21:32:08,312 - root - INFO - lr_schedule cosine
2025-11-17 21:32:08,312 - root - INFO - lr 0.0005
2025-11-17 21:32:08,312 - root - INFO - warmup 0
2025-11-17 21:32:08,312 - root - INFO - optimizer Adam
2025-11-17 21:32:08,312 - root - INFO - data_loader_config pytorch
2025-11-17 21:32:08,312 - root - INFO - num_data_workers 8
2025-11-17 21:32:08,312 - root - INFO - n_in_channels 20
2025-11-17 21:32:08,312 - root - INFO - n_out_channels 20
2025-11-17 21:32:08,312 - root - INFO - train_data_path /data/train
2025-11-17 21:32:08,312 - root - INFO - valid_data_path /data/valid
2025-11-17 21:32:08,312 - root - INFO - inf_data_path /data/test
2025-11-17 21:32:08,312 - root - INFO - time_means_path /data/stats/time_means.npy
2025-11-17 21:32:08,312 - root - INFO - global_means_path /data/stats/global_means.npy
2025-11-17 21:32:08,312 - root - INFO - global_stds_path /data/stats/global_stds.npy
2025-11-17 21:32:08,312 - root - INFO - wireup_info env
2025-11-17 21:32:08,312 - root - INFO - wireup_store tcp
2025-11-17 21:32:08,312 - root - INFO - amp_enabled False
2025-11-17 21:32:08,312 - root - INFO - amp_dtype torch.float32
2025-11-17 21:32:08,312 - root - INFO - ---------------------------------------------------
2025-11-17 21:32:08,513 - root - INFO - rank 0, begin data loader init
2025-11-17 21:32:08,551 - root - INFO - Getting file stats from /data/train/2012.h5
2025-11-17 21:32:08,552 - root - INFO - Overriding total number of samples to: 512
2025-11-17 21:32:08,552 - root - INFO - Number of samples per year: 1460
2025-11-17 21:32:08,552 - root - INFO - Found data at path /data/train. Number of examples: 512. Image Shape: 360 x 720 x 20
2025-11-17 21:32:08,552 - root - INFO - Getting file stats from /data/valid/2016.h5
2025-11-17 21:32:08,553 - root - INFO - Overriding total number of samples to: 128
2025-11-17 21:32:08,553 - root - INFO - Number of samples per year: 1460
2025-11-17 21:32:08,553 - root - INFO - Found data at path /data/valid. Number of examples: 128. Image Shape: 360 x 720 x 20
2025-11-17 21:32:08,553 - root - INFO - rank 0, data loader initialized
2025-11-17 21:32:08,943 - root - INFO - VisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(20, 384, kernel_size=(8, 8), stride=(8, 8))
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0-11): 12 x Block(
(drop_path): Identity()
(norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=384, out_features=384, bias=True)
(k): Linear(in_features=384, out_features=384, bias=True)
(v): Linear(in_features=384, out_features=384, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(mlp): MLP(
(fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(head): Linear(in_features=384, out_features=1280, bias=False)
)
2025-11-17 21:32:08,943 - root - INFO - Starting Training Loop...
/dli/task/train.py:176: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:22.)
tr_loss.append(loss.item() / world_size)
2025-11-17 21:32:48,380 - root - INFO - Time taken for epoch 1 is 31.060786 sec, avg 16.483807 samples/sec
2025-11-17 21:32:48,380 - root - INFO - Avg train loss=0.583575
2025-11-17 21:32:52,875 - root - INFO - Avg val loss=0.4340733289718628
2025-11-17 21:32:52,876 - root - INFO - Total validation time: 3.9601058959960938 sec
2025-11-17 21:33:23,642 - root - INFO - Time taken for epoch 2 is 30.764906 sec, avg 16.642339 samples/sec
2025-11-17 21:33:23,642 - root - INFO - Avg train loss=0.400669
2025-11-17 21:33:27,975 - root - INFO - Avg val loss=0.3799823522567749
2025-11-17 21:33:27,976 - root - INFO - Total validation time: 3.7919225692749023 sec
2025-11-17 21:33:59,814 - root - INFO - Time taken for epoch 3 is 31.836994 sec, avg 16.081920 samples/sec
2025-11-17 21:33:59,814 - root - INFO - Avg train loss=0.365027
2025-11-17 21:34:04,047 - root - INFO - Avg val loss=0.3569737672805786
2025-11-17 21:34:04,047 - root - INFO - Total validation time: 3.7000949382781982 sec
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-7df0.qdstrm'
[1/1] [0% ] baseline-dli_dw8.nsys-repProcessing events...
[1/1] [========================100%] baseline-dli_dw8.nsys-rep
Generated:
/dli/task/logs/baseline-dli_dw8.nsys-rep
2025-11-17 21:34:43,205 - root - INFO - Time taken for epoch 4 is 39.157385 sec, avg 13.075439 samples/sec
2025-11-17 21:34:43,206 - root - INFO - Avg train loss=0.352449
2025-11-17 21:34:47,627 - root - INFO - Avg val loss=0.3528955578804016
2025-11-17 21:34:47,627 - root - INFO - Total validation time: 3.882009744644165 sec
2025-11-17 21:34:47,631 - root - INFO - DONE ---- rank 0
Using multiple data workers—typically 4 or 8—is essential for keeping the GPU fed with data during training. Each data worker operates as an independent process (or thread, depending on the framework) that loads and preprocesses batches of data in parallel. By overlapping I/O and preprocessing operations with model computation, more data workers help to hide disk access latency, mitigate bottlenecks caused by slow data augmentations, and fully utilize available CPU resources. This setup reduces idle time between training steps, ensuring that the next mini-batch is ready right when the GPU needs it. The optimal number of workers depends on system resources, dataset characteristics, and the complexity of preprocessing, but in practice, adding more data workers will significantly improve throughput up to a certain point.
Using DALI can avoid this host-to-device memcpy bottleneck, as DALI performs a wide range of data loading and augmentation operations directly on the GPU. This leverages GPU acceleration for preprocessing, and DALI's C++ backend maintains its own highly efficient worker threads, improving concurrency and throughput far beyond traditional CPU-based data pipelines. # -----------------------------------------------------------------------------
# How DALI Interacts with CUDA for Accelerated Data Loading
# -----------------------------------------------------------------------------
#
# The DALI pipeline is tightly integrated with CUDA. Its workers and GPU operators move data
# efficiently from storage to device, avoiding Python and host memory bottlenecks.
#
# +---------------------+ (1) read/augment +-----------------+
# Data/FileSystem | DALI CPU Workers |--------------------->| CPU Buffer |
# (images, HDF5, +---------------------+ (NumPy, +-----------------+
# .npy, etc.) Pillow...) |
# | (2) memcpy (optional)
# | |
# | +--------------+ (3) DALI GPU ops +---------------------+
# +------->| DALI Pipeline|------------------->| CUDA Device Mem |
# +--------------+ (Crop, Normalize, +---------------------+
# Augment, etc) |
# (4) tensor returned to PyTorch on GPU, with no H2D copy needed.
# |
# [PyTorch training loop, tensor already in cuda:0, ready for GPU kernel]
#
# - DALI operators marked as "device='gpu'" run as CUDA kernels and operate on device memory.
# - No host-to-device PCIe transfers are needed after data is loaded into GPU memory.
# - The pipeline uses internal CUDA streams and events for asynchronicity:
# - Different sets of CUDA streams can be used for each worker/pipeline.
# - DALI synchronizes only when the output tensor is returned to PyTorch.
# - This approach is analogous to the CUDA buffer/execution region (see /distributed-gpu).
#
# CPU (Data Workers) CUDA Device (GPU) PyTorch Training Loop
# +---------------------+ +----------------------+ +-----------------------+
# | Disk/Net IO |--read-----> | DALI pipeline ops | ---> | model(input.cuda()) |
# | Decoding, Sharding | | (normalize, augment) | +-----------------------+
# | Preprocessing (CPU) | +----------------------+ │
# +---------------------+ │ │ Backprop
# | └------------------------┘
# sharding logic ensures
# each GPU/process gets
# the correct portion of data
#
# See: Memory regions, CUDA device buffer layout, and device-kernel launches (cf. the
# diagrams and commentary in /distributed-gpu).
class DaliDataLoader(object):
def get_pipeline(self):
pipeline = Pipeline(
batch_size=self.batch_size,
num_threads=2,
device_id=self.device_index,
py_num_workers=self.num_data_workers,
py_start_method="spawn",
seed=self.global_seed,
)
with pipeline: # get input and target from external source
inp, tar = fn.external_source(
source=esh.ERA5ES(
self.location,
self.train,
self.batch_size,
self.dt,
self.img_size,
self.n_in_channels,
self.n_out_channels,
self.num_shards,
self.shard_id,
self.limit_nsamples,
enable_logging=False,
seed=self.global_seed,
),
num_outputs=2,
layout=["CHW", "CHW"],
batch=False,
no_copy=True,
parallel=True,
)
# upload directly to GPU (DALI op executes as CUDA kernel)
inp = inp.gpu()
tar = tar.gpu()
if self.normalize:
inp = fn.normalize(
inp,
device="gpu",
axis_names="HW",
batch=False,
mean=self.in_bias,
stddev=self.in_scale,
)
tar = fn.normalize(
tar,
device="gpu",
axis_names="HW",
batch=False,
mean=self.out_bias,
stddev=self.out_scale,
)
pipeline.set_outputs(inp, tar)
return pipeline
def __init__(self, params, location, train, seed=333):
# set up seeds
self.global_seed = seed # same seed across all ranks
model_id = comm.get_world_rank() // comm.get_size("tp-cp-pp")
self.model_seed = self.global_seed + model_id # model-wise seed
self.local_seed = self.global_seed + comm.get_world_rank() # unique per-rank
self.num_data_workers = params.num_data_workers
self.device_index = torch.cuda.current_device()
self.batch_size = int(params.local_batch_size)
self.location = location
self.train = train
self.dt = params.dt
self.n_in_channels = params.n_in_channels
self.n_out_channels = params.n_out_channels
self.img_size = params.img_size
self.limit_nsamples = (
params.limit_nsamples if train else params.limit_nsamples_val
)
# load normalization stats
self.normalize = True
means = np.load(params.global_means_path)[0][: self.n_in_channels]
stds = np.load(params.global_stds_path)[0][: self.n_in_channels]
self.in_bias = means
self.in_scale = stds
means = np.load(params.global_means_path)[0][: self.n_out_channels]
stds = np.load(params.global_stds_path)[0][: self.n_out_channels]
self.out_bias = means
self.out_scale = stds
# set sharding for distributed, or local single process
if dist.is_initialized():
self.num_shards = params.data_num_shards
self.shard_id = params.data_shard_id
else:
self.num_shards = 1
self.shard_id = 0
# compute number of batches for __len__
extsource = esh.ERA5ES(
self.location,
self.train,
self.batch_size,
self.dt,
self.img_size,
self.n_in_channels,
self.n_out_channels,
self.num_shards,
self.shard_id,
self.limit_nsamples,
seed=self.global_seed,
)
self.num_batches = extsource.num_steps_per_epoch
del extsource
# create DALI pipeline & launch py workers
self.pipeline = self.get_pipeline()
self.pipeline.start_py_workers()
self.pipeline.build()
# Create iterator over GPU-preprocessed batches
self.iterator = DALIGenericIterator(
[self.pipeline],
["inp", "tar"],
auto_reset=True,
last_batch_policy=LastBatchPolicy.DROP,
prepare_first_batch=True,
)
def __len__(self):
return self.num_batches
def __iter__(self):
for token in self.iterator:
inp = token[0]["inp"]
tar = token[0]["tar"]
yield inp, tar
Mixed precision training
Before looking at the code for mixed precision training, let’s briefly explain how BF16 Automatic Mixed Precision (AMP) works in PyTorch and modern deep learning frameworks. BF16, or Brain Floating Point, is a 16-bit floating point format that has the same range as FP32 (full-precision), but with less mantissa precision, typically used for training deep neural networks. Automatic Mixed Precision (AMP) allows you to mix computations in lower precision (like BF16 or FP16) and full precision automatically. In PyTorch, this is handled using torch.cuda.amp.autocast, which transparently casts operations to the appropriate precision based on hardware support and the data type specified (dtype=torch.bfloat16 for BF16). BF16 is supported on recent NVIDIA (Ampere and newer), Intel, and AMD accelerators, often providing nearly the same accuracy as FP32 for deep learning, but with significantly higher throughput and lower memory usage. The autocast context ensures that certain numerically-sensitive operations, like reductions and softmax, remain in FP32 for robustness, while most matrix multiplications and convolutions run in BF16 for performance.
# Hierarchy of AMP in Deep Learning Workflows (BF16 / FP16 Mixed Precision)
#
# +---------------------------------------------------------+
# | Deep Neural Networks (Model) |
# +---------------------------------------------------------+
# |
# | DL Frameworks (e.g., PyTorch) |
# |---------------------------------------------------------|
# | - Provides AMP support (autocast/GradScaler) |
# | - Exposes APIs to enable mixed precision training |
# +---------------------------------------------------------+
# |
# | Automatic Mixed Precision (AMP) - Hardware Assisted |
# |---------------------------------------------------------|
# | - Integrates with hardware features |
# | - Controls casting (BF16/FP16 ↔ FP32) on the fly |
# | - Ensures safe ops remain in FP32 automatically |
# +---------------------------------------------------------+
# |
# | Hardware (Tensor Cores, CPUs/GPUs) |
# |---------------------------------------------------------|
# | - BF16/FP16 Tensor Core acceleration |
# | - High-throughput, hardware-level support |
# +---------------------------------------------------------+
#
# Data travels down through these layers, where AMP orchestrates use of specialized hardware to accelerate
# mixed-precision compute, as visualized in the image above.
#
# In code, AMP typically looks like:
#
# with autocast(device_type, enabled=True, dtype=torch.bfloat16):
# y = model(x) # matmul/conv ops run on Tensor Cores as BF16/FP16, safe ops stay FP32
# loss = loss_fn(y, target)
#
# Gradient scaling and safe weight management are handled by the framework (see image above).
#
# Model/DL framework
# |
# [autocast]
# |
# (BF16/FP16 ops scheduled for tensor cores by AMP)
# |
# Hardware execution
# (Tensor Cores/accelerators)
with torch.no_grad():
inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
tr_loss = loss_func(gen, tar)
inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
val_loss = loss_func(gen, tar)
val_rmse = weighted_rmse(gen, tar)
if params.distributed:
torch.distributed.all_reduce(
tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
)
if world_rank == 0:
args.tboard_writer.add_scalar("Loss/train", tr_loss.item(), 0)
args.tboard_writer.add_scalar("Loss/valid", val_loss.item(), 0)
args.tboard_writer.add_scalar(
"RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], 0
)
...
torch.cuda.nvtx.range_push(f"forward")
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
loss = loss_func(gen, tar)
torch.cuda.nvtx.range_pop() # forward
...
with torch.inference_mode():
with torch.no_grad():
for i, data in enumerate(val_data_loader, 0):
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
inp, tar = map(lambda x: x.to(device), data)
gen = model(inp)
val_loss += loss_func(gen, tar)
val_rmse += weighted_rmse(gen, tar)
valid_steps += 1
if params.distributed:
torch.distributed.all_reduce(
val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
)
2025-11-17 21:44:35,117 - root - INFO - Starting Training Loop...
/dli/task/train.py:176: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:22.)
tr_loss.append(loss.item() / world_size)
2025-11-17 21:44:41,821 - root - INFO - Time taken for epoch 1 is 5.921940 sec, avg 83.756342 samples/sec
2025-11-17 21:44:41,821 - root - INFO - Avg train loss=0.586307
2025-11-17 21:44:43,182 - root - INFO - Avg val loss=0.431182861328125
2025-11-17 21:44:43,183 - root - INFO - Total validation time: 0.9876902103424072 sec
2025-11-17 21:44:48,198 - root - INFO - Time taken for epoch 2 is 5.013919 sec, avg 102.115723 samples/sec
2025-11-17 21:44:48,198 - root - INFO - Avg train loss=0.399898
2025-11-17 21:44:49,094 - root - INFO - Avg val loss=0.37918391823768616
2025-11-17 21:44:49,095 - root - INFO - Total validation time: 0.5196459293365479 sec
2025-11-17 21:44:54,139 - root - INFO - Time taken for epoch 3 is 5.043273 sec, avg 101.521369 samples/sec
2025-11-17 21:44:54,139 - root - INFO - Avg train loss=0.363822
2025-11-17 21:44:54,990 - root - INFO - Avg val loss=0.3559249937534332
2025-11-17 21:44:54,990 - root - INFO - Total validation time: 0.4742579460144043 sec
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-cb82.qdstrm'
[1/1] [0% ] baseline-dli_dw8_dali_bf16.nsys-repProcessing events...
[1/1] [========================100%] baseline-dli_dw8_dali_bf16.nsys-rep
Generated:
/dli/task/logs/baseline-dli_dw8_dali_bf16.nsys-rep
2025-11-17 21:45:07,261 - root - INFO - Time taken for epoch 4 is 12.269330 sec, avg 41.730072 samples/sec
2025-11-17 21:45:07,261 - root - INFO - Avg train loss=0.351357
2025-11-17 21:45:08,138 - root - INFO - Avg val loss=0.35201317071914673
2025-11-17 21:45:08,138 - root - INFO - Total validation time: 0.49225330352783203 sec
2025-11-17 21:45:10,384 - root - INFO - DONE ---- rank 0
DDP and FSDP
DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP) are two strategies in PyTorch for training models across multiple GPUs or nodes. DDP works by keeping a full copy of the model on each GPU, splitting each batch of data across devices, and then averaging gradients between them using efficient all-reduce operations; this is usually the default choice for multi-GPU training because it’s simple, fast, and well-optimized, but it requires that every GPU has enough memory to hold the whole model. FSDP, on the other hand, is designed for very large models that don’t fit in memory on a single GPU: instead of replicating the whole model on each device, it shards (splits) parameters, gradients, and optimizer states across GPUs so that each GPU only stores a slice of them, temporarily gathering full parameters only when needed for forward and backward passes. This sharding drastically reduces memory usage and allows training of much larger models, at the cost of more complex communication patterns and potentially higher communication overhead than standard DDP.
import sys
import os
import time
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
import torch.multiprocessing
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
import logging
from utils import logging_utils
logging_utils.config_logger()
from utils.YParams import YParams
from utils import get_data_loader_distributed
from utils.loss import l2_loss, l2_loss_opt
from utils.metrics import weighted_rmse
from utils.plots import generate_images
from networks import vit
def train(params, args, local_rank, world_rank, world_size):
"""
==========================================================================
train()
==========================================================================
| Overall Structure:
|
| +-----------------------------------------------------------+
| | Setup, Data, Model, Optimizer, Scheduler |
| +-----------------------------------------------------------+
| | Initial Logging & Initial Loss/Eval (TensorBoard) |
| +-----------------------------------------------------------+
| | ========== Main Training Loop (epochs) =========== |
| | | | |
| | | 1. Per-epoch setup (sampler, timers, mode) | |
| | | 2. --------- Per-step (batch) ----------- | |
| | | | a. Data to device, timing/metrics | | |
| | | | b. Forward, loss, backward, opt | | |
| | | | c. Distributed/allreduce | | |
| | | | d. Running stats, LR step | | |
| | | +-------------------------------------+ | |
| | | 3. Per-epoch eval: validation loader | |
| | +-----------------------------------------------+ |
| +-----------------------------------------------------------+
| | Final timing and summary stats |
+---------------------------------------------------------------+
"""
# -------------- (1) Device and CUDNN setup --------------------
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(local_rank)
device = torch.device("cuda:%d" % local_rank)
# -------------- (2) Data Loaders |
logging.info("rank %d, begin data loader init" % world_rank)
train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(
params, params.train_data_path, params.distributed, train=True
)
val_data_loader, valid_dataset = get_data_loader_distributed(
params, params.valid_data_path, params.distributed, train=False
)
logging.info("rank %d, data loader initialized" % (world_rank))
# -------------- (3) Model Instantiation |
model = vit.ViT(params).to(device)
if params.enable_jit:
# Optionally use torch.compile for JIT optimization
model = torch.compile(model)
# -------------- (4) AMP/Distributed Setup |
if params.amp_dtype == torch.float16:
scaler = GradScaler("cuda")
if params.distributed and not args.noddp:
"""
=========================================================
DistributedDataParallel (DDP) Overview:
DDP is PyTorch's mechanism for achieving data-parallel training
across multiple GPUs, using NCCL for high-performance GPU-to-GPU
communication. For each GPU/rank, you launch a separate process.
Backward gradients are automatically all-reduced so each model
replica stays in sync.
Example:
+------------------+
| Process/Rank 0 | GPU:0
| +------------+ | model parameters θ₀
| | Model |---|-------+
| +------------+ | |
+------------------+ |
| v
+------------------+ AllReduce: gradients
| Process/Rank 1 | across NCCL communicator
| +------------+ | (compare to halo exchange
| | Model |---|--------- ring in Jacobi)
| +------------+ |
+------------------+
- Each process has its own CUDA device and model replica
- Forward/backward compute locally on own batch
- After backward(), gradients are automatically averaged
across all GPUs/ranks (using NCCL/AllReduce)
- Optimizer step updates local model parameters
Compare this to stencil/halo exchange with MPI/NCCL in
/distributed-gpu, but DDP allreduces the *whole* gradient,
not just boundaries.
Buffer sharing: broadcast_buffers True/False
- Buffers = non-parameter state (e.g., running stats)
- True: Syncs these each iteration (default/PyTorch default)
- False: Leaves them independent on each process
More: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
=========================================================
"""
# Optional buffer sharing config for DDP
if args.disable_broadcast_buffers:
model = DistributedDataParallel(
model,
device_ids=[local_rank],
bucket_cap_mb=args.bucket_cap_mb,
broadcast_buffers=False,
gradient_as_bucket_view=True,
)
else:
model = DistributedDataParallel(
model, device_ids=[local_rank], bucket_cap_mb=args.bucket_cap_mb
)
# -------------- (5) Optimizer |
if params.enable_fused:
optimizer = optim.Adam(
model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95)
)
else:
optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95))
# -------------- (6) Log Model |
if world_rank == 0:
logging.info(model)
# -------------- (7) Scheduler |
iters = 0
startEpoch = 0
if params.lr_schedule == "cosine":
if params.warmup > 0:
lr_scale = lambda x: min(
(x + 1) / params.warmup,
0.5 * (1 + np.cos(np.pi * x / params.num_iters)),
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale)
else:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=params.num_iters
)
else:
scheduler = None
# -------------- (8) Loss Selection |
if params.enable_jit:
loss_func = l2_loss_opt
else:
loss_func = l2_loss
# -------------- (9) Initial Logging to TensorBoard |
if world_rank == 0:
logging.info("Starting Training Loop...")
# ==[ Initial Loss/Validation Logging ]==========================
with torch.no_grad():
inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
gen = model(inp)
tr_loss = loss_func(gen, tar)
inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
gen = model(inp)
val_loss = loss_func(gen, tar)
val_rmse = weighted_rmse(gen, tar)
if params.distributed:
torch.distributed.all_reduce(tr_loss)
torch.distributed.all_reduce(val_loss)
torch.distributed.all_reduce(val_rmse)
if world_rank == 0:
args.tboard_writer.add_scalar("Loss/train", tr_loss.item() / world_size, 0)
args.tboard_writer.add_scalar("Loss/valid", val_loss.item() / world_size, 0)
args.tboard_writer.add_scalar(
"RMSE(u10m)/valid", val_rmse.cpu().numpy()[0] / world_size, 0
)
# ========== MAIN TRAINING LOOP (epochs) ==========
params.num_epochs = params.num_iters // len(train_data_loader)
iters = 0
t1 = time.time()
for epoch in range(startEpoch, startEpoch + params.num_epochs):
# ---[A] Epoch: prep, sampler, timers ---
torch.cuda.synchronize() # Ensure correct epoch timing (sync GPU)
if params.distributed and (train_sampler is not None):
train_sampler.set_epoch(epoch)
start = time.time()
tr_loss = []
tr_time = 0.0
dat_time = 0.0
log_time = 0.0
model.train()
step_count = 0
# ----------- B. PER-BATCH TRAINING ------------
for i, data in enumerate(train_data_loader, 0):
# (Optional) Profiling: Begin/End CUDA profiler at a certain epoch
if world_rank == 0:
if epoch == 3 and i == 0:
torch.cuda.profiler.start()
if epoch == 3 and i == len(train_data_loader) - 1:
torch.cuda.profiler.stop()
torch.cuda.nvtx.range_push(f"step {i}")
iters += 1
dat_start = time.time()
torch.cuda.nvtx.range_push(f"data copy in {i}")
# --- Copy batch to GPU ---
inp, tar = map(lambda x: x.to(device), data)
torch.cuda.nvtx.range_pop() # data copy in
tr_start = time.time()
b_size = inp.size(0)
# --- ZERO GRAD ---
optimizer.zero_grad()
# --- Forward w/ autocast (AMP) ---
torch.cuda.nvtx.range_push(f"forward")
with autocast("cuda", enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
loss = loss_func(gen, tar)
torch.cuda.nvtx.range_pop() # forward
# --- Backward, Step (AMP or standard) ---
if params.amp_dtype == torch.float16:
scaler.scale(loss).backward()
torch.cuda.nvtx.range_push(f"optimizer")
scaler.step(optimizer)
torch.cuda.nvtx.range_pop() # optimizer
scaler.update()
else:
loss.backward()
torch.cuda.nvtx.range_push(f"optimizer")
optimizer.step()
torch.cuda.nvtx.range_pop() # optimizer
# --- Allreduce if distributed ---
if params.distributed:
torch.distributed.all_reduce(loss)
tr_loss.append(loss.item() / world_size)
torch.cuda.nvtx.range_pop() # step
# --- LR Scheduler Step ---
scheduler.step()
# --- Time/Stats ---
tr_end = time.time()
tr_time += tr_end - tr_start
dat_time += tr_start - dat_start
step_count += 1
# =====[ End of Training steps for this epoch ]=====
torch.cuda.synchronize() # Accurate epoch timing
end = time.time()
# ----------- C. Epoch-End Reporting/Logging (Rank 0) -------
if world_rank == 0:
iters_per_sec = step_count / (end - start)
samples_per_sec = params["global_batch_size"] * iters_per_sec
logging.info(
"Time taken for epoch %i is %f sec, avg %f samples/sec",
epoch + 1,
end - start,
samples_per_sec,
)
logging.info(" Avg train loss=%f" % np.mean(tr_loss))
args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters)
args.tboard_writer.add_scalar(
"Learning Rate", optimizer.param_groups[0]["lr"], iters
)
args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters)
args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters)
fig = generate_images([inp, tar, gen])
args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True)
# ========== D. Validation Evaluation ============
val_start = time.time()
val_loss = torch.zeros(1, device=device)
val_rmse = torch.zeros(
(params.n_out_channels), dtype=torch.float32, device=device
)
valid_steps = 0
model.eval()
with torch.inference_mode():
with torch.no_grad():
for i, data in enumerate(val_data_loader, 0):
with autocast(
"cuda", enabled=params.amp_enabled, dtype=params.amp_dtype
):
inp, tar = map(lambda x: x.to(device), data)
gen = model(inp)
val_loss += loss_func(gen, tar)
val_rmse += weighted_rmse(gen, tar)
valid_steps += 1
if params.distributed:
torch.distributed.all_reduce(val_loss)
val_loss /= world_size
torch.distributed.all_reduce(val_rmse)
val_rmse /= world_size
val_rmse /= valid_steps # Avg validation rmse
val_loss /= valid_steps
val_end = time.time()
# ---------- (E) Validation logging (TensorBoard/Console) ----------
if world_rank == 0:
logging.info(" Avg val loss={}".format(val_loss.item()))
logging.info(" Total validation time: {} sec".format(val_end - val_start))
args.tboard_writer.add_scalar("Loss/valid", val_loss, iters)
args.tboard_writer.add_scalar(
"RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters
)
args.tboard_writer.flush()
# ===== Final Timing Summary =====
t2 = time.time()
tottime = t2 - t1
# ==========================================================================
# End of train()
# ==========================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--run_num",
default="00",
type=str,
help="tag for indexing the current experiment",
)
parser.add_argument(
"--yaml_config",
default="./config/ViT.yaml",
type=str,
help="path to yaml file containing training configs",
)
parser.add_argument(
"--config", default="base", type=str, help="name of desired config in yaml file"
)
parser.add_argument(
"--amp_mode",
default="none",
type=str,
choices=["none", "fp16", "bf16"],
help="select automatic mixed precision mode",
)
parser.add_argument(
"--enable_fused", action="store_true", help="enable fused Adam optimizer"
)
parser.add_argument(
"--enable_jit", action="store_true", help="enable JIT compilation"
)
parser.add_argument(
"--local_batch_size",
default=None,
type=int,
help="local batchsize (manually override global_batch_size config setting)",
)
parser.add_argument(
"--num_iters", default=None, type=int, help="number of iters to run"
)
parser.add_argument(
"--num_data_workers",
default=None,
type=int,
help="number of data workers for data loader",
)
parser.add_argument(
"--data_loader_config",
default=None,
type=str,
choices=["pytorch", "dali"],
help="dataloader configuration. choices: 'pytorch', 'dali'",
)
parser.add_argument(
"--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb"
)
parser.add_argument(
"--disable_broadcast_buffers",
action="store_true",
help="disable syncing broadcasting buffers",
)
parser.add_argument(
"--noddp", action="store_true", help="disable DDP communication"
)
args = parser.parse_args()
run_num = args.run_num
params = YParams(os.path.abspath(args.yaml_config), args.config)
# Update config with modified args
# set up amp
if args.amp_mode != "none":
params.update({"amp_mode": args.amp_mode})
amp_dtype = torch.float32
if params.amp_mode == "fp16":
amp_dtype = torch.float16
elif params.amp_mode == "bf16":
amp_dtype = torch.bfloat16
params.update(
{"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype}
)
if args.enable_fused:
params.update({"enable_fused": args.enable_fused})
if args.enable_jit:
params.update({"enable_jit": args.enable_jit})
if args.data_loader_config:
params.update({"data_loader_config": args.data_loader_config})
if args.num_iters:
params.update({"num_iters": args.num_iters})
if args.num_data_workers:
params.update({"num_data_workers": args.num_data_workers})
params.distributed = False
if "WORLD_SIZE" in os.environ:
params.distributed = int(os.environ["WORLD_SIZE"]) > 1
world_size = int(os.environ["WORLD_SIZE"])
else:
world_size = 1
world_rank = 0
local_rank = 0
if params.distributed:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
world_rank = torch.distributed.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
if args.local_batch_size:
# Manually override batch size
params.local_batch_size = args.local_batch_size
params.update({"global_batch_size": world_size * args.local_batch_size})
else:
# Compute local batch size based on number of ranks
params.local_batch_size = params.global_batch_size // world_size
# for dali data loader, set the actual number of data shards and id
params.data_num_shards = world_size
params.data_shard_id = world_rank
# Set up directory
baseDir = params.expdir
expDir = os.path.join(
baseDir, args.config + "/%dGPU/" % (world_size) + str(run_num) + "/"
)
if world_rank == 0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
logging_utils.log_to_file(
logger_name=None, log_filename=os.path.join(expDir, "out.log")
)
params.log()
args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/"))
params.experiment_dir = os.path.abspath(expDir)
train(params, args, local_rank, world_rank, world_size)
if params.distributed:
torch.distributed.barrier()
logging.info("DONE ---- rank %d" % world_rank)
Model parallelism
Model parallelism is a complementary strategy where a single model is split across multiple devices, instead of each device holding a full copy as in DDP or FSDP. In tensor (intra-layer) model parallelism, individual layers are partitioned—e.g., splitting a large linear layer’s weight matrix by columns or rows—so that each GPU computes part of the layer’s output and then the partial results are combined. In pipeline model parallelism, different groups of layers are placed on different devices (like stages in a pipeline), and microbatches are streamed through these stages so multiple parts of the model can run concurrently. Model parallelism is especially useful when a single layer or the whole model is too large to fit on one GPU, but it usually involves more complex orchestration and communication patterns than pure data parallel approaches, and in practice large-scale systems often combine data parallelism (DDP/FSDP) with model parallelism.
import sys
import os
import time
import numpy as np
import argparse
import pynvml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
import torch.multiprocessing
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import ReduceOp
import logging
from utils import logging_utils
import warnings
warnings.simplefilter("ignore", FutureWarning)
logging_utils.config_logger()
from utils.YParams import YParams
from utils import get_data_loader_distributed
from utils import comm
from utils.loss import l2_loss, l2_loss_opt
from utils.metrics import weighted_rmse
from networks import vit
from networks import vit_te
from distributed.mappings import init_ddp_model_and_reduction_hooks
from distributed.helpers import init_params_for_shared_weights
from utils.plots import generate_images
def train(params, args, local_rank, world_rank, world_size):
# ----------------------------------------------------------------------- #
# (1) Device & Library Setup
# ----------------------------------------------------------------------- #
# Ensure deterministic performance for convolution operations and set device
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
# Initialize GPU monitoring to track memory usage
pynvml.nvmlInit()
nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index)
# ----------------------------------------------------------------------- #
# (2) DataLoader Initialization
# ----------------------------------------------------------------------- #
# Get distributed train/validation data loaders and samplers
logging.info("rank %d, begin data loader init" % world_rank)
train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(
params, params.train_data_path, params.distributed, train=True
)
val_data_loader, valid_dataset = get_data_loader_distributed(
params, params.valid_data_path, params.distributed, train=False
)
logging.info("rank %d, data loader initialized" % world_rank)
# ----------------------------------------------------------------------- #
# (3) Model Initialization & JIT Compilation
# ----------------------------------------------------------------------- #
# Choose Transformer Engine or native ViT backend
if params.model_backend == 'transformer-engine':
logging.info("using transformer-engine backend")
model = vit_te.ViT(params).to(device)
else:
logging.info("using native backend")
model = vit.ViT(params).to(device)
# JIT compilation (optional, for acceleration)
if params.enable_jit:
# NOTE: DDP interoperability with torch.compile may need special handling
model = torch.compile(model)
# AMP: Automatic Mixed Precision
if params.amp_dtype == torch.float16:
scaler = GradScaler('cuda')
# ----------------------------------------------------------------------- #
# (4) Distributed/TP Shared Weight Synchronization
# ----------------------------------------------------------------------- #
# If tensor/pipeline parallel model requires, sync weights
"""
MODEL PARALLELISM: TENSOR & PIPELINE (PyTorch/pynvml, Exascale Style)
-----------------------------------------------------------------------
[How deep learning tensor/pipeline parallelism actually works in PyTorch,
and how parameter/init synchronization mimics exascale Jacobi/MPI code.
Device/memory setup and monitoring via torch.cuda & pynvml -- see below.]
=====================================================================
(1) TENSOR PARALLELISM (PyTorch distributed, weight sharding, sync)
=====================================================================
Model parameter broadcast/init (like MPI_Bcast in Jacobi, see @2025-11-16):
+------------------- torch.distributed.broadcast (NCCL/MPI)--------------------+
| |
+----------+ +----------+ +----------+ +----------+ +---v------+
| GPU 0 | | GPU 1 | | GPU 2 | | GPU 3 | | Python |
| (rank 0) | | (rank 1) | | (rank 2) | | (rank 3) | | process |
+----+-----+ +-----+----+ +-----+----+ +-----+----+ +---+------+
| | | | |
| | | | |
| (shard large weights via param.split()) | |
+-----------------|---------------------------------+----------------------------+
| (each GPU stores a partition)
+-------------------------+
| (forward/backward: partial matmuls, reduce/add)
(compare: Jacobi halo rows exchanged with MPI_Sendrecv() and parameters broadcast with MPI_Bcast)
=====================================================================
(2) PIPELINE PARALLELISM (PyTorch: staged blocks, microbatch pipeline)
=====================================================================
Input Batch
|
+----------+ +----------+ +----------+ +----------+
| GPU 0 |--->| GPU 1 |--->| GPU 2 |--->| GPU 3 |
| (Stage 0)| | (Stage 1)| | (Stage 2)| | (Stage 3)|
+----------+ +----------+ +----------+ +----------+
| | | |
[Encoder 0] [Encoder 1] [Encoder 2] [Encoder 3]
| | | |
(microbatch split, data passed stage to stage)
|
torch.distributed.broadcast(init params) --- like Jacobi: all subdomains/ranks
must agree on boundaries at setup
=====================================================================
(3) DEVICE & MEMORY MANAGEMENT (PyTorch, pynvml, MPI Analogy)
=====================================================================
+-----------------------------------------------------------+
| Python Process (per rank) |
+-----------------------------------------------------------+
| torch.cuda.set_device(local_rank) |
| device = torch.device(f"cuda:{local_rank}") |
| model.to(device) |
| pynvml.nvmlInit() (get handle for GPU)|
| nvmlDeviceGetMemoryInfo(nvml_handle) (GB used) |
+-----------------------------------------------------------+
|
[like Jacobi/MPI: cudaSetDevice, deviceCount, rank <-> device mapping]
=====================================================================
(4) SYNC & COLLECTIVES: PyTorch vs Jacobi MPI
=====================================================================
+---------------------------+ +------------------------------+
| PyTorch World | | Exascale Jacobi (C/MPI) |
+---------------------------+ +------------------------------+
| torch.distributed.broadcast| <->| MPI_Bcast (e.g. ncclUniqueId)|
| torch.distributed.all_reduce|<->| MPI_Allreduce (convergence) |
| torch.distributed.barrier | <->| MPI_Barrier |
+---------------------------+ +------------------------------+
KEY FLOW:
- PyTorch uses torch.distributed backends (NCCL/MPI) to implement
the same process-group collectives as Jacobi C/MPI: sync weights,
share state, coordinate progress.
- init_params_for_shared_weights(model) triggers these bcasts.
- pynvml, torch.cuda lifecycles mirror cuda runtime and info calls
in Jacobi code (“assign rank to device”, “log mem stats”).
=====================================================================
(5) KEY MAPPING
=====================================================================
PyTorch init_params_for_shared_weights(model):
|
+---> torch.distributed.broadcast(...) for all shared (sharded) params,
often per process-group (tp, cp)
MPI Jacobi:
|
+---> MPI_Bcast()/Sendrecv() for domain boundaries, ids
See /distributed-gpu for the equivalent device set,
rank assignment, MPI_Bcast/ncclUniqueId & ncclCommInitRank, mirrored here in torch:
- torch.cuda.device <-> cudaSetDevice(rank % num_devices)
- torch.distributed actions <-> MPI_SENDRECV/ALLREDUCE/BCAST
- pynvml.nvmlDeviceGet... <-> cudaMemGetInfo/MPI logging
"""
if comm.get_size("tp-cp") > 1:
init_params_for_shared_weights(model)
# Wrap model with DDP if required
if params.distributed and not args.noddp:
model = init_ddp_model_and_reduction_hooks(
model,
device_ids=[local_rank],
output_device=[local_rank],
bucket_cap_mb=args.bucket_cap_mb
)
# ----------------------------------------------------------------------- #
# (5) Optimizer Setup
# ----------------------------------------------------------------------- #
# Use fused Adam optimizer if enabled
if params.enable_fused:
optimizer = optim.Adam(
model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95)
)
else:
optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95))
# ----------------------------------------------------------------------- #
# (6) Model Logging & Memory Monitoring (rank 0 master only)
# ----------------------------------------------------------------------- #
if world_rank == 0:
logging.info(model)
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024.0**3)
logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.")
iters = 0
startEpoch = 0
# ----------------------------------------------------------------------- #
# (7) Learning Rate Scheduler Configuration
# ----------------------------------------------------------------------- #
if params.lr_schedule == "cosine":
if params.warmup > 0:
# Linear warmup to cosine decay
lr_scale = lambda x: min(
(x + 1) / params.warmup,
0.5 * (1 + np.cos(np.pi * x / params.num_iters)),
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale)
else:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=params.num_iters
)
else:
scheduler = None
# ----------------------------------------------------------------------- #
# (8) Loss Function Selection
# ----------------------------------------------------------------------- #
# Use JIT-optimized or plain loss depending on setting
if params.enable_jit:
loss_func = l2_loss_opt
else:
loss_func = l2_loss
# ----------------------------------------------------------------------- #
# (9) Initial Logging and Baseline Metrics
# ----------------------------------------------------------------------- #
if world_rank == 0:
logging.info("Starting Training Loop...")
# Compute & log initial train/val loss and RMSE before epochs begin
with torch.no_grad():
# Take first batch from training and validation for baseline
inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
tr_loss = loss_func(gen, tar)
inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
val_loss = loss_func(gen, tar)
val_rmse = weighted_rmse(gen, tar)
# Reduce baseline metrics across data-parallel group if needed
if params.distributed:
torch.distributed.all_reduce(
tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
)
if world_rank == 0:
args.tboard_writer.add_scalar("Loss/train", tr_loss.item(), 0)
args.tboard_writer.add_scalar("Loss/valid", val_loss.item(), 0)
args.tboard_writer.add_scalar(
"RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], 0
)
# ----------------------------------------------------------------------- #
# (10) Main Training/Validation Loop
# ----------------------------------------------------------------------- #
params.num_epochs = params.num_iters // len(train_data_loader)
iters = 0
t1 = time.time()
for epoch in range(startEpoch, startEpoch + params.num_epochs):
torch.cuda.synchronize() # ensure accurate epoch timings
# For reproducibility in distributed scenario, reset epoch for sampler
if params.distributed and (train_sampler is not None):
train_sampler.set_epoch(epoch)
start = time.time()
tr_loss = []
tr_time = 0.0
dat_time = 0.0
log_time = 0.0
model.train()
step_count = 0
# (10.1) Training steps per batch
for i, data in enumerate(train_data_loader, 0):
# (a) Optionally enable profiling for diagnostics on rank 0
if world_rank == 0:
if epoch == 3 and i == 0:
torch.cuda.profiler.start()
if epoch == 3 and i == len(train_data_loader) - 1:
torch.cuda.profiler.stop()
torch.cuda.nvtx.range_push(f"step {i}")
iters += 1
dat_start = time.time()
torch.cuda.nvtx.range_push(f"data copy in {i}")
# (b) Copy data to device/GPU
inp, tar = map(lambda x: x.to(device), data)
torch.cuda.nvtx.range_pop() # end data copy timing
tr_start = time.time()
b_size = inp.size(0)
optimizer.zero_grad()
torch.cuda.nvtx.range_push("forward")
# (c) Forward pass and loss under autocast for AMP
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
gen = model(inp)
loss = loss_func(gen, tar)
torch.cuda.nvtx.range_pop() # end forward timing
# Optionally log memory usage after first forward pass
if world_rank == 0 and i == 1:
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024.0**3)
logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.")
# (d) Backward & optimizer step (support autocast/scaler if fp16 AMP)
if params.amp_dtype == torch.float16:
scaler.scale(loss).backward()
torch.cuda.nvtx.range_push("optimizer")
scaler.step(optimizer)
torch.cuda.nvtx.range_pop() # optimizer
scaler.update()
else:
loss.backward()
torch.cuda.nvtx.range_push("optimizer")
optimizer.step()
torch.cuda.nvtx.range_pop() # optimizer
# (e) Synchronize loss across data-parallel processes if distributed
if params.distributed:
torch.distributed.all_reduce(
loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
tr_loss.append(loss.item())
torch.cuda.nvtx.range_pop() # end step
# (f) Learning rate schedule step
scheduler.step()
# (g) Timing and diagnostics
tr_end = time.time()
tr_time += tr_end - tr_start
dat_time += tr_start - dat_start
step_count += 1
torch.cuda.synchronize() # epoch synchronization for timing
end = time.time()
# (10.2) Logging and visualization (rank 0 only)
if world_rank == 0:
iters_per_sec = step_count / (end - start)
samples_per_sec = params["global_batch_size"] * iters_per_sec
logging.info(
"Time taken for epoch %i is %f sec, avg %f samples/sec",
epoch + 1,
end - start,
samples_per_sec,
)
logging.info(" Avg train loss=%f" % np.mean(tr_loss))
args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters)
args.tboard_writer.add_scalar(
"Learning Rate", optimizer.param_groups[0]["lr"], iters
)
args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters)
args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters)
# Visualize a sample result
fig = generate_images([inp, tar, gen])
args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True)
# ------------------------------------------------------------------- #
# (10.3) Validation Evaluation Phase
# ------------------------------------------------------------------- #
val_start = time.time()
val_loss = torch.zeros(1, device=device)
val_rmse = torch.zeros(
(params.n_out_channels), dtype=torch.float32, device=device
)
valid_steps = 0
model.eval()
with torch.inference_mode():
with torch.no_grad():
for i, data in enumerate(val_data_loader, 0):
with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
inp, tar = map(lambda x: x.to(device), data)
gen = model(inp)
val_loss += loss_func(gen, tar)
val_rmse += weighted_rmse(gen, tar)
valid_steps += 1
# Reduce validation metrics across all ranks for distributed
if params.distributed:
torch.distributed.all_reduce(
val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
)
torch.distributed.all_reduce(
val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
)
# Normalize validation metrics by number of steps
val_rmse /= valid_steps
val_loss /= valid_steps
val_end = time.time()
# (10.4) Log validation results (only rank 0 logs and writes)
if world_rank == 0:
logging.info(" Avg val loss={}".format(val_loss.item()))
logging.info(" Total validation time: {} sec".format(val_end - val_start))
args.tboard_writer.add_scalar("Loss/valid", val_loss, iters)
args.tboard_writer.add_scalar(
"RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters
)
args.tboard_writer.flush()
# ----------------------------------------------------------------------- #
# (11) Cleanup and Final Timing
# ----------------------------------------------------------------------- #
torch.cuda.synchronize()
t2 = time.time()
tottime = t2 - t1
pynvml.nvmlShutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--run_num",
default="00",
type=str,
help="tag for indexing the current experiment",
)
parser.add_argument(
"--yaml_config",
default="./config/ViT.yaml",
type=str,
help="path to yaml file containing training configs",
)
parser.add_argument(
"--config", default="base", type=str, help="name of desired config in yaml file"
)
parser.add_argument(
"--amp_mode",
default="none",
type=str,
choices=["none", "fp16", "bf16"],
help="select automatic mixed precision mode",
)
parser.add_argument(
"--enable_fused", action="store_true", help="enable fused Adam optimizer"
)
parser.add_argument(
"--enable_jit", action="store_true", help="enable JIT compilation"
)
parser.add_argument(
"--local_batch_size",
default=None,
type=int,
help="local batchsize (manually override global_batch_size config setting)",
)
parser.add_argument(
"--num_iters", default=None, type=int, help="number of iters to run"
)
parser.add_argument(
"--num_data_workers",
default=None,
type=int,
help="number of data workers for data loader",
)
parser.add_argument(
"--data_loader_config",
default=None,
type=str,
choices=["pytorch", "dali"],
help="dataloader configuration. choices: 'pytorch', 'dali'",
)
parser.add_argument(
"--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb"
)
parser.add_argument(
"--disable_broadcast_buffers",
action="store_true",
help="disable syncing broadcasting buffers",
)
parser.add_argument(
"--noddp", action="store_true", help="disable DDP communication"
)
# model parallelism arguments
parser.add_argument(
"--tensor_parallel",
default=1,
type=int,
help="Number of GPUs for tensor parallelism",
)
parser.add_argument(
"--context_parallel",
default=1,
type=int,
help="Number of GPUs for context parallelism",
)
parser.add_argument(
"--parallel_order",
default="tp-cp-dp",
type=str,
help="Order of ranks for parallelism",
)
args = parser.parse_args()
run_num = args.run_num
params = YParams(os.path.abspath(args.yaml_config), args.config)
# Update config with modified args
# set up amp
if args.amp_mode != "none":
params.update({"amp_mode": args.amp_mode})
amp_dtype = torch.float32
if params.amp_mode == "fp16":
amp_dtype = torch.float16
elif params.amp_mode == "bf16":
amp_dtype = torch.bfloat16
params.update(
{"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype}
)
if args.enable_fused:
params.update({"enable_fused": args.enable_fused})
if args.enable_jit:
params.update({"enable_jit": args.enable_jit})
if args.data_loader_config:
params.update({"data_loader_config": args.data_loader_config})
if args.num_iters:
params.update({"num_iters": args.num_iters})
if args.num_data_workers:
params.update({"num_data_workers": args.num_data_workers})
params.distributed = False
# setup model parallel sizes
params["tp"] = args.tensor_parallel
params["cp"] = args.context_parallel
params["order"] = args.parallel_order
# initialize comm
comm.init(params, verbose=True)
# get info from comm
world_size = comm.get_world_size()
world_rank = comm.get_world_rank()
local_rank = comm.get_local_rank()
params.distributed = world_size > 1
assert (
params["global_batch_size"] % comm.get_size("dp") == 0
), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('dp')} GPU."
if args.local_batch_size:
# Manually override batch size
params.local_batch_size = args.local_batch_size
params.update(
{"global_batch_size": comm.get_size("dp") * args.local_batch_size}
)
else:
# Compute local batch size based on number of ranks
params.local_batch_size = int(
params["global_batch_size"] // comm.get_size("dp")
)
# for data loader, set the actual number of data shards and id
params.data_num_shards = comm.get_size("dp")
params.data_shard_id = comm.get_rank("dp")
# Set up directory
baseDir = params.expdir
expDir = os.path.join(
baseDir, args.config + "/%dMP/" % (comm.get_size("tp-cp")) + str(run_num) + "/"
)
if world_rank == 0:
if not os.path.isdir(expDir):
os.makedirs(expDir)
logging_utils.log_to_file(
logger_name=None, log_filename=os.path.join(expDir, "out.log")
)
params.log()
args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/"))
params.experiment_dir = os.path.abspath(expDir)
train(params, args, local_rank, world_rank, world_size)
if params.distributed:
torch.distributed.barrier()
logging.info("DONE ---- rank %d" % world_rank)