Deep learning at scale: lesson from experiments from Perlmutter

Nov 17, 2025

SystemJUWELS (JSC)Perlmutter (NERSC)JUPITER (JSC, EuroHPC)Colossus (xAI)
Year2018 / 202020212024–20252024–2025
PurposePetascale HPCHPC + GPU accelerationExascale HPC + AILarge-scale AI training
ComputeXeon + EPYC/A100EPYC + 4× A100 (GPU nodes)~6000 GH200 nodes (4× GH200)100k–200k H100/H200 GPUs
GPU Count3,744 A100 GPUs~7,168 A100 GPUs~24,000 GH200 GPUs100k+ Hopper-class GPUs
CPUXeon / EPYCAMD EPYC 7763Grace (4×72 cores/node)Minimal (GPU-centric)
Peak Perf~70 PFLOP/s~70 PFLOP/s (GPU partition)~1 EFLOP/s (FP64)AI-precision PFLOPs (not FP64)
NetworkHDR InfiniBandSlingshot 11NDR200 InfiniBandSpectrum-X Ethernet
Power~10–20 MW~20 MW~18 MW150–300 MW
Use-caseScientific HPCHPC + ML/AIHPC + AI hybridLLM training
Cost~€145M~$600M~€500M~$10–15B

The problem

Figure 2: (a) The multi-layer transformer architecture that utilizes the Adaptive Fourier Neural Operator with shared MLP and frequency soft-thresholding for spatial token mixing. The input frame is first divided into a h × w grid of patches, where each patch has a small size p × p × c. Each patch is then embedded in a higher dimensional space with high number of latent channels and position embedding is added to form a sequence of tokens. Tokens are then mixed spatially using AFNO, and subsequently for each token the latent channels are mixed. This process is repeated for L layers, and finally a linear decoder reconstructs the patches for the next frame from the final embedding. The right-hand panels describe the FourCastNet model’s additional training and inference modes: (b) two-step fine-tuning, (c) backbone model that forecasts the 20 variables in Table 1 with secondary precipitation diagnostic model (note that p(k + 1) denotes the 6 hour accumulated total precipitation that falls between k + 1 and k + 2 time steps) (d) forecast model in free-running autoregressive inference mode. For details, see Pathak et al. (2022), FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators, arXiv:2202.11214.

Dataset

ERA5 is more than just a climate dataset—it’s a cornerstone for large-scale climate research and AI. At approximately 0.25° spatial resolution, every global surface field comprises over a million grid points. Factoring in 100+ vertical levels, that’s hundreds of millions of data values per hour. Spanning 24 hours a day, 365 days a year, and more than 80 years (starting from 1940), ERA5 collects trillions of data points, yielding multiple petabytes of compressed information. This data volume enables researchers to investigate global and regional weather patterns, train high-capacity machine learning models, or reconstruct the hour-by-hour evolution of Earth’s climate from a uniquely data-rich perspective.Special thanks and credit to the following contributors for the slides and code underpinning this research: Hashank Subramanian, Steven Farrell, Josh Romero, Thorsten Kurth, and Corneel Casert. Resources are available in the NERSC repo.

Tue Nov 18 19:28:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 NVL                On  |   00000001:00:00.0 Off |                    0 |
| N/A   39C    P0             64W /  400W |       1MiB /  95830MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 NVL                On  |   00000002:00:00.0 Off |                    0 |
| N/A   39C    P0             67W /  400W |       1MiB /  95830MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
import torch.nn.functional as F
import torch
import torch.nn as nn
from functools import partial
from networks.helpers import DropPath, trunc_normal_

# mp stuff
from utils import comm
from distributed.layers import (
    DistributedMatmul,
    DistributedMLP,
    DistributedAttention,
    DistributedLayerNorm,
)
from distributed.helpers import compute_split_shapes
from distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region


class MLP(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5
        self.fused_attn = True

        #        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q,
                k,
                v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        cp_shapes=None,
    ):
        super().__init__()

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        if (comm.get_size("tp-cp")) > 1:
            # model parallelism is on, distribute the layers
            # tp: tensor parallel shards the weights
            # cp: context parallel shards the sequence
            self.norm1 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp")
            self.norm2 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp")
            self.attn = DistributedAttention(
                dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
                comm_tp_name="tp",
                comm_cp_name="cp",
                cp_shapes=cp_shapes,
            )
            self.mlp = DistributedMLP(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop=drop,
                comm_tp_name="tp",
                comm_cp_name="cp",
            )
        else:
            self.norm1 = norm_layer(dim)
            self.norm2 = norm_layer(dim)
            self.attn = Attention(
                dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
            )
            self.mlp = MLP(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop=drop,
            )

    def forward(self, x):
        y = self.attn(self.norm1(x))
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        # grid of patches
        self.h = img_size[0] // patch_size
        self.w = img_size[1] // patch_size
        num_patches = self.h * self.w
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class VisionTransformer(nn.Module):
    """
    Vision Transformer Architecture — Structure Diagram (Left-to-Right, Detailed)

    +-----------------+   +-----------------+   +-------------------------------+   +-----------------+   +-----------------+
    |   Input Image   |-->|   PatchEmbed    |-->| + Positional Embedding        |-->|    Dropout      |-->|     Scatter     |
    |   (B, C, H, W)  |   | Conv2d splits   |   | (learned, per token)          |   |   (pos_drop)    |   | to context      |
    |                 |   | into patches    |   |                               |   |                 |   | parallel        |
    +-----------------+   +-----------------+   +-------------------------------+   +-----------------+   +-------+---------+
                                                                                                                  |
                                                                                                                  v
----------------------------------------------- Transformer Encoder Blocks ---------------------------------------+-------------

                              Repeats "depth" times: (for i = 1 to depth)
+--------------+   +------------+   +--------------+   +------------+   +-------------+   +------------+   +-------------+
|    Input     |-->|   Norm1    |-->| Multi-Head   |-->|  DropPath  |-->|  Residual   |-->|   Norm2    |-->|    MLP      |
|  (B, N, D)   |   |(Layer/Dist)|   |  Attention   |   |(stoch.     |   | Connection  |   |            |   | (2 layers,  |
|              |   |            |   | (LN -> MHSA  |   | depth,     |   |   (+Input)  |   |            |   |  GELU,      |
|              |   |            |   |   QKV-Proj)  |   | optional)  |   |             |   |            |   |  Dropout)   |
+--------------+   +------------+   +--------------+   +------------+   +------+------+   +------------+   +------|------+
                                                                               ^                                  |
                                                                               |                                  v
                                                                               |                   +--------------+------+
                                                                               |------------------>|         DropPath    |
                                                                                                   +---------------------+
                                                                                                  
--------------+------------------------------------------------------------------------------------------------------------------
              |
              V
+-------------+-------------+   +------------------------------+   +-------------------------------+   +------------------------------+
|   LayerNorm /             |-->|    Gather from context       |-->|  Head: Linear projection      |-->|   Reshape & Rearrangement    |
| DistributedLayerNorm      |   |    parallel (if needed)      |   |  to patch×patch×channels      |   | (to output image shape)      |
+---------------------------+   +------------------------------+   +-------------------------------+   +------------------------------+

    Key Properties:
      - patch_embed: Converts image to sequence of patch tokens
      - pos_embed: Learnable position embeddings per token
      - blocks: N repeated Transformer blocks with LayerNorm, MHSA, MLP, DropPath and residual
      - norm: Final normalization; distributed option if needed
      - head: Linear projection from embedding space to reconstructed output
      - All computation left-to-right (data flow shown above)

    Notation:
      B = batch size, C = in channels, H,W = input dims, N = num patches, D = embed_dim

    """
    def __init__(
        self,
        img_size=[224, 224],
        patch_size=16,
        in_chans=3,
        out_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        **kwargs
    ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size
        self.out_ch = out_chans
        self.drop_rate = drop_rate

        # ─── Patch Embedding ───────────────────────────────────────────────────────────
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # ─── Parallel Shapes for Sequence/Context ──────────────────────────────────────
        self.cp_shapes = compute_split_shapes(num_patches, comm.get_size("cp"))

        # ─── Learnable Positional Embedding ────────────────────────────────────────────
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # ─── Stochastic Depth/DropPath Scheduling ──────────────────────────────────────
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]

        # ─── Stacked Transformer Blocks ────────────────────────────────────────────────
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    cp_shapes=self.cp_shapes,
                )
                for i in range(depth)
            ]
        )

        # ─── Final LayerNorm (possibly distributed) ────────────────────────────────────
        if (comm.get_size("tp-cp")) > 1:
            self.norm = DistributedLayerNorm(embed_dim, comm_tp_name="tp", comm_cp_name="cp")
        else:
            self.norm = norm_layer(embed_dim)

        # ─── Head: Linear map to patch × patch × channel output ────────────────────────
        self.out_size = self.out_ch * self.patch_size * self.patch_size
        self.head = nn.Linear(embed_dim, self.out_size, bias=False)

        # ─── Initialization ────────────────────────────────────────────────────────────
        trunc_normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def prepare_tokens(self, x):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)  # patch linear embedding
        # add positional encoding to each token
        x = x + self.pos_embed
        return self.pos_drop(x)

    def forward_head(self, x):
        B, _, _ = x.shape  # B x N x embed_dim
        x = x.reshape(B, self.patch_embed.h, self.patch_embed.w, self.embed_dim)
        B, h, w, _ = x.shape

        # apply head
        x = self.head(x)
        x = x.reshape(shape=(B, h, w, self.patch_size, self.patch_size, self.out_ch))
        x = torch.einsum("nhwpqc->nchpwq", x)
        x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1]))

        return x

    def forward(self, x):
        x = self.prepare_tokens(x)

        # split sequence if cp is on (shape of x is (batch, seq, embed))
        x = scatter_to_parallel_region(x, dim=1, comm_name="cp")

        # if cp is on, each block operates on a sequence shard
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        # gather sequence if cp is on
        x = gather_from_parallel_region(x, dim=1, shapes=self.cp_shapes, comm_name="cp")

        x = self.forward_head(x)
        return x


