Chaim Rand, Author at Towards Data Science https://towardsdatascience.com/author/chaimrand/ The world’s leading publication for data science, AI, and ML professionals. Thu, 27 Feb 2025 21:59:48 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Chaim Rand, Author at Towards Data Science https://towardsdatascience.com/author/chaimrand/ 32 32 Debugging the Dreaded NaN https://towardsdatascience.com/debugging-the-dreaded-nan/ Thu, 27 Feb 2025 21:52:06 +0000 https://towardsdatascience.com/?p=598513 Capturing and reproducing failures in PyTorch training with Lightning

The post Debugging the Dreaded NaN appeared first on Towards Data Science.

]]>
You are training your latest AI model, anxiously watching as the loss steadily decreases when suddenly — boom! Your logs are flooded with NaNs (Not a Number) — your model is irreparably corrupted and you’re left staring at your screen in despair. To make matters worse, the NaNs don’t appear consistently. Sometimes your model trains just fine; other times, it fails inexplicably. Sometimes it will crash immediately, sometimes after many days of training.

NaNs in Deep Learning workloads are amongst the most frustrating issues to encounter. And because they often appear sporadically — triggered by a specific combination of model state, input data, and stochastic factors — they can be incredibly difficult to reproduce and debug.

Given the considerable cost of training AI models and the potential waste caused by NaN failures, it is recommended to have dedicated tools for capturing and analyzing NaN occurrences. In a previous post, we discussed the challenge of debugging NaNs in a TensorFlow training workload. We proposed an efficient scheme for capturing and reproducing NaNs and shared a sample TensorFlow implementation. In this post, we adopt and demonstrate a similar mechanism for debugging NaNs in PyTorch workloads. The general scheme is as follows:

On each training step:

  1. Save a copy of the training input batch.
  2. Check the gradients for NaN values. If any appear, save a checkpoint with the current model weights before the model is corrupted. Also, save the input batch and, if necessary, the stochastic state. Discontinue the training job.
  3. Reproduce and debug the NaN occurrence by loading the saved experiment state.

Although this scheme can be easily implemented in native PyTorch, we will take the opportunity to demonstrate some of the conveniences of PyTorch Lightning — a powerful open-source framework designed to streamline the development of machine learning (ML) models. Built on PyTorch, Lightning abstracts away many of the boiler-plate components of an ML experiment, such as training loops, data distribution, logging, and more, enabling developers to focus on the core logic of their models.

To implement our NaN capturing scheme, we will use Lightning’s callback interface — a dedicated structure that enables inserting custom logic at specific points during the flow of execution.

Importantly, please do not view our choice of Lightning or any other tool or technique that we mention as an endorsement of its use. The code that we will share is intended for demonstrative purposes — please do not rely on its correctness or optimality.

Many thanks to Rom Maltser for his contributions to this post.

NaNCapture Callback

To implement our NaN capturing solution, we create a NaNCapture Lightning callback. The constructor receives a directory path for storing/loading checkpoints and sets up the NaNCapture state. We also define utilities for checking for NaNs, storing checkpoints, and halting the training job.

 import os
import torch
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # update to True when Nan is identified
        self.nan_captured = False
        
        # stores a copy of the last batch
        self.last_batch = None
        self.batch_idx = None

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().item()
        # alternatively check for finite
        # return not torch.isfinite(tensor).item()

    @staticmethod
    def halt_training(trainer):
        trainer.should_stop = True
        # communicate stop command to all other ranks
        trainer.strategy.reduce_boolean_decision(trainer.should_stop,
                                                 all=False)

    def save_ckpt(self, trainer):
        os.makedirs(self.dirpath, exist_ok=True)
        # include trainer.global_rank to avoid conflict
        filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
        full_path = os.path.join(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        trainer.save_checkpoint(full_path, False)

Callback Function: on_train_batch_start

We begin by implementing the on_train_batch_start hook to store a copy of each input batch. In case of a NaN event, this batch will be stored in the checkpoint.

Callback Function: on_before_optimizer_step

Next we implement the on_before_optimizer_step hook. Here, we check for NaN entries in all of the gradient tensors. If found, we store a checkpoint with the uncorrupted model weights and halt the training.

Python">    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        if not self.nan_captured:
            # Check if gradients contain NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan found")
                self.save_ckpt(trainer)
                self.halt_training(trainer)

Capturing the Training State

To enable reproducibility, we include the NaNCapture state in the checkpoint by appending it to the training state dictionary. Lightning provides dedicated utilities for saving and loading a callback state:

def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
        return d


    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]

Reproducing the NaN Occurrence

We have described how our NaNCapture callback can be used to store the training state that resulted in a NaN, but how do we reload this state in order to reproduce the issue and debug it? To accomplish this, we leverage Lightning’s dedicated data loading class, LightningDataModule.

DataModule Function: on_before_batch_transfer

In the code block below, we extend the LightningDataModule class to allow injecting a fixed training input batch. This is achieved by overriding the on_before_batch_transfer hook, as shown below:

from lightning.pytorch import LightningDataModule

class InjectableDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()
        self.cached_batch = None

    def set_custom_batch(self, batch):
        self.cached_batch = batch

    def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.cached_batch:
            return self.cached_batch
        return batch

Callback Function: on_train_start

The final step is modifying the on_train_start hook of our NaNCapture callback to inject the stored training batch into the LightningDataModule.

    def on_train_start(self, trainer, pl_module):
        if self.nan_captured:
            datamodule = trainer.datamodule
            datamodule.set_custom_batch(self.last_batch)

In the next section we will demonstrate the end-to-end solution using a toy example.

Toy Example

To test our new callback, we create a resnet50-based image classification model with a loss function deliberately designed to trigger NaN occurrences.

Instead of using the standard CrossEntropy loss, we compute binary_cross_entropy_with_logits for each class independently and divide the result by the number of samples belonging to that class. Inevitably, we will encounter a batch in which one or more classes are missing, leading to a divide-by-zero operation, resulting in NaN values and corrupting the model.

The implementation below follows Lightning’s introductory tutorial.

import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.functional as F

num_classes = 20


# define a lightning module
class ResnetModel(pl.LightningModule):
    def __init__(self):
        """Initializes a new instance of the MNISTModel class."""
        super().__init__()
        self.model = torchvision.models.resnet50(num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        outputs = self(x)
        # uncomment for default loss
        # return F.cross_entropy(outputs, y)
        
        # calculate binary_cross_entropy for each class individually
        losses = []
        for c in range(num_classes):
            count = torch.count_nonzero(y==c)
            masked = torch.where(y==c, 1., 0.)
            loss = F.binary_cross_entropy_with_logits(
                outputs[..., c],
                masked,
                reduction='sum'
            )
            mean_loss = loss/count # could result in NaN
            losses.append(mean_loss)
        total_loss = torch.stack(losses).mean()
        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

We define a synthetic dataset and encapsulate it in our InjectableDataModule class:

import os
import random
from torch.utils.data import Dataset, DataLoader

batch_size = 128
num_steps = 800

# A dataset with random images and labels
class FakeDataset(Dataset):
    def __len__(self):
        return batch_size*num_steps

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(random.randint(0, num_classes-1),
                             dtype=torch.int64)
        return rand_image, label



# define a lightning datamodule
class FakeDataModule(InjectableDataModule):

    def train_dataloader(self):
        dataset = FakeDataset()
        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=os.cpu_count(),
            pin_memory=True
        )

Finally, we initialize a Lightning Trainer with our NaNCapture callback and call trainer.fit with our Lightning module and Lightning DataModule.

import time

if __name__ == "__main__":

    # Initialize a lightning module
    lit_module = ResnetModel()

    # Initialize a DataModule
    mnist_data = FakeDataModule()

    # Train the model
    ckpt_dir = "./ckpt_dir"
    trainer = pl.Trainer(
        max_epochs=1,
        callbacks=[NaNCapture(ckpt_dir)]
    )

    ckpt_path = None
    
    # check is nan ckpt exists
    if os.path.isdir(ckpt_dir):

    # check if nan ckpt exists
    if os.path.isdir(ckpt_dir):
        dir_contents = [os.path.join(ckpt_dir, f)
                        for f in os.listdir(ckpt_dir)]
        ckpts = [f for f in dir_contents
                 if os.path.isfile(f) and f.endswith('.ckpt')]
        if ckpts:
            ckpt_path = ckpts[0]

    t0 = time.perf_counter()
    trainer.fit(lit_module, mnist_data, ckpt_path=ckpt_path)
    print(f"total runtime: {time.perf_counter() - t0}")

After a number of training steps, a NaN event will occur. At this point a checkpoint is saved with the full training state and the training is halted.

When the script is run again the exact state that caused the NaN will be reloaded allowing us to easily reproduce the issue and debug its root cause.

Performance Overhead

To assess the impact of our NaNCapture callback on runtime performance, we modified our experiment to use CrossEntropyLoss (to avoid NaNs) and measured the average throughput when running with and without NaNCapture callback. The experiments were conducted on an NVIDIA L40S GPU, with a PyTorch 2.5.1 Docker image.

Overhead of NaNCapture Callback (by Author)

For our toy model, the NaNCapture callback adds a minimal 1.5% overhead to the runtime performance — a small price to pay for the valuable debugging capabilities it provides.

Naturally, the actual overhead will depend on the specifics of the model and runtime environment.

How to Handle Stochasticity

The solution we have described henceforth will succeed in reproducing the training state provided that the model does not include any randomness. However, introducing stochasticity into the model definition is often critical for convergence. A common example of a stochastic layer is torch.nn.Dropout.

You may find that your NaN event depends on the precise state of randomness when the failure occurred. Consequently, we would like to enhance our NaNCapture callback to capture and restore the random state at the point of failure. The random state is determined by a number of libraries. In the code block below, we attempt to capture the full state of randomness:

import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # update to True when Nan is identified
        self.nan_captured = False
        
        # stores a copy of the last batch
        self.last_batch = None
        self.batch_idx = None

        # rng state
        self.rng_state = {
            "torch": None,
            "torch_cuda": None,
            "numpy": None,
            "random": None
        }

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().item()
        # alternatively check for finite
        # return not torch.isfinite(tensor).item()

    @staticmethod
    def halt_training(trainer):
        trainer.should_stop = True
        trainer.strategy.reduce_boolean_decision(trainer.should_stop,
                                                 all=False)

    def save_ckpt(self, trainer):
        os.makedirs(self.dirpath, exist_ok=True)
        # include trainer.global_rank to avoid conflict
        filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
        full_path = os.path.join(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        trainer.save_checkpoint(full_path, False)

    def on_train_start(self, trainer, pl_module):
        if self.nan_captured:
            # inject batch
            datamodule = trainer.datamodule
            datamodule.set_custom_batch(self.last_batch)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
       if self.nan_captured:
            # restore random state
            torch.random.set_rng_state(self.rng_state["torch"])
            torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
            np.random.set_state(self.rng_state["numpy"])
            random.setstate(self.rng_state["random"])
        else:
            # capture current batch
            self.last_batch= deepcopy(batch)
            self.batch_idx = batch_idx
    
            # capture current random state
            self.rng_state["torch"] = torch.random.get_rng_state()
            self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
            self.rng_state["numpy"] = np.random.get_state()
            self.rng_state["random"] = random.getstate()
    
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        if not self.nan_captured:
            # Check if gradients contain NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan found")
                self.save_ckpt(trainer)
                self.halt_training(trainer)

    def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
            d["rng_state"] = self.rng_state
        return d

    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]
            self.rng_state = state_dict["rng_state"]

Importantly, setting the random state may not guarantee full reproducibility. The GPU owes its power to its massive parallelism. In some GPU operations, multiple threads may read or write concurrently to the same memory locations resulting in nondeterminism. PyTorch allows for some control over this via its use_deterministic_algorithms, but this may impact the runtime performance. Additionally, there is a possibility that the NaN event will not reproduced once this configuration setting is changed. Please see the PyTorch documentation on reproducibility for more details.

Summary

Encountering NaN failures is one of the most discouraging events that can happen in machine learning development. These errors not only waste valuable computation and development resources, but often indicate fundamental issues in the model architecture or experiment design. Due to their sporadic, sometimes elusive nature, debugging NaN failures can be a nightmare.

This post introduced a proactive approach for capturing and reproducing NaN errors using a dedicated Lightning callback. The solution we shared is a proposal which can be modified and extended for your specific use case.

While this solution may not address every possible NaN scenario, it significantly reduces debugging time when applicable, potentially saving developers countless hours of frustration and wasted effort.

The post Debugging the Dreaded NaN appeared first on Towards Data Science.

]]>
Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics https://towardsdatascience.com/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics/ Fri, 07 Feb 2025 01:22:43 +0000 https://towardsdatascience.com/?p=597508 Metric collection is an essential part of every machine learning project, enabling us to track model performance and monitor training progress. Ideally, metrics should be collected and computed without introducing any additional overhead to the training process. However, just like other components of the training loop, inefficient metric computation can introduce unnecessary overhead, increase training-step […]

The post Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics appeared first on Towards Data Science.

]]>
Metric collection is an essential part of every machine learning project, enabling us to track model performance and monitor training progress. Ideally, Metrics should be collected and computed without introducing any additional overhead to the training process. However, just like other components of the training loop, inefficient metric computation can introduce unnecessary overhead, increase training-step times and inflate training costs.

This post is the seventh in our series on performance profiling and optimization in PyTorch. The series has aimed to emphasize the critical role of performance analysis and Optimization in machine learning development. Each post has focused on different stages of the training pipeline, demonstrating practical tools and techniques for analyzing and boosting resource utilization and runtime efficiency.

In this installment, we focus on metric collection. We will demonstrate how a naïve implementation of metric collection can negatively impact runtime performance and explore tools and techniques for its analysis and optimization.

To implement our metric collection, we will use TorchMetrics a popular library designed to simplify and standardize metric computation in Pytorch. Our goals will be to:

  1. Demonstrate the runtime overhead caused by a naïve implementation of metric collection.
  2. Use PyTorch Profiler to pinpoint performance bottlenecks introduced by metric computation.
  3. Demonstrate optimization techniques to reduce metric collection overhead.

To facilitate our discussion, we will define a toy PyTorch model and assess how metric collection can impact its runtime performance. We will run our experiments on an NVIDIA A40 GPU, with a PyTorch 2.5.1 docker image and TorchMetrics 1.6.1.

It’s important to note that metric collection behavior can vary greatly depending on the hardware, runtime environment, and model architecture. The code snippets provided in this post are intended for demonstrative purposes only. Please do not interpret our mention of any tool or technique as an endorsement for its use.

Toy Resnet Model

In the code block below we define a simple image classification model with a ResNet-18 backbone.

import time
import torch
import torchvision

device = "cuda"

model = torchvision.models.resnet18().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())

We define a synthetic dataset which we will use to train our toy model.

from torch.utils.data import Dataset, DataLoader

# A dataset with random images and labels
class FakeDataset(Dataset):
    def __len__(self):
        return 100000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(data=index % 1000, dtype=torch.int64)
        return rand_image, label

train_set = FakeDataset()

batch_size = 128
num_workers = 12

train_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True
)

We define a collection of standard metrics from TorchMetrics, along with a control flag to enable or disable metric calculation.

from torchmetrics import (
    MeanMetric,
    Accuracy,
    Precision,
    Recall,
    F1Score,
)

# toggle to enable/disable metric collection
capture_metrics = False

if capture_metrics:
        metrics = {
        "avg_loss": MeanMetric(),
        "accuracy": Accuracy(task="multiclass", num_classes=1000),
        "precision": Precision(task="multiclass", num_classes=1000),
        "recall": Recall(task="multiclass", num_classes=1000),
        "f1_score": F1Score(task="multiclass", num_classes=1000),
    }

    # Move all metrics to the device
    metrics = {name: metric.to(device) for name, metric in metrics.items()}

Next, we define a PyTorch Profiler instance, along with a control flag that allows us to enable or disable profiling. For a detailed tutorial on using PyTorch Profiler, please refer to the first post in this series.

from torch import profiler

# toggle to enable/disable profiling
enable_profiler = True

if enable_profiler:
    prof = profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),
        on_trace_ready=profiler.tensorboard_trace_handler("./logs/"),
        profile_memory=True,
        with_stack=True
    )
    prof.start()

Lastly, we define a standard training step:

model.train()

t0 = time.perf_counter()
total_time = 0
count = 0

for idx, (data, target) in enumerate(train_loader):
    data = data.to(device, non_blocking=True)
    target = target.to(device, non_blocking=True)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    if capture_metrics:
        # update metrics
        metrics["avg_loss"].update(loss)
        for name, metric in metrics.items():
            if name != "avg_loss":
                metric.update(output, target)

        if (idx + 1) % 100 == 0:
            # compute metrics
            metric_results = {
                name: metric.compute().item() 
                    for name, metric in metrics.items()
            }
            # print metrics
            print(f"Step {idx + 1}: {metric_results}")
            # reset metrics
            for metric in metrics.values():
                metric.reset()

    elif (idx + 1) % 100 == 0:
        # print last loss value
        print(f"Step {idx + 1}: Loss = {loss.item():.4f}")

    batch_time = time.perf_counter() - t0
    t0 = time.perf_counter()
    if idx > 10:  # skip first steps
        total_time += batch_time
        count += 1

    if enable_profiler:
        prof.step()

    if idx > 200:
        break

if enable_profiler:
    prof.stop()

avg_time = total_time/count
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')

Metric Collection Overhead

To measure the impact of metric collection on training step time, we ran our training script both with and without metric calculation. The results are summarized in the following table.

The Overhead of Naive Metric Collection (by Author)

Our naïve metric collection resulted in a nearly 10% drop in runtime performance!! While metric collection is essential for machine learning development, it usually involves relatively simple mathematical operations and hardly warrants such a significant overhead. What is going on?!!

Identifying Performance Issues with PyTorch Profiler

To better understand the source of the performance degradation, we reran the training script with the PyTorch Profiler enabled. The resultant trace is shown below:

Trace of Metric Collection Experiment (by Author)

The trace reveals recurring “cudaStreamSynchronize” operations that coincide with noticeable drops in GPU utilization. These types of “CPU-GPU sync” events were discussed in detail in part two of our series. In a typical training step, the CPU and GPU work in parallel: The CPU manages tasks like data transfers to the GPU and kernel loading, and the GPU executes the model on the input data and updates its weights. Ideally, we would like to minimize the points of synchronization between the CPU and GPU in order to maximize performance. Here, however, we can see that the metric collection has triggered a sync event by performing a CPU to GPU data copy. This requires the CPU to suspend its processing until the GPU catches up which, in turn, causes the GPU to wait for the CPU to resume loading the subsequent kernel operations. The bottom line is that these synchronization points lead to inefficient utilization of both the CPU and GPU. Our metric collection implmentation adds eight such synchronization events to each training step.

A closer examination of the trace shows that the sync events are coming from the update call of the MeanMetric TorchMetric. For the experienced profiling expert, this may be sufficient to identify the root cause, but we will go a step further and use the torch.profiler.record_function utility to identify the exact offending line of code.

Profiling with record_function

To pinpoint the exact source of the sync event, we extended the MeanMetric class and overrode the update method using record_function context blocks. This approach allows us to profile individual operations within the method and identify performance bottlenecks.

class ProfileMeanMetric(MeanMetric):
    def update(self, value, weight = 1.0):
        # broadcast weight to value shape
        with profiler.record_function("process value"):
            if not isinstance(value, torch.Tensor):
                value = torch.as_tensor(value, dtype=self.dtype,
                                        device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
        with profiler.record_function("broadcast weight"):
            weight = torch.broadcast_to(weight, value.shape)
        with profiler.record_function("cast_and_nan_check"):
            value, weight = self._cast_and_nan_check_input(value, weight)

        if value.numel() == 0:
            return

        with profiler.record_function("update value"):
            self.mean_value += (value * weight).sum()
        with profiler.record_function("update weight"):
            self.weight += weight.sum()

We then updated our avg_loss metric to use the newly created ProfileMeanMetric and reran the training script.

Trace of Metric Collection with record_function (by Author)

The updated trace reveals that the sync event originates from the following line:

weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)

This operation converts the default scalar value weight=1.0 into a PyTorch tensor and places it on the GPU. The sync event occurs because this action triggers a CPU-to-GPU data copy, which requires the CPU to wait for the GPU to process the copied value.

Optimization 1: Specify Weight Value

Now that we have found the source of the issue, we can overcome it easily by specifying a weight value in our update call. This prevents the runtime from converting the default scalar weight=1.0 into a tensor on the GPU, avoiding the sync event:

# update metrics
 if capture_metric:
     metrics["avg_loss"].update(loss, weight=torch.ones_like(loss))

Rerunning the script after applying this change reveals that we have succeeded in eliminating the initial sync event… only to have uncovered a new one, this time coming from the _cast_and_nan_check_input function:

Trace of Metric Collection following Optimization 1 (by Author)

