aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/utility/PolymorphicDowncast.hpp
blob: 76b00fa888145c7bbd1d1a2d1a70de672b27ed4a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "Assert.hpp"

#include <armnn/Exceptions.hpp>

#include <memory>
#include <type_traits>

namespace armnn
{

// If we are testing then throw an exception, otherwise regular assert
#if defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
#   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ConditionalThrow<std::bad_cast>(cond)
#else
#   define ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond) ARMNN_ASSERT(cond)
#endif

//Only check the condition if debug build or during testing
#if !defined(NDEBUG) || defined(ARMNN_POLYMORPHIC_CAST_TESTABLE)
#   define ARMNN_POLYMORPHIC_CAST_CHECK(cond)  ARMNN_POLYMORPHIC_CAST_CHECK_METHOD(cond)
#else
#   define ARMNN_POLYMORPHIC_CAST_CHECK(cond) // release builds dont check the cast
#endif


namespace utility
{
// static_pointer_cast overload for std::shared_ptr
template <class T1, class T2>
std::shared_ptr<T1> StaticPointerCast (const std::shared_ptr<T2>& sp)
{
    return std::static_pointer_cast<T1>(sp);
}

// dynamic_pointer_cast overload for std::shared_ptr
template <class T1, class T2>
std::shared_ptr<T1> DynamicPointerCast (const std::shared_ptr<T2>& sp)
{
    return std::dynamic_pointer_cast<T1>(sp);
}

// static_pointer_cast overload for raw pointers
template<class T1, class T2>
inline T1* StaticPointerCast(T2* ptr)
{
    return static_cast<T1*>(ptr);
}

// dynamic_pointer_cast overload for raw pointers
template<class T1, class T2>
inline T1* DynamicPointerCast(T2* ptr)
{
    return dynamic_cast<T1*>(ptr);
}

} // namespace utility

/// Polymorphic downcast for build in pointers only
///
/// Usage: Child* pChild = PolymorphicDowncast<Child*>(pBase);
///
/// \tparam DestType    Pointer type to the target object (Child pointer type)
/// \tparam SourceType  Pointer type to the source object (Base pointer type)
/// \param value        Pointer to the source object
/// \return             Pointer of type DestType (Pointer of type child)
template<typename DestType, typename SourceType>
DestType PolymorphicDowncast(SourceType* value)
{
    static_assert(std::is_pointer<DestType>::value,
                  "PolymorphicDowncast only works with pointer types.");

    ARMNN_POLYMORPHIC_CAST_CHECK(dynamic_cast<DestType>(value) == value);
    return static_cast<DestType>(value);
}


/// Polymorphic downcast for shared pointers and build in pointers
///
/// Usage: auto pChild = PolymorphicPointerDowncast<Child>(pBase)
///
/// \tparam DestType    Type of the target object (Child type)
/// \tparam SourceType  Pointer type to the source object (Base (shared) pointer type)
/// \param value        Pointer to the source object
/// \return             Pointer of type DestType ((Shared) pointer of type child)
template<typename DestType, typename SourceType>
auto PolymorphicPointerDowncast(const SourceType& value)
{
    ARMNN_POLYMORPHIC_CAST_CHECK(utility::DynamicPointerCast<DestType>(value)
                                 == value);
    return utility::StaticPointerCast<DestType>(value);
}

} //namespace armnn