Scalaでモナド

前回(第11回)の名古屋scala勉強会の前半はfor式を肴にモナド三昧。『ふつうのHaskell』を発売直後*1に買ったものの途中で投げ出した自分はやっぱりチンプンカンプン。でもまぁチンプンカンプンなりにいろいろ教えてもらったので、それを頼りになんとかモナドのさわりだけでも学んでみようかと。


教科書はこちら。
モナドのすべて

残念ながらというか、当たり前というかコード例がHaskell。自分のようにHaskellの基本的な文法さえ記憶の彼方という人や、Haskell知らないという人にはとりあえずここ。
Haskell基礎文法最速マスター

これで『モナドのすべて』にあるHaskellコードはなんとなく理解できるようになるはず。

モナドの定義

定義だけならいたって簡単。

「returnとbind(>>=)が定義されていて、モナド則を満たす」

これだけ。ScalaだとOptionやListがモナドの例として挙げられているのでこれらが上の定義を満たしているかみてみることにする。

returnとbind

Haskellでのreturnとbindは、Scalaではそれぞれ単一引数コンストラクタとflatMapが対応するとのこと。

returnとbindのシグネチャはこう。

-- モナド m の型
data m a = ... 

-- return はモナドのインスタンスを作る型構築子 
return :: a -> m a

-- bind はモナドのインスタンス m a と、 
-- a から別のモナドのインスタンス m b を作る計算と
-- を組み合わせて新しいモナドのインスタンス m b を作る
(>>=) :: m a -> (a -> m b) -> m b

returnはそれぞれOptionならSome関数が、ListならList関数が相当するのはなんとなくわかる。
bindの定義がOptionやListのflatMapと対応するか確認してみる。

// Option
sealed abstract class Option[+A] extends Product {
  def flatMap[B](f: A => Option[B]): Option[B] = ...
}

// List(TraversableLike)
trait TraversableLike[+A, +Repr] extends HasNewBuilder[A, Repr] with TraversableOnce[A] { 
  def flatMap[B, That](f: A => Traversable[B])(implicit bf: CanBuildFrom[Repr, B, That]): That = ...
}

bindの第一引数(m a)をレシーバと考えればOptionのflatMapはbindそのもの。Listはなんかいろいろ付いてるけどコードを追ってみると、こちらも同じようなものと見てよさそうだ。

モナド

次にモナド則。3つあってHaskellではこう

(return x) >>= f == f x
m >>= return == m
(m >>= f) >>= g == m >>= (\x -> f x >>= g)

同様にOptionとListがモナド則を満たしているか見てみる。証明はできないので具体的な値で確認する。

1番目の式

Optionから。上に書いたとおりSomeがreturnに相当。

scala> def f(x:Int) = Some(x*2)
f: (x: Int)Some[Int]

scala> Some(1).flatMap(f)      
res6: Option[Int] = Some(2)

scala> f(1)              
res7: Some[Int] = Some(2)

なんでflatMapのほうはOption[Int]型でf(1)はSome[Int]型なのかわからないけど、とりあえず同じ値Some(2)が返る。

Listでも確かめてみる。同じくListがreturnに相当。

scala> def f(x:Int) = List(x*2)
f: (x: Int)List[Int]

scala> List(1).flatMap(f)        
res10: List[Int] = List(2)

scala> f(1)              
res11: List[Int] = List(2)
2番目の式
scala> Some(1).flatMap(Some(_)) == Some(1)
res13: Boolean = true

scala> List(1).flatMap(List(_)) == List(1)
res14: Boolean = true

蛇足だけどflatMapじゃなくてmapだとこうなってしまう。

scala> List(1).map(List(_))
res2: List[List[Int]] = List(List(1))
3番目の式
scala> def f(x:Int) = List(x+1)
f: (x: Int)List[Int]

scala> def g(x:Int) = List(x*3)
g: (x: Int)List[Int]

scala> List(1).flatMap(f).flatMap(g)
res15: List[Int] = List(6)

scala> List(1).flatMap(f(_).flatMap(g))     
res16: List[Int] = List(6)

