Classically polymorphic visit replaces some uses of dynamic_cast

I’ve informally described this C++ utility to several people by now; it’s time I just wrote it up in a blog post. This utility lets you “visit” a polymorphic base object as if it were a variant, by specifying the domain of possible child types at the point of the call to visit. my::visit<X,Y,Z>(b, f) returns f(static_cast<X&>(b)) if the dynamic type of b is exactly X; f(static_cast<Y&>(b)) if the dynamic type of b is exactly Y; and so on. It perfectly forwards the argument when making the call to f.

In other words, it does the same thing as std::visit, but instead of operating on a std::variant<X,Y,Z>, it operates on a Base& which the programmer promises will refer to an X, Y, or Z at runtime.

A three-argument overload permits the programmer to specify a custom behavior e(b) when the dynamic type of b isn’t exactly X, Y, or Z. The two-argument overload simply throws std::bad_cast in that case, exactly as if we’d failed a dynamic_cast.

Sample use-case

class Connection { ~~~ };
class TCPConnection final : public Connection { ~~~ };
class UDPConnection final : public Connection { ~~~ };

void handle(TCPConnection&);
void handle(UDPConnection&);

void frotz(Connection& conn) {
    if (auto *tcp = dynamic_cast<TCPConnection *>(&conn)) {
        handle(*tcp);
    } else if (auto *udp = dynamic_cast<UDPConnection *>(&conn)) {
        handle(*udp);
    } else {
        throw std::runtime_error("we don't expect this");
    }
}

Given the above class hierarchy, we could rewrite frotz as follows:

void frotz(Connection& conn) {
    my::visit<TCPConnection, UDPConnection>(conn, [](auto& conn) {
        handle(conn);
    });
}

Inside the body of the lambda, conn has the correct static type: the body of the lambda will be instantiated for both TCPConnection& and UDPConnection&. This is the “magic” that lets it call the correct overload of handle in each case.

Here’s another example, using the same class hierarchy, demonstrating the three-argument overload:

bool isTCPorUDP(const Connection& conn) {
    return (dynamic_cast<const TCPConnection *>(&conn) != nullptr)
        || (dynamic_cast<const UDPConnection *>(&conn) != nullptr);
}

Rewriting gives us more lines of code, but two fewer dynamic_casts:

bool isTCPorUDP(const Connection& conn) {
    return my::visit<TCPConnection, UDPConnection>(
        conn,
        [](auto&&) { return true; },  // when TCP or UDP
        [](auto&&) { return false; }  // when something else
    );
}

If you saw my talk dynamic_cast From Scratch” (CppCon 2017), you may observe that the question “Can conn be dynamic-cast to TCPConnection?” is not synonymous with the question “Is the dynamic type of conn exactly TCPConnection?” True! But a lot of real-world code has shallow hierarchies (no grandchild classes), and often when you see “dynamic_cast” in real-world code, what it really means is “compare typeids” — the programmer was just too lazy to spell it out. For this kind of codebase, switching those wasteful dynamic_casts to visit can be a performance win!

I must also point out that according to the classically polymorphic ideal, both handle() and isTCPorUDP() should simply be virtual member functions of Connection; using dynamic_cast to sniff at the dynamic type of conn is obviously an antipattern. But, again, a lot of real-world code ends up doing it anyway. So we might as well let them do it conveniently and efficiently.

The usual programming caveat applies: If it hurts when you do this, don’t do this.

Implementation

I’ll present the implementation backwards, starting with the two-argument overload that throws bad_cast on failure. Notice that throughout I’m using the namespace my:: where presumably you’d be using something else.

template<class... DerivedClasses, class Base, class F>
auto visit(Base&& b, const F& f)
{
    return my::visit<DerivedClasses...>(
        static_cast<Base&&>(b), f, [](Base&&) {
            throw std::bad_cast();
        }
    );
}

