aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/ConstTensorLayerVisitor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/ConstTensorLayerVisitor.cpp')
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.cpp236
1 files changed, 40 insertions, 196 deletions
diff --git a/src/armnn/test/ConstTensorLayerVisitor.cpp b/src/armnn/test/ConstTensorLayerVisitor.cpp
index d3d8698972..e21e777409 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.cpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.cpp
@@ -58,73 +58,6 @@ void TestLstmLayerVisitor::CheckDescriptor(const LstmDescriptor& descriptor)
CHECK(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
}
-void TestLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
-void TestLstmLayerVisitor::CheckInputParameters(const LstmInputParams& inputParams)
-{
- CheckConstTensorPtrs("ProjectionBias", m_InputParams.m_ProjectionBias, inputParams.m_ProjectionBias);
- CheckConstTensorPtrs("ProjectionWeights", m_InputParams.m_ProjectionWeights, inputParams.m_ProjectionWeights);
- CheckConstTensorPtrs("OutputGateBias", m_InputParams.m_OutputGateBias, inputParams.m_OutputGateBias);
- CheckConstTensorPtrs("InputToInputWeights",
- m_InputParams.m_InputToInputWeights, inputParams.m_InputToInputWeights);
- CheckConstTensorPtrs("InputToForgetWeights",
- m_InputParams.m_InputToForgetWeights, inputParams.m_InputToForgetWeights);
- CheckConstTensorPtrs("InputToCellWeights", m_InputParams.m_InputToCellWeights, inputParams.m_InputToCellWeights);
- CheckConstTensorPtrs(
- "InputToOutputWeights", m_InputParams.m_InputToOutputWeights, inputParams.m_InputToOutputWeights);
- CheckConstTensorPtrs(
- "RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, inputParams.m_RecurrentToInputWeights);
- CheckConstTensorPtrs(
- "RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, inputParams.m_RecurrentToForgetWeights);
- CheckConstTensorPtrs(
- "RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, inputParams.m_RecurrentToCellWeights);
- CheckConstTensorPtrs(
- "RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, inputParams.m_RecurrentToOutputWeights);
- CheckConstTensorPtrs(
- "CellToInputWeights", m_InputParams.m_CellToInputWeights, inputParams.m_CellToInputWeights);
- CheckConstTensorPtrs(
- "CellToForgetWeights", m_InputParams.m_CellToForgetWeights, inputParams.m_CellToForgetWeights);
- CheckConstTensorPtrs(
- "CellToOutputWeights", m_InputParams.m_CellToOutputWeights, inputParams.m_CellToOutputWeights);
- CheckConstTensorPtrs("InputGateBias", m_InputParams.m_InputGateBias, inputParams.m_InputGateBias);
- CheckConstTensorPtrs("ForgetGateBias", m_InputParams.m_ForgetGateBias, inputParams.m_ForgetGateBias);
- CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
-}
-
-void TestQLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
void TestQLstmLayerVisitor::CheckDescriptor(const QLstmDescriptor& descriptor)
{
CHECK(m_Descriptor.m_CellClip == descriptor.m_CellClip);
@@ -134,95 +67,6 @@ void TestQLstmLayerVisitor::CheckDescriptor(const QLstmDescriptor& descriptor)
CHECK(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
}
-void TestQLstmLayerVisitor::CheckInputParameters(const LstmInputParams& inputParams)
-{
- CheckConstTensorPtrs("InputToInputWeights",
- m_InputParams.m_InputToInputWeights,
- inputParams.m_InputToInputWeights);
-
- CheckConstTensorPtrs("InputToForgetWeights",
- m_InputParams.m_InputToForgetWeights,
- inputParams.m_InputToForgetWeights);
-
- CheckConstTensorPtrs("InputToCellWeights",
- m_InputParams.m_InputToCellWeights,
- inputParams.m_InputToCellWeights);
-
- CheckConstTensorPtrs("InputToOutputWeights",
- m_InputParams.m_InputToOutputWeights,
- inputParams.m_InputToOutputWeights);
-
- CheckConstTensorPtrs("RecurrentToInputWeights",
- m_InputParams.m_RecurrentToInputWeights,
- inputParams.m_RecurrentToInputWeights);
-
- CheckConstTensorPtrs("RecurrentToForgetWeights",
- m_InputParams.m_RecurrentToForgetWeights,
- inputParams.m_RecurrentToForgetWeights);
-
- CheckConstTensorPtrs("RecurrentToCellWeights",
- m_InputParams.m_RecurrentToCellWeights,
- inputParams.m_RecurrentToCellWeights);
-
- CheckConstTensorPtrs("RecurrentToOutputWeights",
- m_InputParams.m_RecurrentToOutputWeights,
- inputParams.m_RecurrentToOutputWeights);
-
- CheckConstTensorPtrs("CellToInputWeights",
- m_InputParams.m_CellToInputWeights,
- inputParams.m_CellToInputWeights);
-
- CheckConstTensorPtrs("CellToForgetWeights",
- m_InputParams.m_CellToForgetWeights,
- inputParams.m_CellToForgetWeights);
-
- CheckConstTensorPtrs("CellToOutputWeights",
- m_InputParams.m_CellToOutputWeights,
- inputParams.m_CellToOutputWeights);
-
- CheckConstTensorPtrs("ProjectionWeights", m_InputParams.m_ProjectionWeights, inputParams.m_ProjectionWeights);
- CheckConstTensorPtrs("ProjectionBias", m_InputParams.m_ProjectionBias, inputParams.m_ProjectionBias);
-
- CheckConstTensorPtrs("InputGateBias", m_InputParams.m_InputGateBias, inputParams.m_InputGateBias);
- CheckConstTensorPtrs("ForgetGateBias", m_InputParams.m_ForgetGateBias, inputParams.m_ForgetGateBias);
- CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
- CheckConstTensorPtrs("OutputGateBias", m_InputParams.m_OutputGateBias, inputParams.m_OutputGateBias);
-
- CheckConstTensorPtrs("InputLayerNormWeights",
- m_InputParams.m_InputLayerNormWeights,
- inputParams.m_InputLayerNormWeights);
-
- CheckConstTensorPtrs("ForgetLayerNormWeights",
- m_InputParams.m_ForgetLayerNormWeights,
- inputParams.m_ForgetLayerNormWeights);
-
- CheckConstTensorPtrs("CellLayerNormWeights",
- m_InputParams.m_CellLayerNormWeights,
- inputParams.m_CellLayerNormWeights);
-
- CheckConstTensorPtrs("OutputLayerNormWeights",
- m_InputParams.m_OutputLayerNormWeights,
- inputParams.m_OutputLayerNormWeights);
-}
-
-void TestQuantizedLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
- const ConstTensor* expected,
- const ConstTensor* actual)
-{
- if (expected == nullptr)
- {
- CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
- }
- else
- {
- CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
- if (actual != nullptr)
- {
- CheckConstTensors(*expected, *actual);
- }
- }
-}
-
void TestQuantizedLstmLayerVisitor::CheckInputParameters(const QuantizedLstmInputParams& inputParams)
{
CheckConstTensorPtrs("InputToInputWeights",
@@ -285,7 +129,7 @@ TEST_CASE("CheckConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, EmptyOptional());
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConvolution2dLayer")
@@ -309,7 +153,7 @@ TEST_CASE("CheckNamedConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, EmptyOptional(), layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckConvolution2dLayerWithBiases")
@@ -338,7 +182,7 @@ TEST_CASE("CheckConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, optionalBiases);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConvolution2dLayerWithBiases")
@@ -368,7 +212,7 @@ TEST_CASE("CheckNamedConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConvolution2dLayer(descriptor, weights, optionalBiases, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckDepthwiseConvolution2dLayer")
@@ -391,7 +235,7 @@ TEST_CASE("CheckDepthwiseConvolution2dLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, EmptyOptional());
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedDepthwiseConvolution2dLayer")
@@ -418,7 +262,7 @@ TEST_CASE("CheckNamedDepthwiseConvolution2dLayer")
weights,
EmptyOptional(),
layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckDepthwiseConvolution2dLayerWithBiases")
@@ -447,7 +291,7 @@ TEST_CASE("CheckDepthwiseConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, optionalBiases);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedDepthwiseConvolution2dLayerWithBiases")
@@ -477,7 +321,7 @@ TEST_CASE("CheckNamedDepthwiseConvolution2dLayerWithBiases")
NetworkImpl net;
IConnectableLayer* const layer = net.AddDepthwiseConvolution2dLayer(descriptor, weights, optionalBiases, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckFullyConnectedLayer")
@@ -500,8 +344,8 @@ TEST_CASE("CheckFullyConnectedLayer")
IConnectableLayer* const layer = net.AddFullyConnectedLayer(descriptor);
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
- weightsLayer->Accept(weightsVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedFullyConnectedLayer")
@@ -525,8 +369,8 @@ TEST_CASE("CheckNamedFullyConnectedLayer")
IConnectableLayer* const layer = net.AddFullyConnectedLayer(descriptor, layerName);
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
- weightsLayer->Accept(weightsVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckFullyConnectedLayerWithBiases")
@@ -556,9 +400,9 @@ TEST_CASE("CheckFullyConnectedLayerWithBiases")
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
biasesLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2));
- weightsLayer->Accept(weightsVisitor);
- biasesLayer->Accept(biasesVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ biasesLayer->ExecuteStrategy(biasesVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedFullyConnectedLayerWithBiases")
@@ -589,9 +433,9 @@ TEST_CASE("CheckNamedFullyConnectedLayerWithBiases")
weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
biasesLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2));
- weightsLayer->Accept(weightsVisitor);
- biasesLayer->Accept(biasesVisitor);
- layer->Accept(visitor);
+ weightsLayer->ExecuteStrategy(weightsVisitor);
+ biasesLayer->ExecuteStrategy(biasesVisitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckBatchNormalizationLayer")
@@ -621,7 +465,7 @@ TEST_CASE("CheckBatchNormalizationLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddBatchNormalizationLayer(descriptor, mean, variance, beta, gamma);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedBatchNormalizationLayer")
@@ -653,7 +497,7 @@ TEST_CASE("CheckNamedBatchNormalizationLayer")
IConnectableLayer* const layer = net.AddBatchNormalizationLayer(
descriptor, mean, variance, beta, gamma, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckConstLayer")
@@ -667,7 +511,7 @@ TEST_CASE("CheckConstLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConstantLayer(input);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedConstLayer")
@@ -682,7 +526,7 @@ TEST_CASE("CheckNamedConstLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddConstantLayer(input, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerBasic")
@@ -754,7 +598,7 @@ TEST_CASE("CheckLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerBasic")
@@ -827,7 +671,7 @@ TEST_CASE("CheckNamedLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerCifgDisabled")
@@ -918,7 +762,7 @@ TEST_CASE("CheckLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerCifgDisabled")
@@ -1010,7 +854,7 @@ TEST_CASE("CheckNamedLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
// TODO add one with peephole
@@ -1097,7 +941,7 @@ TEST_CASE("CheckLstmLayerPeephole")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckLstmLayerPeepholeCifgDisabled")
@@ -1211,7 +1055,7 @@ TEST_CASE("CheckLstmLayerPeepholeCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerPeephole")
@@ -1298,7 +1142,7 @@ TEST_CASE("CheckNamedLstmLayerPeephole")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
// TODO add one with projection
@@ -1385,7 +1229,7 @@ TEST_CASE("CheckLstmLayerProjection")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedLstmLayerProjection")
@@ -1472,7 +1316,7 @@ TEST_CASE("CheckNamedLstmLayerProjection")
NetworkImpl net;
IConnectableLayer* const layer = net.AddLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerBasic")
@@ -1544,7 +1388,7 @@ TEST_CASE("CheckQLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedQLstmLayerBasic")
@@ -1617,7 +1461,7 @@ TEST_CASE("CheckNamedQLstmLayerBasic")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabled")
@@ -1712,7 +1556,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabledPeepholeEnabled")
@@ -1829,7 +1673,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabledPeepholeEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgEnabledPeepholeEnabled")
@@ -1919,7 +1763,7 @@ TEST_CASE("CheckQLstmLayerCifgEnabledPeepholeEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerProjectionEnabled")
@@ -2009,7 +1853,7 @@ TEST_CASE("CheckQLstmLayerProjectionEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckQLstmLayerCifgDisabledLayerNormEnabled")
@@ -2132,7 +1976,7 @@ TEST_CASE("CheckQLstmLayerCifgDisabledLayerNormEnabled")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQLstmLayer(descriptor, params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
@@ -2222,7 +2066,7 @@ TEST_CASE("CheckQuantizedLstmLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
TEST_CASE("CheckNamedQuantizedLstmLayer")
@@ -2312,7 +2156,7 @@ TEST_CASE("CheckNamedQuantizedLstmLayer")
NetworkImpl net;
IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params, layerName);
- layer->Accept(visitor);
+ layer->ExecuteStrategy(visitor);
}
}