跳至内容

Haskell/Continuation Passing Style

来自维基教科书,开放的书籍,开放的世界
(重定向自 Haskell/CPS)
Continuation Passing Style (简称 CPS) 是一种编程风格,其中函数不返回值;而是将控制权传递给一个 continuation,它指定了接下来会发生什么。在本章中,我们将考虑这在 Haskell 中是如何实现的,特别是如何用单子来表达 CPS。
高级 Haskell

幺半群
Applicative 函子
可折叠的
可遍历的
箭头教程
理解箭头 75% developed
Continuation Passing Style
Zippers 75% developed
透镜
共单子 0% developed
值递归 (MonadFix)
有副作用的流
可变对象 50% developed
并发 0% developed
模板 Haskell 0% developed
类型族 0% developed

Continuation Passing Style (简称 CPS) 是一种编程风格,其中函数不返回值;而是将控制权传递给一个 continuation,它指定了接下来会发生什么。在本章中,我们将考虑这在 Haskell 中是如何实现的,特别是如何用单子来表达 CPS。

什么是Continuation?

[edit | edit source]

为了消除困惑,我们将再次回顾书中很早之前的一个例子,当我们介绍 ($) 运算符时

> map ($ 2) [(2*), (4*), (8*)]
[4,8,16]

上面的表达式没有什么特别之处,除了用它而不是 map (*2) [2, 4, 8] 来写有点古怪。($) 部分使代码看起来像反过来的,就好像我们把值应用到函数而不是相反。现在,重点来了:这种看起来无害的反转正是 Continuation Passing Style 的核心!

从 CPS 的角度来看,($ 2) 是一个 挂起的计算:一个具有通用类型 (a -> r) -> r 的函数,它接受另一个函数作为参数,并产生一个最终结果。a -> r 参数是 continuation;它指定了计算如何结束。在本例中,列表中的函数通过 map 作为 continuation 提供,产生三个不同的结果。请注意,挂起的计算在很大程度上可以与普通值互换:flip ($) [1] 将任何值转换为挂起的计算,并将 id 作为其 continuation 传递会返回原始值。

它们有什么用处?

[edit | edit source]

Continuation 不仅仅是用来给 Haskell 新手留下深刻印象的技巧。它们使程序能够显式地操作和显著地改变程序的控制流。例如,可以利用 continuation 从过程提前返回。异常和失败也可以用 continuation 处理 - 传递一个 continuation 表示成功,另一个 continuation 表示失败,并调用相应的 continuation。其他可能性包括“挂起”计算并在以后返回,以及实现简单的并发形式(特别是,Haskell 的一个实现 Hugs 使用 continuation 来实现协作式并发)。

在 Haskell 中,continuation 可以以类似的方式使用,用于在单子中实现有趣的控制流。请注意,对于此类用例通常存在替代技术,尤其是在与惰性结合使用的情况下。在某些情况下,CPS 可以通过消除某些构造模式匹配序列来提高性能(例如,函数返回一个复杂结构,调用者会在某个时刻将其分解),尽管足够聪明的编译器应该能够消除此类序列[2]

传递Continuation

[edit | edit source]

利用 continuation 的一种基本方法是对我们的函数进行修改,使它们返回挂起的计算而不是普通值。我们将通过两个简单的例子来说明如何做到这一点。

pythagoras

[edit | edit source]

示例:一个简单的模块,没有 continuation

-- We assume some primitives add and square for the example:

add :: Int -> Int -> Int
add x y = x + y

square :: Int -> Int
square x = x * x

pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)

修改为返回挂起的计算,pythagoras 看起来像这样

示例:一个简单的模块,使用 continuation

-- We assume CPS versions of the add and square primitives,
-- (note: the actual definitions of add_cps and square_cps are not
-- in CPS form, they just have the correct type)

add_cps :: Int -> Int -> ((Int -> r) -> r)
add_cps x y = \k -> k (add x y)

square_cps :: Int -> ((Int -> r) -> r)
square_cps x = \k -> k (square x)