ちょっと分かりにくけど、flatMapが連続で呼ばれている(3つめの式)のと、flatMapが入れ子になっている(4つめの式)のが同じ値を返している。

MonadPlus

モナドがさらに以下を満たすmzeroとmplusを持つとMonadPlusと呼ばれるらしい。
(ただ、『モナドのすべて』ではMonadPlusについて、ここ以外ではあまり出てこないためMonadPlusだとどう嬉しいのかといったことはモナド以上によくわからなかった)

mzero >>= f == mzero
m >>= (\x -> mzero) == mzero
mzero `mplus` m == m
m `mplus` mzero == m

Optionの場合はmzeroがNone、mplusがorElseに相当するとみなせばMaybeと同じようにMonadPlusといえる(はず)。
これも具体的に確かめてみる。

scala> val a = Some(1)
a: Some[Int] = Some(1)

scala> None.flatMap(x => Some(1))
res20: Option[Int] = None

scala> a.flatMap(x => None)
res21: Option[Nothing] = None

scala> None orElse a
res22: Option[Int] = Some(1)

scala> a orElse None
res23: Option[Int] = Some(1)

Listの場合は空リストがmzero、++がmplusになる。

ScalaでStateモナド

モナドの定義はわかったので、今度はScalaで実際に実装してみる。実装したのはStateモナド

まずHaskellでの定義から

newtype State s a = State { runState :: (s -> (a,s)) } 
 
instance Monad (State s) where 
    return a        = State $ \s -> (a,s)
    (State x) >>= f = State $ \s -> let (v,s') = x s in runState (f v) s' 

class MonadState m s | m -> s where 
    get :: m s
    put :: s -> m ()

instance MonadState (State s) s where 
    get   = State $ \s -> (s,s) 
    put s = State $ \_ -> ((),s) 

Scalaで定義するとこんな感じになるはず。

object State {
  def get[S]:State[S,S] = State[S,S]( (s:S) => (s,s))
  def put[S](s:S):State[S,Unit] = State[S,Unit]( _ => ((),s))

  def mkState[S,A](a:A) = State((s:S) => (a, s)) // return

}

case class State[S, A](runState: S => (A, S)) {
  import State._
  def flatMap[B](f: A => State[S, B]): State[S,B] = State { (s: S) =>
    runState(s) match { case (v, sd) => f(v).runState(sd) }
  }

  def map[B](f: A=>B):State[S,B] = flatMap((a:A) => mkState[S,B](f(a)))
}

putの定義にでてくる()はUnit値。Scalaにもちゃんとある。コップ本で見かけたときいったいいつ使うのかと思っていたけど、こんなところで再会。

for文で利用するためにはflatMapだけでなくmapも必要。標準的なmapはflatMapで実装できる。

def map[A,B](f: A=>B) = flatMap(x => new C(f(x))) //Cは自分自身のコンストラクタ

これは名古屋Scala勉強会で教えてもらった。


実際に使ってみる。『モナドのすべて』ではRandomを使っていたけどScalaのRandomはimmutableじゃないので面白くない。
簡単ではあるけどList[Int]を状態とみなして使ってみる。

object ListState{
  import State._
  def getN(n:Int) = for ( st1 <- get[List[Int]];
		    _ <- put(st1 ++ List(n*2))) yield n*2

  def mkListValueST = (for( n1 <- getN(1);
		   n2 <- getN(n1)) yield (n1, n2)).runState(_)

  def main(args:Array[String]){
    println(mkListValueST(List(0)))
  }
}

実行してみる。

$ scala ListState
((2,4),List(0, 2, 4))

immutableであるList[Int]をあたかもmutableであるかのように引き回せるのは面白い。

最後に

OptionやListがモナドであることを見て、Stateモナドを実装してみた。
ここまで学んでみての感想は、名古屋Scala勉強会でも誰かが言ってた気がするけど
知っているに越したことはないがScalaに必須とは思えないという感じかな。

ただし、『モナドのすべて』の第Ⅲ部はまだ読んでないので、読むとまた違った感想になるかも。

*1:奥付見たらもう4年も前のことらしく衝撃を受けた