## Numerical stability in PyTorch When using any numerical computation library such as NumPy or PyTorch, it's important to note that writing mathematically correct code doesn't necessarily lead to correct results. You also need to make sure that the computations are stable. Let's start with a simple example. Mathematically, it's easy to see that `x * y / y = x` for any non zero value of `x`. But let's see if that's always true in practice: ```python import numpy as np x = np.float32(1) y = np.float32(1e-50) # y would be stored as zero z = x * y / y print(z) # prints nan ``` The reason for the incorrect result is that `y` is simply too small for float32 type. A similar problem occurs when `y` is too large: ```python y = np.float32(1e39) # y would be stored as inf z = x * y / y print(z) # prints nan ``` The smallest positive value that float32 type can represent is 1.4013e-45 and anything below that would be stored as zero. Also, any number beyond 3.40282e+38, would be stored as inf. ```python print(np.nextafter(np.float32(0), np.float32(1))) # prints 1.4013e-45 print(np.finfo(np.float32).max) # print 3.40282e+38 ``` To make sure that your computations are stable, you want to avoid values with small or very large absolute value. This may sound very obvious, but these kind of problems can become extremely hard to debug especially when doing gradient descent in PyTorch. This is because you not only need to make sure that all the values in the forward pass are within the valid range of your data types, but also you need to make sure of the same for the backward pass (during gradient computation). Let's look at a real example. We want to compute the softmax over a vector of logits. A naive implementation would look something like this: ```python import torch def unstable_softmax(logits): exp = torch.exp(logits) return exp / torch.sum(exp) print(unstable_softmax(torch.tensor([1000., 0.])).numpy()) # prints [ nan, 0.] ``` Note that computing the exponential of logits for relatively small numbers results to gigantic results that are out of float32 range. The largest valid logit for our naive softmax implementation is `ln(3.40282e+38) = 88.7`, anything beyond that leads to a nan outcome. But how can we make this more stable? The solution is rather simple. It's easy to see that `exp(x - c) Σ exp(x - c) = exp(x) / Σ exp(x)`. Therefore we can subtract any constant from the logits and the result would remain the same. We choose this constant to be the maximum of logits. This way the domain of the exponential function would be limited to `[-inf, 0]`, and consequently its range would be `[0.0, 1.0]` which is desirable: ```python import torch def softmax(logits): exp = torch.exp(logits - torch.reduce_max(logits)) return exp / torch.sum(exp) print(softmax(torch.tensor([1000., 0.])).numpy()) # prints [ 1., 0.] ``` Let's look at a more complicated case. Consider we have a classification problem. We use the softmax function to produce probabilities from our logits. We then define our loss function to be the cross entropy between our predictions and the labels. Recall that cross entropy for a categorical distribution can be simply defined as `xe(p, q) = -Σ p_i log(q_i)`. So a naive implementation of the cross entropy would look like this: ```python def unstable_softmax_cross_entropy(labels, logits): logits = torch.log(softmax(logits)) return -torch.sum(labels * logits) labels = torch.tensor([0.5, 0.5]) logits = torch.tensor([1000., 0.]) xe = unstable_softmax_cross_entropy(labels, logits) print(xe.numpy()) # prints inf ``` Note that in this implementation as the softmax output approaches zero, the log's output approaches infinity which causes instability in our computation. We can rewrite this by expanding the softmax and doing some simplifications: ```python def softmax_cross_entropy(labels, logits, dim=-1): scaled_logits = logits - torch.max(logits) normalized_logits = scaled_logits - torch.logsumexp(scaled_logits, dim) return -torch.sum(labels * normalized_logits) labels = torch.tensor([0.5, 0.5]) logits = torch.tensor([1000., 0.]) xe = softmax_cross_entropy(labels, logits) print(xe.numpy()) # prints 500.0 ``` We can also verify that the gradients are also computed correctly: ```python logits.requires_grad_(True) xe = softmax_cross_entropy(labels, logits) g = torch.autograd.grad(xe, logits)[0] print(g.numpy()) # prints [0.5, -0.5] ``` Let me remind again that extra care must be taken when doing gradient descent to make sure that the range of your functions as well as the gradients for each layer are within a valid range. Exponential and logarithmic functions when used naively are especially problematic because they can map small numbers to enormous ones and the other way around.