pythagoras_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_cps x y = \k ->
 square_cps x $ \x_squared ->
 square_cps y $ \y_squared ->
 add_cps x_squared y_squared $ k

pythagoras_cps 示例是如何工作的

  1. 对 x 进行平方并将结果传递到 (\x_squared -> ...) continuation 中
  2. 对 y 进行平方并将结果传递到 (\y_squared -> ...) continuation 中
  3. 将 x_squared 和 y_squared 相加并将结果传递到顶层/程序 continuation k 中。

我们可以通过将 print 作为程序 continuation 传递来在 GHCi 中尝试它

*Main> pythagoras_cps 3 4 print
25

如果我们查看 pythagoras_cps 的类型(没有在 (Int -> r) -> r 周围加上可选的括号),并将其与 pythagoras 的原始类型进行比较,我们会注意到,continuation 实际上被添加为一个额外的参数,因此证明了“continuation passing style”这个名字的合理性。

示例: 一个简单的更高阶函数,没有延续

thrice :: (a -> a) -> a -> a
thrice f x = f (f (f x))
*Main> thrice tail "foobar"
"bar"

thrice 这样的更高阶函数,当转换为 CPS 时,也会以 CPS 形式将函数作为参数。因此,f :: a -> a 将变成 f_cps :: a -> ((a -> r) -> r),最终类型将是 thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)。定义的其余部分很自然地遵循类型——我们用 CPS 版本替换 f,并传递现有的延续。

示例: 一个简单的更高阶函数,带延续

thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)
thrice_cps f_cps x = \k ->
 f_cps x $ \fx ->
 f_cps fx $ \ffx ->
 f_cps ffx $ k


Cont 单子

[编辑 | 编辑源代码]

有了延续传递函数,下一步是提供一种简洁的方式来组合它们,最好是能避免我们上面看到的长串嵌套 lambda。一个好的开始是为将 CPS 函数应用于挂起计算提供一个组合子。它可能的类型是

chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)

(你可能想在继续阅读之前尝试实现它。提示:首先说明结果是一个函数,它接受一个 b -> r 延续;然后,让类型来指导你。)

这是实现

chainCPS s f = \k -> s $ \x -> f x $ k

我们用一个新的挂起计算 (由 f 生成) 来提供原始的挂起计算 s,并将最终的延续 k 传递给它。不出所料,它与前面示例中嵌套的 lambda 模式非常相似。

chainCPS 的类型看起来很熟悉吗?如果我们将 (a -> r) -> r 替换为 (Monad m) => m a,并将 (b -> r) -> r 替换为 (Monad m) => m b,我们将得到 (>>=) 签名。此外,我们的老朋友 flip ($) 扮演着类似 return 的角色,因为它以一种微不足道的方式从一个值中生成一个挂起计算。瞧!我们得到了一个单子!现在我们只需要 [3] 一个 Cont r a 类型来包装挂起计算,以及通常的包装器和解包器函数。

cont :: ((a -> r) -> r) -> Cont r a
runCont :: Cont r a -> (a -> r) -> r

Cont 的单子实例直接来自我们的演示,唯一的区别是包装和解包的繁琐。

instance Monad (Cont r) where
    return x = cont ($ x)
    s >>= f  = cont $ \c -> runCont s $ \x -> runCont (f x) c

最终的结果是,单子实例使得延续传递 (以及因此的 lambda 链) 变得隐式。单子绑定将 CPS 函数应用于挂起计算,而 runCont 用于提供最终的延续。例如,毕达哥拉斯示例变为

示例: 使用 Cont 单子的 pythagoras 示例

-- Using the Cont monad from the transformers package.
import Control.Monad.Trans.Cont

add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

square_cont :: Int -> Cont r Int
square_cont x = return (square x)

pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y = do
    x_squared <- square_cont x
    y_squared <- square_cont y
    add_cont x_squared y_squared

虽然看到一个单子自然地出现总是令人高兴,但此时可能还会有一丝失望。CPS 的承诺之一是通过延续来精确地控制流程操作。然而,在将我们的函数转换为 CPS 后,我们立即将延续隐藏在一个单子后面。为了纠正这一点,我们将介绍 callCC,这个函数让我们能够明确控制延续——但只有在我们想要的地方。

