aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test')
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.cpp236
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.hpp358
-rw-r--r--src/armnn/test/NetworkTests.cpp118
-rw-r--r--src/armnn/test/OptimizerTests.cpp67
-rw-r--r--src/armnn/test/TestInputOutputLayerVisitor.cpp8
-rw-r--r--src/armnn/test/TestInputOutputLayerVisitor.hpp56
-rw-r--r--src/armnn/test/TestLayerVisitor.cpp56
-rw-r--r--src/armnn/test/TestLayerVisitor.hpp19
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp4
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp30
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.cpp4
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp24
12 files changed, 592 insertions, 388 deletions
diff --git a/src/armnn/test/ConstTensorLayerVisitor.cpp b/src/armnn/test/ConstTensorLayerVisitor.cpp
index d3d8698972..e21e777409 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.cpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.cpp
@@ -58,73 +58,6 @@ void TestLstmLayerVisitor::CheckDescriptor(const LstmDescriptor& descriptor)
CHECK(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
}
-void TestLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
-void TestLstmLayerVisitor::CheckInputParameters(const LstmInputParams& inputParams)
-{
- CheckConstTensorPtrs("ProjectionBias", m_InputParams.m_ProjectionBias, inputParams.m_ProjectionBias);
- CheckConstTensorPtrs("ProjectionWeights", m_InputParams.m_ProjectionWeights, inputParams.m_ProjectionWeights);
- CheckConstTensorPtrs("OutputGateBias", m_InputParams.m_OutputGateBias, inputParams.m_OutputGateBias);
- CheckConstTensorPtrs("InputToInputWeights",
- m_InputParams.m_InputToInputWeights, inputParams.m_InputToInputWeights);
- CheckConstTensorPtrs("InputToForgetWeights",
- m_InputParams.m_InputToForgetWeights, inputParams.m_InputToForgetWeights);
- CheckConstTensorPtrs("InputToCellWeights", m_InputParams.m_InputToCellWeights, inputParams.m_InputToCellWeights);
- CheckConstTensorPtrs(
- "InputToOutputWeights", m_InputParams.m_InputToOutputWeights, inputParams.m_InputToOutputWeights);
- CheckConstTensorPtrs(
- "RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, inputParams.m_RecurrentToInputWeights);
- CheckConstTensorPtrs(
- "RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, inputParams.m_RecurrentToForgetWeights);
- CheckConstTensorPtrs(
- "RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, inputParams.m_RecurrentToCellWeights);
- CheckConstTensorPtrs(
- "RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, inputParams.m_RecurrentToOutputWeights);
- CheckConstTensorPtrs(
- "CellToInputWeights", m_InputParams.m_CellToInputWeights, inputParams.m_CellToInputWeights);
- CheckConstTensorPtrs(
- "CellToForgetWeights", m_InputParams.m_CellToForgetWeights, inputParams.m_CellToForgetWeights);
- CheckConstTensorPtrs(
- "CellToOutputWeights", m_InputParams.m_CellToOutputWeights, inputParams.m_CellToOutputWeights);
- CheckConstTensorPtrs("InputGateBias", m_InputParams.m_InputGateBias, inputParams.m_InputGateBias);
- CheckConstTensorPtrs("ForgetGateBias", m_InputParams.m_ForgetGateBias, inputParams.m_ForgetGateBias);
- CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
-}
-
-void TestQLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
void TestQLstmLayerVisitor::CheckDescriptor(const QLstmDescriptor& descriptor)
{
CHECK(m_Descriptor.m_CellClip == descriptor.m_CellClip);
@@ -134,95 +67,6 @@ void TestQLstmLayerVisitor::CheckDescriptor(const QLstmDescriptor& descriptor)
CHECK(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
}
-void TestQLstmLayerVisitor::CheckInputParameters(const LstmInputParams& inputParams)
-{
- CheckConstTensorPtrs("InputToInputWeights",
- m_InputParams.m_InputToInputWeights,
- inputParams.m_InputToInputWeights);
-
- CheckConstTensorPtrs("InputToForgetWeights",
- m_InputParams.m_InputToForgetWeights,
- inputParams.m_InputToForgetWeights);
-
- CheckConstTensorPtrs("InputToCellWeights",
- m_InputParams.m_InputToCellWeights,
- inputParams.m_InputToCellWeights);
-
- CheckConstTensorPtrs("InputToOutputWeights",
- m_InputParams.m_InputToOutputWeights,
- inputParams.m_InputToOutputWeights);
-
- CheckConstTensorPtrs("RecurrentToInputWeights",
- m_InputParams.m_RecurrentToInputWeights,
- inputParams.m_RecurrentToInputWeights);
-
- CheckConstTensorPtrs("RecurrentToForgetWeights",
- m_InputParams.m_RecurrentToForgetWeights,
- inputParams.m_RecurrentToForgetWeights);
-
- CheckConstTensorPtrs("RecurrentToCellWeights",
- m_InputParams.m_RecurrentToCellWeights,
- inputParams.m_RecurrentToCellWeights);
-
- CheckConstTensorPtrs("RecurrentToOutputWeights",
- m_InputParams.m_RecurrentToOutputWeights,
- inputParams.m_RecurrentToOutputWeights);
-
- CheckConstTensorPtrs("CellToInputWeights",
- m_InputParams.m_CellToInputWeights,
- inputParams.m_CellToInputWeights);
-
- CheckConstTensorPtrs("CellToForgetWeights",
- m_InputParams.m_CellToForgetWeights,
- inputParams.m_CellToForgetWeights);
-
- CheckConstTensorPtrs("CellToOutputWeights",
- m_InputParams.m_CellToOutputWeights,
- inputParams.m_CellToOutputWeights);
-
- CheckConstTensorPtrs("ProjectionWeights", m_InputParams.m_ProjectionWeights, inputParams.m_ProjectionWeights);
- CheckConstTensorPtrs("ProjectionBias", m_InputParams.m_ProjectionBias, inputParams.m_ProjectionBias);
-
- CheckConstTensorPtrs("InputGateBias", m_InputParams.m_InputGateBias, inputParams.m_InputGateBias);
- CheckConstTensorPtrs("ForgetGateBias", m_InputParams.m_ForgetGateBias, inputParams.m_ForgetGateBias);
- CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
- CheckConstTensorPtrs("OutputGateBias", m_InputParams.m_OutputGateBias, inputParams.m_OutputGateBias);
-
- CheckConstTensorPtrs("InputLayerNormWeights",
- m_InputParams.m_InputLayerNormWeights,
- inputParams.m_InputLayerNormWeights);
-
- CheckConstTensorPtrs("ForgetLayerNormWeights",
- m_InputParams.m_ForgetLayerNormWeights,
- inputParams.m_ForgetLayerNormWeights);
-
- CheckConstTensorPtrs("CellLayerNormWeights",
- m_InputParams.m_CellLayerNormWeights,
- inputParams.m_CellLayerNormWeights);
-
- CheckConstTensorPtrs("OutputLayerNormWeights",
- m_InputParams.m_OutputLayerNormWeights,
- inputParams.m_OutputLayerNormWeights);
-}
-
-void TestQuantizedLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
void TestQuantizedLstmLayerVisitor::CheckInputParameters(const QuantizedLstmInputParams& inputParams)
{
CheckConstTensorPtrs("InputToInputWeights",
@@ -285,7 +129,7 @@ TEST_CASE("CheckConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, EmptyOptional());
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConvolution2dLayer")
@@ -309,7 +153,7 @@ TEST_CASE("CheckNamedConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, EmptyOptional(), layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckConvolution2dLayerWithBiases")
@@ -338,7 +182,7 @@ TEST_CASE("CheckConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, optionalBiases);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConvolution2dLayerWithBiases")
@@ -368,7 +212,7 @@ TEST_CASE("CheckNamedConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, optionalBiases, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckDepthwiseConvolution2dLayer")
@@ -391,7 +235,7 @@ TEST_CASE("CheckDepthwiseConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, EmptyOptional());
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedDepthwiseConvolution2dLayer")
@@ -418,7 +262,7 @@ TEST_CASE("CheckNamedDepthwiseConvolution2dLayer")
weights,
EmptyOptional(),
layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckDepthwiseConvolution2dLayerWithBiases")
@@ -447,7 +291,7 @@ TEST_CASE("CheckDepthwiseConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, optionalBiases);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedDepthwiseConvolution2dLayerWithBiases")
@@ -477,7 +321,7 @@ TEST_CASE("CheckNamedDepthwiseConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, optionalBiases, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckFullyConnectedLayer")
@@ -500,8 +344,8 @@ TEST_CASE("CheckFullyConnectedLayer")
IConnectableLayer* const layer = net.AddFullyConnectedLayer(descriptor);
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
- weightsLayer->Accept(weightsVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedFullyConnectedLayer")
@@ -525,8 +369,8 @@ TEST_CASE("CheckNamedFullyConnectedLayer")
IConnectableLayer* const layer = net.AddFullyConnectedLayer(descriptor, layerName);
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
- weightsLayer->Accept(weightsVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckFullyConnectedLayerWithBiases")
@@ -556,9 +400,9 @@ TEST_CASE("CheckFullyConnectedLayerWithBiases")
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
biasesLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2));
- weightsLayer->Accept(weightsVisitor);
- biasesLayer->Accept(biasesVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ biasesLayer->ExecuteStrategy(biasesVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedFullyConnectedLayerWithBiases")
@@ -589,9 +433,9 @@ TEST_CASE("CheckNamedFullyConnectedLayerWithBiases")
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
biasesLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2));
- weightsLayer->Accept(weightsVisitor);
- biasesLayer->Accept(biasesVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ biasesLayer->ExecuteStrategy(biasesVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckBatchNormalizationLayer")
@@ -621,7 +465,7 @@ TEST_CASE("CheckBatchNormalizationLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddBatchNormalizationLayer(descriptor, mean, variance, beta, gamma);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedBatchNormalizationLayer")
@@ -653,7 +497,7 @@ TEST_CASE("CheckNamedBatchNormalizationLayer")
IConnectableLayer* const layer = net.AddBatchNormalizationLayer(
descriptor, mean, variance, beta, gamma, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckConstLayer")
@@ -667,7 +511,7 @@ TEST_CASE("CheckConstLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConstantLayer(input);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConstLayer")
@@ -682,7 +526,7 @@ TEST_CASE("CheckNamedConstLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConstantLayer(input, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerBasic")
@@ -754,7 +598,7 @@ TEST_CASE("CheckLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerBasic")
@@ -827,7 +671,7 @@ TEST_CASE("CheckNamedLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerCifgDisabled")
@@ -918,7 +762,7 @@ TEST_CASE("CheckLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerCifgDisabled")
@@ -1010,7 +854,7 @@ TEST_CASE("CheckNamedLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
// TODO add one with peephole
@@ -1097,7 +941,7 @@ TEST_CASE("CheckLstmLayerPeephole")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerPeepholeCifgDisabled")
@@ -1211,7 +1055,7 @@ TEST_CASE("CheckLstmLayerPeepholeCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerPeephole")
@@ -1298,7 +1142,7 @@ TEST_CASE("CheckNamedLstmLayerPeephole")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
// TODO add one with projection
@@ -1385,7 +1229,7 @@ TEST_CASE("CheckLstmLayerProjection")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerProjection")
@@ -1472,7 +1316,7 @@ TEST_CASE("CheckNamedLstmLayerProjection")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerBasic")
@@ -1544,7 +1388,7 @@ TEST_CASE("CheckQLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedQLstmLayerBasic")
@@ -1617,7 +1461,7 @@ TEST_CASE("CheckNamedQLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabled")
@@ -1712,7 +1556,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabledPeepholeEnabled")
@@ -1829,7 +1673,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabledPeepholeEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgEnabledPeepholeEnabled")
@@ -1919,7 +1763,7 @@ TEST_CASE("CheckQLstmLayerCifgEnabledPeepholeEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerProjectionEnabled")
@@ -2009,7 +1853,7 @@ TEST_CASE("CheckQLstmLayerProjectionEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabledLayerNormEnabled")
@@ -2132,7 +1976,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabledLayerNormEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
@@ -2222,7 +2066,7 @@ TEST_CASE("CheckQuantizedLstmLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedQuantizedLstmLayer")
@@ -2312,7 +2156,7 @@ TEST_CASE("CheckNamedQuantizedLstmLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
}
diff --git a/src/armnn/test/ConstTensorLayerVisitor.hpp b/src/armnn/test/ConstTensorLayerVisitor.hpp
index 35e2e872f7..5538852b60 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.hpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.hpp
@@ -5,9 +5,14 @@
#pragma once
#include "TestLayerVisitor.hpp"
+#include "LayersFwd.hpp"
#include <armnn/Descriptors.hpp>
#include <armnn/LstmParams.hpp>
#include <armnn/QuantizedLstmParams.hpp>
+#include <armnn/utility/PolymorphicDowncast.hpp>
+#include <backendsCommon/TensorHandle.hpp>
+
+#include <doctest/doctest.h>
namespace armnn
{
@@ -27,17 +32,33 @@ public:
virtual ~TestConvolution2dLayerVisitor() {}
- void VisitConvolution2dLayer(const IConnectableLayer* layer,
- const Convolution2dDescriptor& convolution2dDescriptor,
- const ConstTensor& weights,
- const Optional<ConstTensor>& biases,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(convolution2dDescriptor);
- CheckConstTensors(m_Weights, weights);
- CheckOptionalConstTensors(m_Biases, biases);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Convolution2d:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
+ CheckConstTensors(m_Weights, constants[0]);
+ if (m_Biases.has_value())
+ {
+ CHECK(constants.size() == 2);
+ CheckConstTensors(m_Biases.value(), constants[1]);
+ }
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
@@ -64,17 +85,33 @@ public:
virtual ~TestDepthwiseConvolution2dLayerVisitor() {}
- void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
- const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
- const ConstTensor& weights,
- const Optional<ConstTensor>& biases,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(convolution2dDescriptor);
- CheckConstTensors(m_Weights, weights);
- CheckOptionalConstTensors(m_Biases, biases);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::DepthwiseConvolution2d:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
+ CheckConstTensors(m_Weights, constants[0]);
+ if (m_Biases.has_value())
+ {
+ CHECK(constants.size() == 2);
+ CheckConstTensors(m_Biases.value(), constants[1]);
+ }
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
@@ -97,13 +134,27 @@ public:
virtual ~TestFullyConnectedLayerVistor() {}
- void VisitFullyConnectedLayer(const IConnectableLayer* layer,
- const FullyConnectedDescriptor& fullyConnectedDescriptor,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(fullyConnectedDescriptor);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::FullyConnected:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::FullyConnectedDescriptor&>(descriptor));
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
@@ -129,21 +180,31 @@ public:
, m_Gamma(gamma)
{}
- void VisitBatchNormalizationLayer(const IConnectableLayer* layer,
- const BatchNormalizationDescriptor& descriptor,
- const ConstTensor& mean,
- const ConstTensor& variance,
- const ConstTensor& beta,
- const ConstTensor& gamma,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(descriptor);
- CheckConstTensors(m_Mean, mean);
- CheckConstTensors(m_Variance, variance);
- CheckConstTensors(m_Beta, beta);
- CheckConstTensors(m_Gamma, gamma);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::BatchNormalization:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
+ CheckConstTensors(m_Mean, constants[0]);
+ CheckConstTensors(m_Variance, constants[1]);
+ CheckConstTensors(m_Beta, constants[2]);
+ CheckConstTensors(m_Gamma, constants[3]);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
@@ -166,81 +227,201 @@ public:
, m_Input(input)
{}
- void VisitConstantLayer(const IConnectableLayer* layer,
- const ConstTensor& input,
- const char* name = nullptr)
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckConstTensors(m_Input, input);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Constant:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckConstTensors(m_Input, constants[0]);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
private:
ConstTensor m_Input;
};
-class TestLstmLayerVisitor : public TestLayerVisitor
+// Used to supply utility functions to the actual lstm test visitors
+class LstmVisitor : public TestLayerVisitor
+{
+public:
+ explicit LstmVisitor(const LstmInputParams& params,
+ const char* name = nullptr)
+ : TestLayerVisitor(name)
+ , m_InputParams(params) {}
+
+protected:
+ template<typename LayerType>
+ void CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams);
+
+ LstmInputParams m_InputParams;
+};
+
+template<typename LayerType>
+void LstmVisitor::CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams)
+{
+ CheckConstTensorPtrs("OutputGateBias",
+ inputParams.m_OutputGateBias,
+ layer->m_BasicParameters.m_OutputGateBias);
+ CheckConstTensorPtrs("InputToForgetWeights",
+ inputParams.m_InputToForgetWeights,
+ layer->m_BasicParameters.m_InputToForgetWeights);
+ CheckConstTensorPtrs("InputToCellWeights",
+ inputParams.m_InputToCellWeights,
+ layer->m_BasicParameters.m_InputToCellWeights);
+ CheckConstTensorPtrs("InputToOutputWeights",
+ inputParams.m_InputToOutputWeights,
+ layer->m_BasicParameters.m_InputToOutputWeights);
+ CheckConstTensorPtrs("RecurrentToForgetWeights",
+ inputParams.m_RecurrentToForgetWeights,
+ layer->m_BasicParameters.m_RecurrentToForgetWeights);
+ CheckConstTensorPtrs("RecurrentToCellWeights",
+ inputParams.m_RecurrentToCellWeights,
+ layer->m_BasicParameters.m_RecurrentToCellWeights);
+ CheckConstTensorPtrs("RecurrentToOutputWeights",
+ inputParams.m_RecurrentToOutputWeights,
+ layer->m_BasicParameters.m_RecurrentToOutputWeights);
+ CheckConstTensorPtrs("ForgetGateBias",
+ inputParams.m_ForgetGateBias,
+ layer->m_BasicParameters.m_ForgetGateBias);
+ CheckConstTensorPtrs("CellBias",
+ inputParams.m_CellBias,
+ layer->m_BasicParameters.m_CellBias);
+
+ CheckConstTensorPtrs("InputToInputWeights",
+ inputParams.m_InputToInputWeights,
+ layer->m_CifgParameters.m_InputToInputWeights);
+ CheckConstTensorPtrs("RecurrentToInputWeights",
+ inputParams.m_RecurrentToInputWeights,
+ layer->m_CifgParameters.m_RecurrentToInputWeights);
+ CheckConstTensorPtrs("InputGateBias",
+ inputParams.m_InputGateBias,
+ layer->m_CifgParameters.m_InputGateBias);
+
+ CheckConstTensorPtrs("ProjectionBias",
+ inputParams.m_ProjectionBias,
+ layer->m_ProjectionParameters.m_ProjectionBias);
+ CheckConstTensorPtrs("ProjectionWeights",
+ inputParams.m_ProjectionWeights,
+ layer->m_ProjectionParameters.m_ProjectionWeights);
+
+ CheckConstTensorPtrs("CellToInputWeights",
+ inputParams.m_CellToInputWeights,
+ layer->m_PeepholeParameters.m_CellToInputWeights);
+ CheckConstTensorPtrs("CellToForgetWeights",
+ inputParams.m_CellToForgetWeights,
+ layer->m_PeepholeParameters.m_CellToForgetWeights);
+ CheckConstTensorPtrs("CellToOutputWeights",
+ inputParams.m_CellToOutputWeights,
+ layer->m_PeepholeParameters.m_CellToOutputWeights);
+
+ CheckConstTensorPtrs("InputLayerNormWeights",
+ inputParams.m_InputLayerNormWeights,
+ layer->m_LayerNormParameters.m_InputLayerNormWeights);
+ CheckConstTensorPtrs("ForgetLayerNormWeights",
+ inputParams.m_ForgetLayerNormWeights,
+ layer->m_LayerNormParameters.m_ForgetLayerNormWeights);
+ CheckConstTensorPtrs("CellLayerNormWeights",
+ inputParams.m_CellLayerNormWeights,
+ layer->m_LayerNormParameters.m_CellLayerNormWeights);
+ CheckConstTensorPtrs("OutputLayerNormWeights",
+ inputParams.m_OutputLayerNormWeights,
+ layer->m_LayerNormParameters.m_OutputLayerNormWeights);
+}
+
+class TestLstmLayerVisitor : public LstmVisitor
{
public:
explicit TestLstmLayerVisitor(const LstmDescriptor& descriptor,
const LstmInputParams& params,
const char* name = nullptr)
- : TestLayerVisitor(name)
+ : LstmVisitor(params, name)
, m_Descriptor(descriptor)
- , m_InputParams(params)
{}
- void VisitLstmLayer(const IConnectableLayer* layer,
- const LstmDescriptor& descriptor,
- const LstmInputParams& params,
- const char* name = nullptr)
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(descriptor);
- CheckInputParameters(params);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Lstm:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::LstmDescriptor&>(descriptor));
+ CheckInputParameters<const LstmLayer>(PolymorphicDowncast<const LstmLayer*>(layer), m_InputParams);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
void CheckDescriptor(const LstmDescriptor& descriptor);
- void CheckInputParameters(const LstmInputParams& inputParams);
- void CheckConstTensorPtrs(const std::string& name, const ConstTensor* expected, const ConstTensor* actual);
private:
LstmDescriptor m_Descriptor;
- LstmInputParams m_InputParams;
};
-class TestQLstmLayerVisitor : public TestLayerVisitor
+class TestQLstmLayerVisitor : public LstmVisitor
{
public:
explicit TestQLstmLayerVisitor(const QLstmDescriptor& descriptor,
const LstmInputParams& params,
const char* name = nullptr)
- : TestLayerVisitor(name)
+ : LstmVisitor(params, name)
, m_Descriptor(descriptor)
- , m_InputParams(params)
{}
- void VisitQLstmLayer(const IConnectableLayer* layer,
- const QLstmDescriptor& descriptor,
- const LstmInputParams& params,
- const char* name = nullptr)
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckDescriptor(descriptor);
- CheckInputParameters(params);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::QLstm:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckDescriptor(static_cast<const armnn::QLstmDescriptor&>(descriptor));
+ CheckInputParameters<const QLstmLayer>(PolymorphicDowncast<const QLstmLayer*>(layer), m_InputParams);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
void CheckDescriptor(const QLstmDescriptor& descriptor);
- void CheckInputParameters(const LstmInputParams& inputParams);
- void CheckConstTensorPtrs(const std::string& name, const ConstTensor* expected, const ConstTensor* actual);
private:
QLstmDescriptor m_Descriptor;
- LstmInputParams m_InputParams;
};
@@ -253,18 +434,31 @@ public:
, m_InputParams(params)
{}
- void VisitQuantizedLstmLayer(const IConnectableLayer* layer,
- const QuantizedLstmInputParams& params,
- const char* name = nullptr)
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerName(name);
- CheckInputParameters(params);
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::QuantizedLstm:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckInputParameters(m_InputParams);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
protected:
- void CheckInputParameters(const QuantizedLstmInputParams& inputParams);
- void CheckConstTensorPtrs(const std::string& name, const ConstTensor* expected, const ConstTensor* actual);
+ void CheckInputParameters(const QuantizedLstmInputParams& params);
private:
QuantizedLstmInputParams m_InputParams;
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp
index 9acb60df4a..25dab596fd 100644
--- a/src/armnn/test/NetworkTests.cpp
+++ b/src/armnn/test/NetworkTests.cpp
@@ -398,26 +398,44 @@ TEST_CASE("NetworkModification_SplitterMultiplication")
TEST_CASE("Network_AddQuantize")
{
- struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ struct Test : public armnn::IStrategy
{
- void VisitQuantizeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- m_Visited = true;
-
- CHECK(layer);
-
- std::string expectedName = std::string("quantize");
- CHECK(std::string(layer->GetName()) == expectedName);
- CHECK(std::string(name) == expectedName);
-
- CHECK(layer->GetNumInputSlots() == 1);
- CHECK(layer->GetNumOutputSlots() == 1);
-
- const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
- CHECK((infoIn.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
- CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8));
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input: break;
+ case armnn::LayerType::Output: break;
+ case armnn::LayerType::Quantize:
+ {
+ m_Visited = true;
+
+ CHECK(layer);
+
+ std::string expectedName = std::string("quantize");
+ CHECK(std::string(layer->GetName()) == expectedName);
+ CHECK(std::string(name) == expectedName);
+
+ CHECK(layer->GetNumInputSlots() == 1);
+ CHECK(layer->GetNumOutputSlots() == 1);
+
+ const armnn::TensorInfo& infoIn = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ CHECK((infoIn.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+ CHECK((infoOut.GetDataType() == armnn::DataType::QAsymmU8));
+ break;
+ }
+ default:
+ {
+ // nothing
+ }
+ }
}
bool m_Visited = false;
@@ -440,7 +458,7 @@ TEST_CASE("Network_AddQuantize")
quantize->GetOutputSlot(0).SetTensorInfo(infoOut);
Test testQuantize;
- graph->Accept(testQuantize);
+ graph->ExecuteStrategy(testQuantize);
CHECK(testQuantize.m_Visited == true);
@@ -448,29 +466,47 @@ TEST_CASE("Network_AddQuantize")
TEST_CASE("Network_AddMerge")
{
- struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ struct Test : public armnn::IStrategy
{
- void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- m_Visited = true;
-
- CHECK(layer);
-
- std::string expectedName = std::string("merge");
- CHECK(std::string(layer->GetName()) == expectedName);
- CHECK(std::string(name) == expectedName);
-
- CHECK(layer->GetNumInputSlots() == 2);
- CHECK(layer->GetNumOutputSlots() == 1);
-
- const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
- CHECK((infoIn0.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
- CHECK((infoIn1.GetDataType() == armnn::DataType::Float32));
-
- const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
- CHECK((infoOut.GetDataType() == armnn::DataType::Float32));
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input: break;
+ case armnn::LayerType::Output: break;
+ case armnn::LayerType::Merge:
+ {
+ m_Visited = true;
+
+ CHECK(layer);
+
+ std::string expectedName = std::string("merge");
+ CHECK(std::string(layer->GetName()) == expectedName);
+ CHECK(std::string(name) == expectedName);
+
+ CHECK(layer->GetNumInputSlots() == 2);
+ CHECK(layer->GetNumOutputSlots() == 1);
+
+ const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ CHECK((infoIn0.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+ CHECK((infoIn1.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+ CHECK((infoOut.GetDataType() == armnn::DataType::Float32));
+ break;
+ }
+ default:
+ {
+ // nothing
+ }
+ }
}
bool m_Visited = false;
@@ -493,7 +529,7 @@ TEST_CASE("Network_AddMerge")
merge->GetOutputSlot(0).SetTensorInfo(info);
Test testMerge;
- network->Accept(testMerge);
+ network->ExecuteStrategy(testMerge);
CHECK(testMerge.m_Visited == true);
}
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index 66da3ad1ff..8416a8dd0d 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -13,13 +13,12 @@
#include <armnn/BackendHelper.hpp>
#include <armnn/BackendRegistry.hpp>
#include <armnn/INetwork.hpp>
-#include <armnn/LayerVisitorBase.hpp>
+#include <armnn/StrategyBase.hpp>
#include <armnn/utility/Assert.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
-#include <armnnUtils/FloatingPointConverter.hpp>
+#include <armnn/backends/IBackendInternal.hpp>
-#include <backendsCommon/IBackendInternal.hpp>
#include <backendsCommon/LayerSupportBase.hpp>
#include <backendsCommon/TensorHandle.hpp>
@@ -201,10 +200,6 @@ public:
return nullptr;
}
- IBackendInternal::Optimizations GetOptimizations() const override
- {
- return {};
- }
IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
{
return std::make_shared<MockLayerSupport>();
@@ -265,10 +260,6 @@ public:
return nullptr;
}
- IBackendInternal::Optimizations GetOptimizations() const override
- {
- return {};
- }
IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
{
return std::make_shared<MockLayerSupport>();
@@ -707,30 +698,42 @@ TEST_CASE("BackendCapabilityTest")
TEST_CASE("BackendHintTest")
{
- class TestBackendAssignment : public LayerVisitorBase<VisitorNoThrowPolicy>
+ class TestBackendAssignment : public StrategyBase<NoThrowStrategy>
{
public:
- void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override
- {
- IgnoreUnused(id, name);
- auto inputLayer = PolymorphicDowncast<const InputLayer*>(layer);
- CHECK((inputLayer->GetBackendId() == "MockBackend"));
- }
-
- void VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override
- {
- IgnoreUnused(id, name);
- auto outputLayer = PolymorphicDowncast<const OutputLayer*>(layer);
- CHECK((outputLayer->GetBackendId() == "MockBackend"));
- }
- void VisitActivationLayer(const IConnectableLayer* layer,
- const ActivationDescriptor& activationDescriptor,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- IgnoreUnused(activationDescriptor, name);
- auto activation = PolymorphicDowncast<const ActivationLayer*>(layer);
- CHECK((activation->GetBackendId() == "CustomBackend"));
+ armnn::IgnoreUnused(descriptor, constants, id, name);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input:
+ {
+ auto inputLayer = PolymorphicDowncast<const InputLayer*>(layer);
+ CHECK((inputLayer->GetBackendId() == "MockBackend"));
+ break;
+ }
+ case armnn::LayerType::Output:
+ {
+ auto outputLayer = PolymorphicDowncast<const OutputLayer*>(layer);
+ CHECK((outputLayer->GetBackendId() == "MockBackend"));
+ break;
+ }
+ case armnn::LayerType::Activation:
+ {
+ auto activation = PolymorphicDowncast<const ActivationLayer*>(layer);
+ CHECK((activation->GetBackendId() == "CustomBackend"));
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
}
};
@@ -802,7 +805,7 @@ TEST_CASE("BackendHintTest")
TestBackendAssignment visitor;
for (auto it = firstLayer; it != lastLayer; ++it)
{
- (*it)->Accept(visitor);
+ (*it)->ExecuteStrategy(visitor);
}
// Clean up the registry for the next test.
backendRegistry.Deregister("MockBackend");
diff --git a/src/armnn/test/TestInputOutputLayerVisitor.cpp b/src/armnn/test/TestInputOutputLayerVisitor.cpp
index 8462290f81..3b18e07694 100644
--- a/src/armnn/test/TestInputOutputLayerVisitor.cpp
+++ b/src/armnn/test/TestInputOutputLayerVisitor.cpp
@@ -19,7 +19,7 @@ TEST_CASE("CheckInputLayerVisitorBindingIdAndName")
NetworkImpl net;
IConnectableLayer *const layer = net.AddInputLayer(1, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckInputLayerVisitorBindingIdAndNameNull")
@@ -28,7 +28,7 @@ TEST_CASE("CheckInputLayerVisitorBindingIdAndNameNull")
NetworkImpl net;
IConnectableLayer *const layer = net.AddInputLayer(1);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckOutputLayerVisitorBindingIdAndName")
@@ -38,7 +38,7 @@ TEST_CASE("CheckOutputLayerVisitorBindingIdAndName")
NetworkImpl net;
IConnectableLayer *const layer = net.AddOutputLayer(1, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckOutputLayerVisitorBindingIdAndNameNull")
@@ -47,7 +47,7 @@ TEST_CASE("CheckOutputLayerVisitorBindingIdAndNameNull")
NetworkImpl net;
IConnectableLayer *const layer = net.AddOutputLayer(1);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
}
diff --git a/src/armnn/test/TestInputOutputLayerVisitor.hpp b/src/armnn/test/TestInputOutputLayerVisitor.hpp
index b89089530e..e812f2f97d 100644
--- a/src/armnn/test/TestInputOutputLayerVisitor.hpp
+++ b/src/armnn/test/TestInputOutputLayerVisitor.hpp
@@ -27,14 +27,28 @@ public:
, visitorId(id)
{};
- void VisitInputLayer(const IConnectableLayer* layer,
- LayerBindingId id,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerBindingId(visitorId, id);
- CheckLayerName(name);
- };
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerBindingId(visitorId, id);
+ CheckLayerName(name);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
+ }
};
class TestOutputLayerVisitor : public TestLayerVisitor
@@ -48,14 +62,28 @@ public:
, visitorId(id)
{};
- void VisitOutputLayer(const IConnectableLayer* layer,
- LayerBindingId id,
- const char* name = nullptr) override
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
{
- CheckLayerPointer(layer);
- CheckLayerBindingId(visitorId, id);
- CheckLayerName(name);
- };
+ armnn::IgnoreUnused(descriptor, constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Output:
+ {
+ CheckLayerPointer(layer);
+ CheckLayerBindingId(visitorId, id);
+ CheckLayerName(name);
+ break;
+ }
+ default:
+ {
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
+ }
+ }
+ }
};
} //namespace armnn
diff --git a/src/armnn/test/TestLayerVisitor.cpp b/src/armnn/test/TestLayerVisitor.cpp
index ec405119d1..d5f705f0da 100644
--- a/src/armnn/test/TestLayerVisitor.cpp
+++ b/src/armnn/test/TestLayerVisitor.cpp
@@ -49,6 +49,62 @@ void TestLayerVisitor::CheckConstTensors(const ConstTensor& expected, const Cons
}
}
+void TestLayerVisitor::CheckConstTensors(const ConstTensor& expected, const ConstTensorHandle& actual)
+{
+ auto& actualInfo = actual.GetTensorInfo();
+ CHECK(expected.GetInfo() == actualInfo);
+ CHECK(expected.GetNumDimensions() == actualInfo.GetNumDimensions());
+ CHECK(expected.GetNumElements() == actualInfo.GetNumElements());
+ CHECK(expected.GetNumBytes() == actualInfo.GetNumBytes());
+ if (expected.GetNumBytes() == actualInfo.GetNumBytes())
+ {
+ //check data is the same byte by byte
+ const unsigned char* expectedPtr = static_cast<const unsigned char*>(expected.GetMemoryArea());
+ const unsigned char* actualPtr = static_cast<const unsigned char*>(actual.Map(true));
+ for (unsigned int i = 0; i < expected.GetNumBytes(); i++)
+ {
+ CHECK(*(expectedPtr + i) == *(actualPtr + i));
+ }
+ actual.Unmap();
+ }
+}
+
+void TestLayerVisitor::CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const std::shared_ptr<ConstTensorHandle> actual)
+{
+ if (expected == nullptr)
+ {
+ CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
+ }
+ else
+ {
+ CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
+ if (actual != nullptr)
+ {
+ CheckConstTensors(*expected, *actual);
+ }
+ }
+}
+
+void TestLayerVisitor::CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const ConstTensor* actual)
+{
+ if (expected == nullptr)
+ {
+ CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
+ }
+ else
+ {
+ CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
+ if (actual != nullptr)
+ {
+ CheckConstTensors(*expected, *actual);
+ }
+ }
+}
+
void TestLayerVisitor::CheckOptionalConstTensors(const Optional<ConstTensor>& expected,
const Optional<ConstTensor>& actual)
{
diff --git a/src/armnn/test/TestLayerVisitor.hpp b/src/armnn/test/TestLayerVisitor.hpp
index e43227f520..eaf1667800 100644
--- a/src/armnn/test/TestLayerVisitor.hpp
+++ b/src/armnn/test/TestLayerVisitor.hpp
@@ -4,13 +4,14 @@
//
#pragma once
-#include <armnn/LayerVisitorBase.hpp>
+#include <armnn/StrategyBase.hpp>
#include <armnn/Descriptors.hpp>
+#include <backendsCommon/TensorHandle.hpp>
namespace armnn
{
-// Abstract base class with do nothing implementations for all layer visit methods
-class TestLayerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
+// Abstract base class with do nothing implementations for all layers
+class TestLayerVisitor : public StrategyBase<NoThrowStrategy>
{
protected:
virtual ~TestLayerVisitor() {}
@@ -19,7 +20,17 @@ protected:
void CheckLayerPointer(const IConnectableLayer* layer);
- void CheckConstTensors(const ConstTensor& expected, const ConstTensor& actual);
+ void CheckConstTensors(const ConstTensor& expected,
+ const ConstTensor& actual);
+ void CheckConstTensors(const ConstTensor& expected,
+ const ConstTensorHandle& actual);
+
+ void CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const ConstTensor* actual);
+ void CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const std::shared_ptr<ConstTensorHandle> actual);
void CheckOptionalConstTensors(const Optional<ConstTensor>& expected, const Optional<ConstTensor>& actual);
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
index 39c00f4604..cfdaaf529b 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
@@ -20,7 +20,7 @@ TEST_CASE(#testName) \
Test##name##LayerVisitor visitor(descriptor, layerName); \
armnn::NetworkImpl net; \
armnn::IConnectableLayer *const layer = net.Add##name##Layer(descriptor, layerName); \
- layer->Accept(visitor); \
+ layer->ExecuteStrategy(visitor); \
}
#define TEST_CASE_CHECK_LAYER_VISITOR_NAME_NULLPTR_AND_DESCRIPTOR(name, testName) \
@@ -30,7 +30,7 @@ TEST_CASE(#testName) \
Test##name##LayerVisitor visitor(descriptor); \
armnn::NetworkImpl net; \
armnn::IConnectableLayer *const layer = net.Add##name##Layer(descriptor); \
- layer->Accept(visitor); \
+ layer->ExecuteStrategy(visitor); \
}
template<typename Descriptor> Descriptor GetDescriptor();
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index a3c1420388..b1f9512655 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -29,15 +29,31 @@ public: \
: armnn::TestLayerVisitor(layerName) \
, m_Descriptor(descriptor) {}; \
\
- void Visit##name##Layer(const armnn::IConnectableLayer* layer, \
- const Descriptor& descriptor, \
- const char* layerName = nullptr) override \
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer, \
+ const armnn::BaseDescriptor& descriptor, \
+ const std::vector<armnn::ConstTensor>& constants, \
+ const char* layerName, \
+ const armnn::LayerBindingId id = 0) override \
{ \
- CheckLayerPointer(layer); \
- CheckDescriptor(descriptor); \
- CheckLayerName(layerName); \
+ armnn::IgnoreUnused(descriptor, constants, id); \
+ switch (layer->GetType()) \
+ { \
+ case armnn::LayerType::Input: break; \
+ case armnn::LayerType::Output: break; \
+ case armnn::LayerType::name: break; \
+ { \
+ CheckLayerPointer(layer); \
+ CheckDescriptor(static_cast<const Descriptor&>(descriptor)); \
+ CheckLayerName(layerName); \
+ break; \
+ } \
+ default: \
+ { \
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); \
+ } \
+ } \
} \
-};
+}; \
} // anonymous namespace
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
index 00d65f8e76..497c36b079 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
@@ -18,7 +18,7 @@ TEST_CASE(#testName) \
Test##name##LayerVisitor visitor("name##Layer"); \
armnn::NetworkImpl net; \
armnn::IConnectableLayer *const layer = net.Add##name##Layer("name##Layer"); \
- layer->Accept(visitor); \
+ layer->ExecuteStrategy(visitor); \
}
#define TEST_CASE_CHECK_LAYER_VISITOR_NAME_NULLPTR(name, testName) \
@@ -27,7 +27,7 @@ TEST_CASE(#testName) \
Test##name##LayerVisitor visitor; \
armnn::NetworkImpl net; \
armnn::IConnectableLayer *const layer = net.Add##name##Layer(); \
- layer->Accept(visitor); \
+ layer->ExecuteStrategy(visitor); \
}
} // anonymous namespace
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index 519cbbacc6..c0db857b71 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -15,12 +15,28 @@ class Test##name##LayerVisitor : public armnn::TestLayerVisitor \
public: \
explicit Test##name##LayerVisitor(const char* layerName = nullptr) : armnn::TestLayerVisitor(layerName) {}; \
\
- void Visit##name##Layer(const armnn::IConnectableLayer* layer, \
- const char* layerName = nullptr) override \
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer, \
+ const armnn::BaseDescriptor& descriptor, \
+ const std::vector<armnn::ConstTensor>& constants, \
+ const char* layerName, \
+ const armnn::LayerBindingId id = 0) override \
{ \
- CheckLayerPointer(layer); \
- CheckLayerName(layerName); \
+ armnn::IgnoreUnused(descriptor, constants, id); \
+ switch (layer->GetType()) \
+ { \
+ case armnn::LayerType::name: \
+ { \
+ CheckLayerPointer(layer); \
+ CheckLayerName(layerName); \
+ break; \
+ } \
+ default: \
+ { \
+ m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); \
+ } \
+ } \
} \
+ \
};
} // anonymous namespace