Profiling with record_function — Part 2

To explore our new sync event, we extended our custom metric with additional profiling probes and reran our script.

class ProfileMeanMetric(MeanMetric):
    def update(self, value, weight = 1.0):
        # broadcast weight to value shape
        with profiler.record_function("process value"):
            if not isinstance(value, torch.Tensor):
                value = torch.as_tensor(value, dtype=self.dtype,
                                        device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
        with profiler.record_function("broadcast weight"):
            weight = torch.broadcast_to(weight, value.shape)
        with profiler.record_function("cast_and_nan_check"):
            value, weight = self._cast_and_nan_check_input(value, weight)

        if value.numel() == 0:
            return

        with profiler.record_function("update value"):
            self.mean_value += (value * weight).sum()
        with profiler.record_function("update weight"):
            self.weight += weight.sum()

    def _cast_and_nan_check_input(self, x, weight = None):
        """Convert input ``x`` to a tensor and check for Nans."""
        with profiler.record_function("process x"):
            if not isinstance(x, torch.Tensor):
                x = torch.as_tensor(x, dtype=self.dtype,
                                    device=self.device)
        with profiler.record_function("process weight"):
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
            nans = torch.isnan(x)
            if weight is not None:
                nans_weight = torch.isnan(weight)
            else:
                nans_weight = torch.zeros_like(nans).bool()
                weight = torch.ones_like(x)

        with profiler.record_function("any nans"):
            anynans = nans.any() or nans_weight.any()

        with profiler.record_function("process nans"):
            if anynans:
                if self.nan_strategy == "error":
                    raise RuntimeError("Encountered `nan` values in tensor")
                if self.nan_strategy in ("ignore", "warn"):
                    if self.nan_strategy == "warn":
                        print("Encountered `nan` values in tensor."
                              " Will be removed.")
                    x = x[~(nans | nans_weight)]
                    weight = weight[~(nans | nans_weight)]
                else:
                    if not isinstance(self.nan_strategy, float):
                        raise ValueError(f"`nan_strategy` shall be float"
                                         f" but you pass {self.nan_strategy}")
                    x[nans | nans_weight] = self.nan_strategy
                    weight[nans | nans_weight] = self.nan_strategy

        with profiler.record_function("return value"):
            retval = x.to(self.dtype), weight.to(self.dtype)
        return retval

The resultant trace is captured below:

Trace of Metric Collection with record_function — part 2 (by Author)

The trace points directly to the offending line:

anynans = nans.any() or nans_weight.any()

This operation checks for NaN values in the input tensors, but it introduces a costly CPU-GPU synchronization event because the operation involves copying data from the GPU to the CPU.

Upon a closer inspection of the TorchMetric BaseAggregator class, we find several options for handling NAN value updates, all of which pass through the offending line of code. However, for our use case — calculating the average loss metric — this check is unnecessary and does not justify the runtime performance penalty.

Optimization 2: Disable NAN Value Checks

To eliminate the overhead, we propose disabling the NaN value checks by overriding the _cast_and_nan_check_input function. Instead of a static override, we implemented a dynamic solution that can be applied flexibly to any descendants of the BaseAggregator class.

from torchmetrics.aggregation import BaseAggregator

def suppress_nan_check(MetricClass):
    assert issubclass(MetricClass, BaseAggregator), MetricClass
    class DisableNanCheck(MetricClass):
        def _cast_and_nan_check_input(self, x, weight=None):
            if not isinstance(x, torch.Tensor):
                x = torch.as_tensor(x, dtype=self.dtype, 
                                    device=self.device)
            if weight is not None and not isinstance(weight, torch.Tensor):
                weight = torch.as_tensor(weight, dtype=self.dtype,
                                         device=self.device)
            if weight is None:
                weight = torch.ones_like(x)
            return x.to(self.dtype), weight.to(self.dtype)
    return DisableNanCheck

NoNanMeanMetric = suppress_nan_check(MeanMetric)

metrics["avg_loss"] = NoNanMeanMetric().to(device)

Post Optimization Results: Success

After implementing the two optimizations — specifying the weight value and disabling the NaN checks—we find the step time performance and the GPU utilization to match those of our baseline experiment. In addition, the resultant PyTorch Profiler trace shows that all of the added “cudaStreamSynchronize” events that were associated with the metric collection, have been eliminated. With a few small changes, we have reduced the cost of training by ~10% without any changes to the behavior of the metric collection.

In the next section we will explore an additional Metric collection optimization.

Example 2: Optimizing Metric Device Placement

In the previous section, the metric values resided on the GPU, making it logical to store and compute the metrics on the GPU. However, in scenarios where the values we wish to aggregate reside on the CPU, it might be preferable to store the metrics on the CPU to avoid unnecessary device transfers.

In the code block below, we modify our script to calculate the average step time using a MeanMetric on the CPU. This change has no impact on the runtime performance of our training step:

avg_time = NoNanMeanMetric()
t0 = time.perf_counter()

for idx, (data, target) in enumerate(train_loader):
    # move data to device
    data = data.to(device, non_blocking=True)
    target = target.to(device, non_blocking=True)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    if capture_metrics:
        metrics["avg_loss"].update(loss)
        for name, metric in metrics.items():
            if name != "avg_loss":
                metric.update(output, target)

        if (idx + 1) % 100 == 0:
            # compute metrics
            metric_results = {
                name: metric.compute().item()
                    for name, metric in metrics.items()
            }
            # print metrics
            print(f"Step {idx + 1}: {metric_results}")
            # reset metrics
            for metric in metrics.values():
                metric.reset()

    elif (idx + 1) % 100 == 0:
        # print last loss value
        print(f"Step {idx + 1}: Loss = {loss.item():.4f}")

    batch_time = time.perf_counter() - t0
    t0 = time.perf_counter()
    if idx > 10:  # skip first steps
        avg_time.update(batch_time)

    if enable_profiler:
        prof.step()

    if idx > 200:
        break

if enable_profiler:
    prof.stop()

avg_time = avg_time.compute().item()
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')

The problem arises when we attempt to extend our script to support distributed training. To demonstrate the problem, we modified our model definition to use DistributedDataParallel (DDP):

# toggle to enable/disable ddp
use_ddp = True

if use_ddp:
    import os
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group("nccl", rank=0, world_size=1)
    torch.cuda.set_device(0)
    model = DDP(torchvision.models.resnet18().to(device))
else:
    model = torchvision.models.resnet18().to(device)

# insert training loop

# append to end of the script:
if use_ddp:
    # destroy the process group
    dist.destroy_process_group()

The DDP modification results in the following error:

RuntimeError: No backend type associated with device type cpu

By default, metrics in distributed training are programmed to synchronize across all devices in use. However, the synchronization backend used by DDP does not support metrics stored on the CPU.

One way to solve this is to disable the cross-device metric synchronization:

avg_time = NoNanMeanMetric(sync_on_compute=False)

In our case, where we are measuring the average time, this solution is acceptable. However, in some cases, the metric synchronization is essential, and we have may have no choice but to move the metric onto the GPU:

avg_time = NoNanMeanMetric().to(device)

Unfortunately, this situation gives rise to a new CPU-GPU sync event coming from the update function.

Trace of avg_time Metric Collection (by Author)

This sync event should hardly come as a surprise—after all, we are updating a GPU metric with a value residing on the CPU, which should necessitate a memory copy. However, in the case of a scalar metric, this data transfer can be completely avoided with a simple optimization.

Optimization 3: Perform Metric Updates with Tensors instead of Scalars

The solution is straightforward: instead of updating the metric with a float value, we convert to a Tensor before calling update.

batch_time = torch.as_tensor(batch_time)
avg_time.update(batch_time, torch.ones_like(batch_time))

This minor change bypasses the problematic line of code, eliminates the sync event, and restores the step time to the baseline performance.

At first glance, this result may seem surprising: We would expect that updating a GPU metric with a CPU tensor should still require a memory copy. However, PyTorch optimizes operations on scalar tensors by using a dedicated kernel that performs the addition without an explicit data transfer. This avoids the expensive synchronization event that would otherwise occur.

Summary

In this post, we explored how a naïve approach to TorchMetrics can introduce CPU-GPU synchronization events and significantly degrade PyTorch training performance. Using PyTorch Profiler, we identified the lines of code responsible for these sync events and applied targeted optimizations to eliminate them:

  • Explicitly specify a weight tensor when calling the MeanMetric.update function instead of relying on the default value.
  • Disable NaN checks in the base Aggregator class or replace them with a more efficient alternative.
  • Carefully manage the device placement of each metric to minimize unnecessary transfers.
  • Disable cross-device metric synchronization when not required.
  • When the metric resides on a GPU, convert floating-point scalars to tensors before passing them to the update function to avoid implicit synchronization.

We have created a dedicated pull request on the TorchMetrics github page covering some of the optimizations discussed in this post. Please feel free to contribute your own improvements and optimizations!

The post Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics appeared first on Towards Data Science.

]]>
Optimizing Transformer Models for Variable-Length Input Sequences https://towardsdatascience.com/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71/ Tue, 26 Nov 2024 14:45:19 +0000 https://towardsdatascience.com/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71/ How PyTorch NestedTensors, FlashAttention2, and xFormers can Boost Performance and Reduce AI Costs

The post Optimizing Transformer Models for Variable-Length Input Sequences appeared first on Towards Data Science.

]]>
As generative AI (Genai) models grow in both popularity and scale, so do the computational demands and costs associated with their training and deployment. Optimizing these models is crucial for enhancing their runtime performance and reducing their operational expenses. At the heart of modern genAI systems is the Transformer architecture and its attention mechanism, which is notably compute-intensive.

In a previous post, we demonstrated how using optimized attention kernels can significantly accelerate the performance of Transformer models. In this post, we continue our exploration by addressing the challenge of variable-length input sequences – an inherent property of real-world data, including documents, code, time-series, and more.

The Challenge of Batching Variable-Length Input

In a typical deep learning workload, individual samples are grouped into batches before being copied to the GPU and fed to the AI model. Batching improves computational efficiency and often aids model convergence during training. Usually, batching involves stacking all of the sample tensors along a new dimension – the batch dimension. However, torch.stack requires all tensors to have the same shape, which is not the case with variable-length sequences.

Padding and its Inefficiencies

The traditional way to address this challenge is to pad the input sequences to a fixed length and then perform stacking. This solution requires appropriate masking within the model so that the output is not affected by the irrelevant tensor elements. In the case of attention layers, a padding mask indicates which tokens are padding and should not be attended to (e.g., see PyTorch MultiheadAttention). However, padding can waste considerable GPU resources, increasing costs and slowing development. This is especially true for large-scale AI models.

Don’t Pad, Concatenate

One way to avoid padding is to concatenate sequences along an existing dimension instead of stacking them along a new dimension. Contrary to torch.stack, torch.cat allows inputs of different shapes. The output of concatenation is a single sequence whose length equals the sum of the lengths of the individual sequences. For this solution to work, our single sequence would need to be supplemented by an attention mask that would ensure that each token attends only to other tokens in the same original sequence, in a process sometimes referred to as document masking. Denoting the sum of the lengths of all of the individual sequences by N and adopting "big O" notation, the size of this mask would need to be O(N²), as would the compute complexity of a naive attention layer (which applies the mask only after calculating the attention scores), making this solution highly inefficient.

Attention Layer Optimization

The solution to this problem comes in the form of specialized attention layers. Contrary to the standard attention layer that performs the full set of O(N²) attention scores only to mask out the irrelevant ones, these optimized attention kernels are designed to calculate only the scores that matter. In this post we will explore several solutions, each with their own distinct characteristics. These include:

Integration into Existing HuggingFace Models

For teams working with pre-trained models, transitioning to these optimizations might seem challenging. We will demonstrate how HuggingFace’s APIs simplify this process, enabling developers to integrate these techniques with minimal code changes and effort.

Disclaimers

  • Please do not interpret our use of any platforms, libraries, or optimization techniques as an endorsement for their use. The best options for you will depend greatly on the specifics of your own use-case.
  • Some of the APIs discussed here are in prototype or beta stages and may change in the future.
  • The code examples provided are for demonstrative purposes only. We make no claims regarding their accuracy, optimality, or robustness.

Special thanks to Yitzhak Levi and Peleg Nahaliel for their contributions to this post.

Toy LLM Model

To facilitate our discussion we will define a simple generative model (partially inspired by the GPT model defined [here](https://www.youtube.com/watch?v=kCc8FmEb1nY)). For a more comprehensive guide on building language models, please see one of the many excellent tutorials available online (e.g., here).

Transformer Block

We begin by constructing a basic Transformer block, specifically designed to facilitate experimentation with different attention mechanisms and optimizations. While our block performs the same computation as standard Transformer blocks, we make slight modifications to the usual choice of operators in order to support the possibility of Pytorch NestedTensor inputs (as described here).

# general imports
import time, functools

# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

# Define Transformer settings
BATCH_SIZE = 32
NUM_HEADS = 16
HEAD_DIM = 64
DIM = NUM_HEADS * HEAD_DIM
DEPTH = 24
NUM_TOKENS = 1024
MAX_SEQ_LEN = 1024
PAD_ID = 0
DEVICE = 'cuda'

class MyAttentionBlock(nn.Module):
    def __init__(
            self,
            attn_fn,
            dim,
            num_heads,
            format=None,
            **kwargs
    ):
        super().__init__()
        self.attn_fn = attn_fn
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.norm1 = nn.LayerNorm(dim, bias=False)
        self.norm2 = nn.LayerNorm(dim, bias=False)
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # mlp layers
        self.fc1 = nn.Linear(dim, dim * 4)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim * 4, dim)

        self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
        if format == 'bshd':
            self.permute = nn.Identity()

    def mlp(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

    def reshape_and_permute(self,x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return self.permute(x)

    def forward(self, x_in, attn_mask=None):
        batch_size = x_in.size(0)
        x = self.norm1(x_in)
        qkv = self.qkv(x)

        # rather than first reformatting and then splitting the input
        # state, we first split and then reformat q, k, v in order to
        # support PyTorch Nested Tensors
        q, k, v = qkv.chunk(3, -1)
        q = self.reshape_and_permute(q, batch_size)
        k = self.reshape_and_permute(k, batch_size)
        v = self.reshape_and_permute(v, batch_size)

        # call the attn_fn with the input attn_mask
        x = self.attn_fn(q, k, v, attn_mask=attn_mask)

        # reformat output
        x = self.permute(x).reshape(batch_size, -1, self.dim)
        x = self.proj(x)
        x = x + x_in
        x = x + self.mlp(self.norm2(x))
        return x

Transformer Decoder Model

Building on our programmable Transformer block, we construct a typical Transformer decoder model.

class MyDecoder(nn.Module):
    def __init__(
            self,
            block_fn,
            num_tokens,
            dim,
            num_heads,
            num_layers,
            max_seq_len,
            pad_idx=None
    ):
        super().__init__()
        self.num_heads = num_heads
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
        self.positional_embedding = nn.Embedding(max_seq_len, dim)
        self.blocks = nn.ModuleList([
            block_fn(
                dim=dim,
                num_heads=num_heads
            )
            for _ in range(num_layers)])
        self.output = nn.Linear(dim, num_tokens)

    def embed_tokens(self, input_ids, position_ids=None):
        x = self.embedding(input_ids)
        if position_ids is None:
            position_ids = torch.arange(input_ids.shape[1],
                                        device=x.device)
        x = x + self.positional_embedding(position_ids)
        return x

    def forward(self, input_ids, position_ids=None, attn_mask=None):
        # Embed tokens and add positional encoding
        x = self.embed_tokens(input_ids, position_ids)
        if self.pad_idx is not None:
            assert attn_mask is None
            # create a padding mask - we assume boolean masking
            attn_mask = (input_ids != self.pad_idx)
            attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) 
                .expand(-1, self.num_heads, -1, -1)

        for b in self.blocks:
            x = b(x, attn_mask)

        logits = self.output(x)
        return logits

Variable Length Sequence Input

Next, we create a dataset containing sequences of variable lengths, where each sequence is made up of randomly generated tokens. For simplicity, we (arbitrarily) select a fixed distribution for the sequence lengths. In real-world scenarios, the distribution of sequence lengths typically reflects the nature of the data, such as the length of documents or audio segments. Note, that the distribution of lengths directly affects the computational inefficiencies caused by padding.

# Use random data
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        length = torch.randint(1, MAX_SEQ_LEN, (1,))
        sequence = torch.randint(1, NUM_TOKENS, (length + 1,))
        inputs = sequence[:-1]
        targets = sequence[1:]
        return inputs, targets

def pad_sequence(sequence, length, pad_val):
    return torch.nn.functional.pad(
        sequence,
        (0, length - sequence.shape[0]),
        value=pad_val
    )

def collate_with_padding(batch):
    padded_inputs = []
    padded_targets = []
    for b in batch:
        padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
        padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
        'inputs': padded_inputs,
        'targets': padded_targets
    }

def data_to_device(data, device):
    if isinstance(data, dict):
        return {
            key: data_to_device(val,device)
            for key, val in data.items()
        }
    elif isinstance(data, (list, tuple)):
        return type(data)(
            data_to_device(val, device) for val in data
        )
    elif isinstance(data, torch.Tensor):
        return data.to(device=device, non_blocking=True)
    else:
        return data.to(device=device)

Training/Evaluation Loop

Lastly, we implement a main function that performs training/evaluation on input sequences of varying length.

def main(
    block_fn, 
    data_collate_fn=collate_with_padding,
    pad_idx=None,
    train=True,
    compile=False
):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")

    # Create dataset and dataloader
    data_set = FakeDataset()
    data_loader = DataLoader(
        data_set,
        batch_size=BATCH_SIZE,
        collate_fn=data_collate_fn,
        num_workers=12,
        pin_memory=True,
        drop_last=True
    )

    model = MyDecoder(
        block_fn=block_fn,
        num_tokens=NUM_TOKENS,
        dim=DIM,
        num_heads=NUM_HEADS,
        num_layers=DEPTH,
        max_seq_len=MAX_SEQ_LEN,
        pad_idx=pad_idx
    ).to(device)

    if compile:
        model = torch.compile(model)

    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())

    def train_step(model, inputs, targets, 
                   position_ids=None, attn_mask=None):
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(inputs, position_ids, attn_mask)
            outputs = outputs.view(-1, NUM_TOKENS)
            targets = targets.flatten()
            loss = criterion(outputs, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    @torch.no_grad()
    def eval_step(model, inputs, targets, 
                  position_ids=None, attn_mask=None):
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(inputs, position_ids, attn_mask)
            if outputs.is_nested:
                outputs = outputs.data._values
                targets = targets.data._values
            else:
                outputs = outputs.view(-1, NUM_TOKENS)
                targets = targets.flatten()
            loss = criterion(outputs, targets)
        return loss

    if train:
        model.train()
        step_fn = train_step
    else:
        model.eval()
        step_fn = eval_step

    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(data_loader):
        # Copy data to GPU
        data = data_to_device(data, device=device)
        step_fn(model, data['inputs'], data['targets'],
                       position_ids=data.get('indices'),
                       attn_mask=data.get('attn_mask'))

        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step >= 100:
            break
    print(f'average step time: {summ / count}')

PyTorch SDPA with Padding

For our baseline experiments, we configure our Transformer block to utilize PyTorch‘s SDPA mechanism. In our experiments, we run both training and evaluation, both with and without torch.compile. These were run on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1

from torch.nn.functional import scaled_dot_product_attention as sdpa
block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
causal_block_fn = functools.partial(
    MyAttentionBlock,
    attn_fn=functools.partial(sdpa, is_causal=True)
)

for mode in ['eval', 'train']:
    for compile in [False, True]:
        block_func = causal_block_fn
            if mode == 'train' else block_fn
        print(f'{mode} with {collate}, '
              f'{"compiled" if compile else "uncompiled"}')
        main(block_fn=block_func,
             pad_idx=PAD_ID,
             train=mode=='train',
             compile=compile)

Performance Results:

  • Evaluation: 132 milliseconds (ms) without torch.compile, 130 ms with torch.compile
  • Training: 342 ms without torch.compile, 299 ms with torch.compile

Optimizing for Variable Length Input

In this section, we will explore several optimization techniques for handling variable-length input sequences in Transformer models.

Padding Optimization

Our first optimization relates not to the attention kernel but to our padding mechanism. Rather than padding the sequences in each batch to a constant length, we pad to the length of the longest sequence in the batch. The following block of code consists of our revised collation function and updated experiments.

def collate_pad_to_longest(batch):
    padded_inputs = []
    padded_targets = []
    max_length = max([b[0].shape[0] for b in batch])
    for b in batch:
        padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
        padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_targets = torch.stack(padded_targets, dim=0)
    return {
        'inputs': padded_inputs,
        'targets': padded_targets
    }

for mode in ['eval', 'train']:
    for compile in [False, True]:
        block_func = causal_block_fn
            if mode == 'train' else block_fn
        print(f'{mode} with {collate}, '
              f'{"compiled" if compile else "uncompiled"}')
        main(block_fn=block_func,
             data_collate_fn=collate_pad_to_longest,
             pad_idx=PAD_ID,
             train=mode=='train',
             compile=compile)

Padding to the longest sequence in each batch results in a slight performance acceleration:

  • Evaluation: 129 ms without torch.compile, 116 ms with torch.compile
  • Training: 337 ms without torch.compile, 294 ms with torch.compile

SDPA with PyTorch NestedTensors

Next, we take advantage of the built-in support for [PyTorch NestedTensors](https://pytorch.org/tutorials/prototype/nestedtensor.html) in SDPA in evaluation mode. Currently a prototype feature, PyTorch NestedTensors allows for grouping together tensors of varying length. These are sometimes referred to as jagged or ragged tensors. In the code block below, we define a collation function for grouping our sequences into NestedTensors. We also define an indices entry so that we can properly calculate the positional embeddings.

PyTorch NestedTensors are supported by a limited number of PyTorch ops. Working around these limitations can require some creativity. For example, addition between NestedTensors is only supported when they share precisely the same "jagged" shape. In the code below we use a workaround to ensure that the indices entry shares the same shape as the model inputs.

def nested_tensor_collate(batch):
    inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
                                           layout=torch.jagged)
    targets = torch.nested.as_nested_tensor([b[1] for b in batch],
                                            layout=torch.jagged)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])

    # workaround for creating a NestedTensor with identical "jagged" shape
    xx = torch.empty_like(inputs)
    xx.data._values[:] = indices

    return {
        'inputs': inputs,
        'targets': targets,
        'indices': xx
    }

