// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "TestLayerVisitor.hpp" namespace armnn { // Concrete TestLayerVisitor subclasses for layers taking Name argument with overridden VisitLayer methods class TestAdditionLayerVisitor : public TestLayerVisitor { public: explicit TestAdditionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestMultiplicationLayerVisitor : public TestLayerVisitor { public: explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestFloorLayerVisitor : public TestLayerVisitor { public: explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitFloorLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestDivisionLayerVisitor : public TestLayerVisitor { public: explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitDivisionLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestSubtractionLayerVisitor : public TestLayerVisitor { public: explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitSubtractionLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestMaximumLayerVisitor : public TestLayerVisitor { public: explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitMaximumLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestMinimumLayerVisitor : public TestLayerVisitor { public: explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitMinimumLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestGreaterLayerVisitor : public TestLayerVisitor { public: explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitGreaterLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestEqualLayerVisitor : public TestLayerVisitor { public: explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitEqualLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestRsqrtLayerVisitor : public TestLayerVisitor { public: explicit TestRsqrtLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitRsqrtLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; class TestGatherLayerVisitor : public TestLayerVisitor { public: explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; void VisitGatherLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; } //namespace armnn