29 January 2013

Reimplementation of dynamic_cast in C++

This article will introduce one of the ways to effectively replace functionality of dynamic_cast. Presented source code will use the latest C++ specification (C++11), but with a little bit of limitation it is not a problem to rewrite the source code to C++03.

Casting of base class to derived class can be achieved in two ways in C++: by using static_cast if we are lucky enough and casting is still possible at compile time, or by using dynamic_cast, if we are having no luck and we have to resort to built-in RTTI and perform casting at runtime. One such example is the attempt to cast from a virtual base class.

Many programmers claim that the need of dynamic_cast is a sign of a bad application design. I slightly disagree and I will point to the existence of several cases when dynamic_cast is very useful.  Imagine that you are working with a logging function, which is designed to store information about individual events. An event can have its base class, say, Event, from which inherits every event class. Thanks to dynamic_cast we could get further information about events - we can cast event from its base class Event to a particular event class and then get more information about it.


In the introduction I wrote about a casting of virtual base class to the derived class. Many people try to avoid the virtual inheritance in C++ as much as they can, but it can be useful when you are writing a framework which is rich on interfaces. Personally, I proved in practice to follow the rule that each interface (or abstract class, if you want), which uses a different interface properties (and/or methods), inherits it virtually. Also, it is good practice to have a base class for all other classes, including the interfaces, such as ObjectMany well-known projects in C++ are actually programmed that way.

The reason of virtual inheritance is to avoid creating multiple instances of base classes. Model situation may be the case where there are classes in the project: Object, InputStream, OutputStream and IOStream, which uses the methods of the two previous classes, which also inherits from the Object class. If we wouldn't use virtual inheritance in InputStream and OutputStream, IOStream object would contain two instances of the Object class. Then, if we would try to call GetHandle() from the IOStream object, the compiler would report an error, since the method call is ambiguous, because it is present in both classes. With virtual inheritance, there is only one instance, shared across all objects within each class. This case is also called diamond inheritance.

Example of non-use virtual inheritance

Example of  use virual inheritance

As already stated, the problem occurs at the moment when you need to cast a virtual class to one of the derived classes - Object to any other class - because of virtual inheritance. This problem can be easily solved by using dynamic_cast. For those who is afraid of dynamic_cast for any reason, I will give an alternative way to reimplement the behavior of dynamic_cast, in compliance with certain rules.

Here is how could look the instance of the class IOStream in the memory - from this picture it is clear that there is no easy way to cast Object class to its superclass.

One of these rules actually is that there must be a fundamental base class from which inherits all the other classes, which we want dynamic cast apply to. The Object class, in our case. The main purpose of this class is to provide information about the derived class (its name) and a pair of methods To(const std::string&) - one const and one non-const - which will be used for casting based on the class name. The heart of the principle itself is the hash-table stored in the Object class. Its task is to maintain pointers to derived subclasses based on their names. Table must be marked as mutable, so that we can edit it from const methods.
class Object
{
    public:
        virtual ~Object();

        static const std::string& GetClassNameStatic();
        virtual const std::string& GetClassName() const;
    
        void* To(const std::string& className);
        const void* To(const std::string& className) const;

    private:
        mutable std::map<std::string, void*> _bases;
};
This proposal itself is obviously not sufficient. It is needed to provide a method for the registration of a derived class and the method for recursively crawling the whole hierarchy of inheritance. I divided the source code into two parts to make it easier to be understood.
class Object
{
    public:
        virtual ~Object() { }

        static const std::string& GetClassNameStatic()
        {
            static std::string className("Object");
            return className;
        }

        virtual const std::string& GetClassName() const
        {
            return GetClassNameStatic();
        }
    
        void* To(const std::string& className)
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

        const void* To(const std::string& className) const
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

    protected:
        void RegisterSubclass(const void* ptr, const std::string& className) const
        {
            _bases[className] = const_cast<void*>(ptr);
        }

        virtual void RegisterAllSubclasses() const
        {
            RegisterSubclass(static_cast<const void*>(this), Object::GetClassName());
        }

