aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/NetworkTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/NetworkTests.cpp')
-rw-r--r--src/armnn/test/NetworkTests.cpp36
1 files changed, 31 insertions, 5 deletions
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp
index d763a85100..9acb60df4a 100644
--- a/src/armnn/test/NetworkTests.cpp
+++ b/src/armnn/test/NetworkTests.cpp
@@ -86,12 +86,15 @@ TEST_CASE("NetworkModification")
inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
armnn::FullyConnectedDescriptor fullyConnectedDesc;
+
+ // Constant layer that now holds weights data for FullyConnected
+ armnn::IConnectableLayer* const constantWeightsLayer = net.AddConstantLayer(weights, "const weights");
armnn::IConnectableLayer* const fullyConnectedLayer = net.AddFullyConnectedLayer(fullyConnectedDesc,
- weights,
- armnn::EmptyOptional(),
"fully connected");
+ CHECK(constantWeightsLayer);
CHECK(fullyConnectedLayer);
+ constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
convLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
armnn::Pooling2dDescriptor pooling2dDesc;
@@ -152,11 +155,12 @@ TEST_CASE("NetworkModification")
multiplicationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
//Tests that all layers are present in the graph.
- CHECK(net.GetGraph().GetNumLayers() == 11);
+ CHECK(net.GetGraph().GetNumLayers() == 12);
//Tests that the vertices exist and have correct names.
CHECK(GraphHasNamedLayer(net.GetGraph(), "input layer"));
CHECK(GraphHasNamedLayer(net.GetGraph(), "conv layer"));
+ CHECK(GraphHasNamedLayer(net.GetGraph(), "const weights"));
CHECK(GraphHasNamedLayer(net.GetGraph(), "fully connected"));
CHECK(GraphHasNamedLayer(net.GetGraph(), "pooling2d"));
CHECK(GraphHasNamedLayer(net.GetGraph(), "activation"));
@@ -200,6 +204,28 @@ TEST_CASE("NetworkModification")
CHECK(&srcLayer->GetOutputSlot(0) == tgtLayer->GetInputSlot(i).GetConnection());
}
};
+ auto checkOneOutputToTwoInputConnectionForTwoDifferentLayers = []
+ (const armnn::IConnectableLayer* const srcLayer1,
+ const armnn::IConnectableLayer* const srcLayer2,
+ const armnn::IConnectableLayer* const tgtLayer,
+ int expectedSrcNumInputs1 = 1,
+ int expectedSrcNumInputs2 = 1,
+ int expectedDstNumOutputs = 1)
+ {
+ CHECK(srcLayer1->GetNumInputSlots() == expectedSrcNumInputs1);
+ CHECK(srcLayer1->GetNumOutputSlots() == 1);
+ CHECK(srcLayer2->GetNumInputSlots() == expectedSrcNumInputs2);
+ CHECK(srcLayer2->GetNumOutputSlots() == 1);
+ CHECK(tgtLayer->GetNumInputSlots() == 2);
+ CHECK(tgtLayer->GetNumOutputSlots() == expectedDstNumOutputs);
+
+ CHECK(srcLayer1->GetOutputSlot(0).GetNumConnections() == 1);
+ CHECK(srcLayer2->GetOutputSlot(0).GetNumConnections() == 1);
+ CHECK(srcLayer1->GetOutputSlot(0).GetConnection(0) == &tgtLayer->GetInputSlot(0));
+ CHECK(srcLayer2->GetOutputSlot(0).GetConnection(0) == &tgtLayer->GetInputSlot(1));
+ CHECK(&srcLayer1->GetOutputSlot(0) == tgtLayer->GetInputSlot(0).GetConnection());
+ CHECK(&srcLayer2->GetOutputSlot(0) == tgtLayer->GetInputSlot(1).GetConnection());
+ };
CHECK(AreAllLayerInputSlotsConnected(*convLayer));
CHECK(AreAllLayerInputSlotsConnected(*fullyConnectedLayer));
@@ -214,8 +240,8 @@ TEST_CASE("NetworkModification")
// Checks connectivity.
checkOneOutputToOneInputConnection(inputLayer, convLayer, 0);
- checkOneOutputToOneInputConnection(convLayer, fullyConnectedLayer);
- checkOneOutputToOneInputConnection(fullyConnectedLayer, poolingLayer);
+ checkOneOutputToTwoInputConnectionForTwoDifferentLayers(convLayer, constantWeightsLayer, fullyConnectedLayer, 1, 0);
+ checkOneOutputToOneInputConnection(fullyConnectedLayer, poolingLayer, 2, 1);
checkOneOutputToOneInputConnection(poolingLayer, activationLayer);
checkOneOutputToOneInputConnection(activationLayer, normalizationLayer);
checkOneOutputToOneInputConnection(normalizationLayer, softmaxLayer);