callCC 是一个非常奇特的函数;最好用例子来介绍它。让我们从一个简单的例子开始

示例: 使用 callCCsquare

-- Without callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)

-- With callCC
squareCCC :: Int -> Cont r Int
squareCCC n = callCC $ \k -> k (n ^ 2)

传递给 callCC 的参数是一个函数,其结果是一个挂起计算 (通用类型 Cont r a),我们将其称为“callCC 计算”。原则上callCC 计算是整个 callCC 表达式的求值结果。需要注意的是,使 callCC 如此特殊的是 k,即参数的参数。它是一个函数,充当弹出按钮:在任何地方调用它都会导致传递给它的值变成一个挂起计算,然后该计算会在 callCC 调用的地方插入到控制流中。这是无条件发生的;特别是,callCC 计算中 k 调用后的任何内容都会被直接丢弃。从另一个角度来看,k 捕获了紧随callCC剩余计算;调用它会将一个值抛入那个特定点 (“callCC” 代表 “call with current continuation”) 的延续中。虽然在这个简单的示例中,效果仅仅是普通 return 的效果,但 callCC 打开了许多可能性,我们现在将开始探索这些可能性。

决定何时使用 k

[编辑 | 编辑源代码]

callCC 为我们提供了对抛入延续的内容以及何时抛入的额外控制权。下面的示例开始展示如何使用这种额外控制权。

示例: 我们的第一个真正的 callCC 函数

foo :: Int -> Cont r String
foo x = callCC $ \k -> do
    let y = x ^ 2 + 3
    when (y > 20) $ k "over twenty"
    return (show $ y - 4)

foo 是一个有点病态的函数,它计算输入的平方并加上 3;如果此计算的结果大于 20,那么我们立即从 callCC 计算 (在本例中,从整个函数) 中返回,并将字符串 "over twenty" 抛入传递给 foo 的延续中。否则,我们将从之前的计算中减去 4,将其 show,并将其抛入延续中。值得注意的是,这里的 k 与命令式语言中的 return 语句的使用方式相同,即立即退出函数。然而,由于这是 Haskell,k 只是一个普通的头等函数,因此你可以将其传递给其他函数(如 when),将其存储在 Reader 中等等。

当然,你可以在 do 块中嵌入对 callCC 的调用

示例: 涉及 do 块的更完善的 callCC 示例

bar :: Char -> String -> Cont r Int
bar c s = do
    msg <- callCC $ \k -> do
        let s0 = c : s
        when (s0 == "hello") $ k "They say hello."
        let s1 = show s0
        return ("They appear to be saying " ++ s1)
    return (length msg)

当你用一个值调用 k 时,整个 callCC 调用会获取该值。实际上,这使得 k 非常类似于其他语言中的 “goto” 语句:当我们在示例中调用 k 时,它会将执行弹出到第一次调用 callCC 的地方,即 msg <- callCC $ ... 行。不会再执行 callCC 的参数 (内部 do 块)。因此,下面的示例包含一行无用的代码

示例: 弹出函数,引入一行无用的代码

quux :: Cont r Int
quux = callCC $ \k -> do
    let n = 5
    k n
    return 25

quux 将返回 5,而不是 25,因为我们在到达 return 25 行之前就从 quux 弹出了。

我们故意在这里打破了一种趋势:通常,当我们引入一个函数时,我们立即给出它的类型,但在这种情况下,我们选择不这样做。原因很简单:该类型非常复杂,它不会立即让人洞悉该函数的功能或工作原理。然而,在最初介绍 callCC 之后,我们更有能力处理它。深呼吸……

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