for compile in [False, True]:
    print(f'eval with nested tensors, '
          f'{"compiled" if compile else "uncompiled"}')
    main(
        block_fn=block_fn,
        data_collate_fn=nested_tensor_collate,
        train=False,
        compile=compile
    )

Although, with torch.compile, the NestedTensor optimization results in a step time of 131 ms, similar to our baseline result, in compiled mode the step time drops to 42 ms for an impressive ~3x improvement.

FlashAttention2

In our previous post we demonstrated the use of FlashAttention and its impact on the performance of a transformer model. In this post we demonstrate the use of [flash_attn_varlen_func](https://github.com/Dao-AILab/flash-attention/blob/v2.7.0/hopper/flash_attn_interface.py#L429) from flash-attn (2.7.0), an API designed for use with variable-sized inputs. To use this function, we concatenate all of the sequences in the batch into a single sequence. We also create a _cuseqlens tensor that points to the indices within the concatenated tensor where each of the individual sequences start. The code block below includes our collation function followed by evaluation and training experiments. Note, that flash_attn_varlen_func does not support torch.compile (at the time of this writing).

def collate_concat(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = torch.tensor([b[0].shape[0] for b in batch])
    seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
    cu_seqlens = torch.nn.functional.pad(seqlens, (1, 0))

    return {
        'inputs': inputs,
        'targets': targets,
        'indices': indices,
        'attn_mask': cu_seqlens
    }

from flash_attn import flash_attn_varlen_func
fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN
).unsqueeze(0)

fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(
    q.squeeze(0),
    k.squeeze(0),
    v.squeeze(0),
    cu_seqlens_q=attn_mask,
    cu_seqlens_k=attn_mask,
    max_seqlen_q=MAX_SEQ_LEN,
    max_seqlen_k=MAX_SEQ_LEN,
    causal=True
).unsqueeze(0)

block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=fa_varlen,
                             format='bshd')

causal_block_fn = functools.partial(MyAttentionBlock,
                                    attn_fn=fa_varlen_causal,
                                    format='bshd')

print('flash-attn eval')
main(
    block_fn=block_fn,
    data_collate_fn=collate_concat,
    train=False
)

print('flash-attn train')
main(
    block_fn=causal_block_fn,
    data_collate_fn=collate_concat,
    train=True,
)

The impact of this optimization is dramatic, 51 ms for evaluation and 160 ms for training, amounting to 2.6x and 2.1x performance boosts compared to our baseline experiment.

XFormers Memory Efficient Attention

In our previous post we demonstrated the use of the memory_efficient_attention operator from xFormers (0.0.28). Here we demonstrate the use of BlockDiagonalMask, which is specifically designed for input sequences of arbitrary length. The required collation function appears in the code block below followed by the evaluation and training experiments. Note, that torch.compile failed in training mode.

from xformers.ops import fmha
from xformers.ops import memory_efficient_attention as mea

def collate_xformer(batch):
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])
    seqlens = [b[0].shape[0] for b in batch]
    batch_sizes = [1 for b in batch]
    block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')
    block_diag._batch_sizes = batch_sizes

    return {
        'inputs': inputs,
        'targets': targets,
        'indices': indices,
        'attn_mask': block_diag
    }

mea_eval = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask)

mea_train = lambda q, k, v, attn_mask: mea(
    q,k,v, attn_bias=attn_mask.make_causal())

block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea_eval,
                             format='bshd')

causal_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea_train,
                             format='bshd')

print(f'xFormer Attention ')
for compile in [False, True]:
    print(f'eval with xFormer Attention, '
          f'{"compiled" if compile else "uncompiled"}')
    main(block_fn=block_fn,
         train=False,
         data_collate_fn=collate_xformer,
         compile=compile)

print(f'train with xFormer Attention')
main(block_fn=causal_block_fn,
     train=True,
     data_collate_fn=collate_xformer)

The resultant step time were 50 ms and 159 ms for evaluation and training without torch.compile. Evaluation with torch.compile resulted in a step time of 42 ms.

Results

The table below summarizes the results of our optimization methods.

Step time results for different optimization methods (lower is better) - by Author
Step time results for different optimization methods (lower is better) – by Author

The best performer for our toy model is xFormer’s memory_efficient_attention which delivered a ~3x performance for evaluation and ~2x performance for training. We caution against deriving any conclusions from these results as the performance impact of different attention functions can vary significantly depending on the specific model and use case.

Optimizing a HuggingFace Model for Variable-Length Input

The tools and techniques described above are easy to implement when creating a model from scratch. However, these days it is not uncommon for ML developers to adopt existing (pretrained) models and finetune them for their use case. While the optimizations we have described can be integrated without changing the set of model weights and without altering the model behavior, it is not entirely clear what the best way to do this is. In an ideal world, our ML framework would allow us to program the use of an attention mechanism that is optimized for variable-length inputs. In this section, we demonstrate how to optimize HuggingFace models for variable-length inputs.

A Toy HuggingFace Model – GPT2LMHeadModel

To facilitate the discussion, we create a toy example in which we train a HuggingFace GPT2LMHead model on variable-length sequences. This requires adapting our random dataset and data-padding collation function according to HuggingFace’s input specifications.

from transformers import GPT2Config, GPT2LMHeadModel

# Use random data
class HuggingFaceFakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        length = torch.randint(1, MAX_SEQ_LEN, (1,))
        input_ids = torch.randint(1, NUM_TOKENS, (length,))
        labels = input_ids.clone()
        labels[0] = PAD_ID # ignore first token
        return {
            'input_ids': input_ids,
            'labels': labels
        }
        return input_ids, labels

def hf_collate_with_padding(batch):
    padded_inputs = []
    padded_labels = []
    for b in batch:
        input_ids = b['input_ids']
        labels = b['labels']
        padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
        padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
    padded_inputs = torch.stack(padded_inputs, dim=0)
    padded_labels = torch.stack(padded_labels, dim=0)
    return {
        'input_ids': padded_inputs,
        'labels': padded_labels,
        'attention_mask': (padded_inputs != PAD_ID)
    }

Training Function

Our training function instantiates a GPT2LMHeadModel based on the requested GPT2Config and proceeds to train it on our variable-length sequences.

def hf_main(
    config,
    collate_fn=hf_collate_with_padding,
    compile=False
):
    torch.random.manual_seed(0)
    device = torch.device(DEVICE)
    torch.set_float32_matmul_precision("high")

    # Create dataset and dataloader
    data_set = HuggingFaceFakeDataset()
    data_loader = DataLoader(
        data_set,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=12 if DEVICE == "CUDA" else 0,
        pin_memory=True,
        drop_last=True
    )

    model = GPT2LMHeadModel(config).to(device)

    if compile:
        model = torch.compile(model)

    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
    optimizer = torch.optim.SGD(model.parameters())

    model.train()

    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(data_loader):
        # Copy data to GPU
        data = data_to_device(data, device=device)
        input_ids = data['input_ids']
        labels = data['labels']
        position_ids = data.get('position_ids')
        attn_mask = data.get('attention_mask')
        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
            outputs = model(input_ids=input_ids,
                            position_ids=position_ids,
                            attention_mask=attn_mask)
            logits = outputs.logits[..., :-1, :].contiguous()
            labels = labels[..., 1:].contiguous()
            loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step >= 100:
            break
    print(f'average step time: {summ / count}')

SDPA with Padding

In the callback below we call our training function with the default sequence-padding collator.

config = GPT2Config(
        n_layer=DEPTH,
        n_embd=DIM,
        n_head=NUM_HEADS,
        vocab_size=NUM_TOKENS,
    )

for compile in [False, True]:
    print(f"HF GPT2 train with SDPA, compile={compile}")
    hf_main(config=config, compile=compile)

The resultant step times are 815 ms without torch.compile and 440 ms with torch.compile.

FlashAttention2

We now take advantage of HuggingFace’s built-in support for FlashAttention2, by setting the _attnimplementation parameter to "flash_attention2". Behind the scenes, HuggingFace will unpad_ the padded data input and then pass them to the optimized flash_attn_varlen_func function we saw above:

flash_config = GPT2Config(
        n_layer=DEPTH,
        n_embd=DIM,
        n_head=NUM_HEADS,
        vocab_size=NUM_TOKENS,
        attn_implementation='flash_attention_2'
    )

print(f"HF GPT2 train with flash")
hf_main(config=flash_config)

The resultant time step is 620 ms, amounting to a 30% boost (in uncompiled mode) with just a simple flick of a switch.

FlashAttention2 with Unpadded Input

Of course, padding the sequences in the collation function only to have them unpadded, hardly seems sensible. In a recent update to HuggingFace, support was added for passing in concatenated (unpadded) sequences to a select number of models. Unfortunately, (as of the time of this writing) our GPT2 model did not make the cut. However, adding support requires just five small line additions to modeling_gpt2.py in order to propagate the sequence _position_ids_ to the flash-attention kernel. The full patch appears in the block below:

@@ -370,0 +371 @@
+        position_ids = None
@@ -444,0 +446 @@
+            position_ids=position_ids
@@ -611,0 +614 @@
+        position_ids=None
@@ -621,0 +625 @@
+            position_ids=position_ids
@@ -1140,0 +1145 @@
+                    position_ids=position_ids

We define a collate function that concatenates our sequences and train our hugging face model on unpadded sequences. (Also see the built-in DataCollatorWithFlattening utility.)

def collate_flatten(batch):
    input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)
    labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
    position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]
    position_ids = torch.concat(position_ids)

    return {
        'input_ids': input_ids,
        'labels': labels,
        'position_ids': position_ids
    }

print(f"HF GPT2 train with flash, no padding")
hf_main(config=flash_config, collate_fn=collate_flatten)

The resulting step time is 323 ms, 90% faster than running flash-attention on the padded input.

Results

The results of our HuggingFace experiments are summarized below.

Step time results for different optimization methods (lower is better) - by Author
Step time results for different optimization methods (lower is better) – by Author

With little effort, we were able to boost our runtime performance by 2.5x when compared to the uncompiled baseline experiment, and by 36% when compared to the compiled version.

In this section, we demonstrated how the HuggingFace APIs allow us to leverage the optimized kernels in FlashAttention2, significantly boosting the training performance of existing models on sequences of varying length.

Summary

As AI models continue to grow in both popularity and complexity, optimizing their performance has become essential for reducing runtime and costs. This is especially true for compute-intensive components like attention layers. In this post, we have continued our exploration of attention layer optimization, and demonstrated new tools and techniques for enhancing Transformer model performance. For more insights on AI model optimization, be sure to check out the first post in this series as well as our many other posts on this topic.

The post Optimizing Transformer Models for Variable-Length Input Sequences appeared first on Towards Data Science.

]]>
Increasing Transformer Model Efficiency Through Attention Layer Optimization https://towardsdatascience.com/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6/ Mon, 18 Nov 2024 20:16:54 +0000 https://towardsdatascience.com/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6/ How paying "better" attention can drive ML cost savings

The post Increasing Transformer Model Efficiency Through Attention Layer Optimization appeared first on Towards Data Science.

]]>
Introduced in the landmark 2017 paper "Attention Is All You Need" (Vaswani et al., 2017), the Transformer architecture is widely regarded as one of the most influential scientific breakthroughs of the past decade. At the core of the Transformer is the Attention mechanism, a novel approach that enables AI models to comprehend complex structures by focusing on different parts of input sequences based on the task at hand. Originally demonstrated in the world of natural language processing, the success of the Transformer architecture has quickly spread to many other domains, including speech recognition, scene understanding, reinforcement learning, protein structure prediction, and more. However, attention layers are highly resource-intensive, and as these layers become the standard across increasingly large models, the costs associated with their training and deployment have surged. This has created an urgent need for strategies that reduce the computational cost of this core layer so as to increase the efficiency and scalability of Transformer-based AI models.

In this post, we will explore several tools for optimizing attention in PyTorch. Our focus will be on methods that maintain the accuracy of the attention layer. These will include PyTorch SDPA, FlashAttention, TransformerEngine Attention, FlexAttention, and xFormer attention. Other methods that reduce the computational cost via approximation of the attention calculation (e.g., DeepSpeed’s Sparse Attention, Longformer, Linformer, and more) will not be considered. Additionally, we will not discuss general optimization techniques that, while beneficial to attention performance, are not specific to the attention computation itself (e.g., FP8 training, model sharding, and more).

Importantly, attention optimization is an active area of research with new methods coming out on a pretty regular basis. Our goal is to increase your awareness of some of the existing solutions and provide you with a foundation for further exploration and experimentation. The code we will share below is intended for demonstrative purposes only – we make no claims regarding its accuracy, optimality, or robustness. Please do not interpret our mention of any platforms, libraries, or optimization techniques as an endorsement for their use. The best options for you will depend greatly on the specifics of your own use-case.

Many thanks to Yitzhak Levi for his contributions to this post.

Toy Model

To facilitate our discussion, we build a Vision Transformer (ViT)-backed classification model using the popular timm Python package (version 0.9.7). We will use this model to illustrate the performance impact of various attention kernels.

We start by defining a simplified Transformer block that allows for programming the attention function by passing it into its constructor. Since attention implementations assume specific input tensor formats, we also include an option for controlling the format, ensuring compatibility with the attention kernel of our choosing.

# general imports
import os, time, functools

# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

# timm imports
from timm.models.vision_transformer import VisionTransformer
from timm.layers import Mlp

IMG_SIZE = 224
BATCH_SIZE = 128

# Define ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196

class MyAttentionBlock(nn.Module):
    def __init__(
            self,
            attn_fn,
            format = None,
            dim: int = 768,
            num_heads: int = 12,
            **kwargs
    ) -> None:
        super().__init__()
        self.attn_fn = attn_fn
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=dim * 4,
        )
        permute = (2, 0, 3, 1, 4)
        self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)

        if format == 'bshd':
            permute = (2, 0, 1, 3, 4)
            self.permute_attn = nn.Identity()
        self.permute_qkv = functools.partial(torch.permute,dims=permute)

    def forward(self, x_in: torch.Tensor) -> torch.Tensor:
        x = self.norm1(x_in)
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        # permute tensor based on the specified format
        qkv = self.permute_qkv(qkv)
        q, k, v = qkv.unbind(0)
        # use the attention function specified by the user
        x = self.attn_fn(q, k, v)
        # permute output according to the specified format
        x = self.permute_attn(x).reshape(B, N, C)
        x = self.proj(x)
        x = x + x_in
        x = x + self.mlp(self.norm2(x))
        return x

We define a randomly generated dataset which we will use to feed to our model during training.

# Use random data
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
                                 dtype=torch.float32)
        label = torch.tensor(data=index % 1000, dtype=torch.int64)
        return rand_image, label 

Next, we define our ViT training function. While our example focuses on demonstrating a training workload, it is crucial to emphasize that optimizing the attention layer is equally, if not more, important during model inference.

The training function we define accepts the customized Transformer block and a flag that controls the use of torch.compile.

def train_fn(block_fn, compile):
    torch.random.manual_seed(0)
    device = torch.device("cuda:0")
    torch.set_float32_matmul_precision("high")

    # Create dataset and dataloader
    train_set = FakeDataset()
    train_loader = DataLoader(
        train_set, batch_size=BATCH_SIZE,
        num_workers=12, pin_memory=True, drop_last=True)

    model = VisionTransformer(
       img_size=IMG_SIZE,
       patch_size=PATCH_SIZE,
       embed_dim=NUM_HEADS*HEAD_DIM,
       depth=DEPTH,
       num_heads=NUM_HEADS,
       class_token=False,
       global_pool="avg",
       block_fn=block_fn
    ).to(device)

    if compile:
        model = torch.compile(model)

    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters())

    model.train()

    t0 = time.perf_counter()
    summ = 0
    count = 0
    for step, data in enumerate(train_loader):
        # Copy data to GPU
        inputs = data[0].to(device=device, non_blocking=True)
        label = data[1].to(device=device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
            outputs = model(inputs)
            loss = criterion(outputs, label)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step > 100:
            break
    print(f'average step time: {summ / count}')

# define compiled and uncompiled variants of our train function
train = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)

In the code block below we define a Pytorch-native attention function and use it to train our ViT model:

def attn_fn(q, k, v):
    scale = HEAD_DIM ** -0.5
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x

block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)

print('Default Attention')
train(block_fn)
print('Compiled Default Attention')
train_compile(block_fn)

We ran this on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1. The uncompiled variant resulted in an average step time of 370 milliseconds (ms), while the compiled variant improved to 242 ms. We will use these results as a baseline for comparison as we consider alternative solutions for performing the attention computation.

PyTorch SDPA

One of the easiest ways to boost the performance of our attention layers in PyTorch is to use the scaled_dot_product_attention (SDPA) function. Currently in beta, PyTorch SDPA consolidates multiple kernel-level optimizations and dynamically selects the most efficient one based on the input’s properties. Supported backends (as of now) include: FlashAttention-2, Memory-Efficient Attention, a C++-based Math Attention, and CuDNN. These backends fuse together high-level operations while employing GPU-level optimizations for increasing compute efficiency and memory utilization.

SDPA is continuously evolving, with new and improved backend implementations being introduced regularly. Staying up to date with the latest PyTorch releases is key to leveraging the most recent performance improvements. For example, PyTorch 2.5 introduced an updated CuDNN backend featuring a specialized SDPA primitive specifically tailored for training on NVIDIA Hopper architecture GPUs.

In the code block below, we iterate through the list of supported backends and assess the runtime performance of training with each one. We use a helper function, _set_sdpabackend, for programming the SDPA backend:

from torch.nn.functional import scaled_dot_product_attention as sdpa

def set_sdpa_backend(backend):
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(False)
    torch.backends.cuda.enable_cudnn_sdp(False)

    if backend in ['flash_sdp','all']:
        torch.backends.cuda.enable_flash_sdp(True)
    if backend in ['mem_efficient_sdp','all']:
        torch.backends.cuda.enable_mem_efficient_sdp(True)
    if backend in ['math_sdp','all']:
        torch.backends.cuda.enable_math_sdp(True)
    if backend in ['cudnn_sdp','all']:
        torch.backends.cuda.enable_cudnn_sdp(True)

for backend in ['flash_sdp', 'mem_efficient_sdp',
                'math_sdp', 'cudnn_sdp']:
    set_sdpa_backend(backend)
    block_fn = functools.partial(MyAttentionBlock,
                                 attn_fn=sdpa)

    print(f'PyTorch SDPA - {backend}')
    train(block_fn)
    print(f'Compiled PyTorch SDPA - {backend}')
    train_compile(block_fn)

We summarize our interim results in the table below

Step times for various attention functions (lower is better) - by Author
Step times for various attention functions (lower is better) – by Author

While the choice of SDPA backend has a noticeable impact on performance when running in eager mode, the optimizations performed by model compilation appear to overshadow the differences between the attention kernels. Once again, we caution against deriving any conclusions from these results as the performance impact of different attention functions can vary significantly depending on the specific model and use case.

Third-Party Attention Kernels

While PyTorch SDPA is a great place to start, using third-party attention kernels can help accelerate your ML workloads further. These alternatives often come with added flexibility, offering a wider range of configuration options for attention. Some may also include optimizations tailored for specific hardware accelerators or newer GPU architectures.

