aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LayerSupportCommon.hpp
blob: 9252b3b9a57c0be7bbc979fc98bf56fc8ff2a2e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/DescriptorsFwd.hpp>
#include <armnn/Types.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/Optional.hpp>

namespace armnn
{

template<typename T, typename V>
void SetValueChecked(Optional<T&> optionalRef, V&& val)
{
    if (optionalRef)
    {
        optionalRef.value() = val;
    }
}

template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename BooleanFunc,
         typename ... Params>
bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported,
                                   DataType dataType,
                                   Float16Func float16FuncPtr,
                                   Float32Func float32FuncPtr,
                                   Uint8Func uint8FuncPtr,
                                   Int32Func int32FuncPtr,
                                   BooleanFunc booleanFuncPtr,
                                   Params&&... params)
{
    switch(dataType)
    {
        case DataType::Float16:
            return float16FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
        case DataType::Float32:
            return float32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
        case DataType::QAsymmU8:
            return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
        case DataType::Signed32:
            return int32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
        case DataType::Boolean:
            return booleanFuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
        default:
            return false;
    }
}

template<typename ... Params>
bool TrueFunc(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(reasonIfUnsupported);
    IgnoreUnused(params...);
    return true;
}

template<typename ... Params>
bool FalseFunc(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(reasonIfUnsupported);
    IgnoreUnused(params...);
    return false;
}

template<typename ... Params>
bool FalseFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type");
    return false;
}

template<typename ... Params>
bool FalseFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type");
    return false;
}

template<typename ... Params>
bool FalseFuncU8(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with 8-bit data type");
    return false;
}

template<typename ... Params>
bool FalseFuncI32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with int32 data type");
    return false;
}

template<typename ... Params>
bool FalseInputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type input");
    return false;
}

template<typename ... Params>
bool FalseInputFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type input");
    return false;
}

template<typename ... Params>
bool FalseOutputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type output");
    return false;
}

template<typename ... Params>
bool FalseOutputFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
    IgnoreUnused(params...);
    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type output");
    return false;
}

}