def ViT(params, **kwargs):
    model = VisionTransformer(
        img_size=tuple(params.img_size),
        in_chans=params.n_in_channels,
        out_chans=params.n_out_channels,
        patch_size=params.patch_size,
        embed_dim=params.embed_dim,
        depth=params.depth,
        num_heads=params.num_heads,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        drop_path_rate=float(params.dropout),
        drop_rate=float(params.dropout),
        attn_drop_rate=float(params.dropout),
        **kwargs
    )
    return model

Data loader

root@86ffe15cc233:/dli/task# \
ENABLE_PROFILING=1 \
PROFILE_OUTPUT=baseline-dli_dw8 \
./submit_dli.sh \
    --config=short \
    --num_data_workers 8 \
    --run_num=nw8
Enabling profiling...
+ mpirun --allow-run-as-root -np 1 bash -c '
    source export_DDP_vars.sh
    nsys profile \
        --trace=cuda,cublas,nvtx \
        --kill none \
        -c cudaProfilerApi \
        -f true \
        -o /dli/task/logs/baseline-dli_dw8 \
        python train.py \
        --config=short \
        --num_data_workers 8 \
        --run_num=nw8
    '
2025-11-17 21:32:08,311 - root - INFO - ------------------ Configuration ------------------
2025-11-17 21:32:08,311 - root - INFO - Configuration file: /dli/task/config/ViT.yaml
2025-11-17 21:32:08,311 - root - INFO - Configuration name: short
2025-11-17 21:32:08,311 - root - INFO - limit_nsamples 512
2025-11-17 21:32:08,311 - root - INFO - limit_nsamples_val 128
2025-11-17 21:32:08,311 - root - INFO - num_iters 128
2025-11-17 21:32:08,311 - root - INFO - embed_dim 384
2025-11-17 21:32:08,311 - root - INFO - depth 12
2025-11-17 21:32:08,311 - root - INFO - dropout 0.0
2025-11-17 21:32:08,311 - root - INFO - patch_size 8
2025-11-17 21:32:08,311 - root - INFO - num_heads 8
2025-11-17 21:32:08,311 - root - INFO - model_backend pytorch
2025-11-17 21:32:08,311 - root - INFO - img_size [360, 720]
2025-11-17 21:32:08,311 - root - INFO - dt 1
2025-11-17 21:32:08,311 - root - INFO - global_batch_size 16
2025-11-17 21:32:08,312 - root - INFO - amp_mode none
2025-11-17 21:32:08,312 - root - INFO - enable_fused False
2025-11-17 21:32:08,312 - root - INFO - enable_jit False
2025-11-17 21:32:08,312 - root - INFO - expdir /dli/task/logs
2025-11-17 21:32:08,312 - root - INFO - lr_schedule cosine
2025-11-17 21:32:08,312 - root - INFO - lr 0.0005
2025-11-17 21:32:08,312 - root - INFO - warmup 0
2025-11-17 21:32:08,312 - root - INFO - optimizer Adam
2025-11-17 21:32:08,312 - root - INFO - data_loader_config pytorch
2025-11-17 21:32:08,312 - root - INFO - num_data_workers 8
2025-11-17 21:32:08,312 - root - INFO - n_in_channels 20
2025-11-17 21:32:08,312 - root - INFO - n_out_channels 20
2025-11-17 21:32:08,312 - root - INFO - train_data_path /data/train
2025-11-17 21:32:08,312 - root - INFO - valid_data_path /data/valid
2025-11-17 21:32:08,312 - root - INFO - inf_data_path /data/test
2025-11-17 21:32:08,312 - root - INFO - time_means_path /data/stats/time_means.npy
2025-11-17 21:32:08,312 - root - INFO - global_means_path /data/stats/global_means.npy
2025-11-17 21:32:08,312 - root - INFO - global_stds_path /data/stats/global_stds.npy
2025-11-17 21:32:08,312 - root - INFO - wireup_info env
2025-11-17 21:32:08,312 - root - INFO - wireup_store tcp
2025-11-17 21:32:08,312 - root - INFO - amp_enabled False
2025-11-17 21:32:08,312 - root - INFO - amp_dtype torch.float32
2025-11-17 21:32:08,312 - root - INFO - ---------------------------------------------------
2025-11-17 21:32:08,513 - root - INFO - rank 0, begin data loader init
2025-11-17 21:32:08,551 - root - INFO - Getting file stats from /data/train/2012.h5
2025-11-17 21:32:08,552 - root - INFO - Overriding total number of samples to: 512
2025-11-17 21:32:08,552 - root - INFO - Number of samples per year: 1460
2025-11-17 21:32:08,552 - root - INFO - Found data at path /data/train. Number of examples: 512. Image Shape: 360 x 720 x 20
2025-11-17 21:32:08,552 - root - INFO - Getting file stats from /data/valid/2016.h5
2025-11-17 21:32:08,553 - root - INFO - Overriding total number of samples to: 128
2025-11-17 21:32:08,553 - root - INFO - Number of samples per year: 1460
2025-11-17 21:32:08,553 - root - INFO - Found data at path /data/valid. Number of examples: 128. Image Shape: 360 x 720 x 20
2025-11-17 21:32:08,553 - root - INFO - rank 0, data loader initialized
2025-11-17 21:32:08,943 - root - INFO - VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(20, 384, kernel_size=(8, 8), stride=(8, 8))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (drop_path): Identity()
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (q): Linear(in_features=384, out_features=384, bias=True)
        (k): Linear(in_features=384, out_features=384, bias=True)
        (v): Linear(in_features=384, out_features=384, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (mlp): MLP(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=384, out_features=1280, bias=False)
)
2025-11-17 21:32:08,943 - root - INFO - Starting Training Loop...
/dli/task/train.py:176: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:22.)
  tr_loss.append(loss.item() / world_size)
2025-11-17 21:32:48,380 - root - INFO - Time taken for epoch 1 is 31.060786 sec, avg 16.483807 samples/sec
2025-11-17 21:32:48,380 - root - INFO -   Avg train loss=0.583575
2025-11-17 21:32:52,875 - root - INFO -   Avg val loss=0.4340733289718628
2025-11-17 21:32:52,876 - root - INFO -   Total validation time: 3.9601058959960938 sec
2025-11-17 21:33:23,642 - root - INFO - Time taken for epoch 2 is 30.764906 sec, avg 16.642339 samples/sec
2025-11-17 21:33:23,642 - root - INFO -   Avg train loss=0.400669
2025-11-17 21:33:27,975 - root - INFO -   Avg val loss=0.3799823522567749
2025-11-17 21:33:27,976 - root - INFO -   Total validation time: 3.7919225692749023 sec
2025-11-17 21:33:59,814 - root - INFO - Time taken for epoch 3 is 31.836994 sec, avg 16.081920 samples/sec
2025-11-17 21:33:59,814 - root - INFO -   Avg train loss=0.365027
2025-11-17 21:34:04,047 - root - INFO -   Avg val loss=0.3569737672805786
2025-11-17 21:34:04,047 - root - INFO -   Total validation time: 3.7000949382781982 sec
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-7df0.qdstrm'
[1/1] [0%                          ] baseline-dli_dw8.nsys-repProcessing events...
[1/1] [========================100%] baseline-dli_dw8.nsys-rep
Generated:
        /dli/task/logs/baseline-dli_dw8.nsys-rep
