Exponentials in 3 Instructions

This post expands on an algorithm shown in the book I wrote on floating-point math.

It is very common in computing to want to do $e^x$ very quickly and not care very much about how accurately you computed it. This is increasingly true in ML and AI algorithms, which can be very tolerant to noise from numerical error and often use low bit precision either way. It also shows up doing things like exponentially-weighted moving averages and other similar functions in signal processing. Thankfully, you can replace $e^x$ in these situations with the following function:

1#include <bit>
2
3float approx_exp(float x) {
4    float xb = x * 12102203;    // Magic
5    int i = xb;                 // Don't worry, we'll be back in float soon
6    return std::bit_cast<float>(i + 1064986823);    // WTF?
7}

This function's output deviates from $e^x$ by less than 3% over the range of all $x$. Visually, a plot of this function looks like this (orange line), compared to an exact exponential (blue line), and the trend continues throughout the range of x values:

That comes at a cost of only 3 full-throughput instructions on most CPUs or GPUs: a floating-point multiplication, a conversion from float to integer, and an integer addition. The bit cast is usually free, since all of these are computable on the vector side of most microprocessors and GPUs have no such concerns. In return for 3 instructions, you get about 5 bits of accuracy on a useful transcendental function.

This particular form of bit-hacking is pretty old, and is best-known in relation to a particular formula for reciprocal square root found in Quake III. As far as I can tell, the earliest publication of the above approximation for exponentiation comes from a 1998 paper by Nicol Schraudolph, who wanted to optimize the computation of $e^x$ in (ironically) neural networks. The idea may be older than that; apocryphal stories of similar tricks date back to the 1980's with people in the orbit of William Kahan.

Breaking it Down

Here is a much clearer piece of code that describes what is going on:

 1#include <bit>
 2
 3const int K_TUNED_MINIMAX = 366393;     // Max relative error: 2.98%
 4const int K_TUNED_L1 = 545948;          // Mean relative error: 1.48%
 5const int K_TUNED_L2 = 486412;          // Root mean squared relative error: 1.77%
 6
 7const int K = K_TUNED_MINIMAX;          // Change based on desired error metric
 8
 9float approx_exp(float x) {
10    // Log2(e) * 2^(23) for e^x in a format with 23 mantissa bits
11    const float BASE = 1.44269504089 * 8388608.0;
12
13    // Multiply by the base and convert to Q9.23 fixed point, rounding down
14    float x_base = x * BASE;
15    int i_base = x_base;
16
17    // Perform a piecewise linear approximation of 2^x
18    // (127 << 23) is the exponent bias of 32-bit float and K is a tuned constant
19    int i_result = i_base + (127 << 23) - K;
20    return std::bit_cast<float>(i_result);
21}

Both float and int represent a 32-bit binary value. A mapping of the integer value of that bit string to the floating point value looks like this:

This shape is driven by the fact that the exponent field is more significant than the mantissa field. 32-bit floats have 23 mantissa bits in the LSBs of the integer, followed by 8 exponent bits, so adding $2^{23}$ to the number in integer space corresponds to adding one to the exponent field, doubling the floating point value of the number.

Thus, the mapping from positive integer to positive floating point is a piecewise linear fixed-point exponential with a bias (thinking only of the positive numbers). That bias is the exponent bias of the floating-point exponent shifted up to the exponent field. Exponents of floating-point numbers are stored as unsigned numbers with a bias, while integers are in two's complement. So, we can compute a quasi-exponential of an integer by simply treating the bit layout of that integer as a floating-point number. The addition of (127 << 23) handles the floating-point exponent bias.

Approximations of nonlinear scalar functions like $e^x$ usually work in four steps:

  1. Deal with infinity, NaN, or other weird inputs
  2. Use algebra to reduce the range of approximation as far as possible
  3. Perform an approximation on the nonlinear function in this reduced range
  4. Undo your algebra from step 2

The magic of functions like this one is that they handle multiple parts of this process implicitly. In our approx_exp function, line 15, which handles the conversion to integer, is where we will fail for several special-case inputs, but adding error handling is a few extra if statements. Steps 2-4 of this process are all done in two assembly instructions: Adding the exponent bias and correction factor gives a linear approximation, and the bit cast gives us the exact numerical behavior we want for step 2 and step 4: A piecewise function whose slope doubles each time its input increases by $2^{23}$.

The K constant is the last piece of the approximation. Alone, it handles step 3 by moving the piecewise linear approximation to the right a little bit so that we cross the exponential at strategically chosen points rather than riding on top of it. I have given 3 options for K here, and in general use, the minimax approximation is likely to be the one that we prefer, and is used in the first code snippet. If you need to optimize other error metrics, the best constants for mean error and root-mean-squared error are also in the code snippet above.

A Second-order Correction

If you care a little bit about accuracy, it is possible to do better by adding some curvature to the mantissa. We have a linear approximation that is symmetric between binades ("binary decade" - a group of floating-point numbers that share a sign and exponent) of our output. Thus, we can run the first two steps of the approximation as before, and then use the mantissa of our first-order output to create a correction factor for a second-order approximation.

The ratio of $2^x$ to a base-2 version of the approx_exp function—replace the base multiplier in line 8 with (1 << 23)—looks like this:

