<insert original title here>

Programming, Haskell, and games

A State Monad Tutorial

with 2 comments

Code that is intended to be typed in a ghci session or output from ghci is contained in a

green background.

Code that is intended to be loaded from a source file by ghci or other Haskell code is contained in a

blue background.

Monads tend to be a difficult concept for beginners to understand. In this guide, I try a different approach by explaining non-monadic values and building from there. If you want to understand monads, please don’t breeze through this tutorial, as you will need to think to understand them.

I will first explain the State type. Let us import it and look at its type:

> :m Control.Monad.State.Strict

For performance reasons, we will almost always want to use the strict version. Often, the functionality of the lazy version will only be useful when we do some debugging.

> :i State

 

newtype State s a = State {runState :: s -> (a, s)}

There should be more lines, but for the purposes of this tutorial, we only care about the first line, which tells us the type of State constructor.
Even though State is defined in a sub-module of Control.Monad, we will not be using its monadic properties yet.

In case you’re confused about the type, we’ll go over it. To make things less confusing, we’ll give the data constructor a different name from the type constructor, for now:

newtype StateType s a = State {runState :: s -> (a, s)}

As we don’t need the functionality or flexibility of “data”, we use newtype to avoid the overhead “data” would give us, although data would work too. If it helps you more clearly understand it, read “newtype” as “data”.

It looks like State wraps a function. Yes, that is what it does. It just wraps a function that takes an initial thing of type s, and returns a tuple that contains some other value that can be pretty much anything and a new value of the same type s.

Let us review the brackets. All it is is record syntax, which eliminates the necessity of pattern matching by providing us with convenient functions automatically. Here’s an example of how Maybe could be written with record syntax:

data AnotherMaybe a = Nothing | Just {fromJust :: a}
-- fromJust is actually manually defined in Data.Maybe; see
-- below example of how to write such a function that unwraps
-- something
data MaybeButOnlyMaybeInt = IntNothing | IntJust {fromIntJust :: Int}
    -- MaybeButOnlyMaybeInt is without type variables

State could be written without record syntax:

newtype StateTypeWithoutRecordSyntax s a = State (s -> (a, s))

But since we frequently need to unwrap State, we like being able to call runState instead of doing pattern matching. But, even if we don’t use record syntax, we can still define runState ourselves:

runStateWithoutRecordSyntax :: StateTypeWithoutRecordSyntax -> (s -> (a, s))
runStateWithoutRecordSyntax :: StateTypeWithoutRecordSyntax -> s -> (a, s) — these two type signatures are the same
runStateWithoutRecordSyntax (State function) = function

Here we define a type for a simple counter:

data Counter = Counter {value :: Integer, numberOfIncrements :: Int} deriving (Show)

We can define a counter whose value is 42 and whose value has been incremented 3 times like this in ghci:

> let ourCounter = Counter 42 3

Or, if you prefer the other syntax:

> let ourCounter = Counter {value = 42, numberOfIncrements = 3}

On a side note, we can construct a new counter by modifying an existing counter:

> ourCounter{value = 24}
Counter {value = 24, numberOfIncrements = 3}

Let us try writing a State function. Start a text editor, and just write the import statement and one function. We have an example of one such function here, but, for experience, you should write your own:

incrementCounter :: Counter -> (Int, Counter)
incrementCounter (Counter { value = oldValue
                          , numberOfIncrements = oldNumOfIncrements
                          }) =
    (oldValue, Counter { value = succ oldValue
                       , numberOfIncrements = succ oldNumOfIncrements
                       })  -- increment a counter, and return the value of the old
                           -- counter along with another (the new) counter.

(The indentation is a bit odd here; I had to make room)

Really, it’s not very difficult; just write it!

incrementCounter is a function that takes takes an initial counter, and returns the value of the initial counter and an incremented counter, both contained in a tuple. It’s important that you understand what the record syntax does here: oldValue and oldNumOfIncrements are set to the value and the number of increments of the initial counter, and the latter usage is simply constructing a new counter.

In case you have not yet been exposed to the “succ” function, which, in this case, returns the successor to our Int; know that we also could have replaced “succ” with “+ 1”:

incrementCounter :: Counter -> (Integer, Counter)
incrementCounter (Counter { value = oldValue
                          , numberOfIncrements = oldNumOfIncrements
                          }) =
                     ( oldValue
                     , Counter { value = oldValue + 1
                               , numberOfIncrements = oldNumOfIncrements + 1
                               }
                     )  -- increment a counter, and return the value of the old counter along with another (the new) counter.