    private:
        mutable std::map<std::string, void*> _bases;
};
At first glance it is clear what the principle will be. The inheritance hierarchy is not created until To() method is called, which implies the advantage of that the table will not be created until it's needed. Use of the if (_bases.size() == 0)statement is safe, because there always is at least Object in the map.

We will override virtual methods RegisterAllSubclasses() and GetClassName() along with the new definition of a static method GetClassNameStatic() for forcing derived classes to register themselves and their bases. The importance of creating a new definition of the static methods will be explained later. InputStream class could then look like this:
class InputStream : virtual public Object
{
    public:
        static const std::string& GetClassNameStatic()
        {
            static std::string className("InputStream");
            return className;
        }

        virtual const std::string& GetClassName() const override
        {
            return GetClassNameStatic();
        }

        void Read();

    protected:
        virtual void RegisterAllSubclasses() const override
        {
            RegisterSubclass(static_cast<const void*>(this), InputStream::GetClassName());
            Object::RegisterAllSubclasses();
        }
};
It is important to state specific class names in the method RegisterAllSubclasses(). If they are not specified, it would call the overriden method by the most derived class (if InputStream inherit from class IOStream, method GetClassName() would always return IOStream class name). Then call the method RegisterAllSubclasses() on all the base classes, which again register their bases recursively. Method RegisterAllSubclasses() in the IOStream class then should look as follows:
        virtual void RegisterAllSubclasses() const override
        {
            RegisterSubclass(static_cast<const void*>(this), IOStream::GetClassName());
            InputStream::RegisterAllSubclasses();
            OutputStream::RegisterAllSubclasses();
        }
Writing these methods in each class separately, of course, is not convenient, so it is better to leave this job for macro. The macro will take the name of the current class and it will be followed by a list (variadic macro) of all base classes . The actual registration of base classes is kind of tricky - it uses variadic templates that C++11 introduced.
#define DEFINE_BASES(class, ...)                                                \
    static const std::string& GetClassNameStatic()                              \
    {                                                                           \
        static std::string className(#class);                                   \
        return className;                                                       \
    }                                                                           \
                                                                                \
    const std::string& GetClassName() const override                            \
    {                                                                           \
        return GetClassNameStatic();                                            \
    }                                                                           \
                                                                                \
    template <typename _empty>                                                  \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
                                                                                \
    }                                                                           \
                                                                                \
    template <typename _empty, typename T, typename... Args>                    \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
        T::RegisterAllSubclasses();                                             \
        RegisterAllSubclasses<void, Args...>();                                 \
    }                                                                           \
                                                                                \
    virtual void RegisterAllSubclasses() const override                         \
    {                                                                           \
        RegisterSubclass(static_cast<const void*>(this), class::GetClassName());\
        RegisterAllSubclasses<void, __VA_ARGS__>();                             \
    }
The principle is simple - base classes will be registered recursively . After the registration of all classes, the Args parameter will be empty and the recursion will terminate by calling "blank" template method template <typename _empty> void RegisterAllSubclasses(). For the above case, the class IOStream will process the call as follows:


Finally, there remains only one thing - the function for dynamic casting itself. Now comes the advantage of the defined static method for determining the name of the class. Function asks for a class name passed through the template through the static function GetClassNameStatic() and passes it to the method To(const std::string&). The method attempts to find a pointer to a specific subclass in the map, which if it is found, it will return it.
template <typename T>
T my_dynamic_cast(Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}

template <typename T>
T my_dynamic_cast(const Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}
Some of you know the std::remove_pointer<T> trait from the Boost library, but in C++11 it is included in the STL in the type_traits header file. The reason why the function is used is obvious - the function needs to get the data type of the class, while there is passed only a pointer to it through the function.

In conclusion, I will say that there are many proposals for expanding this project. Apart from speeding up the casting by the indexing based on the string hash (which makes this alternative up to 10x faster than standard dynamic_cast), this project should be extended into full own RTTI reimplementation.

