Pages

Saturday 13 December 2014

Demistify the State Monad with Scala 2/2

In the previous post we have explored how to define a Stack data structure and we have defined methods to pull and pop from that. All the methods accept the current stack and return a tuple containing the result of the operation and the new stack. We defined that kind of signature:

type State[S, A] = S => (A, S)

So here is where we are:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
package fp

object StackApp extends App {

  type Stack[A] = List[A]
  type State[S, A] = S => (A, S)
  
  def push[A](a:A) : State[Stack[A], Unit] = stack => ((), a :: stack)
  
  def pop[A]: State[Stack[A], Option[A]] = {
    case a :: tail => (Some(a), tail)
    case Nil       => (None, Nil)
  }
  
  def popPairs[A]: State[Stack[A],(Option[A], Option[A])] = stack => {
    val (opt1, stack1) = pop(stack)
    val (opt2, stack2) = pop(stack1)
    ((opt1, opt2), stack2)
  }
}

Now that we have defined the State our target is to avoid the expressions inside popPairs, where we need carefully to pass the right instance of stack on each expression.

In order to do that we should define methods to compose different state together, for this reason we are going to introduce the 2 key methods: map and flatMap.


def map[S, A, B](sa:State[S, A])(f:A=>B): State[S, B] = ???
def flatMap[S, A, B](sa:State[S, A])(f:A=>State[S, B]): State[S, B] = ???

Ok lets see how we can implement these methods.
Lets start with map. The first thing we have to consider is the return type, has to be a State[S, B], so in essence has to be a function of type:
Given an initial state I have to return a tuple of kind (b, newState)


def map[S, A, B](sa:State[S, A])(f:A=>B): State[S, B] = state => {
  ...
  (b, newState)
}

Lets look the full implementation and comment it:


1
2
3
4
def map[S, A, B](sa:State[S, A])(f:A=>B): State[S, B] = state => {
    val (a, newState) = sa(state)
    (f(a), newState)
  }

We are simply applying the input state to sa, the result of this operation is in line 2 a tuple with result a:A and a newState:S.
We can simply apply f:A=>B to a and return the new tuple (f(a), newState).
The result is a function that take a state as input and return a tuple (b:B, newState:S), respecting the State[S, B] signature.
Note:Is interesting the high level of abstraction we are using in this function. Map is an higher level function because we are not reasoning on the possible implementations and it turns to be more like a game where we are trying to match the types. Is surprising that there are no many possible implementations of map, the type signature is restricting the possible implementations.
Ok lets move now on flatMap. This method is very important for "composibility".
The only difference on the signature compared to map is that the function f is transforming directly into the final type.
Why is this so important ?
Well lets assume we have only map method, and we call map inside another map:


1
2
3
4
5
6
7
map(dummyState) { stack1 => 
    ...
    map(dummyState2) {
      ...
      (b, finalState)
    }
  }

Well by composing 2 map methods one inside the other the final return type will be:

State[S, State[S, B]]

which is not exactly what we wanted. The idea is to flatten the inner State[S, B] to have just:
State[S, B] as final result.

In order to do that we need a new function able to do a map operation and then flatten it, for this reason we call it flatMap:


1
2
3
4
def flatMap[S, A, B](sa:State[S, A])(f:A=>State[S, B]): State[S, B] = state => {
    val (a, newState) = sa(state)
    f(a)(newState)
  }

Here we go, as per map we apply the state to sa, then we take the value a and we pass it to the function f.
This time f will return a new State[S, B], which is not good as return type because we need to return the tuple (b:B, s:State),  but because f(a) : State[S, B] we just can apply the newState:S to get our desired result in line 3.

Now that we have flatMap and map we can use them in the popPairs:


def popPairs[A]: State[Stack[A],(Option[A], Option[A])] =
    flatMap(pop[A]) ( opt1 => map(pop[A]) ( opt2 => (opt1, opt2) ) )

Here we go, with flatMap and map we are manipulating the State object obtained  from the pop call but now we don't have to bother with the states.

Lets try to improve the syntax.
The next step is to promote the State[S, A], from a simple alias to a propert trait with the map and flatMap methods part of it.


1
2
3
4
5
trait State[S, A] {
    def apply(s:S): (A, S)
    def map[B](f:A =>B): State[S, B]
    def flatMap[B](f:A=>State[S, B]): State[S, B]
  }