2025-11-17 21:34:43,205 - root - INFO - Time taken for epoch 4 is 39.157385 sec, avg 13.075439 samples/sec
2025-11-17 21:34:43,206 - root - INFO -   Avg train loss=0.352449
2025-11-17 21:34:47,627 - root - INFO -   Avg val loss=0.3528955578804016
2025-11-17 21:34:47,627 - root - INFO -   Total validation time: 3.882009744644165 sec
2025-11-17 21:34:47,631 - root - INFO - DONE ---- rank 0

Using multiple data workers—typically 4 or 8—is essential for keeping the GPU fed with data during training. Each data worker operates as an independent process (or thread, depending on the framework) that loads and preprocesses batches of data in parallel. By overlapping I/O and preprocessing operations with model computation, more data workers help to hide disk access latency, mitigate bottlenecks caused by slow data augmentations, and fully utilize available CPU resources. This setup reduces idle time between training steps, ensuring that the next mini-batch is ready right when the GPU needs it. The optimal number of workers depends on system resources, dataset characteristics, and the complexity of preprocessing, but in practice, adding more data workers will significantly improve throughput up to a certain point.

Using DALI can avoid this host-to-device memcpy bottleneck, as DALI performs a wide range of data loading and augmentation operations directly on the GPU. This leverages GPU acceleration for preprocessing, and DALI's C++ backend maintains its own highly efficient worker threads, improving concurrency and throughput far beyond traditional CPU-based data pipelines.
# -----------------------------------------------------------------------------
# How DALI Interacts with CUDA for Accelerated Data Loading
# -----------------------------------------------------------------------------
#
# The DALI pipeline is tightly integrated with CUDA. Its workers and GPU operators move data
# efficiently from storage to device, avoiding Python and host memory bottlenecks.
#
#                  +---------------------+   (1) read/augment   +-----------------+
# Data/FileSystem  | DALI CPU Workers    |--------------------->|   CPU Buffer    |
# (images, HDF5,   +---------------------+      (NumPy,        +-----------------+
#  .npy, etc.)                                    Pillow...)                 |
#          |                                                        (2) memcpy (optional)
#          |                                                                 |
#          |        +--------------+  (3) DALI GPU ops  +---------------------+
#          +------->| DALI Pipeline|------------------->|   CUDA Device Mem   |
#                   +--------------+ (Crop, Normalize,  +---------------------+
#                                     Augment, etc)               |
#                                                            (4) tensor returned to PyTorch on GPU, with no H2D copy needed.
#                                                                 |
#         [PyTorch training loop, tensor already in cuda:0, ready for GPU kernel]
#
#  - DALI operators marked as "device='gpu'" run as CUDA kernels and operate on device memory.
#  - No host-to-device PCIe transfers are needed after data is loaded into GPU memory.
#  - The pipeline uses internal CUDA streams and events for asynchronicity:
#         - Different sets of CUDA streams can be used for each worker/pipeline.
#         - DALI synchronizes only when the output tensor is returned to PyTorch.
#  - This approach is analogous to the CUDA buffer/execution region (see /distributed-gpu).
#
#      CPU (Data Workers)                  CUDA Device (GPU)           PyTorch Training Loop
#  +---------------------+             +----------------------+      +-----------------------+
#  | Disk/Net IO         |--read-----> | DALI pipeline ops    | ---> |   model(input.cuda()) |
#  | Decoding, Sharding  |             | (normalize, augment) |      +-----------------------+
#  | Preprocessing (CPU) |             +----------------------+        │
#  +---------------------+                    │                        │ Backprop
#                |                            └------------------------┘
#        sharding logic ensures                                       
#        each GPU/process gets                                       
#        the correct portion of data
#
#  See: Memory regions, CUDA device buffer layout, and device-kernel launches (cf. the
#  diagrams and commentary in /distributed-gpu).

class DaliDataLoader(object):
    def get_pipeline(self):
        pipeline = Pipeline(
            batch_size=self.batch_size,
            num_threads=2,
            device_id=self.device_index,
            py_num_workers=self.num_data_workers,
            py_start_method="spawn",
            seed=self.global_seed,
        )

        with pipeline:  # get input and target from external source
            inp, tar = fn.external_source(
                source=esh.ERA5ES(
                    self.location,
                    self.train,
                    self.batch_size,
                    self.dt,
                    self.img_size,
                    self.n_in_channels,
                    self.n_out_channels,
                    self.num_shards,
                    self.shard_id,
                    self.limit_nsamples,
                    enable_logging=False,
                    seed=self.global_seed,
                ),
                num_outputs=2,
                layout=["CHW", "CHW"],
                batch=False,
                no_copy=True,
                parallel=True,
            )

            # upload directly to GPU (DALI op executes as CUDA kernel)
            inp = inp.gpu()
            tar = tar.gpu()

            if self.normalize:
                inp = fn.normalize(
                    inp,
                    device="gpu",
                    axis_names="HW",
                    batch=False,
                    mean=self.in_bias,
                    stddev=self.in_scale,
                )

                tar = fn.normalize(
                    tar,
                    device="gpu",
                    axis_names="HW",
                    batch=False,
                    mean=self.out_bias,
                    stddev=self.out_scale,
                )

            pipeline.set_outputs(inp, tar)
        return pipeline

    def __init__(self, params, location, train, seed=333):
        # set up seeds
        self.global_seed = seed  # same seed across all ranks
        model_id = comm.get_world_rank() // comm.get_size("tp-cp-pp")
        self.model_seed = self.global_seed + model_id  # model-wise seed
        self.local_seed = self.global_seed + comm.get_world_rank()  # unique per-rank

        self.num_data_workers = params.num_data_workers
        self.device_index = torch.cuda.current_device()
        self.batch_size = int(params.local_batch_size)

        self.location = location
        self.train = train
        self.dt = params.dt
        self.n_in_channels = params.n_in_channels
        self.n_out_channels = params.n_out_channels
        self.img_size = params.img_size
        self.limit_nsamples = (
            params.limit_nsamples if train else params.limit_nsamples_val
        )

        # load normalization stats
        self.normalize = True
        means = np.load(params.global_means_path)[0][: self.n_in_channels]
        stds = np.load(params.global_stds_path)[0][: self.n_in_channels]
        self.in_bias = means
        self.in_scale = stds
        means = np.load(params.global_means_path)[0][: self.n_out_channels]
        stds = np.load(params.global_stds_path)[0][: self.n_out_channels]
        self.out_bias = means
        self.out_scale = stds

        # set sharding for distributed, or local single process
        if dist.is_initialized():
            self.num_shards = params.data_num_shards
            self.shard_id = params.data_shard_id
        else:
            self.num_shards = 1
            self.shard_id = 0

        # compute number of batches for __len__
        extsource = esh.ERA5ES(
            self.location,
            self.train,
            self.batch_size,
            self.dt,
            self.img_size,
            self.n_in_channels,
            self.n_out_channels,
            self.num_shards,
            self.shard_id,
            self.limit_nsamples,
            seed=self.global_seed,
        )
        self.num_batches = extsource.num_steps_per_epoch
        del extsource

        # create DALI pipeline & launch py workers
        self.pipeline = self.get_pipeline()
        self.pipeline.start_py_workers()
        self.pipeline.build()

        # Create iterator over GPU-preprocessed batches
        self.iterator = DALIGenericIterator(
            [self.pipeline],
            ["inp", "tar"],
            auto_reset=True,
            last_batch_policy=LastBatchPolicy.DROP,
            prepare_first_batch=True,
        )

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        for token in self.iterator:
            inp = token[0]["inp"]
            tar = token[0]["tar"]
            yield inp, tar

Mixed precision training

