## Faster training with mixed precision By default tensors and model parameters in PyTorch are stored in 32-bit floating point precision. Training neural networks using 32-bit floats is usually stable and doesn't cause major numerical issues, however neural networks have been shown to perform quite well in 16-bit and even lower precisions. Computation in lower precisions can be significantly faster on modern GPUs. It also has the extra benefit of using less memory enabling training larger models and/or with larger batch sizes which can boost the performance further. The problem though is that training in 16 bits often becomes very unstable because the precision is usually not enough to perform some operations like accumulations. To help with this problem PyTorch supports training in mixed precision. In a nutshell mixed-precision training is done by performing some expensive operations (like convolutions and matrix multplications) in 16-bit by casting down the inputs while performing other numerically sensitive operations like accumulations in 32-bit. This way we get all the benefits of 16-bit computation without its drawbacks. Next we talk about using Autocast and GradScaler to do automatic mixed-precision training. ### Autocast `autocast` helps improve runtime performance by automatically casting down data to 16-bit for some computations. To understand how it works let's look at an example: ```python import torch x = torch.rand([32, 32]).cuda() y = torch.rand([32, 32]).cuda() with torch.cuda.amp.autocast(): a = x + y b = x @ y print(a.dtype) # prints torch.float32 print(b.dtype) # prints torch.float16 ``` Note both `x` and `y` are 32-bit tensors, but `autocast` performs matrix multiplication in 16-bit while keeping addition operation in 32-bit. What if one of the operands is in 16-bit? ```python import torch x = torch.rand([32, 32]).cuda() y = torch.rand([32, 32]).cuda().half() with torch.cuda.amp.autocast(): a = x + y b = x @ y print(a.dtype) # prints torch.float32 print(b.dtype) # prints torch.float16 ``` Again `autocast` and casts down the 32-bit operand to 16-bit to perform matrix multiplication, but it doesn't change the addition operation. By default, addition of two tensors in PyTorch results in a cast to higher precision. In practice, you can trust `autocast` to do the right casting to improve runtime efficiency. The important thing is to keep all your forward pass computations under `autocast` context: ```python model = ... loss_fn = ... with torch.cuda.amp.autocast(): outputs = model(inputs) loss = loss_fn(outputs, targets) ``` This maybe all you need if you have a relatively stable optimization problem and if you use a relatively low learning rate. Adding this one line of extra code can reduce your training up to half on modern hardware. ## GradScalar As we mentioned in the beginning of this section, 16-bit precision may not always be enough for some computations. One particular case of interest is representing gradient values, a great portion of which are usually small values. Representing them with 16-bit floats often leads to buffer underflows (i.e. they'd be represented as zeros). This makes training neural networks very unstable. `GradScalar` is designed to resolve this issue. It takes as input your loss value and multiplies it by a large scalar, inflating gradient values, and therefore making them represnetable in 16-bit precision. It then scales them down during gradient update to ensure parameters are updated correctly. This is generally what `GradScalar` does. But under the hood `GradScalar` is a bit smarter than that. Inflating the gradients may actually result in overflows which is equally bad. So `GradScalar` actually monitors the gradient values and if it detects overflows it skips updates, scaling down the scalar factor according to a configurable schedule. (The default schedule usually works but you may need to adjust that for your use case.) Using `GradScalar` is very easy in practice: ```python scaler = torch.cuda.amp.GradScaler() loss = ... optimizer = ... # an instance torch.optim.Optimizer scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() ``` Note that we first create an instance of `GradScalar`. In training loop we call `GradScalar.scale` to scale the loss before calling backward to produce inflated gradients, we then use `GradScalar.step` which (may) update the model parameters. We then call `GradScalar.update` which performs the scalar update if needed. That's all! The following is a sample code that show cases mixed precision training on a synthetic problem of learning to generate a checkerboard from image coordinates. You can paste it on a [Google Colab](https://colab.research.google.com/), set the backend to GPU and compare the single and mixed-precision performance. Note that this is a small toy example, in practice with larger networks you may see larger boosts in performance using mixed precision. ## An Example ### Generating a checker board ```python import torch import matplotlib.pyplot as plt import time def grid(width, height): hrange = torch.arange(width).unsqueeze(0).repeat([height, 1]).div(width) vrange = torch.arange(height).unsqueeze(1).repeat([1, width]).div(height) output = torch.stack([hrange, vrange], 0) return output def checker(width, height, freq): hrange = torch.arange(width).reshape([1, width]).mul(freq / width / 2.0).fmod(1.0).gt(0.5) vrange = torch.arange(height).reshape([height, 1]).mul(freq / height / 2.0).fmod(1.0).gt(0.5) output = hrange.logical_xor(vrange).float() return output # Note the inputs are grid coordinates and the target is a checkerboard inputs = grid(512, 512).unsqueeze(0).cuda() targets = checker(512, 512, 8).unsqueeze(0).unsqueeze(1).cuda() ``` ### Defining a convolutional neural network ```python class Net(torch.jit.ScriptModule): def __init__(self): super().__init__() self.net = torch.nn.Sequential( torch.nn.Conv2d(2, 256, 1), torch.nn.BatchNorm2d(256), torch.nn.ReLU(), torch.nn.Conv2d(256, 256, 1), torch.nn.BatchNorm2d(256), torch.nn.ReLU(), torch.nn.Conv2d(256, 256, 1), torch.nn.BatchNorm2d(256), torch.nn.ReLU(), torch.nn.Conv2d(256, 1, 1)) @torch.jit.script_method def forward(self, x): return self.net(x) ``` ### Single precision training ```python net = Net().cuda() loss_fn = torch.nn.MSELoss() opt = torch.optim.Adam(net.parameters(), 0.001) start_time = time.time() for i in range(500): opt.zero_grad() outputs = net(inputs) loss = loss_fn(outputs, targets) loss.backward() opt.step() print(loss) print(time.time() - start_time) plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu()); plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu()); plt.show() ``` ### Mixed precision training ```python net = Net().cuda() loss_fn = torch.nn.MSELoss() opt = torch.optim.Adam(net.parameters(), 0.001) scaler = torch.cuda.amp.GradScaler() start_time = time.time() for i in range(500): opt.zero_grad() with torch.cuda.amp.autocast(): outputs = net(inputs) loss = loss_fn(outputs, targets) scaler.scale(loss).backward() scaler.step(opt) scaler.update() print(loss) print(time.time() - start_time) plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu().float()); plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu().float()); plt.show() ``` ## Reference - https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html