diff options
Diffstat (limited to 'src/armnn/test/TestInputOutputLayerVisitor.hpp')
-rw-r--r-- | src/armnn/test/TestInputOutputLayerVisitor.hpp | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/src/armnn/test/TestInputOutputLayerVisitor.hpp b/src/armnn/test/TestInputOutputLayerVisitor.hpp index b89089530e..e812f2f97d 100644 --- a/src/armnn/test/TestInputOutputLayerVisitor.hpp +++ b/src/armnn/test/TestInputOutputLayerVisitor.hpp @@ -27,14 +27,28 @@ public: , visitorId(id) {}; - void VisitInputLayer(const IConnectableLayer* layer, - LayerBindingId id, - const char* name = nullptr) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - CheckLayerPointer(layer); - CheckLayerBindingId(visitorId, id); - CheckLayerName(name); - }; + armnn::IgnoreUnused(descriptor, constants, id); + switch (layer->GetType()) + { + case armnn::LayerType::Input: + { + CheckLayerPointer(layer); + CheckLayerBindingId(visitorId, id); + CheckLayerName(name); + break; + } + default: + { + m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); + } + } + } }; class TestOutputLayerVisitor : public TestLayerVisitor @@ -48,14 +62,28 @@ public: , visitorId(id) {}; - void VisitOutputLayer(const IConnectableLayer* layer, - LayerBindingId id, - const char* name = nullptr) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - CheckLayerPointer(layer); - CheckLayerBindingId(visitorId, id); - CheckLayerName(name); - }; + armnn::IgnoreUnused(descriptor, constants, id); + switch (layer->GetType()) + { + case armnn::LayerType::Output: + { + CheckLayerPointer(layer); + CheckLayerBindingId(visitorId, id); + CheckLayerName(name); + break; + } + default: + { + m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); + } + } + } }; } //namespace armnn |