In this section, we will explore some of the third-party attention kernels available and evaluate their potential impact on runtime performance.

FlashAttention-3

While Pytorch SDPA supports a FlashAttention backend, more advanced FlashAttention implementations can be found in the flash-attn library. Here we will explore the FlashAttention-3 beta release which boasts a speed of up to 2x compared to FlashAttention-2. Given the early stage in its development, FlashAttention-3 can only be installed directly from the GitHub repository and its use is limited to certain head dimensions. Additionally, it does not yet support model compilation. In the following code block, we configure our transformer block to use flash-attn-3 while setting the attention input format to "bshd" (batch, sequence, head, depth) to meet the expectations of the library.

# flash attention 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,k,v: fa3(q,k,v)[0]
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=attn_fn,
                             format='bshd')

print(f'Flash Attention 3')
train(block_fn)

The resultant step time was 240 ms, making it 5% faster than the SDPA flash-attn.

Transformer Engine

Transformer Engine (TE) is a specialized library designed to accelerate Transformer models on NVIDIA GPUs. TE is updated regularly with optimizations that leverage the capabilities of the latest NVIDIA hardware and software offerings, giving users access to specialized kernels long before they are integrated into general-purpose frameworks such as PyTorch.

In the code block below we use DotProductAttention from TE version 1.11.0. Similar to PyTorch SDPA, TE supports a number of backends which are controlled via environment variables. Here we demonstrate the use of the _NVTE_FUSEDATTN backend.

def set_te_backend(backend):
    # must be applied before first use of
    # transformer_engine.pytorch.attention
    os.environ["NVTE_FLASH_ATTN"] = '0'
    os.environ["NVTE_FUSED_ATTN"] = '0'
    os.environ["NVTE_UNFUSED_ATTN"] = '0'
    if backend == 'flash':
        os.environ["NVTE_FLASH_ATTN"] = '1'
    if backend == 'fused':
        os.environ["NVTE_FUSED_ATTN"] = '1'
    if backend == 'unfused':
        os.environ["NVTE_UNFUSED_ATTN"] = '1'

from transformer_engine.pytorch.attention import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
                              qkv_format='bshd',
                              # disable masking (default is causal mask)
                              attn_mask_type='no_mask')

block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=attn_fn,
                             format='bshd')

print(f'Transformer Engine Attention')
train(block_fn)
print(f'Compiled Transformer Engine Attention')
train_compile(block_fn)

TE attention resulted in average step times of 243 ms and 204 ms for the eager and compiled model variants, correspondingly.

XFormer Attention

Underlying the memory-efficient backend of PyTorch SDPA is an attention kernel provided by the xFormers library. Once again, we can go to the source to benefit from the latest kernel optimizations and from the full set of API capabilities. In the following code block we use the memory_efficient_attention operator from xFormers version 0.0.28.

# xformer memory efficient attention
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea,
                             format='bshd')

print(f'xFormer Attention ')
train(block_fn)
print(f'Compiled xFormer Attention ')
train_compile(block_fn)

This eager model variant resulted in an average step time of 246 ms, making it 10.5% faster than the SDPA memory efficient kernel. The compiled variant resulted in a step time of 203 ms.

Results

The table below summarizes our experiments:

Step times for various attention functions (lower is better) - by Author
Step times for various attention functions (lower is better) – by Author

The winner for the eager model was flash-attn-3 with an average step time that is 54% faster than our baseline model. This translates to a similar 54% reduction in training costs. In compiled mode, the performance across the optimized kernels was more or less equal, with the fastest implementations achieving 202 ms, representing a 20% improvement compared to the baseline experiment.

As mentioned above, the precise impact savings is greatly dependent on the model definition. To assess this variability, we reran the experiments using modified settings that increased the attention sequence length to 3136 tokens.

IMG_SIZE = 224
BATCH_SIZE = 8

# Define ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136

The results are summarized in the table below:

Results for large seqlen (lower is better) - by Author
Results for large seqlen (lower is better) – by Author

Our immediate observation is that when the sequence length is greater the performance impact of the attention kernels is far more pronounced. Once again, flash-attn-3 came out in front for the eager execution mode – this time with a ~5x increase in performance compared to the PyTorch-native function. For the compiled model we see that the TE kernel broke away from the pack with an overall best step-time of 53 ms.

Customizing Attention with FlexAttention

Thus far, we’ve focused on the standard attention function. However, sometimes we may want to use a variant of the typical attention computation in which we either mask out some of the values of intermediate tensors or apply some operation on them. These types of changes may interfere with our ability to use the optimized attention blocks we covered above. In this section we discuss some of the ways to address this:

Leverage Advanced Kernel APIs Many optimized attention kernels provide extensive APIs with controls for customizing the attention computation. Before implementing a new solution, explore these APIs to determine if they already support your required functionality.

Implement a custom kernel:If the existing APIs do not meet your needs, you could consider creating your own custom attention implementation. In previous posts (e.g., here) we discussed some of the pros and cons of custom kernel development. Achieving optimal performance can be extremely difficult. If you do go down this path, one approach might be to start with an existing (optimal) kernel and apply minimal changes to integrate the desired change.

Use FlexAttention:A recent addition to PyTorch, FlexAttention empowers users to implement a wide variety of attention variants without needing to compromise on performance. Denoting the result of the dot product of the query and key tokens by score, flex_attention allows for programming either a _scoremod function or a _block_mask_ mask that is automatically applied to the score tensor. See the documentation as well as the accompanying attention-gym repository for examples of the types of operations that the API enables.

[FlexAttention](https://pytorch.org/blog/flexattention/) works by compiling the _scoremod operator into the attention operator, thereby creating a single fused kernel. It also leverages the sparsity of _blockmasks to avoid unnecessary computations. The benchmarks reported in the FlexAttention documentation show considerable performance gains for a variety of use cases.

Let’s see both the _scoremod and _blockmask in action.

Score Mod Example – Soft-Capping with Tanh

Soft-capping is a common technique used to control the logit sizes (e.g., see here). The following code block extends our PyTorch-native attention kernel with soft-capping:

def softcap_attn(q, k, v):
    scale = HEAD_DIM ** -0.5
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    # apply soft-capping
    attn = 30 * torch.tanh(attn/30)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x

In the code block below we train our model, first with our PyTorch-native kernel, and then with the optimized Flex Attention API. These experiments were run with the 3136-length sequence settings.

# flex attention imports
from torch.nn.attention.flex_attention import (
    create_block_mask,
    create_mask,
    flex_attention
)
compiled_flex = torch.compile(flex_attention)

# score_mod definition
def tanh_softcap(score, b, h, q_idx, kv_idx):
    return 30 * torch.tanh(score/30)

block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)

print(f'Attention with Softcap')
train(block_fn)
print(f'Compiled Attention with Softcap')
train_compile(block_fn)

flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)

block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=compiled_flex_fn)

print(f'Flex Attention with Softcap')
train(compiled_block_fn)
print(f'Compiled Flex Attention with Softcap')
train_compile(block_fn)

The results of the experiments are captured in the table below:

Soft-cap step time results (lower is better) - by Author
Soft-cap step time results (lower is better) – by Author

The impact of the Flash Attention kernel is clearly evident, delivering performance boosts of approximately 3.5x in eager mode and 1.5x in compiled mode.

Mask Mod Example – Neighborhood Masking

We assess the _maskmod functionality by applying a sparse mask to our attention score. Recall that each token in our sequence represents a patch in our 2D input image. We modify our kernel so that each token attends only to other tokens that are within a 5×5 window in the corresponding 2-D token array.

# convert the token id to a 2d index
def seq_indx_to_2d(idx):
    n_row_patches = IMG_SIZE // PATCH_SIZE
    r_ind = idx // n_row_patches
    c_ind = idx % n_row_patches
    return r_ind, c_ind

# only attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
    q_r, q_c = seq_indx_to_2d(q_idx)
    kv_r, kv_c = seq_indx_to_2d(kv_idx)
    return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)

As a baseline for our experiment, we use PyTorch SDPA which includes support for passing in an attention mask. The following block includes the masked SDPA experiment followed by the Flex Attention implementation:

# materialize the mask to use in SDPA
mask = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, device='cuda')

set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=mask)
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=masked_sdpa)
print(f'Masked SDPA Attention')
train(block_fn)
print(f'Compiled Masked SDPA Attention')
train_compile(block_fn)

block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)

block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=compiled_flex_fn)

print(f'Masked Flex Attention')
train(compiled_block_fn)
print(f'Compiled Masked Flex Attention')
train_compile(block_fn)

The results of the experiments are captured below:

Masked attention step time results (lower is better) - by Author
Masked attention step time results (lower is better) – by Author

Once again, Flex Attention offers a considerable performance boost, amounting to 2.19x in eager mode and 2.59x in compiled mode.

Flex Attention Limitations

Although we have succeeded in demonstrating the power and potential of Flex Attention, there are a few limitations that should be noted:

  1. Limited Scope of Modifications: With Flex Attention you can (as of the time of this writing) only modify the attention score (the result of the dot product between the query and key tokens). It does not support changes at other stages of the attention computation.
  2. Dependency on torch.compile: Given the reliance on torch.compile, care must be taken to avoid excessive recompilations which could greatly degrade runtime performance. For instance, while the support for Document Masking is very compelling, it will perform as expected only if the sum of the lengths of all of the documents remains fixed.
  3. No Support for Trainable Parameters in _scoremod: At the time of this writing, Flex Attention does not support a _scoremod implementation that includes trainable parameters. For example, while the documentation highlights support for relative position encodings, these are commonly implemented with trainable parameters (rather than fixed values) which cannot currently be accommodated.

In the face of these limitations, we can return to one of the other optimization opportunities discussed above.

Summary

As the reliance on transformer architectures and attention layers in ML models increases, so does the need for tools and techniques for optimizing these components. In this post, we have explored a number of attention kernel variants, each with its own unique properties, capabilities, and limitations. Importantly, one size does not fit all – different models and use cases will warrant the use of different kernels and different optimization strategies. This underscores the importance of having a wide variety of tools and techniques for optimizing attention layers.

In a sequel to this post, we will further explore attention layer optimization by focusing on applying some of the tools we discussed to tackle the challenge of handling variable-sized input sequences. Stay tuned…

The post Increasing Transformer Model Efficiency Through Attention Layer Optimization appeared first on Towards Data Science.

]]>
On the Programmability of AWS Trainium and Inferentia https://towardsdatascience.com/on-the-programmability-of-aws-trainium-and-inferentia-cd455826e26c/ Fri, 01 Nov 2024 08:17:22 +0000 https://towardsdatascience.com/on-the-programmability-of-aws-trainium-and-inferentia-cd455826e26c/ Accelerating AI/ML Model Training with Custom Operators - Part 4

The post On the Programmability of AWS Trainium and Inferentia appeared first on Towards Data Science.

]]>
In this post we continue our exploration of the opportunities for runtime optimization of machine learning (ML) workloads through custom operator development. This time, we focus on the tools provided by the AWS Neuron SDK for developing and running new kernels on AWS Trainium and AWS Inferentia. With the rapid development of the low-level model components (e.g., attention layers) driving the AI revolution, the programmability of the accelerators used for training and running ML models is crucial. Dedicated AI chips, in particular, must offer a worthy alternative to the widely used and highly impactful general-purpose GPU (GPGPU) development frameworks, such as CUDA and Triton.