Our state type is Counter, and our last result type is Int. I don’t consider that function to be a State function since we haven’t wrapped it in the State constructor yet. Let’s do so:

incrementCounterState :: StateType Counter Integer  -- in real Haskell code, we would write "incrementCounter :: State Integer Integer"
incrementCounterState = State incrementCounter
    where incrementCounter :: Counter -> (Integer, Counter)
          incrementCounter (Counter { value = oldValue,
                                    , numberOfIncrements = oldNumOfIncrements}
                                    }) =
                               (oldValue, Counter { value = succ oldValue
                                                  , numberOfIncrements = succ oldNumOfIncrements
                                                  })  -- increment a counter, and return the value of the old counter along with another (the new) counter.

The only difference between incrementCounterState and incrementCounter is that incrementCounterState is wrapped in the State constructor, and in order to get our State function, we just call runState on incrementCounterState.

> :t incrementCounterState
StateType Integer Int
> :t runState incrementCounterState
Counter -> (Int, Counter)

As you have noticed, runState just unwraps the function.

At this point, I’ve stopped calling the State type constructor StateType.

Let us try to increment ourCounter using our new wrapped state computation, incrementCounterState:

> incrementCounterState ourCounter
:1:0:
Couldn't match expected type `Counter -> t'
against inferred type `State Counter Integer'
In the expression: incrementCounterState ourCounter
In the definition of `it': it = incrementCounterState ourCounter

Uh oh! Can you figure out why this is happening? GHCi wants (expects) incrementCounterState to be a function (that takes a Counter and returns something that GHCi doesn’t care what it is), but incrementCounterState is really of type “StateType Counter Integer”. We forgot to unwrap the value contained in the State value! Since incrementCounterState is really just a value that contains a function, we’ll unwrap that value to get that function by calling runState. Let’s see what the type of runState incrementCounterState is:

> :t runState incrementCounterState
runState incrementCounterState :: Counter -> (Integer, Counter)

OK. Let us try this again.

> runState incrementCounterState $ ourCounter
(42,Counter {value = 43, numberOfIncrements = 4})

This can also be written without the ‘$’:

> runState incrementCounterState ourCounter
(42,Counter {value = 43, numberOfIncrements = 4})

Hooray! it worked! We incremented our counter and retrieved the value of the first counter!

Right, so, States just pretty much wraps a function that takes some value, which we call the initial state, and returns a value of *any* type and another “state” that needs to have the same type as the initial state together in a tuple.

Up to this point, we have not used any of the monadic properties of State. Before we step there, let’s first exercise our skills with State a bit.

How about a function that doubles the value of a counter and increments its numberOfIncrements? Can you write that?

Up to this point, we have not used any of the monadic properties of State. Let’s jump right in! For now, I won’t worry you with do notation or the usage of “gets”, “put”, “>>”, and the like; we will only focus on bind (>>=) and return.

What is (>>=)? Let us look at its type:

:t (>>=)
(>>=) :: (Monad m) => m a -> (a -> m b) -> m b

It takes a monad/value (which is, in our case, a State computation / function that is wrapped with the State constructor), and a function, and returns a monad that “bound” those two together. The function just takes an unwrapped value and returns a new monad. It depends on the monad for how values are unwrapped. It is up to the monad to define a way for bind to unwrap the value of the first value.

Let’s put this in the context of State:

(>>=) :: (State s) a -> (a -> (State s) b) -> (State s) b

Which can be written as

(>>=) :: State s a -> (a -> State s b) -> State s b

Notice how we just replaced m with (State s).

So, (>>=) pretty much combines two monads together. How does this work for State? Let’s take a look at its implementation along with return:
Since the bind function (and return) are functions of the Monad class, the definition needs to begin with an “instance” line:

instance Monad (State s) where
    return a = State $ \s -> (a, s)
    m >>= k = State $ \s -> let (a, s') = runState m s
in  runState (k a) s'

But, yikes! That piece of code can look confusing, because it certainly confused me! Well, it takes one State, which can be incrementCounterState, and a function. What does that function do, exactly? We know it needs to give us the final State computation, which is a function that takes an initial state and returns a value and a new state. At some point, you should think about State’s bind and think about how it works.

We already have a State computation, incrementCounterState, which increments a counter. Can we write a function that increments a counter twice? Sure! Besides the obvious method of writing a new state computation and using succ twice, we can combine incrementCounterState:

incrementCounterTwice :: State Counter Int
-- we remember that incrementCounterState has the same type as incrementCounterTwice
incrementCounterTwice = incrementCounterState >>=
                        (\theValueOfTheOriginalCounter -> incrementCounterState)
incrementCounterTwice' counter = State $ (runState (incrementCounterState >>=
                                                    (\theValueOfTheOriginalCounter ->
                                                         incrementCounterState)) $
                                                       counter)
