1. Overview

In this tutorial, we’ll learn about the subtle differences between map(), flatMap(), and flatten() in Kotlin.

2. map()

map() is an extension function in Kotlin that is defined as:

fun <T, R> Iterable<T>.map(transform: (T) -> R): List<R> 

As shown above, this function iterates over all elements of an Iterable one by one. During this iteration, it transforms every single element of type T to another element of type R. At the end, it converts all elements of the receiving collection, and we’ll end up with a List.

This function is usually useful in one-to-one mapping situations. For example, let’s suppose each order consists of many order lines as its detailed purchase items:

data class Order(val lines: List<OrderLine>)
data class OrderLine(val name: String, val price: Int)

Now, if we have an Order, we can use map() to find the name of each item*:*

val order = Order(
  listOf(OrderLine("Tomato", 2), OrderLine("Garlic", 3), OrderLine("Chives", 2))
)
val names = order.lines.map { it.name }

assertThat(names).containsExactly("Tomato", "Garlic", "Chives")

Here, we’re converting a List to a List by passing a simple transformation function. This function accepts each OrderLine as the input (the it variable) and converts it to a String. As another example, we can calculate the total price of an Order like:

val totalPrice = order.lines.map { it.price }.sum()
assertEquals(7, totalPrice)

Basically, map() is the equivalent of the following imperative coding style:

val result = mutableListOf<R>()
for (each in this) {
    result += transform(each)
}

*When we’re using map(), we just have to write the transform part*. Defining a new collection, iteration, and adding each transformed element to that collection is just some boilerplate code and part of the implementation details.

3. flatMap()

As opposed to map(), flatMap() is usually useful for flattening one-to-many relationships. Because of that, its signature looks like:

fun <T, R> Iterable<T>.flatMap(transform: (T) -> Iterable<R>): List<R>

As shown above, it transforms each element of type T into a collection of type R. Despite this, instead of ending up with a List<Iterable>, flatMap() flattens each Iterable to its individual elements. Therefore, we’ll have a List as a result.

As an example, let’s suppose we have a collection of orders, and we’re going to find all the distinct item names:

val orders = listOf(
  Order(listOf(OrderLine("Garlic", 1), OrderLine("Chives", 2))),
  Order(listOf(OrderLine("Tomato", 3), OrderLine("Garlic", 4))),
  Order(listOf(OrderLine("Potato", 5), OrderLine("Chives", 6))),
)

At first, we somehow should convert the List to a List. If we use map() here, we’ll end up with a List<List>, which is not desirable:

orders.map { it.lines } // List<List<OrderLine>>

Since we need to flatten the List to individual OrderLines here, we can use the flatMap() function:

val lines: List<OrderLine> = orders.flatMap { it.lines }
val names = lines.map { it.name }.distinct()
assertThat(names).containsExactlyInAnyOrder("Garlic", "Chives", "Tomato", "Potato")

As shown above, *the flatMap() function flattens the one-to-many relationship between each Order and its OrderLines*.

The imperative style equivalent of flatMap() is something like:

val result = mutableListOf<OrderLine>()
for (order in orders) {
    val transformedList = order.lines
    for (individual in transformedList) {
        result += individual
    }
}

Again, the collection initialization, iteration, and flattening are part of the hidden implementation details of flatMap(). All we have to do is provide the transformation function.

4. flatten()

Kotlin also offers the flatten() function for Iterable and Array.

public fun <T> Iterable<Iterable<T>>.flatten(): List<T> {...}

As the name implies, flatten() converts nested Iterable objects into a single, flat Iterable:

val orderLines = listOf(
    listOf(OrderLine("Garlic", 1), OrderLine("Chives", 2)),
    listOf(OrderLine("Tomato", 3), OrderLine("Garlic", 4)),
    listOf(OrderLine("Potato", 5), OrderLine("Chives", 6)),
)
   
val lines: List<OrderLine> = orderLines.flatten()
val expected = listOf(
    OrderLine("Garlic", 1),
    OrderLine("Chives", 2),
    OrderLine("Tomato", 3),
    OrderLine("Garlic", 4),
    OrderLine("Potato", 5),
    OrderLine("Chives", 6),
)
assertThat(lines).hasSize(6).isEqualTo(expected)

It’s important to note that unlike map() and flatMap() functions, flatten() doesn’t apply any transformations:

public fun <T> Iterable<Iterable<T>>.flatten(): List<T> {
    val result = ArrayList<T>()
    for (element in this) {
        result.addAll(element)
    }
    return result
}

As *flatten()*‘s implementation code shows, *it flattens one level of the nested Iterables and doesn’t accept any function or lambda parameter to perform transformations.*

5. map(), flatMap(), and flatten()

We’ve learned map(), flatMap(), and flatten() functions through examples:

  • map() – transformation only
  • flatten() – flattening only
  • flatMap() – transformation and flattening

So, the logic implemented by flatMap() can be written with map() and flatten() too. In other words: flatMap()* is equivalent to map() and then *flatten().

An example might help us quickly understand the differences and relations among these three functions:

val orders = listOf(
    Order(listOf(OrderLine("Garlic", 1), OrderLine("Chives", 2))),
    Order(listOf(OrderLine("Tomato", 3), OrderLine("Garlic", 4))),
    Order(listOf(OrderLine("Potato", 5), OrderLine("Chives", 6))),
)
 
val expected = listOf(
    OrderLine("Garlic", 1),
    OrderLine("Chives", 2),
    OrderLine("Tomato", 3),
    OrderLine("Garlic", 4),
    OrderLine("Potato", 5),
    OrderLine("Chives", 6),
)
 
val resultMapAndFlatten: List<OrderLine> = orders.map { it.lines }.flatten()
val resultFlatMap:List<OrderLine> = orders.flatMap { it.lines }
 
assertThat(resultFlatMap).isEqualTo(resultMapAndFlatten).hasSize(6).isEqualTo(expected)

The above test shows that the result of map() and then flatten() matches what we’d expect when using flatMap() directly. orders.map { it.lines }.flatten() and orders.flatMap { it.lines } have the same result.

6. Conclusion

In this article, we learned the differences between map(), flatMap(), and flatten() in Kotlin. To sum up, map() is usually useful for one-to-one mappings, while flatMap() is more useful for flattening one-to-many mappings. flatten() only flattens nested Iterable objects without any transformations.

As usual, all the examples are available over on GitHub.