Before looking at the code for mixed precision training, let’s briefly explain how BF16 Automatic Mixed Precision (AMP) works in PyTorch and modern deep learning frameworks. BF16, or Brain Floating Point, is a 16-bit floating point format that has the same range as FP32 (full-precision), but with less mantissa precision, typically used for training deep neural networks. Automatic Mixed Precision (AMP) allows you to mix computations in lower precision (like BF16 or FP16) and full precision automatically. In PyTorch, this is handled using torch.cuda.amp.autocast, which transparently casts operations to the appropriate precision based on hardware support and the data type specified (dtype=torch.bfloat16 for BF16). BF16 is supported on recent NVIDIA (Ampere and newer), Intel, and AMD accelerators, often providing nearly the same accuracy as FP32 for deep learning, but with significantly higher throughput and lower memory usage. The autocast context ensures that certain numerically-sensitive operations, like reductions and softmax, remain in FP32 for robustness, while most matrix multiplications and convolutions run in BF16 for performance.

# Hierarchy of AMP in Deep Learning Workflows (BF16 / FP16 Mixed Precision)
#
#      +---------------------------------------------------------+
#      |               Deep Neural Networks (Model)              |
#      +---------------------------------------------------------+
#                               |
#      |                 DL Frameworks (e.g., PyTorch)           |
#      |---------------------------------------------------------|
#      |   - Provides AMP support (autocast/GradScaler)          |
#      |   - Exposes APIs to enable mixed precision training     |
#      +---------------------------------------------------------+
#                               |
#      |     Automatic Mixed Precision (AMP) - Hardware Assisted |
#      |---------------------------------------------------------|
#      |   - Integrates with hardware features                   |
#      |   - Controls casting (BF16/FP16 ↔ FP32) on the fly      |
#      |   - Ensures safe ops remain in FP32 automatically       |
#      +---------------------------------------------------------+
#                               |
#      |                Hardware (Tensor Cores, CPUs/GPUs)       |
#      |---------------------------------------------------------|
#      |   - BF16/FP16 Tensor Core acceleration                  |
#      |   - High-throughput, hardware-level support             |
#      +---------------------------------------------------------+
#
# Data travels down through these layers, where AMP orchestrates use of specialized hardware to accelerate
# mixed-precision compute, as visualized in the image above.
#
# In code, AMP typically looks like:
#
#   with autocast(device_type, enabled=True, dtype=torch.bfloat16):
#       y = model(x)    # matmul/conv ops run on Tensor Cores as BF16/FP16, safe ops stay FP32
#       loss = loss_fn(y, target)
#
# Gradient scaling and safe weight management are handled by the framework (see image above).
#
#     Model/DL framework
#           |
#       [autocast]
#           |
#    (BF16/FP16 ops scheduled for tensor cores by AMP)
#           |
#        Hardware execution
#        (Tensor Cores/accelerators)

with torch.no_grad():
    inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
    with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
        gen = model(inp)
    tr_loss = loss_func(gen, tar)
    inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
    with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
        gen = model(inp)
    val_loss = loss_func(gen, tar)
    val_rmse = weighted_rmse(gen, tar)
    if params.distributed:
        torch.distributed.all_reduce(
            tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
        )
        torch.distributed.all_reduce(
            val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
        )
        torch.distributed.all_reduce(
            val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
        )
    if world_rank == 0:
        args.tboard_writer.add_scalar("Loss/train", tr_loss.item(), 0)
        args.tboard_writer.add_scalar("Loss/valid", val_loss.item(), 0)
        args.tboard_writer.add_scalar(
            "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], 0
        )

...

    torch.cuda.nvtx.range_push(f"forward")
    with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
        gen = model(inp)
        loss = loss_func(gen, tar)
    torch.cuda.nvtx.range_pop()  # forward
...

with torch.inference_mode():
    with torch.no_grad():
        for i, data in enumerate(val_data_loader, 0):
            with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
                inp, tar = map(lambda x: x.to(device), data)
                gen = model(inp)
                val_loss += loss_func(gen, tar)
                val_rmse += weighted_rmse(gen, tar)
            valid_steps += 1

        if params.distributed:
            torch.distributed.all_reduce(
                val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
            )
            torch.distributed.all_reduce(
                val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
            )
2025-11-17 21:44:35,117 - root - INFO - Starting Training Loop...
/dli/task/train.py:176: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:22.)
  tr_loss.append(loss.item() / world_size)
2025-11-17 21:44:41,821 - root - INFO - Time taken for epoch 1 is 5.921940 sec, avg 83.756342 samples/sec
2025-11-17 21:44:41,821 - root - INFO -   Avg train loss=0.586307
2025-11-17 21:44:43,182 - root - INFO -   Avg val loss=0.431182861328125
2025-11-17 21:44:43,183 - root - INFO -   Total validation time: 0.9876902103424072 sec
2025-11-17 21:44:48,198 - root - INFO - Time taken for epoch 2 is 5.013919 sec, avg 102.115723 samples/sec
2025-11-17 21:44:48,198 - root - INFO -   Avg train loss=0.399898
2025-11-17 21:44:49,094 - root - INFO -   Avg val loss=0.37918391823768616
2025-11-17 21:44:49,095 - root - INFO -   Total validation time: 0.5196459293365479 sec
2025-11-17 21:44:54,139 - root - INFO - Time taken for epoch 3 is 5.043273 sec, avg 101.521369 samples/sec
2025-11-17 21:44:54,139 - root - INFO -   Avg train loss=0.363822
2025-11-17 21:44:54,990 - root - INFO -   Avg val loss=0.3559249937534332
2025-11-17 21:44:54,990 - root - INFO -   Total validation time: 0.4742579460144043 sec
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-cb82.qdstrm'
[1/1] [0%                          ] baseline-dli_dw8_dali_bf16.nsys-repProcessing events...
[1/1] [========================100%] baseline-dli_dw8_dali_bf16.nsys-rep
Generated:
        /dli/task/logs/baseline-dli_dw8_dali_bf16.nsys-rep
2025-11-17 21:45:07,261 - root - INFO - Time taken for epoch 4 is 12.269330 sec, avg 41.730072 samples/sec
2025-11-17 21:45:07,261 - root - INFO -   Avg train loss=0.351357
2025-11-17 21:45:08,138 - root - INFO -   Avg val loss=0.35201317071914673
2025-11-17 21:45:08,138 - root - INFO -   Total validation time: 0.49225330352783203 sec
2025-11-17 21:45:10,384 - root - INFO - DONE ---- rank 0

DDP and FSDP

DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP) are two strategies in PyTorch for training models across multiple GPUs or nodes. DDP works by keeping a full copy of the model on each GPU, splitting each batch of data across devices, and then averaging gradients between them using efficient all-reduce operations; this is usually the default choice for multi-GPU training because it’s simple, fast, and well-optimized, but it requires that every GPU has enough memory to hold the whole model. FSDP, on the other hand, is designed for very large models that don’t fit in memory on a single GPU: instead of replicating the whole model on each device, it shards (splits) parameters, gradients, and optimizer states across GPUs so that each GPU only stores a slice of them, temporarily gathering full parameters only when needed for forward and backward passes. This sharding drastically reduces memory usage and allows training of much larger models, at the cost of more complex communication patterns and potentially higher communication overhead than standard DDP.

import sys
import os
import time
import numpy as np
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
import torch.multiprocessing
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel

import logging
from utils import logging_utils

logging_utils.config_logger()
from utils.YParams import YParams
from utils import get_data_loader_distributed
from utils.loss import l2_loss, l2_loss_opt
from utils.metrics import weighted_rmse
from utils.plots import generate_images
from networks import vit