Now lets define a companion object for State and a factory method, we will also define the map and flatMap implementation.


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
trait State[S, A] {
    def apply(s:S): (A, S)
    def map[B](f:A =>B): State[S, B] = State { s => 
     val (a, newState) = this(s)
     (f(a), newState)
    }
    def flatMap[B](f:A=>State[S, B]): State[S, B] = State { s => 
     val (a, newState) = this(s)
     f(a)(newState)
    }
  }
  
  object State {
    def apply[S, A](r: S => (A, S)): State[S, A] = new State[S, A] {
      def apply(s:S) = r(s)
    }
  }

So now map and flatMap are methods of the State trait, also the companion object State has a factory method that create a new state using a function of type S => (A, S).

Now our code will look like the below:


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package fp

object StackApp extends App {

  type Stack[A] = List[A]
  
  trait State[S, A] {
    def apply(s:S): (A, S)
    def map[B](f:A =>B): State[S, B] = State { s => 
     val (a, newState) = this(s)
     (f(a), newState)
    }
    def flatMap[B](f:A=>State[S, B]): State[S, B] = State { s => 
     val (a, newState) = this(s)
     f(a)(newState)
    }
  }
  
  object State {
    def apply[S, A](r: S => (A, S)): State[S, A] = new State[S, A] {
      def apply(s:S) = r(s)
    }
  }
  
  def push[A](a:A) : State[Stack[A], Unit] = State { stack => ((), a :: stack) }
  
  def pop[A]: State[Stack[A], Option[A]] = State {
    case a :: tail => (Some(a), tail)
    case Nil       => (None, Nil)
  }
  
  def popPairs[A]: State[Stack[A],(Option[A], Option[A])] =
    pop[A].flatMap(opt1 => pop[A].map(opt2 => (opt1, opt2) ))
    
}

Fantastic now our popPairs uses the dot notation and looks more similar to object oriented programming style but is still not optimal.

In scala all the expressions where composibility is used (flatMap of flatMap of .... map) can be translated in a for comprehension expression.


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def popPairs[A]: State[Stack[A],(Option[A], Option[A])] =
  pop[A].flatMap(opt1 => pop[A].map(opt2 => (opt1, opt2) ))

/**
 * IS EQUIVALENT TO :
 */
    
def popPairs[A]: State[Stack[A], (Option[A], Option[A])] = for {
  opt1 <- pop[A]
  opt2 <- pop[A]
} yield (opt1, opt2)

The second version is similar to imperative programming style but is in reality syntactical sugar of the first version.

So finally our code is:


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package fp

object StackApp extends App {

  type Stack[A] = List[A]
  
  trait State[S, A] {
    def apply(s:S): (A, S)
    def map[B](f:A =>B): State[S, B] = State { s => 
     val (a, newState) = this(s)
     (f(a), newState)
    }
    def flatMap[B](f:A=>State[S, B]): State[S, B] = State { s => 
     val (a, newState) = this(s)
     f(a)(newState)
    }
  }
  
  object State {
    def apply[S, A](r: S => (A, S)): State[S, A] = new State[S, A] {
      def apply(s:S) = r(s)
    }
  }
  
  def push[A](a:A) : State[Stack[A], Unit] = State { stack => ((), a :: stack) }
  
  def pop[A]: State[Stack[A], Option[A]] = State {
    case a :: tail => (Some(a), tail)
    case Nil       => (None, Nil)
  }
  
  def popPairs[A]: State[Stack[A],(Option[A], Option[A])] = for {
    opt1 <- pop[A]
    opt2 <- pop[A]
  } yield (opt1, opt2)
  
}

As you can see we are not dealing anymore with the states in popPairs and the code is quiet intuitive and similar on how it would look in imperative style.

The benefit of using the State[S, A] is that none of these functions are manipulating the state, they are actually pushing the state modification to the very up of your stack so all these functions are PURE, this will give to your program an important property: Reference Transparency and in distributed system it can be extended to Location Transparency, very important for system which requires high availability and fault tolerance, aspects that we will cover on the next post, when we will use the State monad to implement a Key Value Store.



3 comments:

  1. How would one actually print the result? Could you please show a println that demonstrates how to access the final value(s)?

    ReplyDelete
    Replies
    1. If you want to stay pure, you should use the IO monad. However, now you need to compose the State monad and the IO monad. To do this, your need monad transformers...

      Delete
  2. You must have created a Stack object somewhere:

    val stackDouble: Stack[Double] = List(1.0, 2.0, 3.0)

    that's the stack you're interacting with:

    val (_, newStack) = push(42.0)(stackDouble)

    and newStack is still a list underneath so you can print it:

    newStack.foreach(println)

    ReplyDelete