aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestNameOnlyLayerVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TestNameOnlyLayerVisitor.hpp')
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp201
1 files changed, 50 insertions, 151 deletions
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index dec0d15a96..1c5ede0802 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -6,154 +6,53 @@
#include "TestLayerVisitor.hpp"
-namespace armnn
-{
-
-// Concrete TestLayerVisitor subclasses for layers taking Name argument with overridden VisitLayer methods
-class TestAdditionLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestAdditionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitAdditionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestDivisionLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitDivisionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestEqualLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitEqualLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestFloorLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitFloorLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestGatherLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitGatherLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestGreaterLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitGreaterLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestMultiplicationLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitMultiplicationLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestMaximumLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitMaximumLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestMinimumLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitMinimumLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestRsqrtLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestRsqrtLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitRsqrtLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestSliceLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitSliceLayer(const IConnectableLayer* layer,
- const SliceDescriptor& sliceDescriptor,
- const char* name = nullptr) override
- {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-class TestSubtractionLayerVisitor : public TestLayerVisitor
-{
-public:
- explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
-
- void VisitSubtractionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
- CheckLayerPointer(layer);
- CheckLayerName(name);
- };
-};
-
-} // namespace armnn
+namespace
+{
+
+// Defines a visitor function with 1 required parameter to be used
+// with layers that do not have a descriptor
+#define VISIT_METHOD_1_PARAM(name) \
+void Visit##name##Layer(const armnn::IConnectableLayer* layer, const char* layerName = nullptr) override
+
+// Defines a visitor function with 2 required parameters to be used
+// with layers that have a descriptor
+#define VISIT_METHOD_2_PARAM(name) \
+void Visit##name##Layer(const armnn::IConnectableLayer* layer, \
+ const armnn::name##Descriptor&, \
+ const char* layerName = nullptr) override
+
+#define TEST_LAYER_VISITOR(name, numVisitorParams) \
+class Test##name##LayerVisitor : public armnn::TestLayerVisitor \
+{ \
+public: \
+ explicit Test##name##LayerVisitor(const char* layerName = nullptr) : armnn::TestLayerVisitor(layerName) {}; \
+ \
+ VISIT_METHOD_##numVisitorParams##_PARAM(name) \
+ { \
+ CheckLayerPointer(layer); \
+ CheckLayerName(layerName); \
+ } \
+};
+
+// Defines a test layer visitor class for a layer, of a given name,
+// that does not require a descriptor
+#define TEST_LAYER_VISITOR_1_PARAM(name) TEST_LAYER_VISITOR(name, 1)
+
+// Defines a test layer visitor class for a layer, of a given name,
+// that requires a descriptor
+#define TEST_LAYER_VISITOR_2_PARAM(name) TEST_LAYER_VISITOR(name, 2)
+
+} // anonymous namespace
+
+TEST_LAYER_VISITOR_1_PARAM(Addition)
+TEST_LAYER_VISITOR_1_PARAM(Division)
+TEST_LAYER_VISITOR_1_PARAM(Equal)
+TEST_LAYER_VISITOR_1_PARAM(Floor)
+TEST_LAYER_VISITOR_1_PARAM(Gather)
+TEST_LAYER_VISITOR_1_PARAM(Greater)
+TEST_LAYER_VISITOR_1_PARAM(Maximum)
+TEST_LAYER_VISITOR_1_PARAM(Minimum)
+TEST_LAYER_VISITOR_1_PARAM(Multiplication)
+TEST_LAYER_VISITOR_1_PARAM(Rsqrt)
+TEST_LAYER_VISITOR_2_PARAM(Slice)
+TEST_LAYER_VISITOR_1_PARAM(Subtraction)