Wednesday, April 20, 2011

Ugly memoization

Here's a problem that I recently ran into. I have a function taking a string and computing some value. I call this function a lot, but a lot of the time the argument has occurred before. The function is reasonably expensive, about 10 us. Only about 1/5 of the calls to the function has a new argument.

So naturally I want to memoize the function. Luckily Hackage has a couple packages for memoization. I found data-memocombinators and MemoTrie and decided to try them. The basic idea with memoization is that you have a function like

  memo :: (a->b) -> (a->b)

I.e., you give a function to memo and you get a new function of the
same type back.  This new function behaves like the original one, but
it remembers every time it is used and the next time it gets the same
argument it will just return the remembered result.
This is only safe in a pure language, but luckily Haskell is pure.
In an imperative language you can use a mutable memo table that stores all the argument-result pairs and updates the memo table each time the function is used. But how is it even possible to implement that in a pure language? The idea is to lazily construct the whole memo table in the call to memo, and it will then be lazily filled in.
Assume that all values of the argument type a can be enumerated by the method enumerate, we could then write memo like this:

  memo f =
      let table = [ (x, f x) | x <- enumerate ]
      in  \ y -> let Just r = lookup y table in r

Note how the memo table is constructed given just f, and this memo
table is then used in the returned function.
The type of this function would be something like

  memo (Enumerate a, Eq a) => (a->b) -> (a->b)

assuming that the class Enumerate has the magic method enumerate.

This just a very simplified example, if you tried to use this it would be terrible because the returned function does linear lookup in a list. Instead we want some kind of search tree, which is what the two packages I mention implement. The MemoTrie package does this in a really beautiful way, I recommend reading Conal's blog post about it.
OK, enough preliminaries. I used criterion to perform the benchmarking, and I tried with no memoization (none), memo-combinators (comb), and MemoTrie (beau). I had a test function taking about 10us, and then i called this functions with different number of repeated arguments: 1, 2, 5, and 10. I.e., 5 means that each argument occurred 5 times as the memoized function was called.

1 2 5 10
none 10.7 10.7 10.7 10.7
comb 62.6 52.2 45.8 43.4
beau 27.6 17.0 10.4 8.1

So with no memoization the time per call was 10.7 us all the time, no surprise there. With the memo combinators it was much slower than no memoization; the overhead for looking something up is bigger than the cost of computing the result. So that was a failure. The MemoTrie does better, at about an argument repetition of five it starts to break even, and at ten it's a little faster to memoize.

Since I estimated my repetition factor in the real code to be about five even the fastest memoization would not be any better then recomputation. So now what? Give up? Of course not! It's time to get dirty.

Once you know a function can be implemented in a pure way, there's no harm in implementing the same function in an impure way as long as it presents the pure interface. So lets write the memo function the way it would be done in, e.g., Scheme or ML. We will use a reference to hold a memo table that gets updated on each call. Here's the code, with the type that the function gets.

import Data.IORef
import qualified Data.Map as M

memoIO :: (Ord a) => (a -> b) -> IO (a -> IO b)
memoIO f = do
  v <- newIORef M.empty
  let f' x = do
        m <- readIORef v
        case M.lookup x m of
          Nothing -> do let { r = f x }; writeIORef v (M.insert x r m); return r
          Just r  -> return r
  return f'

The memoIO allocated a reference with an empty memo table.
We then define a new function, f', which when it's called
with get the memo table and look up the argument.  If the argument is
in the table then we just return the result, if it's not then we
compute the result, store it in the table, and return it.
Good old imperative programming (see below why this code is not good
imperative code).

But, horror, now the type is all wrong, there's IO in two places. The function we want to implement is actually pure. So what to do? Well, if you have a function involving the IO type, but you can prove it is actually pure, then (and only then) you are allowed to use unsafePerformIO.

I'll wave my hands instead of a proof (but more later), and here we go

  memo :: (Ord a) => (a -> b) -> (a -> b)
  memo f = let f' = unsafePerformIO (memoIO f) in \ x -> unsafePerformIO (f' x)

Wow, two unsafePerformIO on the same line.  It doesn't get
much less safe than that.
Let's benchmark again:

1 2 5 10
none 10.7 10.7 10.7 10.7
comb 62.6 52.2 45.8 43.4
beau 27.6 17.0 10.4 8.1
ugly 13.9 7.7 3.9 2.7

Not too shabby, using the ugly memoization is actually a win already at two, and just a small overhead if the argument occurs once.  We have a winner!

No so fast, there's

A snag

My real code can actually be multi-threaded, so the memo function had better work in a multi-threaded setting. Well, it doesn't. There's no guarantee about readIORef and writeIORef when doing multi-threading.
So we have to rewrite it. Actually, the code I first wrote is the one below; I hardly ever use IORef because I want it to work with multi-threading.

memoIO f = do
  v <- newMVar M.empty
  let f' x = do
        m <- takeMVar v
        case M.lookup x m of
          Nothing -> do let { r = f x }; putMVar v (M.insert x r m); return r
          Just r  -> do                  putMVar v m;                return r
  return f'

So now we use an MVar instead.  This makes it thread safe.
Only one thread can execute between the takeMVar and the
putMVar.  This guarantees than only one thread can update the
memo table at a time.  If two threads try at the same time one has to
wait a little.  How long?  The time it takes for the lookup, plus some
small constant.  Remember that Haskell is lazy, the the (f x)
is not actually computed with the lock held, which is good.
So I think this is a perfectly reasonable memoIO. And we can do the same unsafe trick as before and make it pure. Performance of this version is the same as with the IORef

Ahhhh, bliss. But wait, there's

Another snag

That might look reasonable, but in fact the memo function is broken now. It appears to work, but here's a simple use that fails

  sid :: String ->; String
  sid = memo id

  fcn s = sid (sid s)

What will happen here?  The outer call to sid will execute
the takeMVar and then do the lookup.  Doing the lookup with
evaluate the argument, x.  But this argument is another call
to sid, this will try to execute the takeMVar.
Disaster has struck, deadlock.

What happened here? The introduction of unsafePerformIO ruined the sequencing guaranteed by the IO monad that would have prevented the deadlock if we had used memoIO. I got what I deserved for using unsafePerformIO.

Can it be repaired? Well, we could make sure x is fully evaluated before grabbing the lock. I settled for a different repair, where the lock is held in a shorter portion of the code.

memoIO f = do
  v <- newMVar M.empty
  let f' x = do
        m <- readMVar v
        case M.lookup x m of
          Nothing -> do let { r = f x }; m <- takeMVar v; putMVar v (M.insert x r m); return r
          Just r  -> return r
  return f'

This solution has its own problem.  It's now possible for several threads
to compute (f x) for the same x and the result of
all but one of those will be lost by overwriting the table.  This is a
price I'm willing to pay for this application.


Yes, you can use imperative programming to implement pure functions. But the onus is on you to prove that it is safe. This is not as easy as you might think. I believe my final version is correct (with the multiple computation caveat), but I'm not 100% sure.

Sunday, April 10, 2011

Phew! Cleaned out a lot of spam comments in my blog. Hopefully my new settings will prevent the crazy onslaught of spammers.