1. Introduction
The functional combinators map() and flatMap() are higher-order functions found on RDD, DataFrame, and DataSet in Apache Spark. With these collections, we can perform transformations on every element in a collection and return a new collection containing the result.
Spark’s map() and flatMap() functions are modeled off their equivalents in the Scala programming language, so what we’ll learn in this article can be applied to those too.
Let’s go ahead and look at some examples to help understand the difference between map() and flatMap().
2. Example of map()
The map() method transforms a collection by applying a function to each element of that collection. It then returns a new collection containing the result.
In the following spark-shell example, we’ll use map() to split each string in the collection by the empty string:
val rdd = sc.parallelize(Seq("map vs flatMap", "apache spark"))
rdd.map(_.split(" ")).collect
res1: Array[String] = Array(Array("map", "vs", "flatMap"), Array("apache", "spark"))
As we can see, the map() method takes the function split(” “) as a parameter and applies it to every element in the RDD. The result of the transformation is then returned as a new collection. It still contains two elements like the original, but the element content is the result of running split(” “) on the initial element content.
Therefore, we can conclude that map() transforms a collection of length N into another collection of length N.
3. Example of flatMap()
flatMap() combines mapping and flattening. It first runs the map() method and then the flatten() method to generate the result. The flatten method will collapse the elements of a collection to create a single collection with elements of the same type.
Let’s look at the same example and apply flat**Map() to the collection instead:
val rdd = sc.parallelize(Seq("map vs flatMap", "apache spark"))
rdd.flatMap(_.split(" ")).collect
res1: Array[String] = Array("map", "vs", "flatMap", "apache", "spark")
The result is a collection that contains all the elements from the two nested collections from the previous example’s output.
To help understand how this works, let’s take another look at the previous example’s output after calling the map() method:
res1: Array[String] = Array(Array("map", "vs", "flatMap"), Array("apache", "spark"))
flatMap() runs flatten() on this output, generating a new result that doesn’t contain nested collections but contains each element of the nested collection instead.
Another useful way to use flatMap() is when dealing with a collection of type Option. Flattening the output allows us to quickly retrieve all the elements that have a value:
val mapOutput = strings.map(toInt)
mapOutput: Seq[Option[Int]] = List(Some(3), Some(5), None, Some(8), None, None, Some(12))
val flattenOutput = mapOutput.flatten
flattenOutput: Seq[Int] = List(3, 5, 8, 12)
4. Conclusion
To conclude, we can see how map() transforms a collection of length N into another collection of length N.
flatMap() performs the same initial transformation as map() before running flatten() on the output – flatten() removes the inner grouping of an item and generates a sequence.