Published on

Lambdas

Authors
  • avatar
    Name
    Justin Ji
    Twitter

People who know me might know that I am a lambda fan. Like, number one lambda glazer. Lambda this, lambda that, etc., which results in my code looking kind of weird.

Anyways, what is a lambda?

A lambda expression is a convenient way of defining an often anonymous function object. Typically, lambdas are used to write short functions that are passed into algorithms/other functions.

Okay, that's pretty abstract. Just think of them as functions, if you could also treat them as a variable. As for the syntax, it looks like this:

auto func = [capture type](parameters) -> trailing_return {
    // your code...
};

Most of this should make enough sense. But what the hell is a capture type?

Our capture type allows us to specify what variables we capture in our local scope.

  • [] means that we don't capture anything
  • [=] means that we capture everything by value
  • [&] means that we capture everything by reference

In addition, we can also capture specific variables. However, we normally just choose one of the broader capture types.

Let's consider sorting a vector of pairs in C++.

#include <algorithm>
#include <iostream>
#include <vector>
#include <utility>

int main() {
    int n;
    std::cin >> n;
    std::vector<std::pair<int, int>> arr(n);
    for (auto &[first, second] : arr) {
        cin >> first >> second;
    }

    /**
     * Consider sorting the array of pairs in descending
     * order of first value, and then tiebreak by sorting in
     * ascending order of second element.
     */
    for (const auto [first, second] : arr) {
        std::cout << first << " " << second << "\n";
    }
}

What would the easiest way of doing this be? A solution without lambdas would be to write a custom comparator function outside of main().

#include <algorithm>
#include <iostream>
#include <vector>
#include <utility>

bool comp(const std::pair<int, int> &a, const std::pair<int, int> &b) {
    if (a.first == b.first) {
        return a.second < b.second;
    }
    return a.first > b.first;
}

int main() {
    int n;
    std::cin >> n;
    std::vector<std::pair<int, int>> arr(n);
    for (auto &[first, second] : arr) {
        std::cin >> first >> second;
    }

    /**
     * Consider sorting the array of pairs in descending
     * order of first value, and then tiebreak by sorting in
     * ascending order of second element.
     */
    std::sort(std::begin(arr), std::end(arr), comp);
    for (const auto [first, second] : arr) {
        std::cout << first << " " << second << "\n";
    }
}

However, this is frankly quite verbose and pollutes our namespace just for a small function that we only need once. Hence, lambdas offer a cleaner alternative to writing a separate function.

#include <algorithm>
#include <iostream>
#include <vector>
#include <utility>

int main() {
    int n;
    std::cin >> n;
    std::vector<std::pair<int, int>> arr(n);
    for (auto &[first, second] : arr) {
        std::cin >> first >> second;
    }

    /**
     * Consider sorting the array of pairs in descending
     * order of first value, and then tiebreak by sorting in
     * ascending order of second element.
     */
    std::sort(std::begin(arr), std::end(arr), 
        [](const std::pair<int, int> &a, const std::pair<int, int> &b) {
        if (a.first == b.first) {
            return a.second < b.second;
        }
        return a.first > b.first;
    });
    for (const auto [first, second] : arr) {
        std::cout << first << " " << second << "\n";
    }
}

Lambdas aren't limited to just writing anonymous functions though. You can actually assign them to a variable, allowing then to act like regular functions. The key difference that makes lambdas convenient is that lambdas can capture a specific scope, making them very useful as smaller helper functions.

Just for the sake of demonstration, let's assume we want to make a function that calculates the least common multiple of two numbers. This is equivalent to the following expression:

lcm(a,b)=abgcd(a,b)\text{lcm}(a, b) = \frac{a \cdot b}{\gcd(a, b)}

Let's assume we needed to find the LCM of two numbers many times, and that we aren't using std::lcm for some reason. In that case, we could use a lambda to write our function.

#include <iostream>
#include <numeric>

int main() {
    auto lcm = [&](int a, int b) -> int {
        return a * b / std::gcd(a, b);
    };

    std::cout << lcm(10, 10) << "\n"; // prints 10
}

What if we wanted to write our own GCD method with a recursive function? Sadly, there isn't a straightforward way to do this. The best way (in my opinion) is to use something called a generic lambda.

The rough idea is that a lambda can't access itself in its function definition. So, to fix this, we just pass the lambda into itself. Some code will probably help in understanding what I mean by this.

#include <iostream>

int main() {
    auto gcd = [](int a, int b, auto self) -> void {
        return b == 0 ? a : self(b, a % b, self);
    };

    std::cout << gcd(20, 30, gcd) << "\n"; // prints 10
}

There are other ways to do this. One way is to use the y_combinator (no, not the startup accelerator) template.

#include <iostream>

namespace std {

template <class Fun> class y_combinator_result {
    Fun fun_;

    public:
    template <class T>
    explicit y_combinator_result(T &&fun) : fun_(std::forward<T>(fun)) {}

    template <class... Args> decltype(auto) operator()(Args &&...args) {
      return fun_(std::ref(*this), std::forward<Args>(args)...);
    }
};

template <class Fun> decltype(auto) y_combinator(Fun &&fun) {
    return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun));
}

}  // namespace std

int main() {
    std::cout << std::y_combinator([](auto gcd, int a, int b) -> int {
        return b == 0 ? a : gcd(b, a % b);
    })(20, 30) << "\n";  // prints 10
}

I personally don't like this method because it's verbose and doesn't have any real benefit over the method described previously.

Another method is to use std::function.

#include <iostream>
#include <functional>

int main() {
    std::function<int(int, int)> gcd = [&](int a, int b) {
        return b == 0 ? a : gcd(b, a % b);
    };
    
    std::cout << gcd(20, 30) << '\n';  // outputs 10
}

This method, along with having no benefits compared to the first method, is also the slowest. This is because of something called type erasure, which I won't get into.

Example - Lambda DFS

Some people, like me, refuse to be sane and use global methods. Well, there are niceties, but for competitive programming purposes we're considered wack. I think.

Anyways, this is how most people write DFS.

std::vector<int> adj[n];
bool vis[n];
void dfs(int u) {
    if (vis[u]) return;
    vis[u] = true;
    for (int v : adj[u]) {
        dfs(v);
    }
}

However, with lambda DFS, you can write it this way.

int main() {
    // ... assume we already read our graph 
    std::vector<std::vector<int>> adj(n);
    std::vector<bool> vis(n);

    const auto dfs = [&](int u, int p, auto &&self) -> void {
        if (vis[u]) return;
        vis[u] = true;
        for (int v : adj[u]) {
            self(v, self);
        }
    };

    // dfs(0, -1, dfs);
}

It actually looks pretty similar to our first code snippet. However, it's a bit more verbose. This probably doesn't matter if you are a global variable user, but it is nice if you prefer to keep everything local.