// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "TestLayerVisitor.hpp" #include namespace armnn { void CheckLayerBindingId(LayerBindingId visitorId, LayerBindingId id) { CHECK_EQ(visitorId, id); } // Concrete TestLayerVisitor subclasses for layers taking LayerBindingId argument with overridden VisitLayer methods class TestInputLayerVisitor : public TestLayerVisitor { private: LayerBindingId visitorId; public: explicit TestInputLayerVisitor(LayerBindingId id, const char* name = nullptr) : TestLayerVisitor(name) , visitorId(id) {}; void ExecuteStrategy(const armnn::IConnectableLayer* layer, const armnn::BaseDescriptor& descriptor, const std::vector& constants, const char* name, const armnn::LayerBindingId id = 0) override { armnn::IgnoreUnused(descriptor, constants, id); switch (layer->GetType()) { case armnn::LayerType::Input: { CheckLayerPointer(layer); CheckLayerBindingId(visitorId, id); CheckLayerName(name); break; } default: { m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); } } } }; class TestOutputLayerVisitor : public TestLayerVisitor { private: LayerBindingId visitorId; public: explicit TestOutputLayerVisitor(LayerBindingId id, const char* name = nullptr) : TestLayerVisitor(name) , visitorId(id) {}; void ExecuteStrategy(const armnn::IConnectableLayer* layer, const armnn::BaseDescriptor& descriptor, const std::vector& constants, const char* name, const armnn::LayerBindingId id = 0) override { armnn::IgnoreUnused(descriptor, constants, id); switch (layer->GetType()) { case armnn::LayerType::Output: { CheckLayerPointer(layer); CheckLayerBindingId(visitorId, id); CheckLayerName(name); break; } default: { m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); } } } }; } //namespace armnn