aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp')
-rw-r--r--src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp75
1 files changed, 40 insertions, 35 deletions
diff --git a/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp b/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp
index 7cb5ded773..ca3c563757 100644
--- a/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp
+++ b/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp
@@ -877,8 +877,13 @@ void PartiallySupportedSubgraphTestImpl()
// Check the substitutions
// -----------------------
- const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
+ OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
BOOST_TEST(substitutions.size() == 2);
+ // Sort into a consistent order
+ std::sort(substitutions.begin(), substitutions.end(), [](auto s1, auto s2) {
+ return strcmp(s1.m_SubstitutableSubgraph.GetLayers().front()->GetName(),
+ s2.m_SubstitutableSubgraph.GetLayers().front()->GetName()) < 0;
+ });
std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 },
{ 1, 1, 1 } };
@@ -914,8 +919,12 @@ void PartiallySupportedSubgraphTestImpl()
// Check the failed subgraphs
// --------------------------
- const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs();
+ OptimizationViews::Subgraphs failedSubgraphs = optimizationViews.GetFailedSubgraphs();
BOOST_TEST(failedSubgraphs.size() == 2);
+ // Sort into a consistent order
+ std::sort(failedSubgraphs.begin(), failedSubgraphs.end(), [](auto s1, auto s2) {
+ return strcmp(s1.GetLayers().front()->GetName(), s2.GetLayers().front()->GetName()) < 0;
+ });
std::vector<ExpectedSubgraphSize> expectedFailedSubgraphSizes{ { 1, 1, 2 },
{ 1, 1, 1 } };
@@ -1060,8 +1069,12 @@ void PartiallyOptimizableSubgraphTestImpl1()
// Check the substitutions
// -----------------------
- const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
+ OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
BOOST_TEST(substitutions.size() == 3);
+ // Sort into a consistent order
+ std::sort(substitutions.begin(), substitutions.end(),
+ [](auto s1, auto s2) { return strcmp(s1.m_SubstitutableSubgraph.GetLayers().front()->GetName(),
+ s2.m_SubstitutableSubgraph.GetLayers().front()->GetName()) < 0; });
std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 },
{ 1, 1, 1 },
@@ -1108,8 +1121,12 @@ void PartiallyOptimizableSubgraphTestImpl1()
// Check the untouched subgraphs
// -----------------------------
- const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
+ OptimizationViews::Subgraphs untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
BOOST_TEST(untouchedSubgraphs.size() == 2);
+ // Sort into a consistent order
+ std::sort(untouchedSubgraphs.begin(), untouchedSubgraphs.end(), [](auto s1, auto s2) {
+ return strcmp(s1.GetLayers().front()->GetName(), s2.GetLayers().front()->GetName()) < 0;
+ });
std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 1 },
{ 1, 1, 1 } };
@@ -1146,7 +1163,7 @@ void PartiallyOptimizableSubgraphTestImpl2()
Graph graph;
LayerNameToLayerMap layersInGraph;
- // Create a fully optimizable subgraph
+ // Create a partially optimizable subgraph
SubgraphViewSelector::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph2(graph, layersInGraph);
BOOST_TEST((subgraphPtr != nullptr));
@@ -1185,44 +1202,32 @@ void PartiallyOptimizableSubgraphTestImpl2()
// -----------------------
const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
- BOOST_TEST(substitutions.size() == 2);
-
- std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 },
- { 2, 1, 2 } };
- std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 },
- { 2, 1, 1 } };
+ BOOST_TEST(substitutions.size() == 1);
- SubgraphView::InputSlots expectedSubstitutableSubgraph2InputSlots =
- ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots());
- expectedSubstitutableSubgraph2InputSlots.push_back(
- ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetInputSlot(0)));
+ ExpectedSubgraphSize expectedSubstitutableSubgraphSizes{ 2, 1, 3 };
+ ExpectedSubgraphSize expectedReplacementSubgraphSizes{ 2, 1, 1 };
- std::vector<SubgraphView::InputSlots> expectedSubstitutableInputSlots
- {
- ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()),
- expectedSubstitutableSubgraph2InputSlots
+ SubgraphView::InputSlots expectedSubstitutableInputSlots = {
+ ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()[0]),
+ ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots()[0])
};
- std::vector<SubgraphView::OutputSlots> expectedSubstitutableOutputSlots
+ SubgraphView::OutputSlots expectedSubstitutableOutputSlots =
{
- ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots()),
- ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots())
+ ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots()[0])
};
- std::vector<SubgraphView::Layers> expectedSubstitutableLayers
+ SubgraphView::Layers expectedSubstitutableLayers
{
- { layersInGraph.at("conv1 layer") },
- { layersInGraph.at("conv3 layer"),
- layersInGraph.at("add layer") }
+ layersInGraph.at("conv1 layer"),
+ layersInGraph.at("conv3 layer"),
+ layersInGraph.at("add layer")
};
- for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++)
- {
- CheckSubstitution(substitutions.at(substitutionIndex),
- expectedSubstitutableSubgraphSizes.at(substitutionIndex),
- expectedReplacementSubgraphSizes.at(substitutionIndex),
- expectedSubstitutableInputSlots.at(substitutionIndex),
- expectedSubstitutableOutputSlots.at(substitutionIndex),
- expectedSubstitutableLayers.at(substitutionIndex));
- }
+ CheckSubstitution(substitutions[0],
+ expectedSubstitutableSubgraphSizes,
+ expectedReplacementSubgraphSizes,
+ expectedSubstitutableInputSlots,
+ expectedSubstitutableOutputSlots,
+ expectedSubstitutableLayers);
// --------------------------
// Check the failed subgraphs