After the discussion about Monads I decided to try to implement a Monad in Scala, eventually without copying from other Monads. My idea was to write something link Option or Try, so a container of a single, optional, value. I started by defining a trait with all the required methods:

  1. trait MonadTrait[T] {
  2.   def bind[M](func: T => MonadTrait[M]): MonadTrait[M]
  3.   def flatMap[M](func: T => MonadTrait[M]): MonadTrait[M] = {
  4.     println("flatmapping")
  5.     bind(func)
  6.   }
  7.   def map[B](f: T => B): MonadTrait[B] = {
  8.     println("mapping")
  9.     flatMap (x => MonadTrait(f(x)))
  10.   }
  11.   def foreach[U](f: T => U): Unit
  12.   def filter(func: T => Boolean): MonadTrait[T]
  13.   def withFilter(func: MonadTrait[T] => Boolean): MonadTrait[T] = {
  14.     println("invoking withFilter")
  15.     filter(x => func(MonadTrait(x)) )
  16.   }
  17.   def get: T
  18.   def toOption: Option[T]
  19. }

As you can see, the normal derived functions are already defined in the trait. Moreover I added a println in each method to know which one is invoked. Started from this I added two classes, one for the monad with a value and the other with an empty monad. I used companion objects for the apply methods.

  1. class Monad[T] private (unit: T) extends MonadTrait[T] {
  2.   override def bind[M](func: T => MonadTrait[M]): MonadTrait[M] = {
  3.     println(s"$unit: binding")
  4.     func(unit)
  5.   }
  6.   override def foreach[U](f: T => U): Unit = {
  7.     println(s"$unit: invoking foreach")
  8.     f(unit)
  9.   }
  10.   override def filter(func: T => Boolean): MonadTrait[T] = {
  11.     println(s"$unit: filtering")
  12.     if(func(unit)) this else EmptyMonad()
  13.   }
  14.   override def withFilter(func: MonadTrait[T] => Boolean): MonadTrait[T] = {
  15.     println(s"$unit: invoking withFilter")
  16.     filter(x => func(MonadTrait(x)) )
  17.   }
  18.   override def get: T = {
  19.     println(s"$unit: getting value")
  20.     unit
  21.   }
  22.   override def toOption: Option[T] = Option(unit)
  23.  
  24.   override def toString: String = s"Monad(${unit.toString})"
  25. }
  26.  
  27. class EmptyMonad[T] private () extends MonadTrait[T] {
  28.   override def bind[M](func: (T) => MonadTrait[M]): MonadTrait[M] = {
  29.     println("EmptyMonad: binding")
  30.     EmptyMonad()
  31.   }
  32.   override def map[B](f: (T) => B): MonadTrait[B] = {
  33.     println("EmptyMonad: mapping")
  34.     EmptyMonad()
  35.   }
  36.   override def foreach[U](f: T => U): Unit = {
  37.     println("EmptyMonad: foreach")
  38.   }
  39.   override def filter(f: T => Boolean): MonadTrait[T] = {
  40.     println("EmptyMonad: filtering")
  41.     EmptyMonad()
  42.   }
  43.   override def withFilter(func: MonadTrait[T] => Boolean): MonadTrait[T] = {
  44.     println("EmptyMonad: invoking withFilter")
  45.     EmptyMonad()
  46.   }
  47.   override def get: T = {
  48.     println("EmptyMonad: getting value")
  49.     throw new NoSuchFieldError("empty monad")
  50.   }
  51.   override def toOption: Option[T] = None
  52.  
  53.   override def toString: String = "EmptyMonad"
  54. }
  55.  
  56. object MonadTrait {
  57.   def apply[T](unit: T): MonadTrait[T] = {
  58.     Option(unit) match {
  59.       case Some(x) => Monad[T](x)
  60.       case None => EmptyMonad()
  61.     }
  62.   }
  63. }
  64. object EmptyMonad {
  65.   def apply[T](): EmptyMonad[T] = new EmptyMonad[T]()
  66. }
  67. object Monad {
  68.   def apply[T](unit: T): Monad[T] = new Monad[T](unit)
  69. }

Well, now it looked exactly as a Monad. Of course it sounds a bit strange to use a Monad (Option) to check which Monad we want to create, but those details are useless, we want to unleash the power of Monads, not to care about details.

As I was saying in the other post, flatMap is my bind operator. I defined a bind operator and I invoked it in the flatMap only for making it explicit. Map as per definition is relaying on flatMap. We have filters operators, because at the end we want to invoke the Monads by a for-comprehension. The two classes are then implementing the different functions in the expected way, by also logging every method call.

So, let’s create our test. Here the code:

  1. object MonadUsage extends App {
  2.   def stupidFunc[T](monad: MonadTrait[T]): MonadTrait[MonadTrait[T]] = MonadTrait(monad)
  3.   val listMonads = (1 to 10) map ( x => MonadTrait(x) )
  4.   val filteredMonads = for {x <- listMonads
  5.                             if x.get > 5}
  6.                         yield x
  7.   for {a <- filteredMonads
  8.        c <- a} {
  9.     println(c)
  10.   }
  11. }

We don’t use the stupidFunc yet, it was created to force the for-comprehension to flatMap our result. The result of running our test is:

  1. 1: getting value
  2. 2: getting value
  3. [...]
  4. 9: getting value
  5. 10: getting value
  6. 6: invoking foreach
  7. 6
  8. [...]
  9. 10: invoking foreach
  10. 10

So, we have the get printed from line 5 of the test (if x.get > 5) and the invoking foreach of the c <- a. What I did after this is to try force the flatmap. For doing so I was adding a level of nesting with the stupidFunc. The executable becomes

  1. object MonadUsage extends App {
  2.   def stupidFunc[T](monad: MonadTrait[T]): MonadTrait[MonadTrait[T]] = MonadTrait(monad)
  3.   val listMonads = (1 to 10) map ( x => MonadTrait(x) )
  4.   val filteredValues = for {x <- listMonads
  5.                             if x.get > 5}
  6.                         yield x
  7.   val test = for {a <- filteredValues
  8.        b <- stupidFunc(a)} yield b
  9.   println(test)
  10. }

 

This leads to a compiler error:

Error:(15, 10) type mismatch;
 found : MonadTrait[MonadTrait[Int]]
 required: scala.collection.GenTraversableOnce[?]
 b <- stupidFunc(a)

This because apparently the definition of flatMap used in the for comprehension expects you to use the standard definition of flatMap in the Traversable hierarchy that is the follow:

def flatMap [B, That] (f: (A) ⇒ GenTraversableOnce[B])(implicit bf: CanBuildFrom[Repr, B, That]): That

for example I could have defined my Monad by extending FilterMonadic. This has the exact definition of map, flatMap, forEach and withFilter used for defining a Monad.
For today it is enough. Stay tuned!

Share