In previous posts (e.g., [here](https://towardsdatascience.com/a-first-look-at-aws-trainium-1e0605071970) and here) we explored the opportunity for building and running ML models on AWS’s custom-built AI chips using the the dedicated AWS Neuron SDK. In its most recent release of the SDK (version 2.20.0), AWS introduced the Neuron Kernel Interface (NKI) for developing custom kernels for [NeuronCore-v2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuron-core-v2.html), the underlying accelerator powering both Trainium and Inferentia2. The NKI interface joins another API that enables NeuronCore-v2 programmability, Neuron Custom C++ Operators. In this post we will explore both opportunities and demonstrate them in action.

Disclaimers

Importantly, this post should not be viewed as a substitute for the official AWS Neuron SDK documentation. At the time of this writing the Neuron SDK APIs for custom kernel development is in Beta, and may change by the time you read this. The examples we share are intended for demonstrative purposes, only. We make no claims as to their optimality, robustness, durability, or accuracy. Please do not view our mention of any platforms, tools, APIs, etc., as an endorsement for their use. The best choices for any project depend on the specifics of the use-case at hand and warrant appropriate investigation and analysis.

Developing Custom Kernels for Neuron Cores

Although the list of ML models supported by the Neuron SDK is continuously growing, some operations remain either unsupported or implemented suboptimally. By exposing APIs for Neuron kernel customization, the SDK empowers developers to create and/or optimize the low-level operations that they need, greatly increasing the opportunity for running ML workloads on Trainium and Inferentia.

As discussed in our previous posts in this series, fully leveraging the power of these AI chips requires a detailed understanding of their low-level architecture.

The Neuron Core Architecture

The NKI documentation includes a dedicated section on the architecture design of NeuronCore-v2 and its implications on custom operator development. Importantly, there are many differences between Neuron cores and their AI accelerator counterparts (e.g., GPUs and TPUs). Optimizing for Neuron cores requires a unique set of strategies and skills.

Similar to other dedicated AI chips, NeuronCore-v2 includes several internal acceleration engines, each of which specializes in performing certain types of computations. The engines can be run asynchronously and in parallel. The Neuron Compiler is responsible for transforming ML models into low-level operations and optimizing the choice of compute engine for each one.

The Tensor engine specializes in matrix multiplication. The Vector and Scalar engines both operate on tensors with the Vector engine specializing in reduction operations and the Scalar engine in non-linear functions. [GpSimd](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/trainium_inferentia2_arch.html#gpsimd-engine) is a general purpose engine capable of running arbitrary C/C++ programs. Note that while the NKI interface exposes access to all four compute engines, custom C++ operators are designed specifically for the GpSimd.

More details on the capabilities of each engine can be found in the architecture documentation. Furthermore, the NKI Instruction Set Architecture (ISA) documentation provides details on the engines on which different low-level operations are run.

Another important aspect of the Neuron chip is its memory architecture. A Neuron device includes three types of memory, HBM, SBUF, and PSUM. An intimate understanding of the capacities and capabilities of each one is crucial for optimal kernel development.

Given the architecture overview, you might conclude that Neuron kernel development requires high expertise. While this may be true for creating fully optimized kernels that leverage all the capabilities of the Neuron core, our aim is to demonstrate the accessibility, value, and potential of the Neuron custom kernel APIs – even for non-expert developers.

Custom NKI Kernels

The NKI interface is a Python-level API that exposes the use of the Neuron core compute engines and memory resources to ML developers. The NKI Getting Started guide details the setup instructions and provides a soft landing with a simple "hello world" kernel. The NKI Programming Model guide details the three stages of a typical NKI kernel (loading inputs, running operations on the computation engines, and storing outputs) and introduces the NKI Tile and Tile-based operations. The NKI tutorials demonstrate a variety of NKI kernel sample applications, with each one introducing new core NKI APIs and capabilities. Given the presumed optimality of the sample kernels, one possible strategy for developing new kernels could be to 1) identify a sample that is similar to the operation you wish to implement and then 2) use it as a baseline and iteratively refine and adjust it to achieve the specific functionality you require.

The [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.html) API Reference Manual details the Python API for kernel development. With a syntax and semantics that are similar to [Triton](https://triton-lang.org/main/index.html) and [NumPy](https://numpy.org/doc/stable/), the NKI language definition aims to maximize accessibility and ease of use. However, it is important to note that NKI kernel development is limited to the operations defined in the NKI library, which (as of the time of this writing) are fewer and more constrained than in libraries such as Triton and NumPy.

Toy Example – A GIOU Kernel

As in our previous posts, we assess the use of NKI by building a custom implementation of the Generalized Intersection Over Union (GIOU) operation on a pair of batches of input boxes. Since GIOU involves pixel-wise operations, we used the exp kernel from the NKI Programming guide as a reference point and incorporated the use of NKI’s advanced tensor indexing in our implementation. To facilitate debugging in a CPU environment, we also added options to run the code using the nki.simulate_kernel and nki.language.device_print.html APIs.

import torch
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import numpy as np

simulate = False

try:
    # if torch libraries are installed assume that we are running on Neuron
    import torch_xla.core.xla_model as xm
    import torch_neuronx
    from torch_neuronx import nki_jit

    device = xm.xla_device()

    # empty implementation 
    def debug_print(*args, **kwargs):
        pass
except:
    # if torch libraries are not installed assume that we are running on CPU
    # and program script to use nki simulation
    simulate = True
    nki_jit = nki.trace
    debug_print = nl.device_print
    device = 'cpu'

@nki_jit
def giou_kernel(preds_ptr,
                targets_ptr,
                output_ptr):
    epsilon = 1e-5
    TILE_M = nl.tile_size.pmax  # 128
    TILE_N = nl.tile_size.psum_fmax  # 512
    TILE_N_OUT = TILE_N // 4

    p_1, p_2 = preds_ptr.shape
    t_1, t_2 = targets_ptr.shape
    o_1, o_2 = output_ptr.shape

    #  verify input
    # batch size must be multiple of 128
    assert p_1 % TILE_M == 0
    assert p_1 == t_1
    assert p_1 == o_1
    # num boxes box *4 must be multiple of 512
    assert p_2 % TILE_N == 0
    assert p_2 == t_2
    assert p_2 // 4 == o_2

    num_tiles_m = p_1 // TILE_M
    num_tiles_n = p_2 // TILE_N

    # Generate tensors for advanced indexing
    i_p = nl.arange(TILE_M)[:, None]
    i_f = nl.arange(TILE_N // 4)[None, :]
    i_f_0 = (4 * i_f)
    i_f_1 = (4 * i_f + 1)
    i_f_2 = (4 * i_f + 2)
    i_f_3 = (4 * i_f + 3)

    # Use affine_range to loop over tiles
    for m in nl.affine_range(num_tiles_m):
        for n in nl.affine_range(num_tiles_n):
            # Load input data from HBM
            preds = nl.load(preds_ptr[m * TILE_M:(m + 1) * TILE_M,
                            n * TILE_N:(n + 1) * TILE_N])
            targets = nl.load(targets_ptr[m * TILE_M:(m + 1) * TILE_M,
                              n * TILE_N:(n + 1) * TILE_N])
            debug_print('preds', preds)
            preds_left = preds[i_p, i_f_0]
            preds_top = preds[i_p, i_f_1]
            preds_right = preds[i_p, i_f_2]
            preds_bottom = preds[i_p, i_f_3]

            gt_left = targets[i_p, i_f_0]
            gt_top = targets[i_p, i_f_1]
            gt_right = targets[i_p, i_f_2]
            gt_bottom = targets[i_p, i_f_3]

            # Compute the area of each box
            area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
            area2 = (gt_right - gt_left) * (gt_bottom - gt_top)

            # Compute the intersection
            left = nl.maximum(preds_left, gt_left)
            top = nl.maximum(preds_top, gt_top)
            right = nl.minimum(preds_right, gt_right)
            bottom = nl.minimum(preds_bottom, gt_bottom)

            inter_w = nl.maximum(right - left, 0)
            inter_h = nl.maximum(bottom - top, 0)
            inter_area = inter_w * inter_h

            union_area = area1 + area2 - inter_area

            iou_val = inter_area / nl.maximum(union_area, epsilon)

            # Compute the smallest enclosing box
            enclose_left = nl.minimum(preds_left, gt_left)
            enclose_top = nl.minimum(preds_top, gt_top)
            enclose_right = nl.maximum(preds_right, gt_right)
            enclose_bottom = nl.maximum(preds_bottom, gt_bottom)

            enclose_w = nl.maximum(enclose_right - enclose_left, 0)
            enclose_h = nl.maximum(enclose_bottom - enclose_top, 0)
            enclose_area = enclose_w * enclose_h

            # Compute GIOU
            delta_area = (enclose_area - union_area)
            enclose_area = nl.maximum(enclose_area, epsilon)
            giou = iou_val - delta_area / enclose_area

            # Store results
            nl.store(output_ptr[m * TILE_M:(m + 1) * TILE_M,
                     n * TILE_N_OUT:(n + 1) * TILE_N_OUT],
                     giou)

To run our GIOU kernel, we generate two batches of random boxes and feed them to our function:

# generate random data in np
np.random.seed(0)
batch_size = 1024
n_boxes = 256
img_size = 256
boxes = []

for i in range(2):
    # Randomly generate box sizes and positions
    box_sizes = np.random.randint(1, img_size, size=(batch_size,n_boxes,2))
    top_left = np.random.randint(0, img_size-1, size=(batch_size,n_boxes,2))
    bottom_right = np.clip(top_left + box_sizes, 0, img_size - 1)

    # Concatenate top-left and bottom-right coordinates
    rand_boxes = np.concatenate((top_left, bottom_right), axis=2)

    boxes.append(rand_boxes.astype(np.float32))

out = np.empty((batch_size, n_boxes), np.float32)

# convert tensors to Pytorch
t_boxes_0 = torch.tensor(boxes[0]).to(device)
t_boxes_1 = torch.tensor(boxes[1]).to(device)
t_out = torch.tensor(out).to(device)

if simulate:
    # the simulation API requires numpy input
    nki.simulate_kernel(giou_kernel, 
                        boxes[0].reshape((batch_size, -1)),
                        boxes[1].reshape((batch_size, -1)),
                        out)
else:
    giou_kernel(t_boxes_0.view((batch_size, -1)),
                t_boxes_1.view((batch_size, -1)),
                t_out)

To assess the performance of our NKI kernel, we will compare it with the following naive implementation of GIOU in PyTorch:

def torch_giou(boxes1, boxes2):
    # loosely based on torchvision generalized_box_iou_loss code
    epsilon = 1e-5

    # Compute areas of both sets of boxes
    area1 = (boxes1[...,2]-boxes1[...,0])*(boxes1[...,3]-boxes1[...,1])
    area2 = (boxes2[...,2]-boxes2[...,0])*(boxes2[...,3]-boxes2[...,1])

    # Corners of intersection
    lt = torch.max(boxes1[..., :2], boxes2[..., :2])
    rb = torch.min(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of intersection
    wh = (rb - lt).clamp(min=0)

    # Area of the intersection
    inter = wh[..., 0] * wh[..., 1]

    # Union of the two boxes
    union = area1 + area2 - inter
    iou = inter / union.clamp(epsilon)

    # Corners of enclosing box
    lti = torch.min(boxes1[..., :2], boxes2[..., :2])
    rbi = torch.max(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of the enclosing box
    whi = (rbi - lti).clamp(min=0)

    # Area of the enclosing box
    areai = (whi[..., 0] * whi[..., 1]).clamp(epsilon)

    return iou - (areai - union) / areai

We use the following benchmarking utility to compare the runtime performance of our two functions:

import time
def benchmark(f, warmup_iters=20, ntrials: int = 100):
    def run(*args, **kwargs):
        # warmup
        for _ in range(warmup_iters):
            f(*args, **kwargs)
        start_time = time.time()
        for _ in range(ntrials):
            f(*args, **kwargs)
        end_time = time.time()
        # Calculate average time per iteration
        avg_time = (end_time - start_time) / ntrials
        return avg_time

    return run

avg_time = benchmark(torch_giou)(t_boxes_0, t_boxes_1)
print(f'torch_giou: {avg_time}')

avg_time = benchmark(giou_kernel)(t_boxes_0.view((batch_size, -1)),
                                  t_boxes_1.view((batch_size, -1)),
                                  t_out)
print(f'giou_kernel: {avg_time}')

Runtime Environment

We ran our script on an Amazon EC2 inf2.xlarge instance (containing two Neuron cores and four vCPUs). We used the most recent version of the Deep Learning AMI for Neuron available at the time of this writing, "Deep Learning AMI Neuron (Ubuntu 22.04) 20241027", with AWS Neuron 2.20.1 and PyTorch 2.1.

Results

Our custom GIOU kernel demonstrated an average runtime of 0.211 milliseconds compared to 0.293, amounting to a 39% performance boost. Keep in mind that these results are unique to our toy example. Other operators, particularly ones that include matrix multiplications (and utilize the Tensor engine) are likely to exhibit different comparative results.

Optimizing NKI Kernel Performance

The next step in our kernel development – beyond the scope of this post – would to be to analyze the performance of the GIOU kernel using the dedicated Neuron Profiler in order to identify bottlenecks and optimize our implementation. Please see the NKI performance guide for more details.

Neuron Custom C++ Operators

The second method for creating a custom Neuron kernel is to build a C++ operator for the GpSimd engine. This method is described in the Neuron Custom C++ Operators Developer Guide and demonstrated in the Neuron Custom C++ Operators in MLP and Neuron Custom C++ Operators Performance Optimization tutorials.

Neuron Custom C++ Operators presents an opportunity for "kernel fusion" on the GpSimd engine by facilitating the combination of multiple low-level operations into a single kernel execution. This approach can significantly reduce the overhead associated with: 1) loading multiple individual kernels, and 2) transferring data between different memory regions.

Toy Example – A GIOU C++ Kernel

In the code block below we implement a C++ GIOU operator for Neuron and save it to a file named giou.cpp. Our kernel uses the TCM accessor for optimizing memory read and write performance and applies the multicore setting in order to use all eight of the GpSimd’s internal processors.

#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include <neuron/neuron-utils.hpp>
#include <algorithm>

// input boxes of shape 1024x256x4
// output scores of shape 1024x256
torch::Tensor giou(const torch::Tensor&amp; t_pred, 
                   const torch::Tensor&amp; t_target) {
  size_t num_samples = t_pred.sizes()[0];
  size_t num_boxes = t_pred.sizes()[1];
  torch::Tensor t_out = get_dst_tensor();

  // get the number of GpSimd processors (8 in NeuronCoreV2) 
  uint32_t cpu_count = get_cpu_count();
  // get index of current processor
  uint32_t cpu_id = get_cpu_id();

  // divide the batch size into 8 partitions 
  uint32_t partition = num_samples / cpu_count;

  // use tcm buffers to load and write data
  size_t tcm_in_size = num_boxes*4;
  size_t tcm_out_size = num_boxes;
  float *tcm_pred = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  float *tcm_target = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  float *tcm_output = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  auto t_pred_tcm_acc = t_pred.tcm_accessor();
  auto t_target_tcm_acc = t_target.tcm_accessor();
  auto t_out_tcm_acc = t_out.tcm_accessor();

  // iterate over each of the entries in the partition
  for (size_t i = 0; i < partition; i++) {
    // load the pred and target boxes into local memory
    t_pred_tcm_acc.tensor_to_tcm<float>(tcm_pred,
                                        partition*cpu_id + i*tcm_in_size,
                                        tcm_in_size);
    t_target_tcm_acc.tensor_to_tcm<float>(tcm_target,
                                          partition*cpu_id + i*tcm_in_size,
                                          tcm_in_size);

    // iterate over each of the boxes in the entry
    for (size_t j = 0; j < num_boxes; j++) {
      const float epsilon = 1e-5;
      const float* box1 = &amp;tcm_pred[j * 4];
      const float* box2 = &amp;tcm_target[j * 4];
      // Compute area of each box
      float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
      float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);

      // Compute the intersection
      float left = std::max(box1[0], box2[0]);
      float top = std::max(box1[1], box2[1]);
      float right = std::min(box1[2], box2[2]);
      float bottom = std::min(box1[3], box2[3]);

      float inter_w = std::max(right - left, 0.f);
      float inter_h = std::max(bottom - top, 0.f);
      float inter_area = inter_w * inter_h;

      // Compute the union area
      float union_area = area1 + area2 - inter_area;

      // IoU
      float iou_val = inter_area / std::max(union_area, epsilon);

      // Compute the smallest enclosing box
      float enclose_left = std::min(box1[0], box2[0]);
      float enclose_top = std::min(box1[1], box2[1]);
      float enclose_right = std::max(box1[2], box2[2]);
      float enclose_bottom = std::max(box1[3], box2[3]);

      float enclose_w = std::max(enclose_right - enclose_left, 0.f);
      float enclose_h = std::max(enclose_bottom - enclose_top, 0.f);
      float enclose_area = std::max(enclose_w * enclose_h, epsilon);

      float result = iou_val - (enclose_area-union_area)/enclose_area;
      tcm_output[j] = result;
    }

    // write the giou scores of all boxes in the current entry
    t_out_tcm_acc.tcm_to_tensor<float>(tcm_output,
                                       partition*cpu_id + i*tcm_out_size,
                                       tcm_out_size);
  }

  torch::neuron::tcm_free(tcm_pred);
  torch::neuron::tcm_free(tcm_target);
  return t_out;
}

We require a separate shape.cpp file that defines the output shape of our GIOU function and registers our custom operator with the Neuron library:

#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include "torchneuron/register.h"

torch::Tensor giou_shape(torch::Tensor boxes1, torch::Tensor boxes2) {
    torch::Tensor t_out = torch::zeros({boxes1.sizes()[0],
                                        boxes1.sizes()[1]},
                                       torch::kFloat);
    return t_out;
}

NEURON_LIBRARY(my_ops, m) {
  m.def("giou", &amp;giou_shape, "giou");
}

The build.py script compiles the C++ operator and exposes it as a Python API:

import os
import torch_neuronx
from torch_neuronx.xla_impl import custom_op

custom_op.load(
    name='giou',
    compute_srcs=['giou.cpp'],
    shape_srcs=['shape.cpp'],
    build_directory=os.getcwd(),
    multicore=True,
    verbose=True
)

The compilation script generates a libgiou.so library containing the implementation of our C++ GIOU operator. In the code block below we load the library and measure the performance of our custom kernel using the benchmarking utility defined above:

from torch_neuronx.xla_impl import custom_op
custom_op.load_library('libgiou.so')

avg_time = benchmark(torch.ops.my_ops.giou)(t_boxes_0, t_boxes_1)
print(f'C++ giou: {avg_time}')

Runtime Environment

We used the same Neuron environment from our NKI experiments to compile and test our C++ kernel. Please note the installation steps that are required for custom C++ operator development.

Results

Our C++ GIOU kernel demonstrated an average runtime of 0.061 milliseconds – nearly five times faster than our baseline implementation. This is presumably a result of "kernel fusion", as discussed above.

Conclusion

The table below summarizes the runtime results of our experiments.

Avg time of different GIOU implementations (lower is better) - by Author
Avg time of different GIOU implementations (lower is better) – by Author

Please keep in mind that these results are specific to the toy example and runtime environment used in this study. The comparative results of other kernels might be very different – depending on the degree to which they can leverage the Neuron core’s internal compute engines.

The table below summarizes some of the differences we observed between the two methods of AWS Neuron kernel customization.

Comparison between kernel customization tools (by Author)
Comparison between kernel customization tools (by Author)

Through its high-level Python interface, the NKI APIs expose the power of the Neuron acceleration engines to ML developers in an accessible and user-friendly manner. The low-level C++ Custom Operators library enables even greater programmability, but is limited to the GpSimd engine. By effectively combining both tools, developers can fully leverage the AWS Neuron architecture’s capabilities.

Summary

With the AI revolution in full swing, many companies are developing advanced new AI chips to meet the growing demand for compute. While public announcements often highlight these chips’ runtime performance, cost savings, and energy efficiency, several core capabilities are essential to making these chips and their software stacks truly viable for ML development. These capabilities include robust debugging tools, performance analysis and optimization utilities, programmability, and more.

In this post, we focused on the utilities available for programming AWS’s homegrown AI accelerators, Trainium and Inferentia, and demonstrated their use in building custom ML operations. These tools empower developers to optimize the performance of their ML models on AWS’s AI chips and open up new opportunities for innovation and creativity.

The post On the Programmability of AWS Trainium and Inferentia appeared first on Towards Data Science.

]]>
AI Model Optimization on AWS Inferentia and Trainium https://towardsdatascience.com/ai-model-optimization-on-aws-inferentia-and-trainium-cfd48e85d5ac/ Sun, 20 Oct 2024 07:19:11 +0000 https://towardsdatascience.com/ai-model-optimization-on-aws-inferentia-and-trainium-cfd48e85d5ac/ Tips for accelerating ML with AWS Neuron SDK

The post AI Model Optimization on AWS Inferentia and Trainium appeared first on Towards Data Science.

]]>
We are in a golden age of AI, with cutting-edge models disrupting industries and poised to transform life as we know it. Powering these advancements are increasingly powerful AI accelerators, such as NVIDIA H100 GPUs, Google Cloud TPUs, AWS’s Trainium and Inferentia chips, and more. With the growing number of options comes the challenge of selecting the most optimal platform for our machine learning (ML) workloads – a crucial decision considering the high costs associated with AI computation. Importantly, a comprehensive assessment of each option necessitates ensuring that we are maximizing its utilization to fully leverage its capabilities.

In this post, we will review several techniques for optimizing an ML workload on AWS’s custom-built AI chips using the AWS Neuron SDK. This continues our ongoing series of posts focused on ML model performance analysis and optimization across various platforms and environments (e.g., see [here](https://towardsdatascience.com/training-ai-models-on-cpu-3903adc9f388) and here). While our primary focus will be on an ML training workload and AWS Inferentia2, the techniques discussed are also applicable to AWS Trainium. (Recall that although AWS Inferentia is primarily designed as an AI inference chip, we have previously demonstrated its effectiveness in training tasks as well.)

Generally speaking, performance optimization is an iterative process that includes a performance analysis step to appropriately identify performance bottlenecks and resource under-utilization (e.g., see here). However, since the techniques we will discuss are general purpose (i.e., they are potentially applicable to any model, regardless of their performance profile), we defer the discussion on performance analysis with the Neuron SDK to a future post.

Disclaimers

The code we will share is intended for demonstrative purposes only – we make no claims regarding its accuracy, optimality, or robustness. Please do not view this post as a substitute for the official Neuron SDK documentation. Please do not interpret our mention of any platforms, libraries, or Optimization techniques as an endorsement for their use. The best options for you will depend greatly on the specifics of your use-case and will require your own in-depth investigation and analysis.

The experiments described below were run on an Amazon EC2 inf2.xlarge instance (containing two Neuron cores and four vCPUs). We used the most recent version of the Deep Learning AMI for Neuron available at the time of this writing, "Deep Learning AMI Neuron (Ubuntu 22.04) 20240927", with AWS Neuron 2.20 and PyTorch 2.1. See the SDK documentation for more details on setup and installation. Keep in mind that the Neuron SDK is under active development and that the APIs we refer to, as well as the runtime measurements we report, may become outdated by the time you read this. Please be sure to stay up-to-date with the latest SDK and documentation available.

Toy Model

To facilitate our discussion, we introduce the following simple Vision Transformer (ViT)-backed classification model (based on timm version 1.0.10):

from torch.utils.data import Dataset
import time, os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from timm.models.vision_transformer import VisionTransformer

# use random data
class FakeDataset(Dataset):
  def __len__(self):
    return 1000000

  def __getitem__(self, index):
    rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
    label = torch.tensor(data=index % 1000, dtype=torch.int64)
    return rand_image, label

def train(batch_size=16, num_workers=0):
  # Initialize XLA process group for torchrun
  import torch_xla.distributed.xla_backend
  torch.distributed.init_process_group('xla')

  # multi-processing: ensure each worker has same initial weights
  torch.manual_seed(0)
  dataset = FakeDataset()
  model = VisionTransformer()

  # load model to XLA device
  device = xm.xla_device()
  model = model.to(device)
  optimizer = torch.optim.Adam(model.parameters())
  data_loader = torch.utils.data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            num_workers=num_workers)

  data_loader = pl.MpDeviceLoader(data_loader, device)
  loss_function = torch.nn.CrossEntropyLoss()
  summ = 0
  count = 0
  t0 = time.perf_counter()

  for step, (inputs, targets) in enumerate(data_loader, start=1):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    xm.optimizer_step(optimizer)
    batch_time = time.perf_counter() - t0
    if step > 10:  # skip first steps
      summ += batch_time
      count += 1
    t0 = time.perf_counter()
    if step > 500:
      break
  print(f'average step time: {summ/count}')

if __name__ == '__main__':
  train()

# Initialization command:
# torchrun --nproc_per_node=2 train.py

Running our baseline model on the two cores of our AWS Inferentia instance, results in a training speed of 251.98 samples per second.

In the next sections, we will iteratively apply a number of potential optimization techniques and assess their impact on step time performance. While we won’t go into the full details of each method, we will provide references for further reading (e.g., here). Importantly, the list we will present is not all-inclusive – there are many techniques beyond what we will cover. We will organize the methods into three categories: Pytorch optimizations, OpenXLA optimizations, and Neuron-specific optimizations. However, the order of presentation is not binding. In fact, some of the techniques are interdependent – for example, applying the mixed precision optimization may free up enough device memory to enable increasing the batch size.

PyTorch Performance Optimizations

In previous posts (e.g., here) we have covered the topic of PyTorch model performance analysis and optimization on GPU, extensively. Many of the techniques we discussed are relevant to other AI accelerators. In this section we will revisit few of these techniques and apply them to AWS Inferentia.

Multi-process Data Loading

In multi process data loading the input data is prepared in one or more dedicated CPU processes rather than in the same process that runs the training step. This allows for overlapping the data loading and training which can increase system utilization and lead to a significant speed-up. The number of processes is controlled by the _numworkers parameter of the PyTorch DataLoader. In the following block we run our script with _numworkers set to one:

train(num_workers=1)

This change results in a training speed of 253.56 samples per second for a boost of less than 1%.

Batch Size Optimization

Another important hyperparameter that can influence training speed is the training batch size. Often, we have found that increasing the batch size improves system utilization and results in better performance. However, the effects can vary based on the model and platform. In the case of our toy model on AWS Inferentia, we find that running with a batch size of 8 samples per neuron core results in a speed of 265.68 samples per second – roughly 5% faster than a batch size of 16 samples per core.

train(batch_size=8, num_workers=1)

PyTorch Automatic Mixed Precision

Another common method for boosting performance is to use lower precision floats such as the 16-bit BFloat16. Importantly, some model components might not be compatible with reduced precision floats. PyTorch’s Automatic Mixed Precision (AMP) mode attempts to match the most appropriate floating point type to each model operation automatically. Although, the Neuron compiler offers different options for employing mixed precision, it also supports the option of using PyTorch AMP. In the code block below we include the modifications required to use PyTorch AMP.

def train(batch_size=16, num_workers=0):
  # Initialize XLA process group for torchrun
  import torch_xla.distributed.xla_backend
  torch.distributed.init_process_group('xla')

  # multi-processing: ensure each worker has same initial weights
  torch.manual_seed(0)
  dataset = FakeDataset()
  model = VisionTransformer()

  # load model to XLA device
  device = xm.xla_device()
  model = model.to(device)
  optimizer = torch.optim.Adam(model.parameters())
  data_loader = torch.utils.data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            num_workers=num_workers)

  data_loader = pl.MpDeviceLoader(data_loader, device)
  loss_function = torch.nn.CrossEntropyLoss()
  summ = 0
  count = 0
  t0 = time.perf_counter()

  for step, (inputs, targets) in enumerate(data_loader, start=1):
    optimizer.zero_grad()

    # use PyTorch AMP
    with torch.autocast(dtype=torch.bfloat16, device_type='cuda'):
      outputs = model(inputs)
      loss = loss_function(outputs, targets)
    loss.backward()
    xm.optimizer_step(optimizer)
    batch_time = time.perf_counter() - t0
    if step > 10:  # skip first steps
      summ += batch_time
      count += 1
    t0 = time.perf_counter()
    if step > 500:
      break
  print(f'average step time: {summ/count}')

if __name__ == '__main__':
  # disable neuron compilar casting
  os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
  torch.cuda.is_bf16_supported = lambda: True
  train(batch_size=8, num_workers=1)

The resultant training speed is 196.64 samples per second, about 26% lower than the default mixed precision setting of the Neuron compiler. It’s important to note that while this post focuses on performance, in real-world scenarios, we would also need to evaluate the effect of the mixed precision policy we choose on model accuracy.

OpenXLA Optimizations

As discussed in a previous post, Neuron Cores are treated as XLA devices and the torch-neuronx Python package implements the PyTorch/XLA API. Consequently, any optimization opportunities provided by the OpenXLA framework, and specifically those offered by the PyTorch/XLA API, can be leveraged on AWS Inferentia and Trainium. In this section we consider a few of these opportunities.

BFloat16 Precision

OpenXLA supports the option of casting all floats to BFloat16 via the XLA_USE_BF16 environment variable, as shown in the code block below:

if __name__ == '__main__':
  os.environ['XLA_USE_BF16'] = '1'
  train(batch_size=8, num_workers=1)

The resultant training speed is 394.51 samples per second, nearly 50% faster than the speed of the default mixed precision option.

Multi-process Device Loading

The PyTorch/XLA MpDeviceLoader and its internal [ParallelLoader](https://pytorch.org/xla/master/_modules/torch_xla/distributed/parallel_loader.html), which are responsible for loading input data on to the accelerator, include a number of parameters for controlling the transfer of data from the host to the device. In the code block below we tune _batches_per_execution_ setting which determines the number of batches copied to the device for each execution cycle of the ParallelLoader. By increasing this setting, we aim to reduce the overhead of the host-to-device communication:

data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers
                                          )
data_loader = pl.MpDeviceLoader(data_loader, 
                                device, batches_per_execution=10)

As a result of this optimization, the training speed increased to 1,027.39 samples per second, representing an additional 260% speed-up.

Torch Compilation with OpenXLA Backend

