首页 文章

如何将可变向量放入状态Monad

提问于
浏览
8

我在haskell中编写了一个小程序,使用State Monad with Vector计算Tree中所有的Int值的出现次数:

import Data.Vector
import Control.Monad.State
import Control.Monad.Identity

data Tree a = Null | Node (Tree a) a (Tree a) deriving Show
main :: IO ()
main = do 
    print $ runTraverse (Node Null 5 Null)


type MyMon a = StateT (Vector Int) Identity a

runTraverse :: Tree Int -> ((),Vector Int)
runTraverse t =  runIdentity (runStateT (traverse t) (Data.Vector.replicate 7 0))

traverse :: Tree Int -> MyMon ()
traverse Null = return ()
traverse (Node l v r) = do
    s <- get
    put (s // [(v, (s ! v) + 1)]) -- s[v] := s[v] + 1
    traverse l
    traverse r
    return ()

但是,不可变向量的“更新”以O(n)复杂度完成 . 我正在寻找O(1)中的更新和O(1)中的访问 . 据我所知,Mutable Vectors做我想要的 . 要使用它们,我需要使用ST或IO . 因为我想做一些UnitTests,我更喜欢ST monad,但我不想在函数调用中传递该向量 . 我需要继续使用Monad变形金刚,因为我将添加像ErrorT和WriterT这样的变换器 .

问题:如何使用Monad变换器将Mutable Vector放入State Monad?

我想出了以下不编译的代码:

import Data.Vector
import Control.Monad.State
import Control.Monad.Identity
import qualified Data.Vector.Mutable as VM
import Control.Monad.ST
import Control.Monad.ST.Trans
type MyMon2 s a = StateT (VM.MVector s Int) (STT s Identity) a

data Tree a = Null | Node (Tree a) a (Tree a) deriving Show
main :: IO ()
main = do 
    print $ runTraverse (Node Null 5 Null)

runTraverse :: Tree Int -> ((),Vector Int)
runTraverse t = runIdentity (Control.Monad.ST.Trans.runST $ do
        emp <- VM.replicate 7 0
        (_,x) <- (runStateT (traverse t) emp)
        v <- Data.Vector.freeze x
        return ((), v)
    )
traverse :: Tree Int -> MyMon2 s ()
traverse Null = return ()
traverse (Node l v r) = do
    d <- get
    a <- (VM.read d v)
    VM.write d v (a + 1)
    put d
    return ()

编译错误是:

TranformersExample: line 16, column 16:
  Couldn't match type `s'
                  with `primitive-0.5.2.1:Control.Monad.Primitive.PrimState
                          (STT s Identity)'
      `s' is a rigid type variable bound by
          a type expected by the context: STT s Identity ((), Vector Int)
          at test/ExecutingTest.hs:15:30
    Expected type: STT s Identity (MVector s Int)
      Actual type: STT
                     s
                     Identity
                     (MVector
                        (primitive-0.5.2.1:Control.Monad.Primitive.PrimState
                           (STT s Identity))
                        Int)
    In the return type of a call of `VM.new'
    In a stmt of a 'do' block: emp <- VM.new 7
    In the second argument of `($)', namely
      `do { emp <- VM.new 7;
            (_, x) <- (runStateT (traverse t) emp);
            v <- freeze x;
            return ((), v) }'
TranformersExample: line 26, column 14:
  Couldn't match type `s'
                  with `primitive-0.5.2.1:Control.Monad.Primitive.PrimState
                          (StateT (MVector s Int) (STT s Identity))'
      `s' is a rigid type variable bound by
          the type signature for traverse :: Tree Int -> MyMon2 s ()
          at test/ExecutingTest.hs:21:13
    Expected type: MVector
                     (primitive-0.5.2.1:Control.Monad.Primitive.PrimState
                        (StateT (MVector s Int) (STT s Identity)))
                     Int
      Actual type: MVector s Int
    In the first argument of `VM.write', namely `d'
    In a stmt of a 'do' block: VM.write d v (a + 1)
    In the expression:
      do { d <- get;
           a <- (VM.read d v);
           VM.write d v (a + 1);
           put d;
           .... }

注意:我知道没有检查边界 .

1 回答

  • 13

    当使用 ST 状态时,你永远不会明确地传递向量(它始终隐藏在 s 参数中),而是对它的引用 . 该引用是不可变的而不是复制的,因此您不需要 State 而只需要一个读者来隐式传递它 .

    import Data.Vector
    import Control.Monad.Reader
    import qualified Data.Vector.Mutable as VM
    import Control.Monad.ST
    
    type MyMon3 s = ReaderT (VM.MVector s Int) (ST s)
    
    data Tree a = Null | Node (Tree a) a (Tree a) deriving Show
    main :: IO ()
    main = do 
        print $ runTraverse (Node Null 5 Null)
    
    runTraverse :: Tree Int -> Vector Int
    runTraverse t = runST $ do
            emp <- VM.replicate 7 0
            runReaderT (traverse t) emp
            Data.Vector.freeze emp
    
    traverse :: Tree Int -> MyMon3 s ()
    traverse Null = return ()
    traverse (Node l v r) = do
        d <- ask
        a <- lift $ VM.read d v
        lift $ VM.write d v (a + 1)
    

相关问题