aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/SubgraphViewTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/SubgraphViewTests.cpp')
-rw-r--r--src/armnn/test/SubgraphViewTests.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp
index 4ce67b0fec..9bb5e69bbb 100644
--- a/src/armnn/test/SubgraphViewTests.cpp
+++ b/src/armnn/test/SubgraphViewTests.cpp
@@ -2063,6 +2063,35 @@ TEST_CASE("SubgraphViewWorkingCopySubstituteSubgraph")
CHECK_THROWS_AS(workingCopy.GetWorkingCopy(), Exception);
}
+TEST_CASE("SubgraphViewPartialWorkingCopySubstituteSubgraph")
+{
+ Graph graph;
+
+ auto input = graph.AddLayer<InputLayer>(0, "Input");
+ auto activation = graph.AddLayer<ActivationLayer>(ActivationDescriptor{}, "Activation");
+ auto output = graph.AddLayer<OutputLayer>(1, "Output");
+
+ input->GetOutputSlot(0).Connect(activation->GetInputSlot(0));
+ activation->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ //Add in out of order
+ auto view = CreateSubgraphViewFrom({activation},
+ {&activation->GetInputSlot(0)},
+ {&activation->GetOutputSlot(0)});
+
+ auto workingCopy = view->GetWorkingCopy();
+
+ // First (and only) layer in the subgraph is the Activation
+ CHECK(std::string((*workingCopy.beginIConnectable())->GetName()) == "Activation");
+
+ // Substitute the "Activation" layer for an equivalent layer
+ auto activation2 = graph.AddLayer<ActivationLayer>(ActivationDescriptor{}, "Activation2");
+ SubgraphView pattern(*workingCopy.beginIConnectable());
+ workingCopy.SubstituteSubgraph(pattern, activation2);
+
+ CHECK(std::string((*workingCopy.beginIConnectable())->GetName()) == "Activation2");
+}
+
TEST_CASE("SubgraphViewWorkingCopyOptimizationViews")
{
Graph graph;