In previous posts (e.g., here), we have demonstrated the potential performance gains from using PyTorch’s graph compilation offering. Although OpenXLA includes its own graph creation and Just-In-Time (JIT) compilation mechanisms, torch.compile can provide additional acceleration by eliminating the need for tracing the model operations at every step. The following code snippet demonstrates the use of the dedicated openxla backend for compiling the model:

model = model.to(device)
model = torch.compile(backend='openxla')

Although torch.compile is currently not yet supported by the Neuron SDK, we include its mention in anticipation of its future release.

Neuron SDK Optimizations

In this section we consider some of the optimization opportunities offered by the AWS Neuron SDK and, more specifically, by the Neuron compiler.

Mixed Precision

The Neuron SDK supports a variety of mixed precision settings. In the code block below we program the compiler to cast all floats to BFloat16 via the _NEURON_CCFLAGS environment variable.

if __name__ == '__main__':
  os.environ["NEURON_CC_FLAGS"] = "--auto-cast all --auto-cast-type bf16"
  train(batch_size=8, num_workers=1)

This results (unsurprisingly) in a similar training speed to the OpenXLA BFloat16 experiment described above.

FP8

One of the unique features of NeuronCoreV2 is its support of the eight-bit floating point type, fp8_e4m3. The code block below demonstrates how to configure the Neuron compiler to automatically cast all floating-point operations to FP8:

if __name__ == '__main__':
 os.environ["NEURON_CC_FLAGS"] = "--auto-cast all --auto-cast-type fp8_e4m3"
 train(batch_size=8, num_workers=1)

While FP8 can accelerate training in some cases, maintaining stable convergence can be more challenging than when using BFloat16 due its reduced precision and dynamic range. Please see our previous post for more on the potential benefits and challenges of FP8 training.

In the case of our model, using FP8 actually harms runtime performance compared to BFloat16, reducing the training speed to 940.36 samples per second.

Compiler Optimizations

The Neuron compiler includes a number of controls for optimizing the runtime performance of the compiled graph. Two key settings are model-type and opt-level. The model-type setting applies optimizations tailored to specific model architectures, such as transformers, while the opt-level setting allows for balancing compilation time against runtime performance. In the code block below, we program the model-type setting to tranformer and the opt-level setting to the highest performance option. We further specify the target runtime device, inf2, to ensure that the model is optimized for the target device.

if __name__ == '__main__':
  os.environ['XLA_USE_BF16'] = '1'
  os.environ["NEURON_CC_FLAGS"] = "--model-type transformer " 
                                  "--optlevel 3" 
                                  " --target inf2"
  train(batch_size=8, num_workers=1)

The above configuration resulted in a training speed of 1093.25 samples per second, amounting to a modest 6% improvement.

Results

We summarize the results of our experiments in the table below. Keep in mind that the effect of each of the optimization methods we discussed will depend greatly on the model and the runtime environment.

Experiment Results (by Author)
Experiment Results (by Author)

The techniques we employed resulted in a 435% performance boost compared to our baseline experiment. It is likely that additional acceleration could be achieved by revisiting and fine-tuning some of the methods we discussed, or by applying other optimization techniques not covered in this post.

Our goal has been to demonstrate some of the available optimization strategies and demonstrate their potential impact on runtime performance. However, in a real-world scenario, we would need to assess the manner in which each of these optimizations impact our model convergence. In some cases, adjustments to the model configuration may be necessary to ensure optimal performance without sacrificing accuracy. Additionally, using a performance profiler to identify bottlenecks and measure system resource utilization is essential for guiding and informing our optimization activities.

Summary

Nowadays, we are fortunate to have a wide variety of systems on which to run our ML workloads. No matter which platform we choose, our goal is to maximize its capabilities. In this post, we focused on AWS Inferentia and reviewed several techniques for accelerating ML workloads running on it. Be sure to check out our other posts for more optimization strategies across various AI accelerators.

The post AI Model Optimization on AWS Inferentia and Trainium appeared first on Towards Data Science.

]]>
Implementing Sequential Algorithms on TPU https://towardsdatascience.com/implementing-sequential-algorithms-on-tpu-41d75c6aaa95/ Mon, 07 Oct 2024 20:50:21 +0000 https://towardsdatascience.com/implementing-sequential-algorithms-on-tpu-41d75c6aaa95/ Accelerating AI/ML Model Training with Custom Operators - Part 3.A

The post Implementing Sequential Algorithms on TPU appeared first on Towards Data Science.

]]>
This is a direct sequel to a previous post on the topic of implementing custom TPU operations with Pallas. Of particular interest are custom kernels that leverage the unique properties of the TPU architecture in a manner that optimizes runtime performance. In this post, we will attempt to demonstrate this opportunity by applying the power of Pallas to the challenge of running sequential algorithms that are interspersed within a predominantly parallelizable deep learning (DL) workload.

We will focus on Non Maximum Suppression (NMS) of bounding-box proposals as a representative algorithm, and explore ways to optimize its implementation. An important component of computer vision (CV) object detection solutions (e.g., Mask RCNN), NMS is commonly used to filter out overlapping bounding boxes, keeping only the "best" ones. NMS receives a list of bounding box proposals, an associated list of scores, and an IOU threshold, and proceeds to greedily and iteratively choose the remaining box with the highest score and disqualify all other boxes with which it has an IOU that exceeds the given threshold. The fact that the box chosen at the n-th iteration depends on the preceding n-1 steps of the algorithm dictates the sequential nature of its implementation. Please see [here](https://medium.com/analytics-vidhya/non-max-suppression-nms-6623e6572536) and/or here for more on the rationale behind NMS and its implementation. Although we have chosen to focus on one specific algorithm, most of our discussion should carry over to other sequential algorithms.

Offloading Sequential Algorithms to CPU

The presence of a sequential algorithm within a predominantly parallelizable ML model (e.g., Mask R-CNN) presents an interesting challenge. While GPUs, commonly used for such workloads, excel at executing parallel operations like matrix multiplication, they can significantly underperform compared to CPUs when handling sequential algorithms. This often leads to computation graphs that include crossovers between the GPU and CPU, where the GPU handles the parallel operations and the CPU handles the sequential ones. NMS is a prime example of a sequential algorithm that is commonly offloaded onto the CPU. In fact, a close analysis of torchvision‘s "CUDA" implementation of NMS, reveals that even it runs a significant portion of the algorithm on CPU.

Although offloading sequential operations to the CPU may lead to improved runtime performance, there are several potential drawbacks to consider:

  1. Cross-device execution between the CPU and GPU usually requires multiple points of synchronization between the devices which commonly results in idle time on the GPU while it waits for the CPU to complete its tasks. Given that the GPU is typically the most expensive component of the training platform our goal is to minimize such idle time.
  2. In standard ML workflows, the CPU is responsible for preparing and feeding data to the model, which resides on the GPU. If the data input pipeline involves compute-intensive processing, this can strain the CPU, leading to "input starvation" on the GPU. In such scenarios, offloading portions of the model’s computation to the CPU could further exacerbate this issue.

To avoid these drawbacks you could consider alternative approaches, such as replacing the sequential algorithm with a comparable alternative (e.g., the one suggested here), settling for a slow/suboptimal GPU implementation of the sequential algorithm, or running the workload on CPU – each of which come with their own potential trade-offs.

Sequential Algorithms on TPU

This is where the unique architecture of the TPU could present an opportunity. Contrary to GPUs, TPUs are sequential processors. While their ability to run highly vectorized operations makes them competitive with GPUs when running parallelizable operations such as matrix multiplication, their sequential nature could make them uniquely suited for running ML workloads that include a mix of both sequential and parallel components. Armed with the Pallas extension to Jax, our newfound TPU kernel creation tool, we will evaluate this opportunity by implementing and evaluating a custom implementation of NMS for TPU.

Disclaimers

The NMS implementations we will share below are intended for demonstrative purposes only. We have not made any significant effort to optimize them or to verify their robustness, durability, or accuracy. Please keep in mind that, as of the time of this writing, Pallas is an experimental feature – still under active development. The code we share (based on JAX version 0.4.32) may become outdated by the time you read this. Be sure to refer to the most up-to-date APIs and resources available for your Pallas development. Please do not view our mention of any algorithm, library, or API as an endorsement for their use.

NMS on CPU

We begin with a simple implementation of NMS in numpy that will serve as a baseline for performance comparison:

import numpy as np

def nms_cpu(boxes, scores, max_output_size, threshold=0.1):
    epsilon = 1e-5

    # Convert bounding boxes and scores to numpy
    boxes = np.array(boxes)
    scores = np.array(scores)

    # coordinates of bounding boxes
    start_x = boxes[:, 0]
    start_y = boxes[:, 1]
    end_x = boxes[:, 2]
    end_y = boxes[:, 3]

    # Compute areas of bounding boxes
    areas = (end_x - start_x) * (end_y - start_y)

    # Sort by confidence score of bounding boxes
    order = np.argsort(scores)

    # Picked bounding boxes
    picked_boxes = []

    # Iterate over bounding boxes
    while order.size > 0 and len(picked_boxes) < max_output_size:

        # The index of the remaining box with the highest score
        index = order[-1]

        # Pick the bounding box with largest confidence score
        picked_boxes.append(index.item())

        # Compute coordinates of intersection
        x1 = np.maximum(start_x[index], start_x[order[:-1]])
        x2 = np.minimum(end_x[index], end_x[order[:-1]])
        y1 = np.maximum(start_y[index], start_y[order[:-1]])
        y2 = np.minimum(end_y[index], end_y[order[:-1]])

        # Compute areas of intersection and union
        w = np.maximum(x2 - x1, 0.0)
        h = np.maximum(y2 - y1, 0.0)

        intersection = w * h
        union = areas[index] + areas[order[:-1]] - intersection

        # Compute the ratio between intersection and union
        ratio = intersection / np.clip(union, min=epsilon)

        # discard boxes above overlap threshold
        keep = np.where(ratio < threshold)
        order = order[keep]

    return picked_boxes

To evaluate the performance of our NMS function, we generate a batch of random boxes and scores (as JAX tensors) and run the script on a Google Cloud TPU v5e system using the same environment and same benchmarking utility as in our previous post. For this experiment, we specify the CPU as the JAX default device:

import jax
from jax import random
import jax.numpy as jnp

def generate_random_boxes(run_on_cpu = False):
    if run_on_cpu:
        jax.config.update('jax_default_device', jax.devices('cpu')[0])
    else:
        jax.config.update('jax_default_device', jax.devices('tpu')[0])

    n_boxes = 1024
    img_size = 1024

    k1, k2, k3 = random.split(random.key(0), 3)

    # Randomly generate box sizes and positions
    box_sizes = random.randint(k1,
                               shape=(n_boxes, 2),
                               minval=1,
                               maxval=img_size)
    top_left = random.randint(k2,
                              shape=(n_boxes, 2),
                              minval=0,
                              maxval=img_size - 1)
    bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)

    # Concatenate top-left and bottom-right coordinates
    rand_boxes = jnp.concatenate((top_left, bottom_right),
                                 axis=1).astype(jnp.bfloat16)
    rand_scores = jax.random.uniform(k3, 
                                     shape=(n_boxes,),
                                     minval=0.0,
                                     maxval=1.0)

    return rand_boxes, rand_scores

rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)

time = benchmark(nms_cpu)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_cpu: {time}')

The resultant average runtime is 2.99 milliseconds. Note the assumption that the input and output tensors reside on the CPU. If they are on the TPU, then the time to copy them between the devices should also be taken into consideration.

NMS on TPU

If our NMS function is a component within a larger computation graph running on the TPU, we might prefer a TPU-compatible implementation to avoid the drawbacks of cross-device execution. The code block below contains a JAX implementation of NMS specifically designed to enable acceleration via JIT compilation. Denoting the number of boxes by N, we begin by calculating the IOU between each of the N(N-1) pairs of boxes and preparing an _N_xN boolean tensor (_maskthreshold) where the _(i,_j)-th entry indicates whether the IOU between boxes i and j exceed the predefined threshold.

To simplify the iterative selection of boxes, we create a copy of the mask tensor (_maskthreshold2) where the diagonal elements are zeroed to prevent a box from suppressing itself. We further define two score-tracking tensors: _outscores, which retains the scores of the chosen boxes (and zeros the scores of the eliminated ones), and _remainingscores, which maintains the scores of the boxes still being considered. We then use the jax.lax.while_loop function to iteratively choose boxes while updating the _outscores and _remainingscores tensors. Note that the format of the output of this function differs from the previous function and may need to be adjusted to fit into subsequent steps of the computation graph.

import functools

# Given N boxes, calculates mask_threshold an NxN boolean mask
# where the (i,j) entry indicates whether the IOU of boxes i and j
# exceed the threshold. Returns mask_threshold, mask_threshold2
# which is equivalent to mask_threshold with zero diagonal and
# the scores modified so that all values are greater than 0
def init_tensors(boxes, scores, threshold=0.1):
    epsilon = 1e-5

    # Extract left, top, right, bottom coordinates
    left = boxes[:, 0]
    top = boxes[:, 1]
    right = boxes[:, 2]
    bottom = boxes[:, 3]

    # Compute areas of boxes
    areas = (right - left) * (bottom - top)

    # Calculate intersection points
    inter_l = jnp.maximum(left[None, :], left[:, None])
    inter_t = jnp.maximum(top[None, :], top[:, None])
    inter_r = jnp.minimum(right[None, :], right[:, None])
    inter_b = jnp.minimum(bottom[None, :], bottom[:, None])

    # Width, height, and area of the intersection
    inter_w = jnp.clip(inter_r - inter_l, 0)
    inter_h = jnp.clip(inter_b - inter_t, 0)
    inter_area = inter_w * inter_h

    # Union of the areas
    union = areas[None, :] + areas[:, None] - inter_area

    # IoU calculation
    iou = inter_area / jnp.clip(union, epsilon)

    # Shift scores to be greater than zero
    out_scores = scores - jnp.min(scores) + epsilon

    # Create mask based on IoU threshold
    mask_threshold = iou > threshold

    # Create mask excluding diagonal (i.e., self IoU is ignored)
    mask_threshold2 = mask_threshold * (1-jnp.eye(mask_threshold.shape[0],
                                                  dtype=mask_threshold.dtype))

    return mask_threshold, mask_threshold2, out_scores

@functools.partial(jax.jit, static_argnames=['max_output_size', 'threshold'])
def nms_jax(boxes, scores, max_output_size, threshold=0.1):
    # initialize mask and score tensors
    mask_threshold, mask_threshold2, out_scores = init_tensors(boxes,
                                                           scores,
                                                           threshold)

    # The out_scores tensor will retain the scores of the chosen boxes
    # and zero the scores of the eliminated ones
    # remaining_scores will maintain non-zero scores for boxes that
    # have not been chosen or eliminated
    remaining_scores = out_scores.copy()

    def choose_box(state):
        i, remaining_scores, out_scores = state
        # choose index of box with highest score from remaining scores
        index = jnp.argmax(remaining_scores)
        # check validity of chosen box
        valid = remaining_scores[index] > 0
        # If valid, zero all scores with IOU greater than threshold
        # (including the chosen index)
        remaining_scores = jnp.where(mask_threshold[index] *valid,
                                     0,
                                     remaining_scores)
        # zero the scores of the eliminated tensors (not including
        # the chosen index)
        out_scores = jnp.where(mask_threshold2[index]*valid,
                               0,
                               out_scores)

        i = i + 1
        return i, remaining_scores, out_scores

    def cond_fun(state):
        i, _, _ = state
        return (i < max_output_size)

    i = 0
    state = (i, remaining_scores, out_scores)

    _, _, out_scores = jax.lax.while_loop(cond_fun, choose_box, state)

    # Output the resultant scores. To extract the chosen boxes,
    # Take the max_output_size highest scores:
    # min = jnp.minimum(jnp.count_nonzero(scores), max_output_size)
    # indexes = jnp.argsort(out_scores, descending=True)[:min]
    return out_scores

# nms_jax can be run on either the CPU the TPU
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)

time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on CPU: {time}')

rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)

time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on TPU: {time}')

The runtimes of this implementation of NMS are 1.231 and 0.416 milliseconds on CPU and TPU, respectively.

Custom NMS Pallas Kernel

We now present a custom implementation of NMS in which we explicitly leverage the fact that on TPUs Pallas kernels are executed in a sequential manner. Our implementation uses two boolean matrix masks and two score-keeping tensors, similar to the approach in our previous function.

We define a kernel function, _choosebox, responsible for selecting the next box and updating the score-keeping tensors, which are maintained in scratch memory. We invoke the kernel across a one-dimensional grid where the number of steps (i.e., the grid-size) is determined by the _max_outputsize parameter.

Note that due to some limitations (as of the time of this writing) on the operations supported by Pallas, some acrobatics are required to implement both the "argmax" function and the validity check for the selected boxes. For the sake of brevity, we omit the technical details and refer the interested reader to the comments in the code below.

from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

# argmax helper function
def pallas_argmax(scores, n_boxes):
    # we assume that the index of each box is stored in the
    # least significant bits of the score (see below)
    idx = jnp.max(scores.astype(float)).astype(int) % n_boxes
    return idx

# Pallas kernel definition
def choose_box(scores, thresh_mask1, thresh_mask2, ret_scores,
               scores_scratch, remaining_scores_scratch, *, nsteps, n_boxes):
    # initialize scratch memory on first step
    @pl.when(pl.program_id(0) == 0)
    def _():
        scores_scratch[...] = scores[...]
        remaining_scores_scratch[...] = scores[...]

    remaining_scores = remaining_scores_scratch[...]

    # choose box
    idx = pallas_argmax(remaining_scores, n_boxes)

    # we use any to verfiy validity of the chosen box due
    # to limitations on indexing in pallas
    valid = (remaining_scores>0).any()

    # updating score tensors
    remaining_scores_scratch[...] = jnp.where(thresh_mask1[idx,...]*valid,
                                              0,
                                              remaining_scores)
    scores_scratch[...] = jnp.where(thresh_mask2[idx,...]*valid,
                                    0,
                                    scores_scratch[...])

    # set return value on final step
    @pl.when(pl.program_id(0) == nsteps - 1)
    def _():
        ret_scores[...] = scores_scratch[...]

@functools.partial(jax.jit, static_argnames=['max_output_size', 'threshold'])
def nms_pallas(boxes, scores, max_output_size, threshold=0.1):
    n_boxes = scores.size
    mask_threshold, mask_threshold2, scores = init_tensors(boxes, 
                                                           scores,
                                                           threshold)

    # In order to work around the Pallas argsort limitation
    # we create a new scores tensor with the same ordering of
    # the input scores tensor in which the index of each score
    # in the ordering is encoded in the least significant bits
    sorted = jnp.argsort(scores, descending=True)

    # descending integers: n_boxes-1, ..., 2, 1, 0
    descending = jnp.flip(jnp.arange(n_boxes))

    # new scores in descending with the least significant
    # bits carrying the argsort of the input scores
    ordered_scores = n_boxes * descending + sorted

    # new scores with same ordering as input scores
    scores = jnp.empty_like(ordered_scores
                            ).at[sorted].set(ordered_scores)

    grid = (max_output_size,)
    return pl.pallas_call(
        functools.partial(choose_box, 
                          nsteps=max_output_size,
                          n_boxes=n_boxes),
        grid_spec=pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                pl.BlockSpec(block_shape=(n_boxes,)),
                pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
                pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
            ],
            out_specs=pl.BlockSpec(block_shape=(n_boxes,)),
            scratch_shapes=[pltpu.VMEM((n_boxes,), scores.dtype),
                            pltpu.VMEM((n_boxes,), scores.dtype)],
            grid=grid,
        ),
        out_shape=jax.ShapeDtypeStruct((n_boxes,), scores.dtype),
        compiler_params=dict(mosaic=dict(
            dimension_semantics=("arbitrary",)))
    )(scores, mask_threshold, mask_threshold2)

rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)

time = benchmark(nms_pallas)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_pallas: {time}')

The average runtime of our custom NMS operator is 0.139 milliseconds, making it roughly three times faster than our JAX-native implementation. This result highlights the potential of tailoring the implementation of sequential algorithms to the unique properties of the TPU architecture.

Note that in our Pallas kernel implementation, we load the full input tensors into TPU VMEM memory. Given the limited capacity of VMEM, scaling up the input size (i.e., increase the number of bounding boxes) will likely lead to memory issues. Typically, such limitations can be addressed by chunking the inputs with BlockSpecs. Unfortunately, applying this approach would break the current NMS implementation. Implementing NMS across input chunks would require a different design, which is beyond the scope of this post.

Results

The results of our experiments are summarized in the table below:

Results of NMS experiments (lower is better) - by Author
Results of NMS experiments (lower is better) – by Author

These results demonstrate the potential for running full ML computation graphs on TPU, even when they include sequential components. The performance improvement demonstrated by our Pallas NMS operator, in particular, highlights the opportunity of customizing kernels in a way that leverages the TPUs strengths.

Summary