我们可以根据我们已经知道的关于 callCC 的内容来理解这一点。总体结果类型和参数的结果类型必须相同 (即 Cont r a),因为在没有调用 k 的情况下,对应的结果值是同一个值。那么,k 的类型呢?如上所述,k 的参数被转化为一个挂起计算,该计算被插入到 callCC 调用的地方;因此,如果后者具有类型 Cont r ak 的参数必须具有类型 a。至于 k 的结果类型,有趣的是,只要它被包装在相同的 Cont r 单子中,它实际上并不重要;换句话说,b 代表一个任意类型。这是因为,由 a 参数生成的挂起计算将接收紧随 callCC 的任何延续,因此 k 的结果所采用的延续是无关紧要的。

注意

k 的结果类型的任意性解释了为什么下面的无用代码示例变体会导致类型错误

quux :: Cont r Int
quux = callCC $ \k -> do
   let n = 5
   when True $ k n
   k 25

k 的结果类型可以是任何形式为 Cont r b 的类型;然而,when 将其限制为 Cont r (),因此结尾的 k 25quux 的结果类型不匹配。解决方法非常简单:将最后的 k 替换为一个普通的 return


为了结束本节,以下是 callCC 的实现。你能在其中识别出 k 吗?

callCC f = cont $ \h -> runCont (f (\a -> cont $ \_ -> h a)) h

虽然代码远非显而易见,但一个令人惊奇的事实是,callCCreturn(>>=)Cont 实现可以从它们的类型签名自动生成——Lennart Augustsson 的 Djinn [1] 是一个可以为你执行此操作的程序。有关 Djinn 背后理论的背景信息,请参见 Phil Gossett 的 Google 技术讲座:[2];有关使用 Djinn 推导延续传递风格的信息,请参见 Dan Piponi 的文章:[3]

示例:一个复杂的控制结构

[编辑 | 编辑源代码]

现在我们将研究一些更真实的控制流操作示例。下面的第一个示例最初取自 关于单子的所有教程 的“Continuation 单子”部分,经许可使用。

示例: 使用 Cont 来构建一个复杂的控制结构

