# Numerically stable way to compute sqrt((b²*c²) / (1-c²)) for c in [-1, 1]

The most interesting part of this stability-wise is the denominator, `sqrt(1 - c*c)`

. For that, all you need to do is expand it as `sqrt(1 - c) * sqrt(1 + c)`

. I don't think this really qualifies as a "clever trick", but it's all that's needed.

For a typical binary floating-point format (for example IEEE 754 binary64, but other common formats should behave equally well, with the possible exception of unpleasant things like the double-double format), if `c`

is close to `1`

then `1 - c`

will be computed exactly, by Sterbenz' Lemma, while `1 + c`

doesn't have any stability issues. Similarly, if `c`

is close to `-1`

then `1 + c`

will be computed exactly, and `1 - c`

will be computed accurately. The square root and multiplication operations will not introduce significant new error.

Here's a numerical demonstration, using Python on a machine with IEEE 754 binary64 floating-point and a correctly-rounded `sqrt`

operation.

Let's take a `c`

close to (but smaller than) `1`

:

```
>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999
```

We have to be a little bit careful here: note that the decimal value shown, `0.999999999`

, is an *approximation* to the exact value of `c`

. The exact value is as shown in the construction from the hexadecimal string, or in fraction form, `562949953365017/562949953421312`

, and it's that exact value that we care about getting good results for.

The exact value of the expression `sqrt(1 - c*c)`

, rounded to 100 decimal places after the point, is:

```
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
```

I computed this using Python's `decimal`

module, and double-checked the result using Pari/GP. Here's the Python calculation:

```
>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813
```

If we compute naively, we get this result:

```
>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05
```

We can easily compute the approximate number of ulps error (with apologies for the amount of type conversion going on - `float`

and `Decimal`

instances can't be mixed directly in arithmetic operations):

```
>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992
```

So the naive result is out by a couple of hundred thousand ulps - roughly speaking, we've lost around 5 decimal places of accuracy.

Now let's try with the expanded version:

```
>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595
```

So here we're accurate to better than 1 ulp error. Not perfectly correctly rounded, but the next best thing.

With some more work, it ought to be possible to state and prove an absolute upper bound on the number of ulps error in the expression `sqrt(1 - c) * sqrt(1 + c)`

, over the domain `-1 < c < 1`

, assuming IEEE 754 binary floating-point, round-ties-to-even rounding mode, and correctly-rounded operations throughout. I haven't done that, but I'd be very surprised if that upper bound turned out to be more than 10 ulps.

Mark Dickinson provides a good answer for the general case, I will add to that with a somewhat more specialized approach.

Many computing environments these days provide an operation called a fused multiply-add, or FMA for short, which was specifically designed with situations like this in mind. In the computation of `fma(a, b, c)`

the full product `a * b`

(untruncated and unrounded) enters into the addition with `c`

, then a single rounding is applied at the end.

Currently shipping GPUs and CPUs, including those based on the ARM64, x86-64, and Power architectures, typically include a fast hardware implementation of FMA, which is exposed in programming languages of the C and C++ families as well as many others as a standard math function `fma()`

. Some -- usually older -- software environments use software emulation of FMA, and some of these emulations have found to be faulty. In addition, such emulations tend to be pretty slow.

Where FMA is available, the expression can be evaluated numerically stable and without risk of premature overflow and underflow as `fabs (b * c) / sqrt (fma (c, -c, 1.0))`

, where `fabs()`

is the absolute value operation for floating-point operands and `sqrt()`

computes the square root. Some environments also offer a reciprocal square root operation, often called `rsqrt()`

, in which case a potential alternative is to use `fabs (b * c) * rsqrt (fma (c, -c, 1.0))`

. The use of `rsqrt()`

avoids the relatively expensive division and is therefore typically faster. However, many implementations of `rsqrt()`

are not correctly rounded like `sqrt()`

, so accuracy may be somewhat worse.

A quick experiment with the code below seems to indicate that the maximum error of the FMA-based expression is about 3 ulps, as long as `b`

is a *normal* floating-point number. I stress that this does *not* prove any error bound. The automated Herbie tool, which tries to find numerically advantageous rewrites of a given floating-point expression suggests to use `fabs (b * c) * sqrt (1.0 / fma (c, -c, 1.0))`

. This seems to be a spurious result however, as I cannot neither think of any particular advantage nor find one experimentally.

```
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#define USE_ORIGINAL (0)
#define USE_HERBIE (1)
/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}
/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
double b2 = b * b;
double c2 = c * c;
return sqrt ((b2 * c2) / (1.0 - c2));
#else
return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
uint64_t double_as_uint64 (double a)
{
uint64_t r;
memcpy (&r, &a, sizeof r);
return r;
}
double floatUlpErr (float res, double ref)
{
uint64_t i, j, err, refi;
int expoRef;
/* ulp error cannot be computed if either operand is NaN, infinity, zero */
if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
(res == 0.0f) || (ref == 0.0f)) {
return 0.0;
}
/* Convert the float result to an "extended float". This is like a float
with 56 instead of 24 effective mantissa bits.
*/
i = ((uint64_t)float_as_uint32(res)) << 32;
/* Convert the double reference to an "extended float". If the reference is
>= 2^129, we need to clamp to the maximum "extended float". If reference
is < 2^-126, we need to denormalize because of the float types's limited
exponent range.
*/
refi = double_as_uint64(ref);
expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
if (expoRef >= 129) {
j = 0x7fffffffffffffffULL;
} else if (expoRef < -126) {
j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
j = j >> (-(expoRef + 126));
} else {
j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
j = j | ((uint64_t)(expoRef + 127) << 55);
}
j = j | (refi & 0x8000000000000000ULL);
err = (i < j) ? (j - i) : (i - j);
return err / 4294967296.0;
}
// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579) /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)
#define N (20)
int main (void)
{
float b, c, errloc_b, errloc_c, res;
double ref, err, maxerr = 0;
c = -1.0f;
while (c <= 1.0f) {
/* try N random values of `b` per every value of `c` */
for (int i = 0; i < N; i++) {
/* allow only normals */
do {
b = uint32_as_float (KISS);
} while (!isnormal (b));
res = func (b, c);
ref = funcd ((double)b, (double)c);
err = floatUlpErr (res, ref);
if (err > maxerr) {
maxerr = err;
errloc_b = b;
errloc_c = c;
}
}
c = nextafterf (c, INFINITY);
}
#if USE_HERBIE
printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
return EXIT_SUCCESS;
}
```