In our previous post we learned of the opportunity for building custom TPU operators using the Pallas extension for JAX. Maximizing this opportunity requires tailoring the kernel implementations to the specific properties of the TPU architecture. In this post, we focused on the sequential nature of the TPU processor and its use in optimizing a custom NMS kernel. While scaling the solution to support an unrestricted number of bounding boxes would require further work, the core principles we have discussed remain applicable.

Still in the experimental phase of its development, there remain some limitations in Pallas that may require creative workarounds. But the strength and potential are clearly evident and we anticipate that they will only increase as the framework matures.

The post Implementing Sequential Algorithms on TPU appeared first on Towards Data Science.

]]>
The Rise of Pallas: Unlocking TPU Potential with Custom Kernels https://towardsdatascience.com/the-rise-of-pallas-unlocking-tpu-potential-with-custom-kernels-67be10ab846a/ Sun, 06 Oct 2024 09:16:53 +0000 https://towardsdatascience.com/the-rise-of-pallas-unlocking-tpu-potential-with-custom-kernels-67be10ab846a/ Accelerating AI/ML Model Training with Custom Operators - Part 3

The post The Rise of Pallas: Unlocking TPU Potential with Custom Kernels appeared first on Towards Data Science.

]]>
This is the third part of a series of posts on the topic of building custom operators for optimizing AI/ML workloads. In our previous post we demonstrated the simplicity and accessibility of Triton. Named for the Greek god of the sea, Triton empowers Python developers to increase their control over the GPU and optimize its use for the specific workload at hand. In this post we move one step down the lineage of Greek mythology to Triton’s daughter, Pallas and discuss her namesake, the JAX extension for writing custom kernels for GPU and Tpu.

One of the most important features of NVIDIA GPUs – and a significant factor in their rise to prominence – is their programmability. A key ingredient of the GPU offering are frameworks for creating General-Purpose GPU (GPGPU) operators, such as CUDA and Triton.

In previous posts (e.g., here) we discussed the opportunity for running ML workloads on Google TPUs and the potential for a meaningful increase in price performance and a reduction in training costs. One of the disadvantages that we noted at the time was the absence of tools for creating custom operators. As a result, models requiring unique operators that were either unsupported by the underlying ML framework (e.g., TensorFlow/Xla) or implemented in a suboptimal manner, would underperform on TPU compared to GPU. This development gap was particularly noticeable over the past few years with the frequent introduction of newer and faster solutions for computing attention on GPU. Enabled by GPU kernel development frameworks, these led to a significant improvement in the efficiency of transformer models.

On TPUs, on the other hand, the lack of appropriate tooling prevented this innovation and transformer models were stuck with the attention mechanisms that were supported by the official SW stack. Fortunately, with the advent of Pallas this gap has been addressed. Built as an extension to JAX and with dedicated support for PyTorch/XLA, Pallas enables the creation of custom kernels for GPU and TPU. For its GPU support Pallas utilizes Triton, and for its TPU support it uses a library called Mosaic. Although we will focus on custom kernels for TPU, it is worth noting that when developing in JAX, GPU kernel customization with Pallas offers some advantages over Triton (e.g., see here).

Our intention in this post is to draw attention to Pallas and demonstrate its potential. Please do not view this post as a replacement for the official Pallas documentation. The examples we will share were chosen for demonstrative purposes, only. We have made no effort to optimize these or verify their robustness, durability, or accuracy.

Importantly, at the time of this writing Pallas is an experimental feature and still under active development. The samples we share (which are based on JAX version 0.4.32 and PyTorch version 2.4.1) may become outdated by the time you read this. Be sure to use the most up-to-date APIs and resources available for your Pallas development.

Many thanks to Yitzhak Levi for his contributions to this post.

Environment Setup

For the experiments described below we use the following environment setup commands:

# create TPU node
gcloud alpha compute tpus queued-resources create v5litepod-1-resource 
     --node-id v5litepod 
     --project <project-id> 
     --zone us-central1-a 
     --accelerator-type v5litepod-1 
     --runtime-version v2-alpha-tpuv5-lite 
     --valid-until-duration 1d 
     --service-account <service-account> 

# check TPU node status (wait for state to be ACTIVE)
gcloud alpha compute tpus queued-resources describe v5litepod-1-resource 
     --project <project-id> 
     --zone us-central1-a

# SSH to TPU node
gcloud alpha compute tpus tpu-vm ssh v5litepod 
     --project <project-id> 
     --zone  us-central1-a

# install dependencies
pip install torch_xla[tpu] 
     -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas]
pip install timm

# run tests
python train.py

#exit ssh
exit

# delete TPU node
gcloud alpha compute tpus queued-resources delete v5litepod-1-resource 
     --project <project-id> 
     --zone us-central1-a --force --quiet

Pallas Kernels for TPU

In the toy example of our first post in this series, we distinguished between two different ways in which custom kernel development can potentially boost performance. The first is by combining (fusing) together multiple operations in a manner that reduces the overhead of: 1) loading multiple individual kernels, and 2) reading and writing intermediate values (e.g., see PyTorch’s tutorial on multiply-add fusion). The second is by meticulously applying the resources of the underlying accelerator in manner that optimizes the function at hand. We briefly discuss these two opportunities as they pertain to developing custom TPU kernels and make note of the limitations of the Pallas support.

Operator Fusion on TPU

The TPU is an XLA (Accelerated Linear Algebra) device, i.e., it runs code that has been generated by the XLA compiler. When training an AI model in a frameworks such as JAX or PyTorch/XLA, the training step is first transformed into an intermediate graph representation (IR). This computation graph is then fed to the XLA compiler which converts it into machine code that can run on the TPU. Contrary to eager execution mode, in which operations are executed individually, this mode of running models enables XLA to identify and implement opportunities for operator fusion during compilation. And, in fact, operator fusion is the XLA compiler’s most important optimization. Naturally, no compiler is perfect and we are certain to come across additional opportunities for fusion through custom kernels. But, generally speaking, we might expect the opportunity for boosting runtime performance in this manner to be lower than in the case of eager execution.

Optimizing TPU Utilization

Creating optimal kernels for TPU requires a comprehensive and intimate understanding of the TPU system architecture. Importantly, TPUs are very different from GPUs: expertise in GPUs and CUDA does not immediately carry over to TPU development. For example, while GPUs contain a large number of processors and draw their strength from their ability to perform massive parallelization, TPUs are primarily sequential with dedicated engines for running highly vectorized operations and support for asynchronous scheduling and memory loading.

The differences between the underlying architectures of the GPU and TPU can have significant implications on how custom kernels should be designed. Mastering TPU kernel development requires 1) appropriate overlapping of memory and compute operations via pipelining, 2) knowing how to mix between the use of the scalar, vector (VPU) and matrix (MXU) compute units and their associated scalar and vector registers (SREG and VREG) and memory caches (SMEM and VMEM), 3) a comprehension of the costs of different low-level operations, 4) appropriate megacore configuration (on supporting TPU generations), 5) a grasp of the different types of TPU topologies and their implications on how to support distributed computing, and more.

Framework Limitations

While the ability to create custom operators in Python using JAX functions and APIs greatly increases the simplicity and accessibility of Pallas kernel development, it also limits its expressivity. Additionally, (as of the time of this writing) there are some JAX APIs that are not supported by Pallas on TPU (e.g., see here). As a result, you may approach Pallas with the intention of implementing a particular operation only to discover that the framework does not support the APIs that you need. This is in contrast to frameworks such as CUDA which enable a great deal of flexibility when developing custom kernels (for GPU).

The matrix multiplication tutorial in the Pallas documentation provides an excellent introduction to Pallas kernel development, highlighting the potential for operator fusion and customization alongside the challenges involved in optimizing performance (e.g., appropriate tuning of the input block size). The tutorial clearly illustrates that maximizing the full potential of the TPU requires a certain degree of specialization. However, as we intend to demonstrate, even the novice ML developer can benefit from Pallas kernels.

Integrating the Use of Existing Pallas Kernels

To benefit from custom Pallas kernels you do not necessarily need to know how to build them. In our first example we demonstrate how you can leverage existing Pallas kernels from dedicated public repositories.

Example – Flash Attention in Torch/XLA

The JAX github repository includes implementations of a number of Pallas kernels, including [flash attention](https://github.com/pytorch/xla/blob/v2.4.0/torch_xla/experimental/custom_kernel.py#L425). Here we will demonstrate its use in a Torch/XLA Vision Transformer (ViT) model. Although Pallas kernels are developed in JAX, they can be adopted into Torch/XLA, e.g., via the make_kernel_from_pallas utility (see the documentation for details). In the case of flash attention the adoption is implemented by Torch/XLA.

In the following code block we define a stripped down version of the classic timm attention block with an option to define the underlying attention operator in the constructor. We will use this option to compare the performance of the flash attention Pallas kernel to its alternatives.

# general imports
import os, time, functools
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch_xla.core.xla_model as xm
# custom kernel import
from torch_xla.experimental.custom_kernel import flash_attention
# timm imports
from timm.layers import Mlp
from timm.models.vision_transformer import VisionTransformer

class TPUAttentionBlock(nn.Module):
    def __init__(
            self,
            dim: int = 768,
            num_heads: int = 12,
            attn_fn = None,
            **kwargs
    ) -> None:
        super().__init__()
        self.attn_fn = attn_fn
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=dim * 4,
        )

    def forward(self, x_in: torch.Tensor) -> torch.Tensor:
        x = self.norm1(x_in)

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        if self.attn_fn is None:
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            x = attn @ v
        else:
            x = self.attn_fn(q, k, v)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = x + x_in
        x = x + self.mlp(self.norm2(x))
        return x

In the following block we train a simple ViT-backed classification model using the input dataset and attention function (_attnfn) of choice.

def train(dataset, attn_fn=None):
    device = xm.xla_device()

    train_loader = DataLoader(
        dataset,
        batch_size=128,
        num_workers=os.cpu_count(),
        pin_memory=True
    )

    # configure the VisionTranformer in a manner that complies with the 
    # Pallas flash_attention kernel constraints
    model = VisionTransformer(
        block_fn=functools.partial(TPUAttentionBlock, attn_fn=attn_fn),
        img_size=256,
        class_token=False,
        global_pool="avg"
    )

    optimizer = torch.optim.SGD(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    # copy the model to the TPU
    model = model.to(device)

    model.train()

    t0 = time.perf_counter()
    summ = 0
    count = 0

    for step, data in enumerate(train_loader):
        # copy data to TPU
        inputs = data[0].to(device=device, non_blocking=True)
        label = data[1].to(device=device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.autocast('xla', dtype=torch.bfloat16):
            output = model(inputs)
            loss = loss_fn(output, label)
        loss.backward()
        optimizer.step()
        xm.mark_step()

        # capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step > 100:
            break

    print(f'average step time: {summ / count}')

Note the specific configuration we chose for the VisionTransformer. This is to comply with certain restrictions (as of the time of this writing) of the custom flash attention kernel (e.g., on tensor shapes).

Finally, we define a dataset and compare the runtimes of training with three different attention routines, 1. using native PyTorch functions, 2. using PyTorch’s built in SDPA function, and 3. using the custom Pallas operator:

# use random data
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, 256, 256], dtype=torch.float32)
        label = torch.tensor(data=index % 1024, dtype=torch.int64)
        return rand_image, label

ds = FakeDataset()

print('PyTorch native')
train(ds, attn_fn=None)

print('PyTorch SDPA')
train(ds, attn_fn=functools.partial(F.scaled_dot_product_attention, scale=1.0))

print('Pallas flash_attention')
train(ds, attn_fn=flash_attention)

The comparative results are captured in the table below:

Step time for different attention blocks (lower is better) - by Author
Step time for different attention blocks (lower is better) – by Author

Although our Pallas kernel clearly underperforms when compared to its alternatives, we should not be discouraged:

  1. It is likely that these results could be improved with appropriate tuning.
  2. These results are specific to the model and runtime environment that we chose. The Pallas kernel may exhibit wholly different comparative results in other use cases.
  3. The real power of Pallas is in the ability to create and adjust low level operators to our specific needs. Although runtime performance is important, a 23% performance penalty (as in our example) may be a small price to pay for this flexibility. Moreover, the opportunity for customization may open up possibilities for optimizations that are not supported by the native framework operations.

Enhancing Existing Kernels

Oftentimes it may be easier to tweak an existing Pallas kernel to your specific needs, rather than creating one from scratch. This is especially recommended if the kernel has already been optimized as performance tuning can be tedious and time-consuming. The official matrix multiplication tutorial includes a few examples of how to extend and enhance an existing kernel. Here we undertake one of the suggested exercises: we implement int8 matrix multiplication and assess its performance advantage over its bfloat16 alternative.

Example – Int8 Matrix Multiplication

In the code block below we implement an int8 version of the matrix multiplication example.

import functools, timeit
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

# set to True to develop/debug on CPU
interpret = False

def matmul_kernel_int8(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
    @pl.when(pl.program_id(2) == 0)
    def _():
        acc_ref[...] = jnp.zeros_like(acc_ref)

    acc_ref[...] += jnp.dot(
        x_ref[...], y_ref[...], preferred_element_type=jnp.int32
    )

    @pl.when(pl.program_id(2) == nsteps - 1)
    def _():
        z_ref[...] = acc_ref[...]

@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul_int8(
        x: jax.Array,
        y: jax.Array,
        *,
        bm: int = 128,
        bk: int = 128,
        bn: int = 128,
):
    m, k = x.shape
    _, n = y.shape
    return pl.pallas_call(
        functools.partial(matmul_kernel_int8, nsteps=k // bk),
        grid_spec=pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                pl.BlockSpec(block_shape=(bm, bk), 
                             index_map=lambda i, j, k: (i, k)),
                pl.BlockSpec(block_shape=(bk, bn),
                             index_map=lambda i, j, k: (k, j)),
            ],
            out_specs=pl.BlockSpec(block_shape=(bm, bn), 
                                   index_map=lambda i, j, k: (i, j)),
            scratch_shapes=[pltpu.VMEM((bm, bn), jnp.int32)],
            grid=(m // bm, n // bn, k // bk),
        ),
        out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
        compiler_params=dict(mosaic=dict(
            dimension_semantics=("parallel", "parallel", "arbitrary"))),
        interpret=interpret
    )(x, y)

Note our use of an int32 accumulation matrix for addressing the possibility of overflow. Also note our use of the _interpret_ flag for debugging of Pallas kernels on CPU (as recommended here).

To assess our kernel, we introduce a slight modification to the benchmarking utilities defined in the tutorial and compare the runtime results to both the jnp.float16 Pallas matmul __ kernel and the built-in JAX matmul API:

def benchmark(f, ntrials: int = 100):
    def run(*args, **kwargs):
        # Compile function first
        jax.block_until_ready(f(*args, **kwargs))
        # Time function
        res=timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                             number=ntrials
                              )
        time = res/ntrials
        # print(f"Time: {time}")
        return time

    return run

def analyze_matmul(m: int, k: int, n: int, dtype: jnp.dtype,
                   mm_func):
    x = jnp.ones((m, k), dtype=dtype)
    y = jnp.ones((k, n), dtype=dtype)
    time = benchmark(mm_func)(x, y)
    print("Matmul time: ", time)
    mm_ops = 2*m*k*n/time
    v5e_ops = 394e12 if dtype == jnp.int8 else 197e12
    print(f"OP/s utilization: {mm_ops / v5e_ops * 100:.4f}%")
    print()

print("bfloat16 Pallas matmul")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("int8 Pallas matmul")
mm = functools.partial(matmul_int8, bm=512, bk=1024, bn=1024)
analyze_matmul(8192, 8192, 8192, jnp.int8, mm)

print("XLA int8 matmul")
mm = functools.partial(jnp.matmul, preferred_element_type=jnp.int32)
analyze_matmul(8192, 8192, 8192, jnp.int8, mm)

The results of our experiment are captured in the table below:

Matmul time and utilization (by Author)
Matmul time and utilization (by Author)

By using int8 matrices (rather than bfloat16matrices) on tpuv5e we can boost the runtime performance of our custom matrix multiplication kernel by 71%. However, as in the case of the bfloat16 example, additional tuning is required to match the performance of the built-in matmul operator. The potential for improvement is highlighted by the drop in system utilization when compared to bfloat16.

Creating a Kernel from Scratch

While leveraging existing kernels can be greatly beneficial, it is unlikely to solve all of your problems. Inevitably, you may need to implement an operation that is either unsupported on TPU or exhibits suboptimal performance. Here we demonstrate the creation of a relatively simple pixel-wise kernel. For the sake of continuity, we choose the same Generalized Intersection Over Union (GIOU) operation as in our previous posts.

Example – A GIOU Pallas Kernel

In the code block below we define a Pallas kernel that implements GIOU on pairs of batches of bounding boxes, each of dimension BxNx4 (where we denote the batch size by B and the number of boxes per sample by N) __ . The function returns a tensor of scores of dimension _Bx_N. We choose a block size of 128 on both the _batc_h axis and the _boxe_s axis, i.e., we divide each of the tensors into blocks of 128x128x4 and pass them to our kernel function. The _grid_ and BlockSpec _index_ma_p are defined accordingly.

import timeit
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

# set to True to develop/debug on CPU
interpret = False

# perform giou on a single block
def giou_kernel(preds_left_ref,
                preds_top_ref,
                preds_right_ref,
                preds_bottom_ref,
                targets_left_ref,
                targets_top_ref,
                targets_right_ref,
                targets_bottom_ref,
                output_ref):
    epsilon = 1e-5

    # copy tensors into local memory
    preds_left = preds_left_ref[...]
    preds_top = preds_top_ref[...]
    preds_right = preds_right_ref[...]
    preds_bottom = preds_bottom_ref[...]

    gt_left = targets_left_ref[...]
    gt_top = targets_top_ref[...]
    gt_right = targets_right_ref[...]
    gt_bottom = targets_bottom_ref[...]

    # Compute the area of each box
    area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
    area2 = (gt_right - gt_left) * (gt_bottom - gt_top)

    # Compute the intersection
    left = jnp.maximum(preds_left, gt_left)
    top = jnp.maximum(preds_top, gt_top)
    right = jnp.minimum(preds_right, gt_right)
    bottom = jnp.minimum(preds_bottom, gt_bottom)

    # intersection width and height
    inter_w = jnp.maximum(right - left, 0)
    inter_h = jnp.maximum(bottom - top, 0)

    # intersection area
    inter_area = inter_w * inter_h

    # union of two boxes
    union_area = area1 + area2 - inter_area

    iou_val = inter_area / jnp.maximum(union_area, epsilon)

    # Compute the smallest enclosing box
    enclose_left = jnp.minimum(preds_left, gt_left)
    enclose_top = jnp.minimum(preds_top, gt_top)
    enclose_right = jnp.maximum(preds_right, gt_right)
    enclose_bottom = jnp.maximum(preds_bottom, gt_bottom)

    # enclosing box width and height
    enclose_w = jnp.maximum(enclose_right - enclose_left, 0)
    enclose_h = jnp.maximum(enclose_bottom - enclose_top, 0)

    # enclosing box area
    enclose_area = enclose_w * enclose_h

    # Compute GIOU
    delta_area = (enclose_area - union_area)
    enclose_area = jnp.maximum(enclose_area, epsilon)
    output_ref[...] = iou_val - delta_area / enclose_area