{- We use the continuation monad to perform "escapes" from code blocks.
This function implements a complicated control structure to process
numbers:

Input (n)     Output                    List Shown
=========     ======                    ==========
0-9           n                         none
10-199        number of digits in (n/2) digits of (n/2)
200-19999     n                         digits of (n/2)
20000-1999999 (n/2) backwards           none
>= 2000000    sum of digits of (n/2)    digits of (n/2)
-} 
fun :: Int -> String
fun n = (`runCont` id) $ do
    str <- callCC $ \exit1 -> do                            -- define "exit1"
        when (n < 10) (exit1 (show n))
        let ns = map digitToInt (show (n `div` 2))
        n' <- callCC $ \exit2 -> do                         -- define "exit2"
            when ((length ns) < 3) (exit2 (length ns))
            when ((length ns) < 5) (exit2 n)
            when ((length ns) < 7) $ do
                let ns' = map intToDigit (reverse ns)
                exit1 (dropWhile (=='0') ns')               --escape 2 levels
            return $ sum ns
        return $ "(ns = " ++ (show ns) ++ ") " ++ (show n')
    return $ "Answer: " ++ str

fun 是一个接受整数 n 的函数。实现使用 ContcallCC 来设置一个使用 ContcallCC 的控制结构,根据 n 所处的范围执行不同的操作,如顶部的注释所述。让我们来剖析它。

  1. 首先,顶部的 (`runCont` id) 只是意味着我们运行后面的 Cont 块,并使用 id 作为最终的延续(或者,换句话说,我们从挂起的计算中提取值而不改变它)。这是必要的,因为 fun 的结果类型没有提到 Cont
  2. 我们将 str 绑定到以下 callCC do 块的结果。
    1. 如果 n 小于 10,我们直接退出,只显示 n
    2. 否则,我们继续执行。我们构造一个 ns 列表,它包含 n `div` 2 的数字。
    3. n'(一个 Int)被绑定到以下内部 callCC do 块的结果。
      1. 如果 length ns < 3,即如果 n `div` 2 的位数少于 3 位,我们将从这个内部 do 块退出,返回位数作为结果。
      2. 如果 n `div` 2 的位数少于 5 位,我们将从内部 do 块退出,返回原始的 n
      3. 如果 n `div` 2 的位数少于 7 位,我们将从内部外部两个 do 块退出,结果为 n `div` 2 的数字以逆序排列(一个 String)。
      4. 否则,我们结束内部 do 块,返回 n `div` 2 的数字之和。
    4. 我们结束这个 do 块,返回字符串 "(ns = X) Y",其中 X 是 ns,即 n `div` 2 的数字,而 Y 是内部 do 块的结果 n'
  3. 最后,我们从整个函数退出,结果是字符串 "Answer: Z",其中 Z 是我们从 callCC do 块得到的字符串。

示例:异常

[编辑 | 编辑源代码]

延续的一个用途是模拟异常。为此,我们保留两个延续:一个用于在发生异常时将我们带到处理程序,另一个用于在成功时将我们带到处理程序后的代码。这是一个简单的函数,它接受两个数字并在它们之间进行整数除法,当除数为零时会失败。

示例: 抛出异常的 div

divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        when (y == 0) $ notOk "Denominator 0"
        ok $ x `div` y
    handler err

{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: Denominator 0
-}

它是如何工作的?我们使用两次嵌套的 callCC 调用。第一个标记一个延续,当没有问题时将使用它。第二个标记一个延续,当我们想要抛出异常时将使用它。如果除数不为 0,x `div` y 将被抛入 ok 延续,因此执行会直接跳回到 divExcpt 的顶层。然而,如果我们传递了一个零除数,我们将错误消息抛入 notOk 延续,这将把我们弹出到内部 do 块中,该字符串将被分配给 err 并传递给 handler

一个更通用的处理异常的方法可以通过以下函数看到。将计算作为第一个参数传递(更准确地说,是一个函数,它接受一个抛出异常的函数并导致计算),并将错误处理程序作为第二个参数传递。这个例子利用了泛型 MonadCont[4],它默认情况下涵盖 Cont 和相应的 ContT 变换器,以及任何其他实例化它的延续单子。

示例: 使用延续的通用 try

import Control.Monad.Cont

tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        x <- c notOk
        ok x
    h err

这是我们的 try 的实际应用

示例: 使用 try

data SqrtException = LessThanZero deriving (Show, Eq)

sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do 
    ln <- lift (putStr "Enter a number to sqrt: " >> readLn)
    when (ln < 0) (throw LessThanZero)
    lift $ print (sqrt ln)

main = runContT (tryCont sqrtIO (lift . print)) return

在这个例子中,抛出错误意味着从一个封闭的 callCC 中跳出。sqrtIO 中的 throw 会跳出 tryCont 的内部 callCC

示例:协程

[编辑 | 编辑源代码]

在本节中,我们创建了一个 CoroutineT 单子,它提供了一个具有 fork(它将一个新的挂起协程入队)和 yield(它挂起当前线程)的单子。

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- We use GeneralizedNewtypeDeriving to avoid boilerplate. As of GHC 7.8, it is safe.

import Control.Applicative
import Control.Monad.Cont
import Control.Monad.State

-- The CoroutineT monad is just ContT stacked with a StateT containing the suspended coroutines.
newtype CoroutineT r m a = CoroutineT {runCoroutineT' :: ContT r (StateT [CoroutineT r m ()] m) a}
    deriving (Functor,Applicative,Monad,MonadCont,MonadIO)

-- Used to manipulate the coroutine queue.
getCCs :: Monad m => CoroutineT r m [CoroutineT r m ()]
getCCs = CoroutineT $ lift get

putCCs :: Monad m => [CoroutineT r m ()] -> CoroutineT r m ()
putCCs = CoroutineT . lift . put

-- Pop and push coroutines to the queue.
dequeue :: Monad m => CoroutineT r m ()
dequeue = do
    current_ccs <- getCCs
    case current_ccs of
        [] -> return ()
        (p:ps) -> do
            putCCs ps
            p

queue :: Monad m => CoroutineT r m () -> CoroutineT r m ()
queue p = do
    ccs <- getCCs
    putCCs (ccs++[p])

-- The interface.
yield :: Monad m => CoroutineT r m ()
yield = callCC $ \k -> do
    queue (k ())
    dequeue

fork :: Monad m => CoroutineT r m () -> CoroutineT r m ()
fork p = callCC $ \k -> do
    queue (k ())
    p
    dequeue

-- Exhaust passes control to suspended coroutines repeatedly until there isn't any left.
exhaust :: Monad m => CoroutineT r m ()
exhaust = do
    exhausted <- null <$> getCCs
    if not exhausted
        then yield >> exhaust
        else return ()

-- Runs the coroutines in the base monad.
runCoroutineT :: Monad m => CoroutineT r m r -> m r
runCoroutineT = flip evalStateT [] . flip runContT return . runCoroutineT' . (<* exhaust)

一些示例用法

printOne n = do
    liftIO (print n)
    yield

example = runCoroutineT $ do
    fork $ replicateM_ 3 (printOne 3)
    fork $ replicateM_ 4 (printOne 4)
    replicateM_ 2 (printOne 2)

输出

3
4
3
2
4
3
2
4
4

示例:实现模式匹配

[编辑 | 编辑源代码]

CPS 函数的一个有趣的用法是实现我们自己的模式匹配。我们将通过一些示例来说明如何做到这一点。

示例: 内置模式匹配

check :: Bool -> String
check b = case b of
    True  -> "It's True"
    False -> "It's False"

现在我们已经学习了 CPS,我们可以像这样重构代码。

示例: CPS 中的模式匹配

type BoolCPS r = r -> r -> r

true :: BoolCPS r
true x _ = x

false :: BoolCPS r
false _ x = x

check :: BoolCPS String -> String
check b = b "It's True" "It's False"
*Main> check true
"It's True"
*Main> check false
"It's False"

这里发生的事情是,我们不是使用普通的值,而是使用函数来表示 TrueFalse,这些函数会选择它们被传递的第一个或第二个参数。由于 truefalse 的行为不同,我们可以实现与模式匹配相同的效果。此外,TrueFalsetruefalse 可以通过 \b -> b True False\b -> if b then true else false 来相互转换。

我们应该看看这个更复杂的例子是如何与 CPS 相关的。

示例: 更复杂的模式匹配及其 CPS 等价物

data Foobar = Zero | One Int | Two Int Int

type FoobarCPS r = r -> (Int -> r) -> (Int -> Int -> r) -> r

zero :: FoobarCPS r
zero x _ _ = x

one :: Int -> FoobarCPS r
one x _ f _ = f x

two :: Int -> Int -> FoobarCPS r
two x y _ _ f = f x y

fun :: Foobar -> Int
fun x = case x of
    Zero -> 0
    One a -> a + 1
    Two a b -> a + b + 2

funCPS :: FoobarCPS Int -> Int
funCPS x = x 0 (+1) (\a b -> a + b + 2)
*Main> fun Zero
0
*Main> fun $ One 3
4
*Main> fun $ Two 3 4
9
*Main> funCPS zero
0
*Main> funCPS $ one 3
4
*Main> funCPS $ two 3 4
9

与前面的例子类似,我们使用函数来表示值。这些函数值选择它们被传递的相应(即匹配)延续,并将存储在它们中的值传递给后者。有趣的是,这个过程不涉及任何比较。如我们所知,模式匹配可以对不是 Eq 实例的类型进行操作:函数值“知道”它们的模式是什么,并将自动选择正确的延续。如果这是从外部完成的,例如,通过一个 pattern_match :: [(pattern, result)] -> value -> result 函数,它将不得不检查和比较模式和值以查看它们是否匹配——因此需要 Eq 实例。

注释

  1. 也就是说,\x -> ($ x),完全写出来是 \x -> \k -> k x
  2. attoparsec 是 CPS 在性能驱动方面的应用的一个例子。
  3. 除了验证单子定律是否成立之外,这留给读者作为练习。
  4. 位于 mtl 包中,模块 Control.Monad.Cont
华夏公益教科书