def train(params, args, local_rank, world_rank, world_size):
    """
    ==========================================================================
                                   train()
    ==========================================================================
    |  Overall Structure:
    |   
    |   +-----------------------------------------------------------+
    |   |        Setup, Data, Model, Optimizer, Scheduler           |
    |   +-----------------------------------------------------------+
    |   |    Initial Logging & Initial Loss/Eval (TensorBoard)      |
    |   +-----------------------------------------------------------+
    |   |    ========== Main Training Loop (epochs) ===========     |
    |   |    |                                               |      |
    |   |    |  1. Per-epoch setup (sampler, timers, mode)   |      |
    |   |    |  2. --------- Per-step (batch) -----------    |      |
    |   |    |  |   a. Data to device, timing/metrics |      |      |
    |   |    |  |   b. Forward, loss, backward, opt   |      |      |
    |   |    |  |   c. Distributed/allreduce          |      |      |
    |   |    |  |   d. Running stats, LR step         |      |      |
    |   |    |  +-------------------------------------+      |      |
    |   |    |  3. Per-epoch eval: validation loader         |      |
    |   |    +-----------------------------------------------+      |
    |   +-----------------------------------------------------------+
    |   |            Final timing and summary stats                 |
    +---------------------------------------------------------------+
   
    """
    # -------------- (1) Device and CUDNN setup --------------------
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda:%d" % local_rank)

    # -------------- (2) Data Loaders                               |
    logging.info("rank %d, begin data loader init" % world_rank)
    train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(
        params, params.train_data_path, params.distributed, train=True
    )
    val_data_loader, valid_dataset = get_data_loader_distributed(
        params, params.valid_data_path, params.distributed, train=False
    )
    logging.info("rank %d, data loader initialized" % (world_rank))

    # -------------- (3) Model Instantiation                       |
    model = vit.ViT(params).to(device)

    if params.enable_jit:
        # Optionally use torch.compile for JIT optimization
        model = torch.compile(model)

    # -------------- (4) AMP/Distributed Setup                    |
    if params.amp_dtype == torch.float16:
        scaler = GradScaler("cuda")
    if params.distributed and not args.noddp:
        """
        =========================================================
        DistributedDataParallel (DDP) Overview:

        DDP is PyTorch's mechanism for achieving data-parallel training
        across multiple GPUs, using NCCL for high-performance GPU-to-GPU
        communication. For each GPU/rank, you launch a separate process.
        Backward gradients are automatically all-reduced so each model
        replica stays in sync.

        Example:
            +------------------+
            | Process/Rank 0   |    GPU:0
            | +------------+   |    model parameters θ₀
            | |   Model    |---|-------+
            | +------------+   |       |
            +------------------+       |
                             |         v
            +------------------+     AllReduce: gradients
            | Process/Rank 1   |    across NCCL communicator
            | +------------+   |    (compare to halo exchange
            | |   Model    |---|--------- ring in Jacobi)
            | +------------+   |
            +------------------+

          - Each process has its own CUDA device and model replica
          - Forward/backward compute locally on own batch
          - After backward(), gradients are automatically averaged
            across all GPUs/ranks (using NCCL/AllReduce)
          - Optimizer step updates local model parameters

        Compare this to stencil/halo exchange with MPI/NCCL in
        /distributed-gpu, but DDP allreduces the *whole* gradient,
        not just boundaries.

        Buffer sharing: broadcast_buffers True/False
          - Buffers = non-parameter state (e.g., running stats)
          - True: Syncs these each iteration (default/PyTorch default)
          - False: Leaves them independent on each process

        More: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
        =========================================================
        """

        # Optional buffer sharing config for DDP
        if args.disable_broadcast_buffers:
            model = DistributedDataParallel(
                model,
                device_ids=[local_rank],
                bucket_cap_mb=args.bucket_cap_mb,
                broadcast_buffers=False,
                gradient_as_bucket_view=True,
            )
        else:
            model = DistributedDataParallel(
                model, device_ids=[local_rank], bucket_cap_mb=args.bucket_cap_mb
            )

    # -------------- (5) Optimizer                                |
    if params.enable_fused:
        optimizer = optim.Adam(
            model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95)
        )
    else:
        optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95))

    # -------------- (6) Log Model                               |
    if world_rank == 0:
        logging.info(model)

    # -------------- (7) Scheduler                               |
    iters = 0
    startEpoch = 0
    if params.lr_schedule == "cosine":
        if params.warmup > 0:
            lr_scale = lambda x: min(
                (x + 1) / params.warmup,
                0.5 * (1 + np.cos(np.pi * x / params.num_iters)),
            )
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=params.num_iters
            )
    else:
        scheduler = None

    # -------------- (8) Loss Selection                          |
    if params.enable_jit:
        loss_func = l2_loss_opt
    else:
        loss_func = l2_loss

    # -------------- (9) Initial Logging to TensorBoard          |
    if world_rank == 0:
        logging.info("Starting Training Loop...")

    # ==[ Initial Loss/Validation Logging ]==========================
    with torch.no_grad():
        inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
        gen = model(inp)
        tr_loss = loss_func(gen, tar)
        inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
        gen = model(inp)
        val_loss = loss_func(gen, tar)
        val_rmse = weighted_rmse(gen, tar)
        if params.distributed:
            torch.distributed.all_reduce(tr_loss)
            torch.distributed.all_reduce(val_loss)
            torch.distributed.all_reduce(val_rmse)
        if world_rank == 0:
            args.tboard_writer.add_scalar("Loss/train", tr_loss.item() / world_size, 0)
            args.tboard_writer.add_scalar("Loss/valid", val_loss.item() / world_size, 0)
            args.tboard_writer.add_scalar(
                "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0] / world_size, 0
            )

    # ========== MAIN TRAINING LOOP (epochs) ==========
    params.num_epochs = params.num_iters // len(train_data_loader)
    iters = 0
    t1 = time.time()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        # ---[A] Epoch: prep, sampler, timers ---
        torch.cuda.synchronize()  # Ensure correct epoch timing (sync GPU)
        if params.distributed and (train_sampler is not None):
            train_sampler.set_epoch(epoch)
        start = time.time()
        tr_loss = []
        tr_time = 0.0
        dat_time = 0.0
        log_time = 0.0

        model.train()
        step_count = 0

        # ----------- B. PER-BATCH TRAINING ------------
        for i, data in enumerate(train_data_loader, 0):
            # (Optional) Profiling: Begin/End CUDA profiler at a certain epoch
            if world_rank == 0:
                if epoch == 3 and i == 0:
                    torch.cuda.profiler.start()
                if epoch == 3 and i == len(train_data_loader) - 1:
                    torch.cuda.profiler.stop()

            torch.cuda.nvtx.range_push(f"step {i}")
            iters += 1
            dat_start = time.time()
            torch.cuda.nvtx.range_push(f"data copy in {i}")

            # --- Copy batch to GPU ---
            inp, tar = map(lambda x: x.to(device), data)
            torch.cuda.nvtx.range_pop()  # data copy in

            tr_start = time.time()
            b_size = inp.size(0)

            # --- ZERO GRAD ---
            optimizer.zero_grad()

            # --- Forward w/ autocast (AMP) ---
            torch.cuda.nvtx.range_push(f"forward")
            with autocast("cuda", enabled=params.amp_enabled, dtype=params.amp_dtype):
                gen = model(inp)
                loss = loss_func(gen, tar)
            torch.cuda.nvtx.range_pop()  # forward

            # --- Backward, Step (AMP or standard) ---
            if params.amp_dtype == torch.float16:
                scaler.scale(loss).backward()
                torch.cuda.nvtx.range_push(f"optimizer")
                scaler.step(optimizer)
                torch.cuda.nvtx.range_pop()  # optimizer
                scaler.update()
            else:
                loss.backward()
                torch.cuda.nvtx.range_push(f"optimizer")
                optimizer.step()
                torch.cuda.nvtx.range_pop()  # optimizer

            # --- Allreduce if distributed ---
            if params.distributed:
                torch.distributed.all_reduce(loss)
            tr_loss.append(loss.item() / world_size)

            torch.cuda.nvtx.range_pop()  # step
            # --- LR Scheduler Step ---
            scheduler.step()

            # --- Time/Stats ---
            tr_end = time.time()
            tr_time += tr_end - tr_start
            dat_time += tr_start - dat_start
            step_count += 1

        # =====[ End of Training steps for this epoch ]=====
        torch.cuda.synchronize()  # Accurate epoch timing
        end = time.time()

        # ----------- C. Epoch-End Reporting/Logging (Rank 0) -------
        if world_rank == 0:
            iters_per_sec = step_count / (end - start)
            samples_per_sec = params["global_batch_size"] * iters_per_sec
            logging.info(
                "Time taken for epoch %i is %f sec, avg %f samples/sec",
                epoch + 1,
                end - start,
                samples_per_sec,
            )
            logging.info("  Avg train loss=%f" % np.mean(tr_loss))
            args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters)
            args.tboard_writer.add_scalar(
                "Learning Rate", optimizer.param_groups[0]["lr"], iters
            )
            args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters)
            args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters)
            fig = generate_images([inp, tar, gen])
            args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True)

        # ========== D. Validation Evaluation ============
        val_start = time.time()
        val_loss = torch.zeros(1, device=device)
        val_rmse = torch.zeros(
            (params.n_out_channels), dtype=torch.float32, device=device
        )
        valid_steps = 0
        model.eval()

        with torch.inference_mode():
            with torch.no_grad():
                for i, data in enumerate(val_data_loader, 0):
                    with autocast(
                        "cuda", enabled=params.amp_enabled, dtype=params.amp_dtype
                    ):
                        inp, tar = map(lambda x: x.to(device), data)
                        gen = model(inp)
                        val_loss += loss_func(gen, tar)
                        val_rmse += weighted_rmse(gen, tar)
                    valid_steps += 1

                if params.distributed:
                    torch.distributed.all_reduce(val_loss)
                    val_loss /= world_size
                    torch.distributed.all_reduce(val_rmse)
                    val_rmse /= world_size

        val_rmse /= valid_steps  # Avg validation rmse
        val_loss /= valid_steps
        val_end = time.time()

        # ---------- (E) Validation logging (TensorBoard/Console) ----------
        if world_rank == 0:
            logging.info("  Avg val loss={}".format(val_loss.item()))
            logging.info("  Total validation time: {} sec".format(val_end - val_start))
            args.tboard_writer.add_scalar("Loss/valid", val_loss, iters)
            args.tboard_writer.add_scalar(
                "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters
            )
            args.tboard_writer.flush()

    # ===== Final Timing Summary =====
    t2 = time.time()
    tottime = t2 - t1

    # ==========================================================================
    #    End of train()
    # ==========================================================================


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--run_num",
        default="00",
        type=str,
        help="tag for indexing the current experiment",
    )
    parser.add_argument(
        "--yaml_config",
        default="./config/ViT.yaml",
        type=str,
        help="path to yaml file containing training configs",
    )
    parser.add_argument(
        "--config", default="base", type=str, help="name of desired config in yaml file"
    )
    parser.add_argument(
        "--amp_mode",
        default="none",
        type=str,
        choices=["none", "fp16", "bf16"],
        help="select automatic mixed precision mode",
    )
    parser.add_argument(
        "--enable_fused", action="store_true", help="enable fused Adam optimizer"
    )
    parser.add_argument(
        "--enable_jit", action="store_true", help="enable JIT compilation"
    )
    parser.add_argument(
        "--local_batch_size",
        default=None,
        type=int,
        help="local batchsize (manually override global_batch_size config setting)",
    )
    parser.add_argument(
        "--num_iters", default=None, type=int, help="number of iters to run"
    )
    parser.add_argument(
        "--num_data_workers",
        default=None,
        type=int,
        help="number of data workers for data loader",
    )
    parser.add_argument(
        "--data_loader_config",
        default=None,
        type=str,
        choices=["pytorch", "dali"],
        help="dataloader configuration. choices: 'pytorch', 'dali'",
    )
    parser.add_argument(
        "--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb"
    )
    parser.add_argument(
        "--disable_broadcast_buffers",
        action="store_true",
        help="disable syncing broadcasting buffers",
    )
    parser.add_argument(
        "--noddp", action="store_true", help="disable DDP communication"
    )
    args = parser.parse_args()

    run_num = args.run_num

    params = YParams(os.path.abspath(args.yaml_config), args.config)

    # Update config with modified args
    # set up amp
    if args.amp_mode != "none":
        params.update({"amp_mode": args.amp_mode})
    amp_dtype = torch.float32

    if params.amp_mode == "fp16":
        amp_dtype = torch.float16
    elif params.amp_mode == "bf16":
        amp_dtype = torch.bfloat16

    params.update(
        {"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype}
    )

    if args.enable_fused:
        params.update({"enable_fused": args.enable_fused})

    if args.enable_jit:
        params.update({"enable_jit": args.enable_jit})

    if args.data_loader_config:
        params.update({"data_loader_config": args.data_loader_config})

    if args.num_iters:
        params.update({"num_iters": args.num_iters})

    if args.num_data_workers:
        params.update({"num_data_workers": args.num_data_workers})

    params.distributed = False
    if "WORLD_SIZE" in os.environ:
        params.distributed = int(os.environ["WORLD_SIZE"]) > 1
        world_size = int(os.environ["WORLD_SIZE"])
    else:
        world_size = 1

    world_rank = 0
    local_rank = 0
    if params.distributed:
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        world_rank = torch.distributed.get_rank()
        local_rank = int(os.environ["LOCAL_RANK"])

    if args.local_batch_size:
        # Manually override batch size
        params.local_batch_size = args.local_batch_size
        params.update({"global_batch_size": world_size * args.local_batch_size})
    else:
        # Compute local batch size based on number of ranks
        params.local_batch_size = params.global_batch_size // world_size

    # for dali data loader, set the actual number of data shards and id
    params.data_num_shards = world_size
    params.data_shard_id = world_rank

    # Set up directory
    baseDir = params.expdir
    expDir = os.path.join(
        baseDir, args.config + "/%dGPU/" % (world_size) + str(run_num) + "/"
    )
    if world_rank == 0:
        if not os.path.isdir(expDir):
            os.makedirs(expDir)
        logging_utils.log_to_file(
            logger_name=None, log_filename=os.path.join(expDir, "out.log")
        )
        params.log()
        args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/"))

    params.experiment_dir = os.path.abspath(expDir)

    train(params, args, local_rank, world_rank, world_size)

    if params.distributed:
        torch.distributed.barrier()
    logging.info("DONE ---- rank %d" % world_rank)

Model parallelism

Model parallelism is a complementary strategy where a single model is split across multiple devices, instead of each device holding a full copy as in DDP or FSDP. In tensor (intra-layer) model parallelism, individual layers are partitioned—e.g., splitting a large linear layer’s weight matrix by columns or rows—so that each GPU computes part of the layer’s output and then the partial results are combined. In pipeline model parallelism, different groups of layers are placed on different devices (like stages in a pipeline), and microbatches are streamed through these stages so multiple parts of the model can run concurrently. Model parallelism is especially useful when a single layer or the whole model is too large to fit on one GPU, but it usually involves more complex orchestration and communication patterns than pure data parallel approaches, and in practice large-scale systems often combine data parallelism (DDP/FSDP) with model parallelism.

import sys
import os
import time
import numpy as np
import argparse
import pynvml

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
import torch.multiprocessing
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import ReduceOp

import logging
from utils import logging_utils
import warnings
warnings.simplefilter("ignore", FutureWarning)

logging_utils.config_logger()
from utils.YParams import YParams
from utils import get_data_loader_distributed
from utils import comm
from utils.loss import l2_loss, l2_loss_opt
from utils.metrics import weighted_rmse
from networks import vit
from networks import vit_te

from distributed.mappings import init_ddp_model_and_reduction_hooks
from distributed.helpers import init_params_for_shared_weights

from utils.plots import generate_images


def train(params, args, local_rank, world_rank, world_size):
    # ----------------------------------------------------------------------- #
    # (1) Device & Library Setup
    # ----------------------------------------------------------------------- #
    # Ensure deterministic performance for convolution operations and set device
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    # Initialize GPU monitoring to track memory usage
    pynvml.nvmlInit()
    nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index)

    # ----------------------------------------------------------------------- #
    # (2) DataLoader Initialization
    # ----------------------------------------------------------------------- #
    # Get distributed train/validation data loaders and samplers
    logging.info("rank %d, begin data loader init" % world_rank)
    train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(
        params, params.train_data_path, params.distributed, train=True
    )
    val_data_loader, valid_dataset = get_data_loader_distributed(
        params, params.valid_data_path, params.distributed, train=False
    )
    logging.info("rank %d, data loader initialized" % world_rank)

    # ----------------------------------------------------------------------- #
    # (3) Model Initialization & JIT Compilation
    # ----------------------------------------------------------------------- #
    # Choose Transformer Engine or native ViT backend
    if params.model_backend == 'transformer-engine':
        logging.info("using transformer-engine backend")
        model = vit_te.ViT(params).to(device)
    else:
        logging.info("using native backend")
        model = vit.ViT(params).to(device)

    # JIT compilation (optional, for acceleration)
    if params.enable_jit:
        # NOTE: DDP interoperability with torch.compile may need special handling
        model = torch.compile(model)

    # AMP: Automatic Mixed Precision
    if params.amp_dtype == torch.float16:
        scaler = GradScaler('cuda')

    # ----------------------------------------------------------------------- #
    # (4) Distributed/TP Shared Weight Synchronization
    # ----------------------------------------------------------------------- #
    # If tensor/pipeline parallel model requires, sync weights

    """
    MODEL PARALLELISM: TENSOR & PIPELINE (PyTorch/pynvml, Exascale Style)
    -----------------------------------------------------------------------

    [How deep learning tensor/pipeline parallelism actually works in PyTorch,
     and how parameter/init synchronization mimics exascale Jacobi/MPI code.
     Device/memory setup and monitoring via torch.cuda & pynvml -- see below.]

    =====================================================================
    (1) TENSOR PARALLELISM (PyTorch distributed, weight sharding, sync)
    =====================================================================

    Model parameter broadcast/init (like MPI_Bcast in Jacobi, see @2025-11-16):

              +------------------- torch.distributed.broadcast (NCCL/MPI)--------------------+
              |                                                                              |
        +----------+     +----------+     +----------+     +----------+                  +---v------+
        |  GPU 0   |     |  GPU 1   |     |  GPU 2   |     |  GPU 3   |                  | Python   |
        | (rank 0) |     | (rank 1) |     | (rank 2) |     | (rank 3) |                  | process  |
        +----+-----+     +-----+----+     +-----+----+     +-----+----+                  +---+------+
             |                 |                |                |                            |
             |                 |                |                |                            |
             |      (shard large weights via param.split())      |                            |
             +-----------------|---------------------------------+----------------------------+
                               | (each GPU stores a partition)
                               +-------------------------+
                                         |    (forward/backward: partial matmuls, reduce/add)
    (compare: Jacobi halo rows exchanged with MPI_Sendrecv() and parameters broadcast with MPI_Bcast)

    =====================================================================
    (2) PIPELINE PARALLELISM (PyTorch: staged blocks, microbatch pipeline)
    =====================================================================

        Input Batch
            |
        +----------+    +----------+    +----------+    +----------+
        |  GPU 0   |--->|  GPU 1   |--->|  GPU 2   |--->|  GPU 3   |
        | (Stage 0)|    | (Stage 1)|    | (Stage 2)|    | (Stage 3)|
        +----------+    +----------+    +----------+    +----------+
             |              |                |                |
        [Encoder 0]    [Encoder 1]     [Encoder 2]      [Encoder 3]
             |              |                |                |
          (microbatch split, data passed stage to stage)
             |
         torch.distributed.broadcast(init params) --- like Jacobi: all subdomains/ranks
                                                        must agree on boundaries at setup

    =====================================================================
    (3) DEVICE & MEMORY MANAGEMENT (PyTorch, pynvml, MPI Analogy)
    =====================================================================

        +-----------------------------------------------------------+
        |                   Python Process (per rank)               |
        +-----------------------------------------------------------+
        | torch.cuda.set_device(local_rank)                         |
        | device = torch.device(f"cuda:{local_rank}")               |
        | model.to(device)                                          |
        | pynvml.nvmlInit()                     (get handle for GPU)|
        | nvmlDeviceGetMemoryInfo(nvml_handle)     (GB used)        |
        +-----------------------------------------------------------+
                          |                                     
    [like Jacobi/MPI: cudaSetDevice, deviceCount, rank <-> device mapping]

    =====================================================================
    (4) SYNC & COLLECTIVES: PyTorch vs Jacobi MPI
    =====================================================================

        +---------------------------+     +------------------------------+
        |      PyTorch World        |     |    Exascale Jacobi (C/MPI)   |
        +---------------------------+     +------------------------------+
        | torch.distributed.broadcast| <->| MPI_Bcast (e.g. ncclUniqueId)|
        | torch.distributed.all_reduce|<->| MPI_Allreduce (convergence)  |
        | torch.distributed.barrier  | <->| MPI_Barrier                  |
        +---------------------------+     +------------------------------+

    KEY FLOW:
      - PyTorch uses torch.distributed backends (NCCL/MPI) to implement
        the same process-group collectives as Jacobi C/MPI: sync weights,
        share state, coordinate progress.
      - init_params_for_shared_weights(model) triggers these bcasts.
      - pynvml, torch.cuda lifecycles mirror cuda runtime and info calls
        in Jacobi code (“assign rank to device”, “log mem stats”).

    =====================================================================
    (5) KEY MAPPING
    =====================================================================

        PyTorch init_params_for_shared_weights(model):
            |
            +---> torch.distributed.broadcast(...) for all shared (sharded) params,
                 often per process-group (tp, cp)
        MPI Jacobi:
            |
            +---> MPI_Bcast()/Sendrecv() for domain boundaries, ids

    See /distributed-gpu for the equivalent device set,
    rank assignment, MPI_Bcast/ncclUniqueId & ncclCommInitRank, mirrored here in torch:
        - torch.cuda.device         <-> cudaSetDevice(rank % num_devices)
        - torch.distributed actions <-> MPI_SENDRECV/ALLREDUCE/BCAST
        - pynvml.nvmlDeviceGet...   <-> cudaMemGetInfo/MPI logging

    """

    if comm.get_size("tp-cp") > 1:
        init_params_for_shared_weights(model)

    # Wrap model with DDP if required
    if params.distributed and not args.noddp:
        model = init_ddp_model_and_reduction_hooks(
            model,
            device_ids=[local_rank],
            output_device=[local_rank],
            bucket_cap_mb=args.bucket_cap_mb
        )

    # ----------------------------------------------------------------------- #
    # (5) Optimizer Setup
    # ----------------------------------------------------------------------- #
    # Use fused Adam optimizer if enabled
    if params.enable_fused:
        optimizer = optim.Adam(
            model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95)
        )
    else:
        optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95))

    # ----------------------------------------------------------------------- #
    # (6) Model Logging & Memory Monitoring (rank 0 master only)
    # ----------------------------------------------------------------------- #
    if world_rank == 0:
        logging.info(model)
        all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024.0**3)
        logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.")

    iters = 0
    startEpoch = 0

    # ----------------------------------------------------------------------- #
    # (7) Learning Rate Scheduler Configuration
    # ----------------------------------------------------------------------- #
    if params.lr_schedule == "cosine":
        if params.warmup > 0:
            # Linear warmup to cosine decay
            lr_scale = lambda x: min(
                (x + 1) / params.warmup,
                0.5 * (1 + np.cos(np.pi * x / params.num_iters)),
            )
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=params.num_iters
            )
    else:
        scheduler = None

    # ----------------------------------------------------------------------- #
    # (8) Loss Function Selection
    # ----------------------------------------------------------------------- #
    # Use JIT-optimized or plain loss depending on setting
    if params.enable_jit:
        loss_func = l2_loss_opt
    else:
        loss_func = l2_loss

    # ----------------------------------------------------------------------- #
    # (9) Initial Logging and Baseline Metrics
    # ----------------------------------------------------------------------- #
    if world_rank == 0:
        logging.info("Starting Training Loop...")

    # Compute & log initial train/val loss and RMSE before epochs begin
    with torch.no_grad():
        # Take first batch from training and validation for baseline
        inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader)))
        with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
            gen = model(inp)
        tr_loss = loss_func(gen, tar)
        inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader)))
        with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
            gen = model(inp)
        val_loss = loss_func(gen, tar)
        val_rmse = weighted_rmse(gen, tar)
        # Reduce baseline metrics across data-parallel group if needed
        if params.distributed:
            torch.distributed.all_reduce(
                tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
            )
            torch.distributed.all_reduce(
                val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
            )
            torch.distributed.all_reduce(
                val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
            )
        if world_rank == 0:
            args.tboard_writer.add_scalar("Loss/train", tr_loss.item(), 0)
            args.tboard_writer.add_scalar("Loss/valid", val_loss.item(), 0)
            args.tboard_writer.add_scalar(
                "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], 0
            )

    # ----------------------------------------------------------------------- #
    # (10) Main Training/Validation Loop
    # ----------------------------------------------------------------------- #
    params.num_epochs = params.num_iters // len(train_data_loader)
    iters = 0
    t1 = time.time()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        torch.cuda.synchronize()  # ensure accurate epoch timings
        # For reproducibility in distributed scenario, reset epoch for sampler
        if params.distributed and (train_sampler is not None):
            train_sampler.set_epoch(epoch)
        start = time.time()
        tr_loss = []
        tr_time = 0.0
        dat_time = 0.0
        log_time = 0.0

        model.train()
        step_count = 0

        # (10.1) Training steps per batch
        for i, data in enumerate(train_data_loader, 0):
            # (a) Optionally enable profiling for diagnostics on rank 0
            if world_rank == 0:
                if epoch == 3 and i == 0:
                    torch.cuda.profiler.start()
                if epoch == 3 and i == len(train_data_loader) - 1:
                    torch.cuda.profiler.stop()

            torch.cuda.nvtx.range_push(f"step {i}")
            iters += 1
            dat_start = time.time()
            torch.cuda.nvtx.range_push(f"data copy in {i}")

            # (b) Copy data to device/GPU
            inp, tar = map(lambda x: x.to(device), data)
            torch.cuda.nvtx.range_pop()  # end data copy timing

            tr_start = time.time()
            b_size = inp.size(0)
            optimizer.zero_grad()

            torch.cuda.nvtx.range_push("forward")
            # (c) Forward pass and loss under autocast for AMP
            with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
                gen = model(inp)
                loss = loss_func(gen, tar)
            torch.cuda.nvtx.range_pop()  # end forward timing

            # Optionally log memory usage after first forward pass
            if world_rank == 0 and i == 1:
                all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024.0**3)
                logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.")

            # (d) Backward & optimizer step (support autocast/scaler if fp16 AMP)
            if params.amp_dtype == torch.float16:
                scaler.scale(loss).backward()
                torch.cuda.nvtx.range_push("optimizer")
                scaler.step(optimizer)
                torch.cuda.nvtx.range_pop()  # optimizer
                scaler.update()
            else:
                loss.backward()
                torch.cuda.nvtx.range_push("optimizer")
                optimizer.step()
                torch.cuda.nvtx.range_pop()  # optimizer

            # (e) Synchronize loss across data-parallel processes if distributed
            if params.distributed:
                torch.distributed.all_reduce(
                    loss, op=ReduceOp.AVG, group=comm.get_group("dp")
                )
            tr_loss.append(loss.item())

            torch.cuda.nvtx.range_pop()  # end step

            # (f) Learning rate schedule step
            scheduler.step()

            # (g) Timing and diagnostics
            tr_end = time.time()
            tr_time += tr_end - tr_start
            dat_time += tr_start - dat_start
            step_count += 1

        torch.cuda.synchronize()  # epoch synchronization for timing
        end = time.time()

        # (10.2) Logging and visualization (rank 0 only)
        if world_rank == 0:
            iters_per_sec = step_count / (end - start)
            samples_per_sec = params["global_batch_size"] * iters_per_sec
            logging.info(
                "Time taken for epoch %i is %f sec, avg %f samples/sec",
                epoch + 1,
                end - start,
                samples_per_sec,
            )
            logging.info("  Avg train loss=%f" % np.mean(tr_loss))
            args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters)
            args.tboard_writer.add_scalar(
                "Learning Rate", optimizer.param_groups[0]["lr"], iters
            )
            args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters)
            args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters)
            # Visualize a sample result
            fig = generate_images([inp, tar, gen])
            args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True)

        # ------------------------------------------------------------------- #
        # (10.3) Validation Evaluation Phase
        # ------------------------------------------------------------------- #
        val_start = time.time()
        val_loss = torch.zeros(1, device=device)
        val_rmse = torch.zeros(
            (params.n_out_channels), dtype=torch.float32, device=device
        )
        valid_steps = 0
        model.eval()

        with torch.inference_mode():
            with torch.no_grad():
                for i, data in enumerate(val_data_loader, 0):
                    with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype):
                        inp, tar = map(lambda x: x.to(device), data)
                        gen = model(inp)
                        val_loss += loss_func(gen, tar)
                        val_rmse += weighted_rmse(gen, tar)
                    valid_steps += 1

                # Reduce validation metrics across all ranks for distributed
                if params.distributed:
                    torch.distributed.all_reduce(
                        val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")
                    )
                    torch.distributed.all_reduce(
                        val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")
                    )

        # Normalize validation metrics by number of steps
        val_rmse /= valid_steps
        val_loss /= valid_steps
        val_end = time.time()

        # (10.4) Log validation results (only rank 0 logs and writes)
        if world_rank == 0:
            logging.info("  Avg val loss={}".format(val_loss.item()))
            logging.info("  Total validation time: {} sec".format(val_end - val_start))
            args.tboard_writer.add_scalar("Loss/valid", val_loss, iters)
            args.tboard_writer.add_scalar(
                "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters
            )
            args.tboard_writer.flush()

    # ----------------------------------------------------------------------- #
    # (11) Cleanup and Final Timing
    # ----------------------------------------------------------------------- #
    torch.cuda.synchronize()
    t2 = time.time()
    tottime = t2 - t1
    pynvml.nvmlShutdown()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--run_num",
        default="00",
        type=str,
        help="tag for indexing the current experiment",
    )
    parser.add_argument(
        "--yaml_config",
        default="./config/ViT.yaml",
        type=str,
        help="path to yaml file containing training configs",
    )
    parser.add_argument(
        "--config", default="base", type=str, help="name of desired config in yaml file"
    )
    parser.add_argument(
        "--amp_mode",
        default="none",
        type=str,
        choices=["none", "fp16", "bf16"],
        help="select automatic mixed precision mode",
    )
    parser.add_argument(
        "--enable_fused", action="store_true", help="enable fused Adam optimizer"
    )
    parser.add_argument(
        "--enable_jit", action="store_true", help="enable JIT compilation"
    )
    parser.add_argument(
        "--local_batch_size",
        default=None,
        type=int,
        help="local batchsize (manually override global_batch_size config setting)",
    )
    parser.add_argument(
        "--num_iters", default=None, type=int, help="number of iters to run"
    )
    parser.add_argument(
        "--num_data_workers",
        default=None,
        type=int,
        help="number of data workers for data loader",
    )
    parser.add_argument(
        "--data_loader_config",
        default=None,
        type=str,
        choices=["pytorch", "dali"],
        help="dataloader configuration. choices: 'pytorch', 'dali'",
    )
    parser.add_argument(
        "--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb"
    )
    parser.add_argument(
        "--disable_broadcast_buffers",
        action="store_true",
        help="disable syncing broadcasting buffers",
    )
    parser.add_argument(
        "--noddp", action="store_true", help="disable DDP communication"
    )

    # model parallelism arguments
    parser.add_argument(
        "--tensor_parallel",
        default=1,
        type=int,
        help="Number of GPUs for tensor parallelism",
    )
    parser.add_argument(
        "--context_parallel",
        default=1,
        type=int,
        help="Number of GPUs for context parallelism",
    )
    parser.add_argument(
        "--parallel_order",
        default="tp-cp-dp",
        type=str,
        help="Order of ranks for parallelism",
    )

    args = parser.parse_args()

    run_num = args.run_num

    params = YParams(os.path.abspath(args.yaml_config), args.config)

    # Update config with modified args
    # set up amp
    if args.amp_mode != "none":
        params.update({"amp_mode": args.amp_mode})
    amp_dtype = torch.float32
    if params.amp_mode == "fp16":
        amp_dtype = torch.float16
    elif params.amp_mode == "bf16":
        amp_dtype = torch.bfloat16

    params.update(
        {"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype}
    )

    if args.enable_fused:
        params.update({"enable_fused": args.enable_fused})

    if args.enable_jit:
        params.update({"enable_jit": args.enable_jit})

    if args.data_loader_config:
        params.update({"data_loader_config": args.data_loader_config})

    if args.num_iters:
        params.update({"num_iters": args.num_iters})

    if args.num_data_workers:
        params.update({"num_data_workers": args.num_data_workers})

    params.distributed = False

    # setup model parallel sizes
    params["tp"] = args.tensor_parallel
    params["cp"] = args.context_parallel
    params["order"] = args.parallel_order
    # initialize comm
    comm.init(params, verbose=True)

    # get info from comm
    world_size = comm.get_world_size()
    world_rank = comm.get_world_rank()
    local_rank = comm.get_local_rank()
    params.distributed = world_size > 1

    assert (
        params["global_batch_size"] % comm.get_size("dp") == 0
    ), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('dp')} GPU."

    if args.local_batch_size:
        # Manually override batch size
        params.local_batch_size = args.local_batch_size
        params.update(
            {"global_batch_size": comm.get_size("dp") * args.local_batch_size}
        )
    else:
        # Compute local batch size based on number of ranks
        params.local_batch_size = int(
            params["global_batch_size"] // comm.get_size("dp")
        )

    # for data loader, set the actual number of data shards and id
    params.data_num_shards = comm.get_size("dp")
    params.data_shard_id = comm.get_rank("dp")

    # Set up directory
    baseDir = params.expdir
    expDir = os.path.join(
        baseDir, args.config + "/%dMP/" % (comm.get_size("tp-cp")) + str(run_num) + "/"
    )
    if world_rank == 0:
        if not os.path.isdir(expDir):
            os.makedirs(expDir)
        logging_utils.log_to_file(
            logger_name=None, log_filename=os.path.join(expDir, "out.log")
        )
        params.log()
        args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/"))

    params.experiment_dir = os.path.abspath(expDir)

    train(params, args, local_rank, world_rank, world_size)

    if params.distributed:
        torch.distributed.barrier()
    logging.info("DONE ---- rank %d" % world_rank)

← Back to all posts