That means we can construct a correction factor for approx_exp by creating an approximation of this function within each binade. However, looking at the shape of the function, it seems that the correction factor should be roughly parabolic. If you remember the idea of #2 above, restricting the range in which to approximate, we get this for free on our correction factor by using the ratio and using some bit manipulation to install an exponent in a floating point number.

We will add the following function evaluation:

 1float corrected_exp(float x) {
 2    // Piecewise linear approximation from before
 3    float x_base = x * 12102203;        // Multiply by log2(e) * 2^23
 4    int i_base = x_base;                // Convert to Q9.23 fixed point, rounding down
 5    int i_result = i_base + (127 << 23) - K;
 6    float first_order = std::bit_cast<float>(i_result);
 7
 8    // Install 127 as the exponent of i_result, placing correction_x in [1, 2)
 9    int correction_xi = (i_result & 0x7fffff) | (127 << 23);
10    float correction_x = std::bit_cast<float>(correction_xi);
11    
12    // 2nd order correction factor, evaluated with a Horner scheme
13    float correction = std::fmaf(correction_x, P2, P1);    // x * P2 + P1
14    correction = std::fmaf(correction, correction_x, P0);  // (x * P2 + P1) * x + P0
15    return first_order * correction;
16}

This is 5 additional assembly instructions. We have two logical instructions to load the new exponent, two for the FMAs to evaluate our polynomial, and a final multiplication to correct the result. The FMA-based evaluation scheme, a Horner scheme, is very efficient on most chips because full-throughput FMA units are now ubiquitous (back to Intel's Haswell). FMAs have high latency, though, so the latency of this new version is significantly longer than the latency of the first-order version.

Now, we need to figure out our constants. The perfect set of constants is now a 4-dimensional integer optimization problem (on K and the three Pn's), which makes it difficult to establish whether we have constructed an optimal approximation or not. Instead of setting up an LP solver for a blog post, I will keep K fixed as before (although this is likely not optimal) and use a Remez approximation from sollya to find the P constants by approximating the following function in the region $[1,2]$: $$ f(x) = \frac{2^{x + \frac{K}{2^{23}}}}{2x} $$ The denominator here is the shape of the linear approximation in the input region with $x$ between 1 and 2, and the numerator is a shifted version of the exponential to capture one binade of our output. Our quick approximation yields the following:

  • P0 = 1.469318866729736328125
  • P1 = -0.671999752521514892578125
  • P2 = 0.22670517861843109130859375

Those coefficients still yield a biased result, so afterwards, I calculated a new K to fix the bias, getting K = 345088 for minimized maximum error. This is still a suboptimal result compared to what could be obtained from better methods.

This results in the following corrected approximation, shown below in orange against an exact curve in blue. If you can't see any blue, that's because the two lines overlap almost exactly.

This is the ratio of $2^x$ to our corrected approximation (light orange) alongside our previous error (blue):

We get an accuracy for $2^x$ of 7.7 bits (maximum 0.48% error), and we can note that our ratio now appears to be third-order. We also now have a function that is very slightly non-monotonic, which is not ideal for every application. The non-monotonicity is driven by the fact that our error is now third order in each binade, so it is at a local minimum on one side of the binade and a local maximum at the other side. Adding an additional term, making our error fourth order, might allow monotonicity at the cost of a third FMA operation (and will also decrease total error).

Some additional tweaks between polynomial coefficients and K should balance the error a bit and get to about 8 bits of accuracy. For a total of 8 micro-ops on a modern CPU or GPU, we can compute an exponential that is accurate enough to give a faithful BF16 result. Higher-order polynomials here will get closer.

Performance Benchmarks

Some benchmarks from an M3 macbook and an AMD Zen 3 server are below. We can see that both of these approximations have much lower reciprocal throughput than the standard exponential, and the fast approximation is lower latency as well. The correction factor takes long enough that despite having an advantage on reciprocal throughput, the corrected exponential has similar or worse latency.

Our benchmark loop has some overhead, which shows up in the reciprocal throughput of our fast approximates. These approximations also have the benefit of vectorizing better than the standard library function, although for fairness, I made sure that each benchmark was a scalar version of the function. When we unroll by 16 lanes to allow auto-vectorization, the Zen 3 throughput results shift signifcantly:

The vectorized versions have a huge throughput advantage, taking a fraction of a nanosecond per exponential. This is because our approximations vectorize well, running 16 lanes at the same speed as 1. A Zen 3 CPU can compute 16 FMAs and 16 float-to-int conversions per cycle. Conversely, the standard exponential function is a scalar function and does not vectorize well.

Conclusions

If you're willing to be wrong, you can definitely do it quickly! Functions like this are useful any time you want a roughly exponential shape but don't really care much about the precise value. These show up in a lot of computational tasks, from monitoring systems to AI, and in these cases, we can take advantage of the bit layout of floating-point numbers to compute a blazing fast approximate exponential. Just like the inverse square root trick, we get a nice speedup by thinking outside the box and adding some bit manipulations to our math.

If you find this kind of work interesting, you might want to think about applying to Anthropic. Send your best performance optimization story to [email protected] and we will be in touch.