Wednesday, April 20, 2011

Ugly memoization

Here's a problem that I recently ran into. I have a function taking a string and computing some value. I call this function a lot, but a lot of the time the argument has occurred before. The function is reasonably expensive, about 10 us. Only about 1/5 of the calls to the function has a new argument.

So naturally I want to memoize the function. Luckily Hackage has a couple packages for memoization. I found data-memocombinators and MemoTrie and decided to try them. The basic idea with memoization is that you have a function like

  memo :: (a->b) -> (a->b)

I.e., you give a function to memo and you get a new function of the
same type back.  This new function behaves like the original one, but
it remembers every time it is used and the next time it gets the same
argument it will just return the remembered result.
This is only safe in a pure language, but luckily Haskell is pure.
In an imperative language you can use a mutable memo table that stores all the argument-result pairs and updates the memo table each time the function is used. But how is it even possible to implement that in a pure language? The idea is to lazily construct the whole memo table in the call to memo, and it will then be lazily filled in.
Assume that all values of the argument type a can be enumerated by the method enumerate, we could then write memo like this:

  memo f =
      let table = [ (x, f x) | x <- enumerate ]
      in  \ y -> let Just r = lookup y table in r

Note how the memo table is constructed given just f, and this memo
table is then used in the returned function.
The type of this function would be something like

  memo (Enumerate a, Eq a) => (a->b) -> (a->b)

assuming that the class Enumerate has the magic method enumerate.

This just a very simplified example, if you tried to use this it would be terrible because the returned function does linear lookup in a list. Instead we want some kind of search tree, which is what the two packages I mention implement. The MemoTrie package does this in a really beautiful way, I recommend reading Conal's blog post about it.
OK, enough preliminaries. I used criterion to perform the benchmarking, and I tried with no memoization (none), memo-combinators (comb), and MemoTrie (beau). I had a test function taking about 10us, and then i called this functions with different number of repeated arguments: 1, 2, 5, and 10. I.e., 5 means that each argument occurred 5 times as the memoized function was called.

1 2 5 10
none 10.7 10.7 10.7 10.7
comb 62.6 52.2 45.8 43.4
beau 27.6 17.0 10.4 8.1

So with no memoization the time per call was 10.7 us all the time, no surprise there. With the memo combinators it was much slower than no memoization; the overhead for looking something up is bigger than the cost of computing the result. So that was a failure. The MemoTrie does better, at about an argument repetition of five it starts to break even, and at ten it's a little faster to memoize.

Since I estimated my repetition factor in the real code to be about five even the fastest memoization would not be any better then recomputation. So now what? Give up? Of course not! It's time to get dirty.

Once you know a function can be implemented in a pure way, there's no harm in implementing the same function in an impure way as long as it presents the pure interface. So lets write the memo function the way it would be done in, e.g., Scheme or ML. We will use a reference to hold a memo table that gets updated on each call. Here's the code, with the type that the function gets.

import Data.IORef
import qualified Data.Map as M

memoIO :: (Ord a) => (a -> b) -> IO (a -> IO b)
memoIO f = do
  v <- newIORef M.empty
  let f' x = do
        m <- readIORef v
        case M.lookup x m of
          Nothing -> do let { r = f x }; writeIORef v (M.insert x r m); return r
          Just r  -> return r
  return f'

The memoIO allocated a reference with an empty memo table.
We then define a new function, f', which when it's called
with get the memo table and look up the argument.  If the argument is
in the table then we just return the result, if it's not then we
compute the result, store it in the table, and return it.
Good old imperative programming (see below why this code is not good
imperative code).

But, horror, now the type is all wrong, there's IO in two places. The function we want to implement is actually pure. So what to do? Well, if you have a function involving the IO type, but you can prove it is actually pure, then (and only then) you are allowed to use unsafePerformIO.

I'll wave my hands instead of a proof (but more later), and here we go

  memo :: (Ord a) => (a -> b) -> (a -> b)
  memo f = let f' = unsafePerformIO (memoIO f) in \ x -> unsafePerformIO (f' x)

Wow, two unsafePerformIO on the same line.  It doesn't get
much less safe than that.
Let's benchmark again:

