From 10e0786f15bdb60e1d632c9a368fce2737852ae4 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 16 Sep 2019 16:30:59 +0100 Subject: IVGCVSW-3877 Reduce code duplication in TestNameOnlyLayerVisitor * Defined macros for common class structure and near-identical test cases Signed-off-by: Aron Virginas-Tar Change-Id: I47a2ece3e1797496c196f63c7fcd71e5748295c6 --- src/armnn/test/TestNameOnlyLayerVisitor.hpp | 201 +++++++--------------------- 1 file changed, 50 insertions(+), 151 deletions(-) (limited to 'src/armnn/test/TestNameOnlyLayerVisitor.hpp') diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp index dec0d15a96..1c5ede0802 100644 --- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp +++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp @@ -6,154 +6,53 @@ #include "TestLayerVisitor.hpp" -namespace armnn -{ - -// Concrete TestLayerVisitor subclasses for layers taking Name argument with overridden VisitLayer methods -class TestAdditionLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestAdditionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitAdditionLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestDivisionLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitDivisionLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestEqualLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitEqualLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestFloorLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitFloorLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestGatherLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitGatherLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestGreaterLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitGreaterLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestMultiplicationLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitMultiplicationLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestMaximumLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitMaximumLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestMinimumLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitMinimumLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestRsqrtLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestRsqrtLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitRsqrtLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestSliceLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitSliceLayer(const IConnectableLayer* layer, - const SliceDescriptor& sliceDescriptor, - const char* name = nullptr) override - { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -class TestSubtractionLayerVisitor : public TestLayerVisitor -{ -public: - explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - - void VisitSubtractionLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { - CheckLayerPointer(layer); - CheckLayerName(name); - }; -}; - -} // namespace armnn +namespace +{ + +// Defines a visitor function with 1 required parameter to be used +// with layers that do not have a descriptor +#define VISIT_METHOD_1_PARAM(name) \ +void Visit##name##Layer(const armnn::IConnectableLayer* layer, const char* layerName = nullptr) override + +// Defines a visitor function with 2 required parameters to be used +// with layers that have a descriptor +#define VISIT_METHOD_2_PARAM(name) \ +void Visit##name##Layer(const armnn::IConnectableLayer* layer, \ + const armnn::name##Descriptor&, \ + const char* layerName = nullptr) override + +#define TEST_LAYER_VISITOR(name, numVisitorParams) \ +class Test##name##LayerVisitor : public armnn::TestLayerVisitor \ +{ \ +public: \ + explicit Test##name##LayerVisitor(const char* layerName = nullptr) : armnn::TestLayerVisitor(layerName) {}; \ + \ + VISIT_METHOD_##numVisitorParams##_PARAM(name) \ + { \ + CheckLayerPointer(layer); \ + CheckLayerName(layerName); \ + } \ +}; + +// Defines a test layer visitor class for a layer, of a given name, +// that does not require a descriptor +#define TEST_LAYER_VISITOR_1_PARAM(name) TEST_LAYER_VISITOR(name, 1) + +// Defines a test layer visitor class for a layer, of a given name, +// that requires a descriptor +#define TEST_LAYER_VISITOR_2_PARAM(name) TEST_LAYER_VISITOR(name, 2) + +} // anonymous namespace + +TEST_LAYER_VISITOR_1_PARAM(Addition) +TEST_LAYER_VISITOR_1_PARAM(Division) +TEST_LAYER_VISITOR_1_PARAM(Equal) +TEST_LAYER_VISITOR_1_PARAM(Floor) +TEST_LAYER_VISITOR_1_PARAM(Gather) +TEST_LAYER_VISITOR_1_PARAM(Greater) +TEST_LAYER_VISITOR_1_PARAM(Maximum) +TEST_LAYER_VISITOR_1_PARAM(Minimum) +TEST_LAYER_VISITOR_1_PARAM(Multiplication) +TEST_LAYER_VISITOR_1_PARAM(Rsqrt) +TEST_LAYER_VISITOR_2_PARAM(Slice) +TEST_LAYER_VISITOR_1_PARAM(Subtraction) -- cgit v1.2.1