aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp')
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp30
1 files changed, 23 insertions, 7 deletions
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index a3c1420388..b1f9512655 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -29,15 +29,31 @@ public: \
: armnn::TestLayerVisitor(layerName) \
, m_Descriptor(descriptor) {}; \
\
- void Visit##name##Layer(const armnn::IConnectableLayer* layer, \
- const Descriptor& descriptor, \
- const char* layerName = nullptr) override \
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer, \
+ const armnn::BaseDescriptor& descriptor, \
+ const std::vector<armnn::ConstTensor>& constants, \
+ const char* layerName, \
+ const armnn::LayerBindingId id = 0) override \
{ \
- CheckLayerPointer(layer); \
- CheckDescriptor(descriptor); \
- CheckLayerName(layerName); \
+ armnn::IgnoreUnused(descriptor, constants, id); \
+ switch (layer->GetType()) \
+ { \
+ case armnn::LayerType::Input: break; \
+ case armnn::LayerType::Output: break; \
+ case armnn::LayerType::name: break; \
+ { \
+ CheckLayerPointer(layer); \
+ CheckDescriptor(static_cast<const Descriptor&>(descriptor)); \
+ CheckLayerName(layerName); \
+ break; \
+ } \
+ default: \
+ { \
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); \
+ } \
+ } \
} \
-};
+}; \
} // anonymous namespace