aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestLayerVisitor.hpp
blob: fe2631fa39d67f8357abac5c4142d5efed78a585 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/ILayerVisitor.hpp>
#include <armnn/Descriptors.hpp>

namespace armnn
{
// Abstract base class with do nothing implementations for all layer visit methods
class TestLayerVisitor : public ILayerVisitor
{
protected:
    virtual ~TestLayerVisitor() {}

    void CheckLayerName(const char* name);

    void CheckLayerPointer(const IConnectableLayer* layer);

    void CheckConstTensors(const ConstTensor& expected, const ConstTensor& actual);

private:
    const char* m_LayerName;

public:
    explicit TestLayerVisitor(const char* name) : m_LayerName(name)
    {
        if (name == nullptr)
        {
            m_LayerName = "";
        }
    }

    virtual void VisitInputLayer(const IConnectableLayer* layer,
                                 LayerBindingId id,
                                 const char* name = nullptr) {}

    virtual void VisitConvolution2dLayer(const IConnectableLayer* layer,
                                         const Convolution2dDescriptor& convolution2dDescriptor,
                                         const ConstTensor& weights,
                                         const char* name = nullptr) {}

    virtual void VisitConvolution2dLayer(const IConnectableLayer* layer,
                                         const Convolution2dDescriptor& convolution2dDescriptor,
                                         const ConstTensor& weights,
                                         const ConstTensor& biases,
                                         const char* name = nullptr) {}

    virtual void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
                                                  const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
                                                  const ConstTensor& weights,
                                                  const char* name = nullptr) {}

    virtual void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
                                                  const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
                                                  const ConstTensor& weights,
                                                  const ConstTensor& biases,
                                                  const char* name = nullptr) {}

    virtual void VisitDetectionPostProcessLayer(const IConnectableLayer* layer,
                                                const DetectionPostProcessDescriptor& descriptor,
                                                const ConstTensor& anchors,
                                                const char* name = nullptr) {}

    virtual void VisitFullyConnectedLayer(const IConnectableLayer* layer,
                                          const FullyConnectedDescriptor& fullyConnectedDescriptor,
                                          const ConstTensor& weights,
                                          const char* name = nullptr) {}

    virtual void VisitFullyConnectedLayer(const IConnectableLayer* layer,
                                          const FullyConnectedDescriptor& fullyConnectedDescriptor,
                                          const ConstTensor& weights,
                                          const ConstTensor& biases,
                                          const char* name = nullptr) {}

    virtual void VisitPermuteLayer(const IConnectableLayer* layer,
                                   const PermuteDescriptor& permuteDescriptor,
                                   const char* name = nullptr) {}

    virtual void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer,
                                          const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
                                          const char* name = nullptr) {}

    virtual void VisitPooling2dLayer(const IConnectableLayer* layer,
                                     const Pooling2dDescriptor& pooling2dDescriptor,
                                     const char* name = nullptr) {}

    virtual void VisitActivationLayer(const IConnectableLayer* layer,
                                      const ActivationDescriptor& activationDescriptor,
                                      const char* name = nullptr) {}

    virtual void VisitNormalizationLayer(const IConnectableLayer* layer,
                                         const NormalizationDescriptor& normalizationDescriptor,
                                         const char* name = nullptr) {}

    virtual void VisitSoftmaxLayer(const IConnectableLayer* layer,
                                   const SoftmaxDescriptor& softmaxDescriptor,
                                   const char* name = nullptr) {}

    virtual void VisitSplitterLayer(const IConnectableLayer* layer,
                                    const ViewsDescriptor& splitterDescriptor,
                                    const char* name = nullptr) {}

    virtual void VisitMergerLayer(const IConnectableLayer* layer,
                                  const OriginsDescriptor& mergerDescriptor,
                                  const char* name = nullptr) {}

    virtual void VisitAdditionLayer(const IConnectableLayer* layer,
                                    const char* name = nullptr) {}

    virtual void VisitMultiplicationLayer(const IConnectableLayer* layer,
                                          const char* name = nullptr) {}

    virtual void VisitBatchNormalizationLayer(const IConnectableLayer* layer,
                                              const BatchNormalizationDescriptor& desc,
                                              const ConstTensor& mean,
                                              const ConstTensor& variance,
                                              const ConstTensor& beta,
                                              const ConstTensor& gamma,
                                              const char* name = nullptr) {}

    virtual void VisitResizeBilinearLayer(const IConnectableLayer* layer,
                                          const ResizeBilinearDescriptor& resizeDesc,
                                          const char* name = nullptr) {}

    virtual void VisitL2NormalizationLayer(const IConnectableLayer* layer,
                                           const L2NormalizationDescriptor& desc,
                                           const char* name = nullptr) {}

    virtual void VisitConstantLayer(const IConnectableLayer* layer,
                                    const ConstTensor& input,
                                    const char* name = nullptr) {}

    virtual void VisitReshapeLayer(const IConnectableLayer* layer,
                                   const ReshapeDescriptor& reshapeDescriptor,
                                   const char* name = nullptr) {}

    virtual void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer,
                                          const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
                                          const char* name = nullptr) {}

    virtual void VisitFloorLayer(const IConnectableLayer* layer,
                                 const char* name = nullptr) {}

    virtual void VisitOutputLayer(const IConnectableLayer* layer,
                                  LayerBindingId id,
                                  const char* name = nullptr) {}

    virtual void VisitLstmLayer(const IConnectableLayer* layer,
                                const LstmDescriptor& descriptor,
                                const LstmInputParams& params,
                                const char* name = nullptr) {}

    virtual void VisitDivisionLayer(const IConnectableLayer* layer,
                                    const char* name = nullptr) {}

    virtual void VisitSubtractionLayer(const IConnectableLayer* layer,
                                       const char* name = nullptr) {}

    virtual void VisitMaximumLayer(const IConnectableLayer* layer,
                                   const char* name = nullptr) {}

    virtual void VisitMeanLayer(const IConnectableLayer* layer,
                                const MeanDescriptor& meanDescriptor,
                                const char* name = nullptr) {}

    virtual void VisitPadLayer(const IConnectableLayer* layer,
                               const PadDescriptor& padDescriptor,
                               const char* name = nullptr) {}

    virtual void VisitStridedSliceLayer(const IConnectableLayer* layer,
                                        const StridedSliceDescriptor& stridedSliceDescriptor,
                                        const char* name = nullptr) {}

    virtual void VisitMinimumLayer(const IConnectableLayer* layer,
                                   const char* name = nullptr) {}

    virtual void VisitGreaterLayer(const IConnectableLayer* layer,
                                   const char* name = nullptr) {}

    virtual void VisitEqualLayer(const IConnectableLayer* layer,
                                 const char* name = nullptr) {}

    virtual void VisitRsqrtLayer(const IConnectableLayer* layer,
                                 const char* name = nullptr) {}

    virtual void VisitGatherLayer(const IConnectableLayer* layer,
                                  const char* name = nullptr) {}
};

} //namespace armnn