diff options
Diffstat (limited to 'src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp | 75 |
1 files changed, 40 insertions, 35 deletions
diff --git a/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp b/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp index 7cb5ded773..ca3c563757 100644 --- a/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp +++ b/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp @@ -877,8 +877,13 @@ void PartiallySupportedSubgraphTestImpl() // Check the substitutions // ----------------------- - const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions(); + OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions(); BOOST_TEST(substitutions.size() == 2); + // Sort into a consistent order + std::sort(substitutions.begin(), substitutions.end(), [](auto s1, auto s2) { + return strcmp(s1.m_SubstitutableSubgraph.GetLayers().front()->GetName(), + s2.m_SubstitutableSubgraph.GetLayers().front()->GetName()) < 0; + }); std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 }, { 1, 1, 1 } }; @@ -914,8 +919,12 @@ void PartiallySupportedSubgraphTestImpl() // Check the failed subgraphs // -------------------------- - const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs(); + OptimizationViews::Subgraphs failedSubgraphs = optimizationViews.GetFailedSubgraphs(); BOOST_TEST(failedSubgraphs.size() == 2); + // Sort into a consistent order + std::sort(failedSubgraphs.begin(), failedSubgraphs.end(), [](auto s1, auto s2) { + return strcmp(s1.GetLayers().front()->GetName(), s2.GetLayers().front()->GetName()) < 0; + }); std::vector<ExpectedSubgraphSize> expectedFailedSubgraphSizes{ { 1, 1, 2 }, { 1, 1, 1 } }; @@ -1060,8 +1069,12 @@ void PartiallyOptimizableSubgraphTestImpl1() // Check the substitutions // ----------------------- - const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions(); + OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions(); BOOST_TEST(substitutions.size() == 3); + // Sort into a consistent order + std::sort(substitutions.begin(), substitutions.end(), + [](auto s1, auto s2) { return strcmp(s1.m_SubstitutableSubgraph.GetLayers().front()->GetName(), + s2.m_SubstitutableSubgraph.GetLayers().front()->GetName()) < 0; }); std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 }, { 1, 1, 1 }, @@ -1108,8 +1121,12 @@ void PartiallyOptimizableSubgraphTestImpl1() // Check the untouched subgraphs // ----------------------------- - const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs(); + OptimizationViews::Subgraphs untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs(); BOOST_TEST(untouchedSubgraphs.size() == 2); + // Sort into a consistent order + std::sort(untouchedSubgraphs.begin(), untouchedSubgraphs.end(), [](auto s1, auto s2) { + return strcmp(s1.GetLayers().front()->GetName(), s2.GetLayers().front()->GetName()) < 0; + }); std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 1 }, { 1, 1, 1 } }; @@ -1146,7 +1163,7 @@ void PartiallyOptimizableSubgraphTestImpl2() Graph graph; LayerNameToLayerMap layersInGraph; - // Create a fully optimizable subgraph + // Create a partially optimizable subgraph SubgraphViewSelector::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph2(graph, layersInGraph); BOOST_TEST((subgraphPtr != nullptr)); @@ -1185,44 +1202,32 @@ void PartiallyOptimizableSubgraphTestImpl2() // ----------------------- const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions(); - BOOST_TEST(substitutions.size() == 2); - - std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 1 }, - { 2, 1, 2 } }; - std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 }, - { 2, 1, 1 } }; + BOOST_TEST(substitutions.size() == 1); - SubgraphView::InputSlots expectedSubstitutableSubgraph2InputSlots = - ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots()); - expectedSubstitutableSubgraph2InputSlots.push_back( - ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetInputSlot(0))); + ExpectedSubgraphSize expectedSubstitutableSubgraphSizes{ 2, 1, 3 }; + ExpectedSubgraphSize expectedReplacementSubgraphSizes{ 2, 1, 1 }; - std::vector<SubgraphView::InputSlots> expectedSubstitutableInputSlots - { - ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()), - expectedSubstitutableSubgraph2InputSlots + SubgraphView::InputSlots expectedSubstitutableInputSlots = { + ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()[0]), + ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots()[0]) }; - std::vector<SubgraphView::OutputSlots> expectedSubstitutableOutputSlots + SubgraphView::OutputSlots expectedSubstitutableOutputSlots = { - ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots()), - ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots()) + ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots()[0]) }; - std::vector<SubgraphView::Layers> expectedSubstitutableLayers + SubgraphView::Layers expectedSubstitutableLayers { - { layersInGraph.at("conv1 layer") }, - { layersInGraph.at("conv3 layer"), - layersInGraph.at("add layer") } + layersInGraph.at("conv1 layer"), + layersInGraph.at("conv3 layer"), + layersInGraph.at("add layer") }; - for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++) - { - CheckSubstitution(substitutions.at(substitutionIndex), - expectedSubstitutableSubgraphSizes.at(substitutionIndex), - expectedReplacementSubgraphSizes.at(substitutionIndex), - expectedSubstitutableInputSlots.at(substitutionIndex), - expectedSubstitutableOutputSlots.at(substitutionIndex), - expectedSubstitutableLayers.at(substitutionIndex)); - } + CheckSubstitution(substitutions[0], + expectedSubstitutableSubgraphSizes, + expectedReplacementSubgraphSizes, + expectedSubstitutableInputSlots, + expectedSubstitutableOutputSlots, + expectedSubstitutableLayers); // -------------------------- // Check the failed subgraphs |