aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestInputOutputLayerVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TestInputOutputLayerVisitor.hpp')
-rw-r--r--src/armnn/test/TestInputOutputLayerVisitor.hpp56
1 files changed, 42 insertions, 14 deletions
diff --git a/src/armnn/test/TestInputOutputLayerVisitor.hpp b/src/armnn/test/TestInputOutputLayerVisitor.hpp
index b89089530e..e812f2f97d 100644
--- a/src/armnn/test/TestInputOutputLayerVisitor.hpp
+++ b/src/armnn/test/TestInputOutputLayerVisitor.hpp
@@ -27,14 +27,28 @@ public:
, visitorId(id)
{};
- void VisitInputLayer(const IConnectableLayer* layer,
- LayerBindingId id,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerBindingId(visitorId, id);
- CheckLayerName(name);
- };
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerBindingId(visitorId, id);
+ CheckLayerName(name);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
+ }
};
class TestOutputLayerVisitor : public TestLayerVisitor
@@ -48,14 +62,28 @@ public:
, visitorId(id)
{};
- void VisitOutputLayer(const IConnectableLayer* layer,
- LayerBindingId id,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerBindingId(visitorId, id);
- CheckLayerName(name);
- };
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Output:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerBindingId(visitorId, id);
+ CheckLayerName(name);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
+ }
};
} //namespace armnn