-- incrementCounterTwice' is another way of writing incrementCounterTwice

(Again, the indentation of incrementCounterTwice’ is a bit awkward, so you might want to pay attention to the parethesis. As long as you understand the first State computation, it is not necessary that you understand the other one)

So, using (>>=), we were able to combine two States (or State computations), which happen to be incrementCounterState, and end up with another monad. In this example, we ignored theValueOfTheOriginalCounter, although we could have used it for whatever purpose.

This our first time using bind, so let’s consider take a good look at what we’re doing. By binding incrementCounterTwice with a function, which ignores the result of incrementCounter when it is applied and returns a State computation, State, we end up with another State.

We can even combine incrementCounterTwice and another State computation:

incrementCounterThrice :: State Counter Int
incrementCounterThrice = incrementCounterTwice >>= (\_ -> incrementCounter)

That is why it is possible to chain combinations:

incrementCounterALotOfTimes :: State Counter Int
incrementCounterALotOfTimes = incrementCounterTwice >>=
                              (\_ -> incrementCounterTwice >>=
                              (\_ -> incrementCounterTwice))

We can make it look even nicer by omitting the redundant parenthesis

incrementCounterALotOfTimes :: State Counter Int
incrementCounterALotOfTimes = incrementCounterTwice >>=
                              \_ -> incrementCounterTwice >>=
                              \_ -> incrementCounterTwice)

Keep in mind that the return value of chained functions don’t all have to be of the same type in a combined monad. Consider

aCombinedThing :: State GameState String
aCombinedThing = foo >>= bar >>= quux

In this monad, a value of type GameState is being threaded through. foo can be of type (State GameState a), where a can be anything, as long as it is the same type that bar takes as the first parameter. Bar would be of type (a -> State GameState b). The return type of bar can differ from the return type of foo. bar, when given an initial unwrapped state value, would return a returned value and a new state. quux would be called with the returned value, and then return a State computation. But, quux needs to be of type State GameState String, since the returned value type of aCombinedThing needs to be the same as the returned value type of quux.

In case you’re confused about how this works, think of it as a wrapped function that takes an initial unwrapped state value, passes it onto foo, then foo returns the new state and passes that state onto bar, and then the new state of that is passed to quux. Take that, but have the return value be passed to the each function after foo.

What is this “>>” function? It’s actually quite simple, as it behaves just like >>= only it doesn’t pass the returned value of the first monad. (This function is actually part of the Monad class, but its default definition should never be changed most of the time. There is actually one more function in the Monad class, known as fail, but it’s pretty unpopular, and it’s commonly suggested not to use it, so I won’t explain it). It can be defined as:
a >> b = a >>= \_ -> b

Let’s use it:

incrementCounterALotOfTimes'  :: State Counter Int
incrementCounterALotOfTimes'  = incrementCounterTwice >> incrementCounterTwice
incrementCounterEvenMoreTimes = incrementCounterTwice >>
                                incrementCounterTwice >>
                                incrementCounterTwice

Much nicer!

Let’s use it in one of our previous examples of using combined State:

aCombinedThing :: State GameState String
aCombinedThing = foo >> bar >> quux

This will, when passed an initial state value, pass the state onto foo, discard foo’s returned value, but take the new state and give it to bar, discard bar’s returned value, and pass bar’s new state to quux. quux returns a returned value of type String and a new GameState. Sometimes, State will not even need a return value, or it won’t make sense; we can use Haskell’s void, known as unit, which is written as ().

aCombinedThing :: State GameState ()
aCombinedThing = foo >> bar >> quux

In this case, we really don’t care about the returned values of any of these. Since the returned values of foo and bar are discarded (even if any are “()”), they can be of any type, but since the type of aCombinedThing is State GameState (), the returned value type of quux needs to be (), so quux is of type State GameState () too.

How about a function that increments a counter twice, but uses the intermediate value (the returned value after the initial counter was incremented once, which is the initial value of the initial counter because of how we defined our State computations) to do something, perhaps constructing a GameState from that value at the end? We’ll do something even simpler: we will write a State computation that increments a counter twice; the return value of this computation is going to be the string representation of the intermediate counter value.