Okay, now for the three-argument entry point. It has two callbacks: not just the generic lambda f, but also an “error” callback e. It just packages up the deduced parameters Base,F,E and passes them off to a struct visit_impl with a static call method:

template<class... DerivedClasses, class Base, class F, class E>
auto visit(Base&& b, const F& f, const E& e)
{
    static_assert(!std::is_pointer_v<std::remove_reference_t<Base>>,
        "Argument must be (reference-to) class type, not pointer type");
    static_assert(std::is_polymorphic_v<std::remove_reference_t<Base>>,
        "Argument must be of polymorphic class type");
    static_assert((std::is_polymorphic_v<DerivedClasses> && ...),
        "Template arguments must all be polymorphic class types");

    return my::visit_impl<Base, F, E>::template call<DerivedClasses...>(
        static_cast<Base&&>(b), f, e
    );
}

Inside visit_impl::call, we’ll do the usual “recursive template” trick, where call<X,Y,Z> calls call<Y,Z> calls call<Z> calls call<>. We must add a few wrinkles at the call<Z> level to deal with the fact that different e(b) callbacks might either return a value (as in our isTCPorUDP example above) or exit without returning (as in the convenience visit above, where e throws bad_cast but claims to return void). Standard C++ doesn’t have any way to mark a function as “not returning.” (There’s [[noreturn]], but that’s just an optimization hint that doesn’t interact with the type system. See “The Ignorable Attributes Rule” (2018-05-15).) So instead, we’ll just assume that if e(b) claims to return void and f((Z&)b) doesn’t, then e(b) must really not return at all.

By the way, notice that in order for decltype(my::visit<X,Y,Z>(b, f, e)) to make sense, the same type T must be returned from f((X&)b), f((Y&)b), and f((Z&)b). Unless it throws, e(b) also must return that same type T.

The final oddity is that although we want visit_impl<B,F,E>::call<X,Y,Z>(b,f,e) to do basically the same thing as visit_impl<B,F,E>::call<X>(b,f,e) plus a recursive case, we need visit_impl<B,F,E>::call<>(b,f,e) to do something completely different — namely, call e(b) unconditionally. To create a function template that is callable as call<>(b,f,e) but doesn’t ambiguate either of our other call signatures, we make call<> a template with a non-type parameter pack. I could have used <int=0> instead of <int...>, but I like the mangled names better this way.

template<class Base, class F, class E>
struct visit_impl {

    template<int...>  // this is call<>(b, f, e)
    static auto call(Base&& b, const F&, const E& e) {
        return e(static_cast<Base&&>(b));
    }

    template<class DerivedClass, class... Rest>
    static auto call(Base&& b, const F& f, const E& e) {
        using Derived = my::match_cvref_t<Base, DerivedClass>;
        using T = decltype(f(static_cast<Derived&&>(b)));
        using ErrorT = decltype(e(static_cast<Base&&>(b)));
        if (typeid(b) == typeid(DerivedClass)) {
            return f(static_cast<Derived&&>(b));
        } else if constexpr (sizeof...(Rest) != 0) {
            return call<Rest...>(static_cast<Base&&>(b), f, e);
        } else if constexpr (std::is_void_v<ErrorT> && !std::is_void_v<T>) {
            // If e(b) has type void, assume it exits by throwing.
            e(static_cast<Base&&>(b));
            throw std::bad_cast();  // we expect this to be unreachable
        } else {
            return e(static_cast<Base&&>(b));
        }
    }
};

Notice the ladder of if/else if constexpr/else if constexpr/else. I’d originally written it with more nested scopes for else blocks, but then I realized that I could do it this way and save some tabstops.

The utility my::match_cvref_t<T, U> simply takes the cvref-qualifiers from T and applies them to the object type U; for example, match_cvref_t<const int&, double> is const double&. The implementation is just a bunch of partial specializations; I won’t bore you with it here.

You can see the complete code for my::visit, in the proper order, on Godbolt here.

Posted 2020-09-29