Finally, the whole source code:
#include <iostream>
#include <string>
#include <map>
#include <type_traits>

#define DEFINE_BASES(class, ...)                                                \
    static const std::string& GetClassNameStatic()                              \
    {                                                                           \
        static std::string className(#class);                                   \
        return className;                                                       \
    }                                                                           \
                                                                                \
    const std::string& GetClassName() const override                            \
    {                                                                           \
        return GetClassNameStatic();                                            \
    }                                                                           \
                                                                                \
    template <typename _empty>                                                  \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
                                                                                \
    }                                                                           \
                                                                                \
    template <typename _empty, typename T, typename... Args>                    \
    void RegisterAllSubclasses() const                                          \
    {                                                                           \
        T::RegisterAllSubclasses();                                             \
        RegisterAllSubclasses<void, Args...>();                                 \
    }                                                                           \
                                                                                \
    virtual void RegisterAllSubclasses() const override                         \
    {                                                                           \
        RegisterSubclass(static_cast<const void*>(this), class::GetClassName());\
        RegisterAllSubclasses<void, __VA_ARGS__>();                             \
    }

class Object
{
    public:
        virtual ~Object() { }

        static std::string& GetClassNameStatic()
        {
            static std::string className("Object");
            return className;
        }

        virtual const std::string& GetClassName() const
        {
            return GetClassNameStatic();
        }

        void* To(const std::string& className) 
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

        const void* To(const std::string& className) const
        {
            if (_bases.size() == 0)
            {
                RegisterAllSubclasses();
            }

            auto result = _bases.find(className);
            return (result != _bases.end() ? (*result).second : nullptr);
        }

    protected:
        void RegisterSubclass(const void* ptr, const std::string& className) const
        {
            _bases[className] = const_cast<void*>(ptr);
        }

        virtual void RegisterAllSubclasses() const
        {
            RegisterSubclass(static_cast<const void*>(this), Object::GetClassName());
        }

    private:
        mutable std::map<std::string, void*> _bases;
};

////////////////////////////////////////////////////////////////////////////////
 
template <typename T>
T my_dynamic_cast(Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}

template <typename T>
T my_dynamic_cast(const Object* ptr)
{
    return static_cast<T>(ptr->To(std::remove_pointer<T>::type::GetClassNameStatic()));
}
 
////////////////////////////////////////////////////////////////////////////////
 
class InputStream : virtual public Object
{
    public:
        DEFINE_BASES(InputStream, Object);
        void Read() { }
};
 
class OutputStream : virtual public Object
{
    public:
        DEFINE_BASES(OutputStream, Object);
        void Write() { }
};
 
class IOStream : public InputStream, public OutputStream
{
    int _value;

    public:
        DEFINE_BASES(IOStream, InputStream, OutputStream);
        IOStream() : _value(0) { }

        int GetValue() const { return _value; }
        void SetValue(int value) { _value = value; }
};
 
int main()
{
    const Object*   co = new IOStream;
    const IOStream* cd = my_dynamic_cast<const IOStream*>(co);
    
    Object*   o = new IOStream;
    IOStream* d = my_dynamic_cast<IOStream*>(o);

    d->SetValue(42);
    
    printf("const:     %i, %p, %p\n", cd->GetValue(), co, cd);
    printf("non-const: %i, %p, %p\n", d->GetValue(), o, d);
    
    delete cd;
    delete d;

    return 0;
}

1 comment:

  1. Interesting concept. Did you compare the performance between your mechanism and dynamic_cast?

    From what I understand, your mechanism adds overhead during the instantiation. For a tall class structure, this will takes more time. You even use std::map which cost about O(log n) with additional memory space.

    Compared to traditional RTTI which uses class structures, and pointer in vtable to point to the typeid object, I think RTTI will give better performance with almost zero overhead in class instantiation as the method you propose.

    But I think there can be some optimization that you can do such as using linked-list data structure as opposed to map.

    Cheers

    ReplyDelete