aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/ILayerSupport.hpp
blob: d63c3a70630b7598f2bff261824f4bfd8a781f9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/DescriptorsFwd.hpp>
#include <armnn/Optional.hpp>
#include <vector>
#include <cctype>
#include <memory>

namespace armnn
{

class TensorInfo;

class ILayerSupport
{
protected:
    ILayerSupport() {}
    virtual ~ILayerSupport() {}

public:
    virtual bool IsActivationSupported(const TensorInfo& input,
                                       const TensorInfo& output,
                                       const ActivationDescriptor& descriptor,
                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsAdditionSupported(const TensorInfo& input0,
                                     const TensorInfo& input1,
                                     const TensorInfo& output,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
                                               const TensorInfo& output,
                                               const TensorInfo& mean,
                                               const TensorInfo& var,
                                               const TensorInfo& beta,
                                               const TensorInfo& gamma,
                                               const BatchNormalizationDescriptor& descriptor,
                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsConstantSupported(const TensorInfo& output,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
                                              const TensorInfo& output,
                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
                                              const TensorInfo& output,
                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsConvolution2dSupported(const TensorInfo& input,
                                          const TensorInfo& output,
                                          const Convolution2dDescriptor& descriptor,
                                          const TensorInfo& weights,
                                          const Optional<TensorInfo>& biases,
                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsDepthwiseConvolutionSupported(const TensorInfo& input,
                                                 const TensorInfo& output,
                                                 const DepthwiseConvolution2dDescriptor& descriptor,
                                                 const TensorInfo& weights,
                                                 const Optional<TensorInfo>& biases,
                                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsDivisionSupported(const TensorInfo& input0,
                                     const TensorInfo& input1,
                                     const TensorInfo& output,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
                                             const FakeQuantizationDescriptor& descriptor,
                                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsFloorSupported(const TensorInfo& input,
                                  const TensorInfo& output,
                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsFullyConnectedSupported(const TensorInfo& input,
                                           const TensorInfo& output,
                                           const TensorInfo& weights,
                                           const TensorInfo& biases,
                                           const FullyConnectedDescriptor& descriptor,
                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsInputSupported(const TensorInfo& input,
                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsL2NormalizationSupported(const TensorInfo& input,
                                            const TensorInfo& output,
                                            const L2NormalizationDescriptor& descriptor,
                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsLstmSupported(const TensorInfo& input,
                                 const TensorInfo& outputStateIn,
                                 const TensorInfo& cellStateIn,
                                 const TensorInfo& scratchBuffer,
                                 const TensorInfo& outputStateOut,
                                 const TensorInfo& cellStateOut,
                                 const TensorInfo& output,
                                 const LstmDescriptor& descriptor,
                                 const TensorInfo& inputToForgetWeights,
                                 const TensorInfo& inputToCellWeights,
                                 const TensorInfo& inputToOutputWeights,
                                 const TensorInfo& recurrentToForgetWeights,
                                 const TensorInfo& recurrentToCellWeights,
                                 const TensorInfo& recurrentToOutputWeights,
                                 const TensorInfo& forgetGateBias,
                                 const TensorInfo& cellBias,
                                 const TensorInfo& outputGateBias,
                                 const TensorInfo* inputToInputWeights,
                                 const TensorInfo* recurrentToInputWeights,
                                 const TensorInfo* cellToInputWeights,
                                 const TensorInfo* inputGateBias,
                                 const TensorInfo* projectionWeights,
                                 const TensorInfo* projectionBias,
                                 const TensorInfo* cellToForgetWeights,
                                 const TensorInfo* cellToOutputWeights,
                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsMeanSupported(const TensorInfo& input,
                                 const TensorInfo& output,
                                 const MeanDescriptor& descriptor,
                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
                                   const OriginsDescriptor& descriptor,
                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsMultiplicationSupported(const TensorInfo& input0,
                                           const TensorInfo& input1,
                                           const TensorInfo& output,
                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsNormalizationSupported(const TensorInfo& input,
                                          const TensorInfo& output,
                                          const NormalizationDescriptor& descriptor,
                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsOutputSupported(const TensorInfo& output,
                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsPadSupported(const TensorInfo& input,
                                const TensorInfo& output,
                                const PadDescriptor& descriptor,
                                Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsPermuteSupported(const TensorInfo& input,
                                    const TensorInfo& output,
                                    const PermuteDescriptor& descriptor,
                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsPooling2dSupported(const TensorInfo& input,
                                      const TensorInfo& output,
                                      const Pooling2dDescriptor& descriptor,
                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsReshapeSupported(const TensorInfo& input,
                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsResizeBilinearSupported(const TensorInfo& input,
                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsSoftmaxSupported(const TensorInfo& input,
                                    const TensorInfo& output,
                                    const SoftmaxDescriptor& descriptor,
                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsSplitterSupported(const TensorInfo& input,
                                     const ViewsDescriptor& descriptor,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;

    virtual bool IsSubtractionSupported(const TensorInfo& input0,
                                        const TensorInfo& input1,
                                        const TensorInfo& output,
                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
}; // class ILayerSupport

using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;

} // namespace armnn