Continuations From the Ground Up
6 June, 2017
It’s difficult to learn functional programming without hearing about continuations. Often they’re mentioned while talking about boosting the performance of pure functional code, sometimes there’s talk of control flow, and occasionally with ‘time-travel’ thrown in there to make it all seem more obscure. It’s all true, but let’s start from the beginning.
This post was generated from a Literate Haskell file using Pandoc, so you can load it up into GHCI and play around if you want.
Continuation Passing Style
module CPS where
import Control.Applicative
import Control.Monad.Trans.Class
import Data.IORef
import Data.Maybe
import qualified Data.Map as M
import qualified System.Exit as E
The main idea of this style is that the called function has control over how its return value is used. Usually, the caller will pass a function that tells the callee how to use its return value. Here’s what that looks like:
add :: Num a => a -> a -> (a -> r) -> r
= \k -> k (a + b) add a b
add
takes two numbers, plus a function that will take the result and do something (returning an unknown answer), then pass the result to this function. We call this ‘extra function’ a continuation because it specifies how the program should continue.
It’s possible to write any program using this style. I’m not going to prove it. As a challenge, let’s restrict ourselves to write every function this way, with two exceptions:
exitSuccess :: a -> IO ()
= E.exitSuccess
exitSuccess _
exitFailure :: Int -> IO ()
= E.exitWith . E.ExitFailure exitFailure
exitSuccess
and exitFailure
do not take a continuation, because the program always ends when they are called.
Let’s define mul
and dvd
:
mul :: Num a => a -> a -> (a -> r) -> r
= \k -> k (a * b)
mul a b
dvd :: Fractional a => a -> a -> (a -> r) -> r
= \k -> k (a / b) dvd a b
Now we can write some programs using this style.
-- Exits with status code 5
prog_1 :: IO ()
= add 2 3 exitFailure
prog_1
-- Exits successfully after multiplying 10 by 10
prog_2 :: IO ()
= mul 10 10 exitSuccess
prog_2
-- Exits with status code (2 + 3) * 5 = 25
prog_3 :: IO ()
= add 2 3 (\two_plus_three -> mul two_plus_three 5 exitFailure) prog_3
We can factor out the continuation to make our program more modular:
-- Equivalent to \k -> k ((2 + 3) * 5)
prog_4 :: (Int -> r) -> r
= \k -> add 2 3 (\two_plus_three -> mul two_plus_three 5 k)
prog_4
-- Equivalent to \k -> k ((2 + 3) * 5 + 5)
prog_5 :: (Int -> r) -> r
= \k -> prog_4 (\res -> add res 5 k) prog_5
In these kind of definitions, we’ll call the k
the current continuation to stand for how the program will (currently) continue execution.
Here’s a more complex expression:
-- (2 + 3) * (7 + 9) + 5
prog_6 :: Num a => (a -> r) -> r
= \k ->
prog_6 2 3 (\five ->
add 7 9 (\sixteen ->
add ->
mul five sixteen (\eighty 5 k))) add eighty
In translating programs to continuation passing style, we transform a tree of computations into a sequence of computations. In doing so, we have reified the flow of the program. We now have a data structure in memory that represents the computations that make up the program. In this case, the data structure is a lot like a linked list- there is a head: ‘the computation that will be performed next’, and a tail: ‘the computations that will be performed on the result’. It’s this ability to represent the flow of the program as a data structure that sets CPS programs apart from regular programs, which we will see later.
Continuation Passing is Monadic
Right now, writing CPS programs in Haskell is too verbose. Fortunately there are some familiar abstractions that will make it elegant:
newtype Cont r a = Cont { runCont :: (a -> r) -> r }
add' :: Num a => a -> a -> Cont r a
= Cont $ add a b
add' a b
mul' :: Num a => a -> a -> Cont r a
= Cont $ mul a b
mul' a b
dvd' :: Num a => a -> a -> Cont r a
= Cont $ dvd a b
dvd' a b
instance Functor (Cont r) where
fmap f c = Cont $ \k -> runCont c (k . f)
instance Applicative (Cont r) where
pure a = Cont ($ a)
<*> ca = Cont $ \k -> runCont cf (\f -> runCont ca (\a -> k (f a)))
cf
instance Monad (Cont r) where
>>= f = Cont $ \k -> runCont ca (\a -> runCont (f a) k) ca
It turns out that the return type of these CPS programs, (a -> r) -> r
, is a Monad. If you don’t understand these implementations, meditate on them until you do. Here some hand-wave-y English explanations that may help:
Functor
fmap
: Continue with the result of c
by changing the result from an a
to a b
then sending that result to the current continuation.
Applicative
pure
: Send an a
to the current continuation
<*>
: Continue with the result of cf
by continuing with the result of ca
by sending (the result of cf
) applied to (the result of ca
) to the current continuation.
Monad
>>=
: Continue with the result of ca
by applying it to f
and passing the current continuation on to the value f
returned.
So now we can rewrite our previous verbose example:
prog_6' :: Cont r Int
= do
prog_6' <- add' 2 3
five <- add' 7 9
sixteen <- mul' five sixteen
eighty 5 add' eighty
and run it:
prog_7 :: IO ()
= runCont prog_6' exitSuccess prog_7
callCC
Consider the following CPS program:
prog_8 :: (Eq a, Fractional a) => a -> a -> a -> (Maybe a -> r) -> r
= \k ->
prog_8 a b c
add b c->
(\b_plus_c if b_plus_c == 0
then k Nothing
else dvd a b_plus_c (k . Just))
It adds b
to c
, then if b + c
is zero, sends Nothing
to the current continuation, otherwise divides a
by b + c
then continues by wrapping that in a Just
and sending the Just
result to the current continuation.
Because the current continuation is ‘how the program will continue with the result of this function’, sending a result to the current continuation early cause the function to exit early. In this sense, it’s a bit like like a jmp
or a goto
.
It is conceivable that somehow we can write a program like this using the Cont
monad. This is where callCC
comes in.
callCC
stands for ‘call with current continuation’, and is the way we’re going to bring the current continuation into scope when writing CPS programs. Here’s an example of how the previous code snippet should look using callCC
:
prog_8' :: (Eq a, Fractional a) => a -> a -> a -> Cont r (Maybe a)
= callCC $
prog_8' a b c -> do
\k <- add' b c
b_plus_c if b_plus_c == 0
then k Nothing
else fmap Just $ dvd' a b_plus_c
Here’s how callCC
is defined:
callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a
= Cont $ \k -> runCont (f (\a -> Cont $ const (k a))) k callCC f
We can see that the current continuation is permanently captured when it is used in the function passed to f
, but it is also used when running the final result of f
. So k
might be called somewhere inside f
, causing f
to exit early, or it might not, in which case k
is guaranteed to be called after f
has finished.
Earlier I said that invoking the current continuation earlier is like jumping. This is a lot easier to show now that we can use it in our Cont
monad. Calling the continuation provided by callCC
will jump the program execution to immediately after the call to callCC
, and set the result of the callCC
continuation to the argument that was passed to the current continuation.
= do
prog_9 <- add' 2 3
five <- callCC $ \k ->
res -- current continuation is never used, so `callCC` is redundant
4 5
mul' -- `res` = 20
add' five res
= do
prog_10 <- add' 2 3
five <- callCC $ \k -> do
res 5
k 4 5 -- this computation is never run
mul' -- program jumps to here, `res` = 5
add' five res
= do
prog_11 <- add' 2 3
five <- callCC $ \k -> do
res if five > 10
then k 10 -- branch A
else mul' 4 5 -- branch B
-- if branch A was reached, `res` = 10
-- if branch B was reached, `res` = 20
add' five res
Another level of abstraction
We can also embed arbitrary effects in the return type of Cont
. In other words, we can create a monad transformer.
newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
callCC' :: ((a -> ContT r m b) -> ContT r m a) -> ContT r m a
= ContT $ \k -> runContT (f (\a -> ContT $ const (k a))) k
callCC' f
instance Functor (ContT r m) where
fmap f c = ContT $ \k -> runContT c (k . f)
instance Applicative (ContT r m) where
pure a = ContT ($ a)
<*> ca = ContT $ \k -> runContT cf (\f -> runContT ca (\a -> k (f a)))
cf
instance Monad (ContT r m) where
>>= f = ContT $ \k -> runContT ca (\a -> runContT (f a) k)
ca
instance MonadTrans (ContT r) where
= ContT $ \k -> ma >>= k lift ma
Notice that the Functor
, Applicative
and Monad
instances for ContT r m
don’t place any constraints on the m
. This means that any type constructor of kind (* -> *)
can be in the m
position. The MonadTrans
instance, however, does require m
is a monad. It’s a very simple definition- the result of running the lifted action is piped into the current continuation using >>=
.
Now that we have a fully-featured CPS monad, we can start doing magic.
The future, the past and alternate timelines
The continuation that callCC
provides access to is the current future of program execution as a single function. That’s why this program-as-a-linear-sequence is so powerful. If you could save the current continuation and call it at a later time somewhere else in your (CPS) program, it would jump ‘back in time’ to the point after that particular callCC
.
To demonstrate this, and end with a bang, here’s a simple boolean SAT solver.
-- Language of boolean expressions
data Expr
= Implies Expr Expr
| Iff Expr Expr
| And Expr Expr
| Or Expr Expr
| Not Expr
| Val Bool
| Var String
deriving (Eq, Show)
-- Reduces a boolean expression to normal form, substituting variables
-- where possible. There are also some equivalences that are necessary to get
-- the SAT solver working e.g. Not (Not x) = x (I said it was a simple one!)
eval :: M.Map String Expr -- ^ Bound variables
-> Expr -- ^ Input expression
-> Expr
=
eval env expr case expr of
Implies p q -> eval env $ Or (Not p) q
Iff p q -> eval env $ Or (And p q) (And (Not p) (Not q))
And a b ->
case (eval env a, eval env b) of
Val False, _) -> Val False
(Val False) -> Val False
(_, Val True, b') -> b'
(Val True) -> a'
(a', -> And a' b'
(a', b') Or a b ->
case (eval env a, eval env b) of
Val True, _) -> Val True
(Val True) -> Val True
(_, Val False, b') -> b'
(Val False) -> a'
(a',
(a', b')| a' == eval env (Not b') -> Val True
| otherwise -> Or a' b'
Not a ->
case eval env a of
Val True -> Val False
Val False -> Val True
Not a' -> a'
-> Not a'
a' Val b -> Val b
Var name -> fromMaybe (Var name) (M.lookup name env)
-- Returns `Nothing` if the expression is not satisfiable
-- If the epxression is satisfiable returns `Just mapping` with a valid
-- variable mapping
sat :: Expr -> ContT r IO (Maybe (M.Map String Expr))
= do
sat expr -- A stack of continuations
<- lift $ newIORef []
try_next_ref
$ \exit -> do
callCC' -- Run `go` after reducing the expression to normal form without any
-- variable values
<- go (eval M.empty expr) try_next_ref exit
res case res of
-- If there was a failure, backtrack and try again
Nothing -> backtrack try_next_ref exit
Just vars -> case eval vars expr of
-- If the expression evaluates to true with the results of `go`, finish
Val True -> exit res
-- Otherwise, backtrack and try again
-> backtrack try_next_ref exit
_
where
-- To backtrack: try to pop a continuation from the stack. If there are
-- none left, exit with failure. If there is a continuation then enter it.
= do
backtrack try_next_ref exit <- lift $ readIORef try_next_ref
try_next case try_next of
-> exit Nothing
[] :rest -> do
next$ writeIORef try_next_ref rest
lift
next
-- It's a tree traversal, but with some twists
=
go expr try_next_ref exit case expr of
-- Twist 1: When we encounter a variable, we first continue as if it's
-- true, but also push a continuation on the stack where it is set to false
Var name -> do
<- callCC' $ \k -> do
res $ modifyIORef try_next_ref (k (Val False) :)
lift pure $ Val True
-- When this program is first run, `res` = True. But if we pop and
-- enter the result of `k (Val False)`, we would end up back here
-- again, with `res` = False
pure $ Just (M.singleton name res)
Val b -> pure $ if b then Just M.empty else Nothing
-- Twist 2: When we get to an Or, only one of the sides needs to be
-- satisfied. So we first continue by checking the left side, but also
-- push a continuation where we check the right side instead.
Or a b -> do
<- callCC' $ \k -> do
side $ modifyIORef try_next_ref (k b :)
lift pure a
-- Similar to the `Var` example. First ruvn, `side` = a. But if later
-- we enter the saved continuation then we will return to this point
-- in the program with `side` = b
go side try_next_ref exit And a b -> do
<- go a try_next_ref exit
a_res <- go b try_next_ref exit
b_res pure $ liftA2 M.union a_res b_res
Not a -> go a try_next_ref exit
-> go (eval M.empty expr) try_next_ref exit _
The solver sets all the variables to True
, and if the full expression evaluates to False
it flips one to False
and automatically re-evaluates the expression, repeating the process untill either it finally evaluates to True
or all possible combinations of boolean values have been tested. It’s not efficient, but it’s a wonderful illustration of the elegance that CPS enables.