aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestNameOnlyLayerVisitor.hpp
blob: dec0d15a969c083ea11c376c83552b9e395cf5e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "TestLayerVisitor.hpp"

namespace armnn
{

// Concrete TestLayerVisitor subclasses for layers taking Name argument with overridden VisitLayer methods
class TestAdditionLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestAdditionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitAdditionLayer(const IConnectableLayer* layer,
                            const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestDivisionLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitDivisionLayer(const IConnectableLayer* layer,
                            const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestEqualLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitEqualLayer(const IConnectableLayer* layer,
                         const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestFloorLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitFloorLayer(const IConnectableLayer* layer,
                         const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestGatherLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitGatherLayer(const IConnectableLayer* layer,
                          const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestGreaterLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitGreaterLayer(const IConnectableLayer* layer,
                           const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestMultiplicationLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitMultiplicationLayer(const IConnectableLayer* layer,
                                  const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestMaximumLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitMaximumLayer(const IConnectableLayer* layer,
                           const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestMinimumLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitMinimumLayer(const IConnectableLayer* layer,
                           const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestRsqrtLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestRsqrtLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitRsqrtLayer(const IConnectableLayer* layer,
                         const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestSliceLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitSliceLayer(const IConnectableLayer* layer,
                         const SliceDescriptor& sliceDescriptor,
                         const char* name = nullptr) override
    {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

class TestSubtractionLayerVisitor : public TestLayerVisitor
{
public:
    explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};

    void VisitSubtractionLayer(const IConnectableLayer* layer,
                               const char* name = nullptr) override {
        CheckLayerPointer(layer);
        CheckLayerName(name);
    };
};

} // namespace armnn