@jax.jit
def batch_giou(preds, targets):
    m, n, _ = preds.shape
    output = pl.pallas_call(
        giou_kernel,
        out_shape=jax.ShapeDtypeStruct((m, n), preds.dtype),
        in_specs=[pl.BlockSpec(block_shape=(128, 128),
                               index_map=lambda i, j: (i, j))]*8,
        out_specs=pl.BlockSpec(block_shape=(128, 128),
                                index_map=lambda i, j: (i, j)),
        grid=(m // 128, n // 128),
        compiler_params=dict(mosaic=dict(
            dimension_semantics=("parallel", "parallel"))),
        interpret=interpret
    )(*jnp.unstack(preds, axis=-1), *jnp.unstack(targets, axis=-1))
    return output

Although the creation of a new TPU kernel is certainly cause for celebration (especially if it enables a previously blocked ML workload) our work is not done. A critical part of Pallas kernel development is tuning the operator, (e.g. the _block size_) for optimal runtime performance. We omit this stage in the interest of brevity.

To asses the performance of our kernel, we compare it to the following native JAX GIOU implementation:

def batched_box_iou(boxes1, boxes2):
    epsilon = 1e-5

    # Compute areas of both sets of boxes
    area1 = (boxes1[..., 2]-boxes1[..., 0])*(boxes1[..., 3]-boxes1[..., 1])
    area2 = (boxes2[..., 2]-boxes2[..., 0])*(boxes2[..., 3]-boxes2[..., 1])

    # corners of intersection
    lt = jnp.maximum(boxes1[..., :2], boxes2[..., :2])
    rb = jnp.minimum(boxes1[..., 2:], boxes2[..., 2:])

    # width and height of intersection
    wh = jnp.clip(rb - lt, a_min=0)

    # area of the intersection
    inter = wh[..., 0] * wh[..., 1]

    # union of the two boxes
    union = area1 + area2 - inter
    iou = inter / jnp.clip(union, a_min=epsilon)

    # corners of enclosing box
    lti = jnp.minimum(boxes1[..., :2], boxes2[..., :2])
    rbi = jnp.maximum(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of the enclosing box
    whi = jnp.clip(rbi - lti, a_min=0)

    # Area of the enclosing box
    areai = jnp.clip(whi[..., 0] * whi[..., 1], a_min=epsilon)

    # Generalized IoU
    return iou - (areai - union) / areai

We generate two batches of randomly generated bounding boxes and measure the performance of our functions using the benchmark function defined above.

from jax import random

batch_size = 1024
n_boxes = 256
img_size = 256
boxes = []
for i in range(2):
    k1, k2 = random.split(random.key(i), 2)

    # Randomly generate box sizes and positions
    box_sizes = random.randint(k1, shape=(batch_size, n_boxes, 2), minval=1, maxval=img_size)
    top_left = random.randint(k2, shape=(batch_size, n_boxes, 2), minval=0, maxval=img_size - 1)
    bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)

    # Concatenate top-left and bottom-right coordinates
    rand_boxes = jnp.concatenate((top_left, bottom_right), axis=2)

    boxes.append(rand_boxes.astype(jnp.float32))

time = benchmark(batch_giou)(boxes[0], boxes[1])
print(f'Pallas kernel: {time}')
time = benchmark(batched_box_iou)(boxes[0], boxes[1])
print(f'JAX function: {time}')
time = benchmark(jax.jit(batched_box_iou))(boxes[0], boxes[1])
print(f'Jitted function: {time}')

The comparative results appear in the table below:

Avg time of different GIOU implementations (lower is better) - by Author
Avg time of different GIOU implementations (lower is better) – by Author

We can see that JIT-compiling our naive JAX implementation results in slightly better performance than our Pallas kernel. Once again, we can see that matching or surpassing the performance results of JIT compilation (and its inherent kernel fusion) would require fine-tuning of our custom kernel.

Utilizing the Sequential Nature of TPUs

While the ability to develop custom kernels for TPU offers great potential, our examples thus far have demonstrated that reaching optimal runtime performance could be challenging. One way to overcome this is to seek opportunities to utilize the unique properties of the TPU architecture. One example of this is the sequential nature of the TPU processor. Although deep learning workloads tend to rely on operations that are easily parallelizable (e.g., matrix multiplication), on occasion they require algorithms that are inherently sequential. These can pose a serious challenge for the SIMT (single instruction multi thread) model of GPUs and can sometimes have a disproportionate impact on runtime performance. In a sequel to this post, we demonstrate how we can implement sequential algorithms in a way that takes advantage of the TPUs sequential processor and in a manner that minimizes their performance penalty.

Summary

The introduction of Pallas marks an important milestone in the evolution of TPUs. By enabling customization of TPU operations it can potentially unlock new opportunities for TPU programmability, particularly in the world of ML. Our intention in this post was to demonstrate the accessibility of this powerful new feature. While our examples have indeed shown this, they have also highlighted the effort required to reach optimal runtime performance.

This post has merely scratched the surface of Pallas kernel development. Be sure to see the official documentation to learn more about automatic differentiation in Pallas, developing sparse kernels, and more.

The post The Rise of Pallas: Unlocking TPU Potential with Custom Kernels appeared first on Towards Data Science.

]]>
Training AI Models on CPU https://towardsdatascience.com/training-ai-models-on-cpu-3903adc9f388/ Sun, 01 Sep 2024 18:59:40 +0000 https://towardsdatascience.com/training-ai-models-on-cpu-3903adc9f388/ Revisiting CPU for ML in an Era of GPU Scarcity

The post Training AI Models on CPU appeared first on Towards Data Science.

]]>
The recent successes in AI are often attributed to the emergence and evolutions of the GPU. The GPU’s architecture, which typically includes thousands of multi-processors, high-speed memory, dedicated tensor cores, and more, is particularly well-suited to meet the intensive demands of AI/ML workloads. Unfortunately, the rapid growth in AI development has led to a surge in the demand for GPUs, making them difficult to obtain. As a result, ML developers are increasingly exploring alternative hardware options for training and running their models. In previous posts, we discussed the possibility of training on dedicated AI ASICs such as Google Cloud TPU, Haban Gaudi, and AWS Trainium. While these options offer significant cost-saving opportunities, they do not suit all ML models and can, like the GPU, also suffer from limited availability. In this post we return to the good old-fashioned Cpu and revisit its relevance to ML applications. Although CPUs are generally less suited to ML workloads compared to GPUs, they are much easier to acquire. The ability to run (at least some of) our workloads on CPU could have significant implications on development productivity.

In previous posts (e.g., here) we emphasized the importance of analyzing and optimizing the runtime performance of AI/ML workloads as a means of accelerating development and minimizing costs. While this is crucial regardless of the compute engine used, the profiling tools and optimization techniques can vary greatly between platforms. In this post, we will discuss some of the performance optimization options that pertain to CPU. Our focus will be on Intel® Xeon® CPU processors (with Intel® AVX-512) and on the PyTorch (version 2.4) framework (although similar techniques can be applied to other CPUs and frameworks, as well). More specifically, we will run our experiments on an Amazon EC2 c7i instance with an AWS Deep Learning AMI. Please do not view our choice of Cloud platform, CPU version, ML framework, or any other tool or library we should mention, as an endorsement over their alternatives.

Our goal will be to demonstrate that although ML development on CPU may not be our first choice, there are ways to "soften the blow" and – in some cases – perhaps even make it a viable alternative.

Disclaimers

Our intention in this post is to demonstrate just a few of the ML optimization opportunities available on CPU. Contrary to most of the online tutorials on the topic of ML optimization on CPU, we will focus on a training workload rather than an inference workload. T[here](https://pytorch.org/blog/accelerated-cpu-inference/) are a number of optimization tools focused specifically on inference that we will not cover (e.g., see here and here).

Please do not view this post as a replacement of the official documentation on any of the tools or techniques that we mention. Keep in mind that given the rapid pace of AI/ML development, some of the content, libraries, and/or instructions that we mention may become outdated by the time you read this. Please be sure to refer to the most up-to-date documentation available.

Importantly, the impact of the optimizations that we discuss on runtime performance is likely to vary greatly based on the model and the details of the environment (e.g., see the high degree of variance between models on the official PyTorch TouchInductor CPU Inference Performance Dashboard). The comparative performance numbers we will share are specific to the toy model and runtime environment that we will use. Be sure to reevaluate all of the proposed optimizations on your own model and runtime environment.

Lastly, our focus will be solely on throughput performance (as measured in samples per second) – not on training convergence. However, it should be noted that some optimization techniques (e.g., batch size tuning, mixed precision, and more) could have a negative effect on the convergence of certain models. In some cases, this can be overcome through appropriate hyperparameter tuning.

Toy Example – ResNet-50

We will run our experiments on a simple image classification model with a ResNet-50 backbone (from Deep Residual Learning for Image Recognition). We will train the model on a fake dataset. The full training script appears in the code block below (loosely based on this example):

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import time

# A dataset with random images and labels
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(data=index % 10, dtype=torch.uint8)
        return rand_image, label

train_set = FakeDataset()

batch_size=128
num_workers=0

train_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    num_workers=num_workers
)

model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
model.train()

t0 = time.perf_counter()
summ = 0
count = 0

for idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    batch_time = time.perf_counter() - t0
    if idx > 10:  # skip first steps
        summ += batch_time
        count += 1
    t0 = time.perf_counter()
    if idx > 100:
        break

print(f'average step time: {summ/count}')
print(f'throughput: {count*batch_size/summ}')

Running this script on a c7i.2xlarge (with 8 vCPUs) and the CPU version of PyTorch 2.4, results in a throughput of 9.12 samples per second. For the sake of comparison, we note that the throughput of the same (unoptimized script) on an Amazon EC2 g5.2xlarge instance (with 1 GPU and 8 vCPUs) is 340 samples per second. Taking into account the comparative costs of these two instance types ($0.357 per hour for a c7i.2xlarge and $1.212 for a g5.2xlarge, as of the time of this writing), we find that training on the GPU instance to give roughly eleven(!!) times better price performance. Based on these results, the preference for using GPUs to train ML models is very well founded. Let’s assess some of the possibilities for reducing this gap.

PyTorch Performance Optimizations

In this section we will explore some basic methods for increasing the runtime performance of our training workload. Although you may recognize some of these from our post on GPU optimization, it is important to highlight a significant difference between training optimization on CPU and GPU platforms. On GPU platforms much of our effort was dedicated to maximizing the parallelization between (the training data preprocessing on) the CPU and (the model training on) the GPU. On CPU platforms all of the processing occurs on the CPU and our goal will be to allocate its resources most effectively.

Batch Size

Increasing the training batch size can potentially increase performance by reducing the frequency of the model parameter updates. (On GPUs it has the added benefit of reducing the overhead of CPU-GPU transactions such as kernel loading). However, while on GPU we aimed for a batch size that would maximize the utilization of the GPU memory, the same strategy might hurt performance on CPU. For reasons beyond the scope of this post, CPU memory is more complicated and the best approach for discovering the most optimal batch size may be through trial and error. Keep in mind that changing the batch size could affect training convergence.

The table below summarizes the throughput of our training workload for a few (arbitrary) choices of batch size:

Training Throughput as Function of Batch Size (by Author)
Training Throughput as Function of Batch Size (by Author)

Contrary to our findings on GPU, on the c7i.2xlarge instance type our model appears to prefer lower batch sizes.

Multi-process Data Loading

A common technique on GPUs is to assign multiple processes to the data loader so as to reduce the likelihood of starvation of the GPU. On GPU platforms, a general rule of thumb is to set the number of workers according to the number of CPU cores. However, on CPU platforms, where the model training uses the same resources as the data loader, this approach could backfire. Once again, the best approach for choosing the optimal number of workers may be trial and error. The table below shows the average throughput for different choices of _numworkers:

Training Throughput as Function of the Number of Data Loading Workers (by Author)
Training Throughput as Function of the Number of Data Loading Workers (by Author)

Mixed Precision

Another popular technique is to use lower precision floating point datatypes such as torch.float16 or torch.bfloat16 with the dynamic range of torch.bfloat16 generally considered to be more amiable to ML training. Naturally, reducing the datatype precision can have adverse effects on convergence and should be done carefully. PyTorch comes with torch.amp, an automatic mixed precision package for optimizing the use of these datatypes. Intel® AVX-512 includes support for the bfloat16 datatype. The modified training step appears below:

for idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    with torch.amp.autocast('cpu',dtype=torch.bfloat16):
        output = model(data)
        loss = criterion(output, target)
    loss.backward()
    optimizer.step()

The throughput following this optimization is 24.34 samples per second, an increase of 86%!!

Channels Last Memory Format

Channels last memory format is a beta-level optimization (at the time of this writing), pertaining primarily to vision models, that supports storing four dimensional (NCHW) tensors in memory such that the channels are the last dimension. This results in all of the data of each pixel being stored together. This optimization pertains primarily to vision models. Considered to be more "friendly to Intel platforms", this memory format reportedly boosts the performance of a ResNet-50 on an Intel® Xeon® CPU. The adjusted training step appears below:

for idx, (data, target) in enumerate(train_loader):
    data = data.to(memory_format=torch.channels_last)
    optimizer.zero_grad()
    with torch.amp.autocast('cpu',dtype=torch.bfloat16):
        output = model(data)
        loss = criterion(output, target)
    loss.backward()
    optimizer.step()

The resulting throughput is 37.93 samples per second – an additional 56% improvement and a total of 415% compared to our baseline experiment. We are on a role!!

Torch Compilation

In a previous post we covered the virtues of PyTorch’s support for graph compilation and its potential impact on runtime performance. Contrary to the default eager execution mode in which each operation is run independently (a.k.a., "eagerly"), the compile API converts the model into an intermediate computation graph which is then JIT-compiled into low-level machine code in a manner that is optimal for the underlying training engine. The API supports compilation via different backend libraries and with multiple configuration options. Here we will limit our evaluation to the default (TorchInductor) backend and the ipex backend from the Intel® Extension for PyTorch, a library with dedicated optimizations for Intel hardware. Please see the documentation for appropriate installation and usage instructions. The updated model definition appears below:

import intel_extension_for_pytorch as ipex

model = torchvision.models.resnet50()
backend='inductor' # optionally change to 'ipex'
model = torch.compile(model, backend=backend)

In the case of our toy model, the impact of torch compilation is only apparent when the "channels last" optimization is disabled (an increase of ~27% for each of the backends). When "channels last" is applied, the performance actually drops. As a result, we drop this optimization from our subsequent experiments.

Memory and Thread Optimizations

There are a number of opportunities for optimizing the use of the underlying CPU resources. These include optimizing memory management and thread allocation to the structure of the underlying CPU hardware. Memory management can be improved through the use of advanced memory allocators (such as Jemalloc and TCMalloc) and/or reducing memory accesses that are slower (i.e., across NUMA nodes). Threading allocation can be improved through appropriate configuration of the OpenMP threading library and/or use of Intel’s Open MP library.

Generally speaking, these kinds of optimizations require a deep level understanding of the CPU architecture and the features of its supporting SW stack. To simplify matters, PyTorch offers the _torch.backends.xeon.run_cpu_ script for automatically configuring the memory and threading libraries so as to optimize runtime performance. The command below will result in the use of the dedicated memory and threading libraries. We will return to the topic of NUMA nodes when we discuss the option of distributed training.

We verify appropriate installation of TCMalloc (conda install conda-forge::gperftools) and Intel’s Open MP library (pip install intel-openmp), and run the following command.

python -m torch.backends.xeon.run_cpu train.py

The use of the _[run_cpu](https://pytorch.org/tutorials/recipes/xeon_run_cpu.html)_ script further boosts our runtime performance to 39.05 samples per second. Note that the _runcpu script includes many controls for further tuning performance. Be sure to check out the documentation in order to maximize its use.

The Intel Extension for PyTorch

The Intel® Extension for PyTorch includes additional opportunities for training optimization via its ipex.optimize function. Here we demonstrate its default use. Please see the documentation to learn of its full capabilities.

 model = torchvision.models.resnet50()
 criterion = torch.nn.CrossEntropyLoss()
 optimizer = torch.optim.SGD(model.parameters())
 model.train()
 model, optimizer = ipex.optimize(
    model, 
    optimizer=optimizer,
    dtype=torch.bfloat16
 )

Combined with the memory and thread optimizations discussed above, the resultant throughput is 40.73 samples per second. (Note that a similar result is reached when disabling the "channel’s last" configuration.)

Distributed Training on CPU

Intel® Xeon® processors are designed with Non-Uniform Memory Access (NUMA) in which the CPU memory is divided into groups, a.k.a., NUMA nodes, and each of the CPU cores is assigned to one node. Although any CPU core can access the memory of any NUMA node, the access to its own node (i.e., its local memory) is much faster. This gives rise to the notion of distributing training across NUMA nodes, where the CPU cores assigned to each NUMA node act as a single process in a distributed process group and data distribution across nodes is managed by Intel® oneCCL, Intel’s dedicated collective communications library.

We can run data distributed training across NUMA nodes easily using the _ipexrun_ utility. In the following code block (loosely based on this example) we adapt our script to run data distributed training (according to usage detailed here):

import os, time
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torchvision
import oneccl_bindings_for_pytorch as torch_ccl
import intel_extension_for_pytorch as ipex

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = os.environ.get("PMI_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("PMI_SIZE", "1")
dist.init_process_group(backend="ccl", init_method="env://")
rank = os.environ["RANK"]
world_size = os.environ["WORLD_SIZE"]

batch_size = 128
num_workers = 0

# define dataset and dataloader
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(data=index % 10, dtype=torch.uint8)
        return rand_image, label

train_dataset = FakeDataset()
dist_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=dist_sampler
)

# define model artifacts
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
model.train()
model, optimizer = ipex.optimize(
    model, 
    optimizer=optimizer,
    dtype=torch.bfloat16
)

# configure DDP
model = torch.nn.parallel.DistributedDataParallel(model)

# run training loop

# destroy the process group
dist.destroy_process_group()

Unfortunately, as of the time of this writing, the Amazon EC2 c7i instance family does not include a multi-NUMA instance type. To test our distributed training script, we revert back to an Amazon EC2 c6i.32xlarge instance with 64 vCPUs and 2 NUMA nodes. We verify the installation of Intel® oneCCL Bindings for PyTorch and run the following command (as documented here):

source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/setvars.sh

# This example command would utilize all the numa sockets of the processor, taking each socket as a rank.
ipexrun cpu --nnodes 1 --omp_runtime intel train.py 

The following table compares the performance results on the c6i.32xlarge instance with and without distributed training:

Distributed Training Across NUMA Nodes (by Author)
Distributed Training Across NUMA Nodes (by Author)

In our experiment, data distribution did not boost the runtime performance. Please see _ipexrun documentation_ for additional performance tuning options.

CPU Training with Torch/XLA

In previous posts (e.g., [here](https://docs.google.com/document/d/1ZzMcrjxITJeN2IjjgbzUjHh-4W1YgDUus3j25Dvn9ng/edit#heading=h.w9ztr841aqk8)) we discussed the [PyTorch/XLA](https://pytorch.org/xla/release/r2.4/index.html#) library and its use of XLA compilation to enable PyTorch based training on _XLA devices_ such as TPU, GPU, and CPU. Similar to torch compilation, XLA uses graph compilation to generate machine code that is optimized for the target device. With the establishment of the OpenXLA Project, one of the stated goals was to support high performance across all hardware backends, including CPU (see the CPU RFC here). The code block below demonstrates the adjustments to our original (unoptimized) script required to train using PyTorch/XLA:

import torch
import torchvision
import timeimport torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

model = torchvision.models.resnet50().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
model.train()

for idx, (data, target) in enumerate(train_loader):
    data = data.to(device)
    target = target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    xm.mark_step()

Unfortunately, (as of the time of this writing) the XLA results on our toy model seem far inferior to the (unoptimized) results we saw above (— by as much as 7X). We expect this to improve as PyTorch/XLA’s CPU support matures.

Results

We summarize the results of a subset of our experiments in the table below. For the sake of comparison, we add the throughput of training our model on Amazon EC2 g5.2xlarge GPU instance following the optimization steps discussed in this post. The samples per dollar was calculated based on the Amazon EC2 On-demand pricing page ($0.357 per hour for a c7i.2xlarge and $1.212 for a g5.2xlarge, as of the time of this writing).

Performance Optimization Results (by Author)
Performance Optimization Results (by Author)

Although we succeeded in boosting the training performance of our toy model on the CPU instance by a considerable margin (446%), it remains inferior to the (optimized) performance on the GPU instance. Based on our results, training on GPU would be ~6.7 times cheaper. It is likely that with additional performance tuning and/or applying additional optimizations strategies, we could further close the gap. Once again, we emphasize that the comparative performance results we have reached are unique to this model and runtime environment.

Amazon EC2 Spot Instances Discounts

The increased availability of cloud-based CPU instance types (compared to GPU instance types) may imply greater opportunity for obtaining compute power at discounted rates, e.g., through Spot Instance utilization. Amazon EC2 Spot Instances are instances from surplus cloud service capacity that are offered for a discount of as much as 90% off the On-Demand pricing. In exchange for the discounted price, AWS maintains the right to preempt the instance with little to no warning. Given the high demand for GPUs, you may find CPU spot instances easier to get ahold of than their GPU counterparts. At the time of this writing, c7i.2xlarge Spot Instance price is $0.1291 which would improve our samples per dollar result to 1135.76 and further reduces the gap between the optimized GPU and CPU price performances (to 2.43X).

While the runtime performance results of the optimized CPU training of our toy model (and our chosen environment) were lower than the GPU results, it is likely that the same optimization steps applied to other model architectures (e.g., ones that include components that are not supported by GPU) may result in the CPU performance matching or beating that of the GPU. And even in cases where the performance gap is not bridged, there may very well be cases where the shortage of GPU compute capacity would justify running some of our ML workloads on CPU.

Summary

Given the ubiquity of the CPU, the ability to use them effectively for training and/or running ML workloads could have huge implications on development productivity and on end-product deployment strategy. While the nature of the CPU architecture is less amiable to many ML applications when compared to the GPU, there are many tools and techniques available for boosting its performance – a select few of which we have discussed and demonstrated in this post.

In this post we focused optimizing training on CPU. Please be sure to check out our many other posts on medium covering a wide variety of topics pertaining to performance analysis and optimization of Machine Learning workloads.

The post Training AI Models on CPU appeared first on Towards Data Science.

]]>