1 2 5 10
none 10.7 10.7 10.7 10.7
comb 62.6 52.2 45.8 43.4
beau 27.6 17.0 10.4 8.1
ugly 13.9 7.7 3.9 2.7

Not too shabby, using the ugly memoization is actually a win already at two, and just a small overhead if the argument occurs once.  We have a winner!

No so fast, there's

A snag

My real code can actually be multi-threaded, so the memo function had better work in a multi-threaded setting. Well, it doesn't. There's no guarantee about readIORef and writeIORef when doing multi-threading.
So we have to rewrite it. Actually, the code I first wrote is the one below; I hardly ever use IORef because I want it to work with multi-threading.

memoIO f = do
  v <- newMVar M.empty
  let f' x = do
        m <- takeMVar v
        case M.lookup x m of
          Nothing -> do let { r = f x }; putMVar v (M.insert x r m); return r
          Just r  -> do                  putMVar v m;                return r
  return f'

So now we use an MVar instead.  This makes it thread safe.
Only one thread can execute between the takeMVar and the
putMVar.  This guarantees than only one thread can update the
memo table at a time.  If two threads try at the same time one has to
wait a little.  How long?  The time it takes for the lookup, plus some
small constant.  Remember that Haskell is lazy, the the (f x)
is not actually computed with the lock held, which is good.
So I think this is a perfectly reasonable memoIO. And we can do the same unsafe trick as before and make it pure. Performance of this version is the same as with the IORef

Ahhhh, bliss. But wait, there's

Another snag

That might look reasonable, but in fact the memo function is broken now. It appears to work, but here's a simple use that fails

  sid :: String ->; String
  sid = memo id

  fcn s = sid (sid s)

What will happen here?  The outer call to sid will execute
the takeMVar and then do the lookup.  Doing the lookup with
evaluate the argument, x.  But this argument is another call
to sid, this will try to execute the takeMVar.
Disaster has struck, deadlock.

What happened here? The introduction of unsafePerformIO ruined the sequencing guaranteed by the IO monad that would have prevented the deadlock if we had used memoIO. I got what I deserved for using unsafePerformIO.

Can it be repaired? Well, we could make sure x is fully evaluated before grabbing the lock. I settled for a different repair, where the lock is held in a shorter portion of the code.

memoIO f = do
  v <- newMVar M.empty
  let f' x = do
        m <- readMVar v
        case M.lookup x m of
          Nothing -> do let { r = f x }; m <- takeMVar v; putMVar v (M.insert x r m); return r
          Just r  -> return r
  return f'

This solution has its own problem.  It's now possible for several threads
to compute (f x) for the same x and the result of
all but one of those will be lost by overwriting the table.  This is a
price I'm willing to pay for this application.


Moral

Yes, you can use imperative programming to implement pure functions. But the onus is on you to prove that it is safe. This is not as easy as you might think. I believe my final version is correct (with the multiple computation caveat), but I'm not 100% sure.