kindOfIncrementCounter :: State Counter String
kindOfIncrementCounter = incrementCounter >>=
                         \initialCounterValue -> incrementCounter >>=
                         \_ -> State $ (\s -> (show initialCounterValue, s))

Notice that the type of the returned value of incrementCounter is not String, but is Int, but that’s OK because it’s just passed to our lambda function, which assigns the name “initialCounterValue” to it. When the state finally reaches our last State computation, the final State computation does need to return a returned value of type String, which it does.

So we see that kindOfIncrementCounter is a State computation that, when “run” with an initial state, will increment it twice and return it and the string representation of the initial counter value. initialCounterValue points to the returned value of incrementCounter after it is called; remember that we defined incrementCounter to be the first value of the counter it takes?

What is return? Return turns a value into a monad. It wraps a value into a monad, lifts a value into a monad, raises a value into a monad, or injects a value into a monad; depending on your preferred terminology. Which State computations does “return 3” produce? Let’s see:

> return [“grass”, “potato”] :: State Counter [String]
No instance for (Show (State Counter [String]))

Hmm. Let’s try deducing its functionality by applying it to a Counter.

> let a = return [“grass”, “potato”] :: State Counter [String]
> :t a
a :: State Counter [String]
> :t runState a
runState a :: Counter -> ([String], Counter)
> (runState a) ourCounter
([“grass”,”potato”],Counter {value = 42, numberOfIncrements = 3})
> runState a ourCounter
([“grass”,”potato”],Counter {value = 42, numberOfIncrements = 3})

It seems that “return x” creates a State computation that leaves the state unchanged (the state might be passed to a later state computation) and returns it with the returned value of x.

How can we be sure? By looking at how return is defined for the State monad, of course!

instance Monad (State s) where
    return a = State $ \s -> (a, s)
    …

We were right! “return x” wraps a function that returns whatever is passed to it in addition to x under the State constructor.

Now, let’s look at three useful functions and how they’re defined for the State monad, before we learn about “do” notation.

get :: State s s
get = State $ \s -> (s, s)

Keep in mind there’s nothing magical about these. Even though they are indirectly defined in Control.Monad.State.Class, which is exported by Control.Monad.State.Lazy and Control.Monad.State.Strict; we can still purely define them ourselves, perhaps as I have done here.

get takes a state, leaves the state unchanged, and also returns that state as the returned value. This is useful because we can pass the current state to the next function. Because we can pass the current state to the next function, we can also use that state later.

Consider writing a State computation that can increment a counter, halve its value, increment the counter, and return the value of the counter after the first increment (along with the final state). We can do this without any helper functions:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter =
    (      State $ \counter -> (_, counter{ value = succ $
                                                   value counter
                                             , numberOfIncrements = succ $
                                                   numberOfIncrements counter
                                             })) >>=
    \_ -> (State $ \s -> (s, s)) >>=
    \thatOneState -> (State $
                          \counter -> (_, counter{ value = value counter `div` 2
                                                 , numberOfIncrements = succ $
                                                       numberOfIncrements counter
                                                 })) >>=  -- we cheat a bit here by
                                                          -- incrementing the number of
                                                          -- the increments of the
                                                          -- counter ever though we are
                                                          -- really halving its value.
    \_ -> (State $
               \counter -> (value thatOneState, counter{ value = succ $
                                                              value counter
                                                        , numberOfIncrements = succ $
                                                              numberOfIncrements counter
                                                        }))

Sure, we could have avoided referring to a previous state by using the returned value of incrementCounter, which is all that we want out of thatOneState, instead of our own State computation here, which does the very same thing except it doesn’t return anything useful; but for the intents of this part of the tutorial, we’ll pretend that we never defined incrementCounter to return that.

Let’s simplify the code by substituting our common code by a simple value (function wrapped in State constructor):

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter = incrementCounter >>=
                                    \_ -> (State $ \s -> (s, s)) >>=
                                    \thatOneState -> (State $ \counter -> (_, counter{value = value counter `div` 2, numberOfIncrements = succ $ numberOfIncrements counter})) >>=
                                    \_ -> incrementState >>=
                                    \_ -> return $ value thatOneState  -- remember that return doesn't affect the state; it creates a State computation that returns the specified value along with the unchanged state

Even though, unlike our previous code, incrementCounter does return a meaningful value, the value of the initial counter; the functionality is not different, since the return value is discarded. Do you remember a function that discards the unwrapped value of a monad? Yes, (>>)! Let’s try that:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter = incrementCounter >>
                                    (State $ \s -> (s, s)) >>=
                                    \thatOneState -> (State $ \counter -> (_, counter{value = value counter `div` 2, numberOfIncrements = succ $ numberOfIncrements counter})) >>
                                    incrementState >>
                                    return $ value thatOneState  -- remember that return doesn't affect the state; it creates a State computation that returns the specified value along with the unchanged state

