aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/SubgraphViewTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/SubgraphViewTests.cpp')
-rw-r--r--src/armnn/test/SubgraphViewTests.cpp93
1 files changed, 93 insertions, 0 deletions
diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp
index feeea5d478..e1181004d9 100644
--- a/src/armnn/test/SubgraphViewTests.cpp
+++ b/src/armnn/test/SubgraphViewTests.cpp
@@ -2291,4 +2291,97 @@ TEST_CASE("SubgraphViewWorkingCopyReplaceSlots")
);
}
+TEST_CASE("SubgraphViewWorkingCopyCloneInputAndOutputSlots")
+{
+ Graph graph;
+
+ const TensorInfo inputInfo({ 1, 8, 8, 16 }, DataType::QAsymmU8, 1.0f, 0);
+ const TensorInfo constInfo({ 1, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0, true);
+ const TensorInfo outputInfo({ 1, 8, 8, 16 }, DataType::QAsymmU8, 1.0f, 0);
+
+ std::vector<uint8_t> constData(constInfo.GetNumElements(), 0);
+ std::iota(constData.begin(), constData.end(), 0);
+ ConstTensor constTensor(constInfo, constData);
+
+ // Add the original pattern
+ IConnectableLayer* input = graph.AddLayer<InputLayer>(0, "input");
+ auto constant = graph.AddLayer<ConstantLayer>("const");
+
+ constant->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor);
+ IConnectableLayer* mul = graph.AddLayer<MultiplicationLayer>("mul");
+ armnn::ViewsDescriptor splitterDesc(2,4);
+ IConnectableLayer* split = graph.AddLayer<SplitterLayer>(splitterDesc, "split");
+ IConnectableLayer* abs = graph.AddLayer<ActivationLayer>(ActivationFunction::Abs, "abs");
+ IConnectableLayer* relu = graph.AddLayer<ActivationLayer>(ActivationFunction::ReLu, "relu");
+ armnn::OriginsDescriptor concatDesc(2, 4);
+ IConnectableLayer* concat = graph.AddLayer<ConcatLayer>(concatDesc, "constant");
+ IConnectableLayer* output = graph.AddLayer<OutputLayer>(0, "output");
+
+ // Create connections between layers
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ constant->GetOutputSlot(0).SetTensorInfo(constInfo);
+ mul->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ input->GetOutputSlot(0).Connect(mul->GetInputSlot(1));
+ constant->GetOutputSlot(0).Connect(mul->GetInputSlot(0));
+ mul->GetOutputSlot(0).Connect(split->GetInputSlot(0));
+ split->GetOutputSlot(0).Connect(abs->GetInputSlot(0));
+ split->GetOutputSlot(1).Connect(relu->GetInputSlot(0));
+ abs->GetOutputSlot(0).Connect(concat->GetInputSlot(0));
+ relu->GetOutputSlot(0).Connect(concat->GetInputSlot(1));
+ concat->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ // constant input //
+ // \ / //
+ // mul //
+ // | //
+ // splitter //
+ // / \ //
+ // abs relu //
+ // \ / //
+ // concat //
+ // | //
+ // output //
+ // //
+ // SubgraphView layers: constant mul splitter abs
+
+ // Add just the InputSlot connected to the InputLayer to the SubgraphView's InputSlots
+ SubgraphView::IInputSlots inputSlots;
+ inputSlots.push_back(&mul->GetInputSlot(1));
+
+ // Add just the OutputSlot connected to the splitter and abs to the SubgraphView's InputSlots
+ SubgraphView::IOutputSlots outputSlots;
+ outputSlots.push_back(&split->GetOutputSlot(1));
+ outputSlots.push_back(&abs->GetOutputSlot(0));
+
+ //Add in out of order
+ auto view = CreateSubgraphViewFrom({constant, mul, split, abs},
+ std::move(inputSlots),
+ std::move(outputSlots));
+
+ SubgraphView workingCopy = view->GetWorkingCopy();
+
+ // Check that only 1 input slot is added.
+ CHECK(workingCopy.GetIInputSlots().size() == 1);
+ CHECK(workingCopy.GetIInputSlots()[0]->GetSlotIndex() == 1);
+
+ CHECK(workingCopy.GetIOutputSlots().size() == 2);
+ CHECK(workingCopy.GetIOutputSlots()[0]->GetOwningIConnectableLayer().GetType() == armnn::LayerType::Splitter);
+ CHECK(workingCopy.GetIOutputSlots()[1]->GetOwningIConnectableLayer().GetType() == armnn::LayerType::Activation);
+
+ // Check the WorkingCopy is as expected before replacement
+ CHECK(workingCopy.GetIConnectableLayers().size() == 4);
+ int idx=0;
+ LayerType expectedSorted[] = {LayerType::Constant,
+ LayerType::Multiplication,
+ LayerType::Splitter,
+ LayerType::Activation};
+ workingCopy.ForEachIConnectableLayer([&idx, &expectedSorted](const IConnectableLayer* l)
+ {
+ CHECK((expectedSorted[idx] == l->GetType()));
+ idx++;
+ }
+ );
+}
+
}