24 comments:

  1. Wouldn't atomicallyModifyIORef do instead of using MVars? I understand it's much faster.

    You'd still have the race where multiple threads could calculate the value, but I think that's unavoidable in this kind of memoisation.

    ReplyDelete
  2. Ah, you said that using MVars was the same observed speed as IORefs anyway, so I guess it doesn't matter.

    ReplyDelete
  3. Using Data.HashMap from unordered-containers might be faster than using Data.Map, as long as your functions arguments are hashable.

    ReplyDelete
  4. I was more worried that your first definition didn't permit memoization of recursive functions (as it would deadlock during the recursive call)!

    Now: how do you do evaluate-just-once memoization that permits recursion or parallelism? With the pure solution it's easy but evidently slow. It seems like we ought to be able to insert a thunk and then return the retrieved thunk, evaluating it after we're done mutating the table.

    ReplyDelete
  5. @Jan Recursion is not as bad as you might think in the first MVar version. The only crucial evaluation happening with the lock held is of the argument to the function. The actual call to f is just a thunk.

    ReplyDelete
  6. Have you released this library on hackage? Really fast (and not just beautiful) memoization seems like a worthy endeavor, and I have some code which I'd like to try it on... I've tried both MemoTrie and memocombinators and gotten some good speedups, but it's still not as fast as I like and the main bottleneck is in the memo code.

    ReplyDelete
  7. Now i am confused... didn't default Haskell behavior implies some sort of memorization?

    (I'm a Haskell newbie in this context)

    In fact, I was convinced that this was one of the biggest advantages of "pureness".

    ReplyDelete
  8. @Tom I'll put the code on hackage, even if it's just 10 lines.

    ReplyDelete
  9. @Francisco Haskell does not memoize functions.

    ReplyDelete
  10. @Lennart And why not? I would thought it was straightforward to do it...

    Can you point me to some design documents or something similar where I can read about this?

    ReplyDelete
  11. @Francisco It's not straight forward. First, to memoize you need to be able to build some kind of table indexed by the function argument. How would do that? Haskell values does not come with any equality or ordering or hashing. And even if they did, how far would you evaluate the argument?

    But much worse, there's a cost associated with memoization. Both in time for the lookup and insertion in the table, and also in space for storing the table. This isn't a cost you want to pay by default.

    ReplyDelete
  12. Could you please elaborate on what might go wrong with the IORef version in a multithreaded program? Is it just that the function might be recomputed on the same arguments multiple times, or could something worse happen?

    ReplyDelete
  13. @Bernie From what I understand, there are no atomicity guarantees whatsoever for IORef, meaning that not even read & write are guaranteed to be atomic. So then even those could give the wrong results. I'm not sure if that currently happens on any platform.

    The documentation says atomicModifyIORef is the safe way to access and modify an IORef in a muti-threaded setting.

    ReplyDelete
  14. @Johan Yes, Data.HaskMap speeds up the 10x from 2.7us to 1.7us. Thanks for the tip.

    ReplyDelete
  15. @francisco: what got me confused in the past is the distinction between memoizaion (you call a function twice with the same arguments, and it is only evaluated the first time) and the funny treatment of names in haskell: when you write something like "x = f 4 2", you cannot change the meaning of x in the future (you can just define a new x to shadow the old one). the meaning is not calculated where it is defined, but lazily, where it is required for the result of the program. the analogy to memoization is that you can use it several times and it is only calculated once. the difference is that it's a name, not a complex expression, ie. a function call.

    not sure if this is useful...

    ReplyDelete
  16. @mf I think I understood everything you said except this:

    the difference is that it's a name, not a complex expression, ie. a function call.

    ReplyDelete
  17. A functor to create memoized functions is very cool, but perhaps unnecessary.

    Check out memoized_fib on the Haskell Wiki.

    ReplyDelete
  18. takeMVar/putMVar and asynchronous exceptions interact badly, don't they, so how about using modifyMVar_ instead?

    ReplyDelete
  19. Excuse me for this simple question. Does the fact that the Map is referenced by an IORef make it mutable? in other words does the insert operation create a new map (copies memory) or it just modifies the old one?

    ReplyDelete
  20. @francisco: Haskell's call-by-need (lazy evaluation) memoizes function arguments, but not functions.

    =====verbose explanation======

    The arguments to a function are thunked, meaning the arguments get evaluated only once, and only when they are needed inside the function. This is not the same as checking if a function call was previously called with the same arguments.

    If the argument is a function, the thunk will call it once without checking if that function had been called else where with the same arguments.

    Thunks are conceptually similar to parameterless anonymous functions with a closure on the argument, a boolean, and a variable to store the result of the argument evaluation. Thus thunks incur no lookup costs, because they are parameterless. The cost of the thunk is the check on the boolean.

    Thunks give the same amount of memoization as call-by-value (which doesn't use thunks). Neither call-by-need nor call-by-value memoize function calls. Rather both do not evaluate the same argument more than once. Call-by-need delays that evaluation with a thunk until the argument is first used within the function.

    Apologies for being so verbose, but Google didn't find an explanation that made this clear, so I figured I would cover all the angles in this comment.

    ReplyDelete
  21. Is this implementation asynchronous-exception-safe? I'm worried there could be an asynchronous exception between the takeMVar and the putMVar, and the MVar would be empty forever. Perhaps we need to use protect or modifyMVar, as in http://community.haskell.org/~simonmar/par-tutorial.pdf

    ReplyDelete
  22. Oh, sorry, I see I'm not the first person to say this!

    ReplyDelete