Overloading Haskell numbers, part 2, Forward Automatic Differentiation.
I will continue my overloading by some examples that have been nicely illustrated by an
article by Jerzy Karczmarczuk. And
blogged about by sigfpe.
But at least I'll end this entry with a small twist I've not seen before.
When computing the derivative of a function you normally do by either symbolic derivation, or by a numerical approximation.
Say that you have a function
f(x) = x2 + 1
and you want to know the derivative at
x=5. Doing it symbolically you first get
f'
f'(x) = 2x
using high school calculus (maybe they don't teach it in high school anymore?), and then you plug in 5
f'(5) = 2*5 = 10
Computing it by numeric differentiation you compute
f'(x) = (f(x+h) - f(x)) / h
for some small h. Let's pick h=1e-5, and we get
f'(5) = 10.000009999444615. Close, but not that good.
So why don't we always use the symbolic method? Well, some functions are not that easy to differentiate. Take this one
g x = if abs (x - 0.7) < 0.4 then x else g (cos x)
What's the derivative? Well, it's tricky because this is not really a proper definition of
g. It's an equation that if solved will yield a definition of
g. And like equations in general, it could have zero, one, or many solutions.
(If we happen to use CPOs there is always a unique smallest solution which is what programs compute, as if by magic.)
If you think
g is contrived, lets pick a different example: computing the square root with Newton-Raphson.
sqr x = convAbs $ iterate improve 1
where improve r = (r + x/r) / 2
convAbs (x1:x2:_) | abs (x1-x2) < 1e-10 = x2
convAbs (_:xs) = convAbs xs
So symbolic is not so easy here, and numeric differentiation is not very accurate.
But there is a third way!
Automatic differentiation.
The idea behind AD is that instead of computing with with just numbers, we instead compute with pairs of numbers. The first component is the normal number, and the second component is the derivative.
What are the rules for these numbers? Let's look at addition
(x, x') + (y, y') = (x+y, x'+y')
To add two numbers you just add the regular part and the derivatives.
For multiplication you have to remember how to compute the derivative of a product:
(f(x)*g(x))' = f(x)*g'(x) + f'(x)*g(x)
So for our pairs we get
(x, x') * (y, y') = (x*y, x*y' + x'*y)
i.e., first the regular product, then the derivative according to the recipe above.
Let's see how it works on
f(x) = x2 + 1
We want the derivative at
x=5. So what is the pair we use for
x? It is (5, 1).
Why? Well it has to be 5 for the regular part, and since this represents
x and the derivative of
x is 1, the pair is (5, 1).
In the right hand side for f we need to replace 1 by (1,0), since the derivative of a constant is 0.
So then we get
f (5,1) = (5,1)*(5,1) + (1,0) = (26,10)
using the rules above. And look! There is the normal result, 26, as well as the derivative, 10.
Let's turn this into Haskell, using the type PD to hold a pair of Doubles
data PD = P Double Double deriving (Eq, Ord, Show)
instance Num PD where
P x x' + P y y' = P (x+y) (x'+y')
P x x' - P y y' = P (x-y) (x'-y')
P x x' * P y y' = P (x*y) (x*y' + y'*x)
fromInteger i = P (fromInteger i) 0
A first observation is that there is nothing Double specific in this definitions; it would work for any Num. So we can change it to
data PD a = P a a deriving (Eq, Ord, Show)
instance Num a => Num (PD a) where ...
Let's also add abs&signum and the Fractional instance
...
abs (P x x') = P (abs x) (signum x * x')
signum (P x x') = P (signum x) 0
instance Fractional a => Fractional (PD a) where
P x x' / P y y' = P (x / y) ( (x'*y - x*y') / (y * y))
fromRational r = P (fromRational r) 0
We can now try the sqr example
Main> sqr (P 9 1)
P 3.0 0.16666666666666666
The derivative of x**0.5 is 0.5*x**(-0.5), i.e., 0.5*9**(-0.5) = 0.5/3 = 0.16666666666666666.
So we got the right answer.
BTW, if you want to be picky the derivative of signum is not 0. The signum function makes a jump from -1 to 1 at 0. So the "proper" value would be 2*dirac, if dirac is a
Dirac pulse. But since we don't have numbers with Dirac pulses (yet), I'll just pretend the derivative is 0 everywhere.
The very clever insight that Jerzy had was that when doing these numbers in Haskell there is no need to limit yourself to just the first derivative. Since Haskell is lazy we can easily keep an infinite list of all derivatives instead of just the first one.
Let's look at how that definition looks. It's very similar to what we just did. But instead of the derivative being just a number, it's now one of our new numbers with a value, and all derivatives...
Since we are now dealing with an infinite data structure we need to define our own show, (==), etc.
data Dif a = D a (Dif a)
val (D x _) = x
df (D _ x') = x'
dVar x = D x 1
instance (Show a) => Show (Dif a) where
show x = show (val x)
instance (Eq a) => Eq (Dif a) where
x == y = val x == val y
instance (Ord a) => Ord (Dif a) where
x `compare` y = val x `compare` val y
instance (Num a) => Num (Dif a) where
D x x' + D y y' = D (x + y) (x' + y')
D x x' - D y y' = D (x - y) (x' - y')
p@(D x x') * q@(D y y') = D (x * y) (x' * q + p * y')
fromInteger i = D (fromInteger i) 0
abs p@(D x x') = D (abs x) (signum p * x')
signum (D x _) = D (signum x) 0
instance (Fractional a) => Fractional (Dif a) where
recip (D x x') = ip
where ip = D (recip x) (-x' * ip * ip)
fromRational r = D (fromRational r) 0
This looks simple, but it's rather subtle. For instance, take the 0 in the definition of fromInteger. It's actually of Dif type, so it's a recursive call to fromInteger.
So let's try with our sqr function again, this time computing up to the third derivative.
The
dVar is used to create a value for "variable" where we want to differentiate.
Main> sqr $ dVar 9
3.0
Main> df $ sqr $ dVar 9
0.16666666666666669
Main> df $ df $ sqr $ dVar 9
-9.259259259259259e-3
Main> df $ df $ df $ sqr $ dVar 9
1.5432098765432098e-3
And the transcendentals in a similar way:
lift (f : f') p@(D x x') = D (f x) (x' * lift f' p)
instance (Floating a) => Floating (Dif a) where
pi = D pi 0
exp (D x x') = r where r = D (exp x) (x' * r)
log p@(D x x') = D (log x) (x' / p)
sqrt (D x x') = r where r = D (sqrt x) (x' / (2 * r))
sin = lift (cycle [sin, cos, negate . sin, negate . cos])
cos = lift (cycle [cos, negate . sin, negate . cos, sin])
acos p@(D x x') = D (acos x) (-x' / sqrt(1 - p*p))
asin p@(D x x') = D (asin x) ( x' / sqrt(1 - p*p))
atan p@(D x x') = D (atan x) ( x' / (p*p - 1))
sinh x = (exp x - exp (-x)) / 2
cosh x = (exp x + exp (-x)) / 2
asinh x = log (x + sqrt (x*x + 1))
acosh x = log (x + sqrt (x*x - 1))
atanh x = (log (1 + x) - log (1 - x)) / 2
And why not try the function g we defined above?
Main> g 10
0.6681539175313869
Main> g (dVar 10)
0.6681539175313869
Main> df $ g (dVar 10)
0.4047642621121782
Main> df $ df $ g (dVar 10)
0.4265424381635987
Main> df $ df $ df $ g (dVar 10)
-1.4395397945007182
It all works very nicely. So now when we can compute the derivative of a function, let's define something somewhat more interesting with it.
Let's revisit the sqr function again. It uses Newton-Raphson to find the square root. How does
Newton-Raphson actually work? Given a differentiable function,
f(x), it finds a zero by starting with some
x1 and then iterating
xn+1 = xn - f(xn)/f'(xn)
until we meet some convergence criterion.
Using this, let's define a function that finds a zero of another function:
findZero f = convRel $ cut $ iterate step start
where step x = x - val fx / val (df fx) where fx = f (dVar x)
start = 1 -- just some value
epsilon = 1e-10
cut = (++ error "No convergence in 1000 steps") . take 1000
convRel (x1:x2:_) | x1 == x2 || abs (x1+x2) / abs (x1-x2) > 1/epsilon = x2
convRel (_:xs) = convRel xs
The only interesting part is the step function that does one iteration with Newton-Raphson. It computes
f x and then divides the normal value with the derivative.
We then produce the infinite list of approximations using step, then cut it of at some point (in case it doesn't converge), and then we look down the list for two values that are within some relative epsilon.
And it even seems to work.
Main> findZero (\x -> x*x - 9)
3.0
Main> findZero (\x -> sin x - 0.5)
0.5235987755982989
Main> sin it
0.5
Main> findZero (\x -> x*x + 9)
*** Exception: No convergence in 1000 steps
Main> findZero (\x -> sqr x - 3)
9.0
Note how it finds a zero of the sqr function which is actually using recursion internally to compute the square root.
So now we can compute numerical derivatives. But wait! We also have symbolic numbers.
Can we combine them? Of course, that is the power of polymorphism.
Let's load up both modules:
Data.Number.Symbolic Dif3> let x :: Num a => Dif (Sym a); x = dVar (var "x")
Data.Number.Symbolic Dif3> df $ x*x
x+x
Data.Number.Symbolic Dif3> df $ sin x
cos x
Data.Number.Symbolic Dif3> df $ sin (exp (x - 4) * x)
(exp (-4.0+x)*x+exp (-4.0+x))*cos (exp (-4.0+x)*x)
Data.Number.Symbolic Dif3> df $ df $ sin (exp (x - 4) * x)
(exp (-4.0+x)*x+exp (-4.0+x)+exp (-4.0+x))*cos (exp (-4.0+x)*x)+(exp (-4.0+x)*x+exp (-4.0+x))*(exp (-4.0+x)*x+exp (-4.0+x))*(-sin (exp (-4.0+x)*x))
We define x to be a differentiable number, "the variable", over symbolic numbers, over some numbers. And then we just happily use df to get the differentiated versions.
So we set out to compute numeric derivatives, and we got these for free. Not too bad.
One final note, the Dif type is defined above can be made more efficient by not keeping all the infinite tails with 0 derivatives around. In a real module for this, you'd want to make this optimization.
[Edit: fixed typo.]
Just a minor typo, but shouldn't that be:
ReplyDeletef'(x) = (f(x + h) - f(x)) / h
To make a derivative-taking operator fully general, it is important to be able to apply the derivative-taking function to functions that themselves internally take derivatives. This is difficult in Haskell; try using the code in the post and you'll see the problem. The issue is discussed in some detail in an IFL-2005 presentation Perturbation Confusion and Referential Transparency: Correct Functional Implementation of Forward-Mode AD by Jeff Siskind and Barak Pearlmutter (me). We also discussed using lazy towers of higher-order derivatives in the multivariate case in a POPL-2007 paper, Lazy Multivariate Higher-Order Forward-Mode AD, which includes working code.
ReplyDeleteThe above all concerns forward-mode automatic differentiation. The reverse-mode AD construct is the one that is useful for taking gradients in high dimensions, and we've also been working on incorporating that into functional languages; see our web page on this FP/AD/Stalingrad stuff.
barak,
ReplyDeleteCan you give an example of the difficulty using derivatives internally? I've not read the IFL paper.
I have read your POPL paper and I have a Haskell version of that code. I have many parts to these posts left. :)
I also have a version of Jerzy's code for the adjoint AD in Haskell. It comes out quite nicely, with the adjoints forming a vector space.
barak,
ReplyDeleteI looked at your example in the IFL paper. I see your point. There's indeed sort of a variable capture problem with this approach since there's only one variable.
I don't think this is a fundamental problem, as just point out you just need some clever tagging.
barak,
ReplyDeleteSo I tried your example
deriv (\ x -> x * deriv (\ y -> x + y) 1) 1
This doesn't type check because of the level confusion inside the inner lambda. I think that's acceptable behaviour.
The lifted version
deriv (\ x -> x * deriv (\ y -> dCon x + y) 1) 1
yields 1 as it should.
Yes, manual insertion of lift operators can bail you out some of the time. But they are fragile and delicate and easy to put in the wrong place. And because they are a static solution, but the problem is at root a dynamic one, they won't work when the correct placement of lift operators cannot be determined statically. For instance, they can't bail you out when the same function is called in multiple differential contexts. For example, consider finding a saddlepoint using min/max where max is defined in terms of min and min uses a gradient method.
ReplyDeleteLennart, Barak,
ReplyDeleteI don't know if you noticed this thread. I think it at least takes care of the fragile, delicate and easy to put in the wrong place objections.
-Bjorn