Let’s use get:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter = incrementCounter >>
                                    get >>=
                                    \thatOneState -> (State $ \counter -> (_, counter{value = value counter `div` 2, numberOfIncrements = succ $ numberOfIncrements counter})) >>
                                    incrementState >>
                                    return $ value thatOneState  -- remember that return doesn't affect the state; it creates a State computation that returns the specified value along with the unchanged state

On top of get, we have even another very useful function, gets:

gets :: (s -> a) -> State s a
gets f = State $ \s -> (f s, s)

gets takes a function, and behaves like get, only it applies the function to the state. This is especially useful when our state is a type defined with record syntax, as we can extract certain things contained in the state, such as Counter. Let’s use gets instead of manually applying “value” to get the value out of thatOneState:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter = incrementCounter >>
                                    gets value >>=
                                    \thatOneValue -> (State $ \counter -> (_, counter{value = value counter `div` 2, numberOfIncrements = succ $ numberOfIncrements counter})) >>
                                    incrementState >>
                                    return thatOneValue  -- remember that return doesn't affect the state; it creates a State computation that returns the specified value along with the unchanged state

We should also look at modify:

modify :: (s -> s) -> State s ()
modify f = State $ \s -> (_, )

modify takes a function and returns a State computation that takes an initial state, and returns a tuple of nothing/void/unit/() and the state applied to the function. Modify modifies the state. Since nothing really meaningful can come out of this, the returned value of this computation is unit (“()”), to which the only response that would make sense would be ignoring it, as only one value is of type “()”.

We have a State computation that does what modify does, so we can simplify our function further:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter =
    incrementCounter >>
    gets value >>=
    \thatOneValue -> (modify $
                             \counter -> counter{ value =
                                                         value counter `div` 2
                                                   , numberOfIncrements = succ $
                                                         numberOfIncrements counter
                                                   }) >>
    incrementState >>
    return thatOneValue  -- remember that return doesn't affect the state; it creates a State computation that returns the specified value along with the unchanged state

Finally, let’s look at put.

put :: s -> State s ()
put s = State $ \_ -> ((), s)

In case it isn’t obvious, put is a function that takes a state and returns a State computation, that takes an initial state, ignores it, and sets a new state that will continue in the chain or thread of State computations. Since no return value would really make sense, it doesn’t have one.

It seems that there are common patterns with Monads. This is indeed the case, and Haskell acknowledges this by supplying the programmer with the convenience known as do notation.

do blocks are simply “syntactic sugar”, meaning that they are directly translated into monads using (>>=) and (>>). The translation is simple:

do foo
   bar

translates to

foo >>
bar

What about (>>=)?

Well,

do returnedValue <- foo
   bar

translates to

foo >>=
\returnedValue -> bar

Let’s try simplifying our State computation, incrementHalveAndIncrementCounter, using do notation:

incrementHalveAndIncrementCounter :: State Counter Integer
incrementHalveAndIncrementCounter = do
    incrementCounter
    thatOneValue <- gets value
    counter{ value = value counter `div` 2
           , numberOfIncrements = succ $ numberOfIncrements counter
           }
    incrementState
    return thatOneValue

Much, much nicer! We have simplified our code using (>>), helper functions, and do notation to the point at which it is effortlessly readable by a person who has not written our code.

My goal is to teach others to help them understand the State monad. I hope that a more “conversational” style of writing has helped me achieve my goal.

By now, you should be able to figure out how the State monad works. After this point, you should be able to easily understand the other monads. I hope you will be as overjoyed as I was when I finally understood Monads! Good luck!

Written by bairyn

March 6, 2010 at 17:58

Posted in Uncategorized

2 Responses

Subscribe to comments with RSS.

  1. Hi I thought this tutorial was very useful. I summarized it in OCaml-ese for a philosophy & linguistics course we’re running: http://lambda.jimpryor.net/state_monad_tutorial/

    Jim Pryor

    December 12, 2010 at 17:44

  2. I have tried to follow your tutorial, but ”

    D:\development\HaskellWS\Core\src\StateMonadTutorial.hs:184:38:
    Not in scope: type constructor or class `State’

    Can you please upload the haskell file which compiles?

    mohanmca

    August 28, 2011 at 07:46


Leave a comment