Saturday, January 10, 2009

LLVM arithmetic

So we want to compute x2-5x+6 using the Haskell LLVM bindings. It would look some something like this.
    xsq <- mul x x
    x5  <- mul 5 x
    r1  <- sub xsq x5
    r   <- add r1 6
Not very readable, it would be nicer to write
    r   <- x^2 - 5*x + 6
But, e.g., the add instruction has the (simplified) type Value a -> Value a -> CodeGenFunction r (Value a), where CodeGenFunction is the monad for generating code for a function. (BTW, the r type variable is used to keep track of the return type of the function, used by the ret instruction.)

We can't change the return type of add, but we can change the argument type.

type TValue r a = CodeGenFunction r (Value a)
add' :: TValue r a -> TValue r a -> TValue r a
add' x y = do x' <- x; y' <- y; add x' y'
Now the type fits what the Num class wants. And we can make an instance declaration.
instance (Num a) => Num (TValue r a) where
    (+) = add'
    (-) = sub'
    (*) = mul'
    fromInteger = return . valueOf . fromInteger
We are getting close, the only little thing is that the x in our original LLVM code has type Value a rather than TValue r a, but return takes care of that. So:
    let x' = return x
    r <- x'^2 - 5*x' + 6
And a quick look at the generated LLVM code (for Double) shows us that all is well.
; x in %0
 %1 = mul double %0, %0
 %2 = mul double 5.000000e+00, %0
 %3 = sub double %1, %2
 %4 = add double %3, 6.000000e+00

All kinds of numeric instances and some other goodies are available in the LLVM.Util.Arithmetic module. Here is a complete Fibonacci (again) using this.

import Data.Int
import LLVM.Core
import LLVM.ExecutionEngine
import LLVM.Util.Arithmetic

mFib :: CodeGenModule (Function (Int32 -> IO Int32))
mFib = recursiveFunction $ \ rfib n ->
    n %< 2 ? (1, rfib (n-1) + rfib (n-2))

main :: IO ()
main = do
    let fib = unsafeGenerateFunction mFib
    print (fib 22)
The %< operator compares values returning a TValue r Bool. The c ? (t, e) is a conditional expression, like C's c ? t : e.

The type given to mFib is not the most general one, of course. The most general one can accept any numeric type for argument and result. Here it is:

mFib :: (Num a, Cmp a, IsConst a,
         Num b, Cmp b, IsConst b, FunctionRet b) =>
        CodeGenModule (Function (a -> IO b))
mFib = recursiveFunction $ \ rfib n ->
    n %< 2 ? (1, rfib (n-1) + rfib (n-2))
It's impossible to generate code for mFib directly; code can only be generated for monomorphic types. So a type signature is needed when generating code to force a monomorphic type.
main = do
    let fib :: Int32 -> Double
        fib = unsafeGenerateFunction mFib
        fib' :: Int16 -> Int64
        fib' = unsafeGenerateFunction mFib
    print (fib 22, fib' 22)

Another example

Let's try a more complex example. To pick something mathematical to have lots of formulae we'll do the Cumulative Distribution Function. For the precision of a Float it can be coded like this in Haskell (normal Haskell, no LLVM):
normcdf x = if x < 0 then 1 - w else w
  where w = 1 - 1 / sqrt (2 * pi) * exp(-l*l / 2) * poly k
        k = 1 / (1 + 0.2316419 * l)
        l = abs x
        poly = horner coeff 
        coeff = [0.0,0.31938153,-0.356563782,1.781477937,-1.821255978,1.330274429] 

horner coeff base = foldr1 multAdd coeff
  where multAdd x y = y*base + x
We cannot use this directly, it has type normcdf :: (Floating a, Ord a) => a -> a. The Ord context is a problem, because there are no instance of Ord for LLVM types. The comparison is the root of the problem, since it returns a Bool rather than a TValue r Bool.

It's possible to hide the Prelude and overload the comparisons, but you cannot overload the if construct. So a little rewriting is necessary regardless. Let's just bite the bullet and change the first line.

normcdf x = x %< 0 ? (1 - w, w)
And with that change, all we need to do is
mNormCDF = createFunction ExternalLinkage $ arithFunction $ normcdf

main :: IO ()
main = do
    writeFunction "CDF.bc" (mNormCDF :: CodeGenModule (Function (Float -> IO Float)))
So what happened here? Looking at normcdf it contains a things that the LLVM cannot handle, like lists. But all the list operations happen when the Haskell code runs and nothing of that is left in the LLVM code.

If you optimize and generate code for normcdf it looks like this (leaving out constants and half the code):

__fun1:
        subl    $28, %esp
        pxor    %xmm0, %xmm0
        ucomiss 32(%esp), %xmm0
        jbe     LBB1_2
        movss   32(%esp), %xmm0
        mulss   %xmm0, %xmm0
        divss   LCPI1_0, %xmm0
        movss   %xmm0, (%esp)
        call    _expf
        fstps   24(%esp)
        movss   32(%esp), %xmm0
        mulss   LCPI1_1, %xmm0
        movss   %xmm0, 8(%esp)
        movss   LCPI1_2, %xmm0
        movss   8(%esp), %xmm1
        addss   %xmm0, %xmm1
        movss   %xmm1, 8(%esp)
        movaps  %xmm0, %xmm1
        divss   8(%esp), %xmm1
        movaps  %xmm1, %xmm2
        mulss   LCPI1_3, %xmm2
        addss   LCPI1_4, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_5, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_6, %xmm2
        mulss   %xmm1, %xmm2
        addss   LCPI1_7, %xmm2
        mulss   %xmm1, %xmm2
        pxor    %xmm1, %xmm1
        addss   %xmm1, %xmm2
        movss   24(%esp), %xmm1
        mulss   LCPI1_8, %xmm1
        mulss   %xmm2, %xmm1
        addss   %xmm0, %xmm1
        subss   %xmm1, %xmm0
        movss   %xmm0, 20(%esp)
        flds    20(%esp)
        addl    $28, %esp
        ret
LBB1_2:
        ...
And that looks quite nice, in my opinion.

Black-Scholes

I work at a bank these days, so let's do the most famous formula in financial maths, the Black-Scholes formula for pricing options. Here's a Haskell rendition of it.
blackscholes iscall s x t r v = if iscall then call else put
  where call = s * normcdf d1 - x*exp (-r*t) * normcdf d2
        put  = x * exp (-r*t) * normcdf (-d2) - s * normcdf (-d1)
        d1 = (log(s/x) + (r+v*v/2)*t) / (v*sqrt t)
        d2 = d1 - v*sqrt t
Again, it uses an if, so it needs a little fix.
blackscholes iscall s x t r v  = iscall ? (call, put)
With that fix, code can be generated directly from blackscholes. The calls to normcdf are expanded inline, but it's easy to make some small changes to the code so that it actually does function calls.

Some figures

To test the speed of the generated code I ran blackscholes over a portfolio of 10,000,000 options and summed their value. The time excludes the time to set up the portfolio array, because that is done in Haskell. I also ran the code in pure Haskell on a list with 10,000,000 elements.
Unoptimized LLVM   17.5s
Optimized LLVM      8.2s
Pure Haskell        9.3s
The only surprise here is how well pure Haskell (with -O2) performs. This is a very good example for Haskell though, because the list gets fused away and everything is strict. A lot of the time is spent in log and exp in this code, so perhaps the similar results are not so surprising after all.

No comments:

Post a Comment