412 TEST_CASE(
"SubgraphForEmptyGraph")
417 CHECK(subgraph.GetInputSlots().empty());
418 CHECK(subgraph.GetOutputSlots().empty());
419 CHECK(subgraph.GetLayers().empty());
422 TEST_CASE(
"SubgraphForEntireGraph")
437 CHECK(subgraph.GetInputSlots().empty());
438 CHECK(subgraph.GetOutputSlots().empty());
439 CHECK(subgraph.GetLayers().size() == graph.
GetNumLayers());
442 TEST_CASE(
"NoSubgraphsForNoMatch")
450 SubgraphViewSelector::SelectSubgraphs(graph, [](
const Layer &) {
return false; });
452 CHECK(subgraphs.empty());
455 TEST_CASE(
"OneSubgraphsSelectedASingleMatch")
463 SubgraphViewSelector::SelectSubgraphs(
468 bool isOutput = l.
GetNameStr().compare(
"output") == 0;
472 CHECK(subgraphs.size() == 1);
473 if (subgraphs.size() == 1)
480 CompareSubgraphViews(subgraphs[0], expected);
484 TEST_CASE(
"MultipleLayersSelectedInTheMiddle")
498 SubgraphViewSelector::SelectSubgraphs(
507 CHECK(subgraphs.size() == 1);
508 if (subgraphs.size() == 1)
514 CompareSubgraphViews(subgraphs[0], expected);
518 TEST_CASE(
"DisjointGraphs")
534 SubgraphViewSelector::SelectSubgraphs(graph,
543 CHECK(subgraphs.size() == 2);
544 if (subgraphs.size() == 2)
546 CHECK((subgraphs[0] !=
nullptr));
547 CHECK((subgraphs[1] !=
nullptr));
548 if (subgraphs[0].
get() !=
nullptr && subgraphs[1].
get() !=
nullptr)
550 if (std::find(subgraphs[0]->GetLayers().begin(), subgraphs[0]->GetLayers().end(), i0) !=
551 subgraphs[0]->GetLayers().end())
553 CompareSubgraphViews(subgraphs[0], expected1);
554 CompareSubgraphViews(subgraphs[1], expected2);
558 CompareSubgraphViews(subgraphs[0], expected2);
559 CompareSubgraphViews(subgraphs[1], expected1);
565 TEST_CASE(
"IslandInTheMiddle")
611 SubgraphViewSelector::SelectSubgraphs(
616 bool toSelect = std::string(l.
GetName())[0] ==
'm';
623 { m0, m1, m2, m3, m4 });
625 auto smallerSubgraph =
627 std::vector<OutputSlot*>{},
630 CHECK(subgraphs.size() == 2);
631 if (subgraphs.size() == 2)
634 CHECK((subgraphs[0] !=
nullptr));
635 CHECK((subgraphs[1] !=
nullptr));
637 if (subgraphs[0].
get() !=
nullptr && subgraphs[1].
get() !=
nullptr)
640 std::sort(subgraphs.begin(), subgraphs.end(),
643 return (lhs->GetLayers().size() < rhs->GetLayers().size());
647 CHECK(subgraphs[0]->GetLayers().size() == 2);
648 CHECK(subgraphs[1]->GetLayers().size() == 5);
650 CompareSubgraphViews(subgraphs[0], smallerSubgraph);
651 CompareSubgraphViews(subgraphs[1], largerSubgraph);
656 TEST_CASE(
"MultipleSimpleSubgraphs")
687 SubgraphViewSelector::SelectSubgraphs(
705 CHECK(subgraphs.size() == 2);
706 if (subgraphs.size() == 2)
709 CHECK((subgraphs[0] !=
nullptr));
710 CHECK((subgraphs[1] !=
nullptr));
712 if (subgraphs[0].
get() !=
nullptr && subgraphs[1].
get() !=
nullptr)
715 std::sort(subgraphs.begin(), subgraphs.end(),
718 return (lhs->GetLayers().size() < rhs->GetLayers().size());
722 CHECK(subgraphs[0]->GetLayers().size() == 1);
723 CHECK(subgraphs[1]->GetLayers().size() == 2);
725 CompareSubgraphViews(subgraphs[0], smallerSubgraph);
726 CompareSubgraphViews(subgraphs[1], largerSubgraph);
731 TEST_CASE(
"SimpleLinearTest")
757 SubgraphViewSelector::SelectSubgraphs(
766 CHECK(subgraphs.size() == 1);
767 if(subgraphs.size() == 1)
773 CompareSubgraphViews(subgraphs[0], expected);
777 TEST_CASE(
"MultiInputSingleOutput")
811 SubgraphViewSelector::SelectSubgraphs(
817 || l.
GetType() == LayerType::Addition);
821 CHECK(subgraphs.size() == 1);
822 if (subgraphs.size() == 1)
826 {layerM1, layerM2, layerM3});
828 CompareSubgraphViews(subgraphs[0], expected);
832 TEST_CASE(
"SingleInputMultiOutput")
861 layerM1->GetOutputSlot(0).Connect(layerM2->GetInputSlot(0));
862 layerM1->GetOutputSlot(1).Connect(layerM3->GetInputSlot(0));
863 layerM2->GetOutputSlot(0).Connect(layerX2->GetInputSlot(0));
864 layerM3->GetOutputSlot(0).Connect(layerX3->GetInputSlot(0));
867 SubgraphViewSelector::SelectSubgraphs(
877 CHECK(subgraphs.size() == 1);
878 if(subgraphs.size() == 1)
882 {layerM1, layerM2, layerM3});
884 CompareSubgraphViews(subgraphs[0], expected);
888 TEST_CASE(
"MultiInputMultiOutput")
930 SubgraphViewSelector::SelectSubgraphs(
936 || l.
GetType() == LayerType::Concat);
941 CHECK(subgraphs.size() == 1);
942 if (subgraphs.size() == 1)
946 {m1, m2, m3, m4, m5});
948 CompareSubgraphViews(subgraphs[0], expected);
952 TEST_CASE(
"ValidMerge")
987 return std::string(l.
GetName())[0] ==
'm';
991 auto expectedSubgraph0 =
1002 CHECK(subgraphs.size() == 2);
1003 if (subgraphs.size() == 2)
1006 CHECK((subgraphs[0] !=
nullptr));
1007 CHECK((subgraphs[1] !=
nullptr));
1009 if (subgraphs[0].
get() !=
nullptr && subgraphs[1].
get() !=
nullptr)
1011 if (subgraphs[0]->GetInputSlots().size() == 1)
1013 CompareSubgraphViews(subgraphs[0], expectedSubgraph0);
1014 CompareSubgraphViews(subgraphs[1], expectedSubgraph1);
1018 CompareSubgraphViews(subgraphs[0], expectedSubgraph1);
1019 CompareSubgraphViews(subgraphs[1], expectedSubgraph0);
1025 TEST_CASE(
"PropagatedDependencies")
1076 SubgraphViewSelector::SelectSubgraphs(
1081 bool toSelect = std::string(l.
GetName())[0] ==
'm';
1088 { m0, m1, m2, m3, m4 });
1091 std::vector<OutputSlot*>{}, { m5, m6 });
1093 auto smallerSubgraph =
1096 CHECK(subgraphs.size() == 3);
1097 if (subgraphs.size() == 3)
1100 CHECK((subgraphs[0] !=
nullptr));
1101 CHECK((subgraphs[1] !=
nullptr));
1102 CHECK((subgraphs[2] !=
nullptr));
1104 if (subgraphs[0].
get() !=
nullptr && subgraphs[1].
get() !=
nullptr && subgraphs[2].
get() !=
nullptr)
1107 std::sort(subgraphs.begin(), subgraphs.end(),
1110 return (lhs->GetLayers().size() < rhs->GetLayers().size());
1114 CompareSubgraphViews(subgraphs[0], smallerSubgraph);
1115 CompareSubgraphViews(subgraphs[1], mediumSubgraph);
1116 CompareSubgraphViews(subgraphs[2], largerSubgraph);
1126 constexpr
bool debug =
false;
1128 std::mt19937 randomGenerator;
1131 auto GetRandom = [&randomGenerator](
auto maxExclusive) {
1135 std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
1136 return static_cast<decltype(maxExclusive)
>(uniform(randomGenerator) *
static_cast<float>(maxExclusive));
1139 auto GetRandomFlag = [&randomGenerator](
float trueProb) {
1140 std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
1141 return uniform(randomGenerator) < trueProb;
1144 constexpr uint32_t numTests = 100;
1145 for (uint32_t testIdx = 0; testIdx < numTests; ++testIdx)
1147 randomGenerator.seed(testIdx);
1155 uint32_t numInputs = 1 + GetRandom(4u);
1156 uint32_t numConstants = 1 + GetRandom(4u);
1157 uint32_t numOutputs = 1 + GetRandom(4u);
1158 uint32_t numConcats = 0 + GetRandom(500u);
1159 uint32_t numSplits = 0 + GetRandom(500u);
1160 float supportedProb = 0.7f;
1162 for (uint32_t i = 0; i < numInputs; ++i)
1164 std::string name =
"input" + std::to_string(i) + (GetRandomFlag(supportedProb) ?
"S" :
"N");
1167 for (uint32_t i = 0; i < numConstants; ++i)
1169 std::string name =
"constant" + std::to_string(i) + (GetRandomFlag(supportedProb) ?
"S" :
"N");
1172 for (uint32_t i = 0; i < numOutputs; ++i)
1174 std::string name =
"output" + std::to_string(i) + (GetRandomFlag(supportedProb) ?
"S" :
"N");
1177 for (uint32_t i = 0; i < numConcats; ++i)
1179 std::string name =
"concat" + std::to_string(i) + (GetRandomFlag(supportedProb) ?
"S" :
"N");
1180 uint32_t numInputs = 1 + GetRandom(3u);
1184 for (uint32_t i = 0; i < numSplits; ++i)
1186 std::string name =
"split" + std::to_string(i) + (GetRandomFlag(supportedProb) ?
"S" :
"N");
1187 uint32_t numOutputs = 1 + GetRandom(3u);
1197 uint32_t maxLayerDepth = 5 + GetRandom(2000u);
1198 std::map<Layer*, uint32_t> layerDepths;
1199 std::map<uint32_t, std::vector<Layer*>> layersAtDepth;
1200 for (
Layer* layer : graph)
1203 if (layer->GetType() == LayerType::Input || layer->GetType() == LayerType::Constant)
1212 depth = 1 + GetRandom(maxLayerDepth);
1214 layerDepths[layer] = depth;
1215 layersAtDepth[depth].push_back(layer);
1220 for (
Layer* layer : graph)
1222 for (uint32_t inputSlotIdx = 0; inputSlotIdx < layer->GetNumInputSlots(); ++inputSlotIdx)
1224 InputSlot& inputSlot = layer->GetInputSlot(inputSlotIdx);
1225 uint32_t maxLayerDepthToConnectTo = layerDepths[layer];
1229 uint32_t layerDepth = GetRandom(maxLayerDepthToConnectTo);
1230 const std::vector<Layer*>& layersToChooseFrom = layersAtDepth[layerDepth];
1231 if (layersToChooseFrom.size() == 0)
1235 Layer* layerToConnectWith = layersToChooseFrom[GetRandom(layersToChooseFrom.size())];
1249 std::ofstream f(
"INPUT_" + std::to_string(testIdx) +
".dot");
1250 graph.SerializeToDot(f);
1254 auto startTime = std::chrono::high_resolution_clock::now();
1257 SubgraphViewSelector::SelectSubgraphs(graph,
1258 [](
const Layer& l) {
return std::string(l.
GetName()).back() ==
'S'; });
1260 auto endTime = std::chrono::high_resolution_clock::now();
1261 auto duration = std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime);
1264 std::cout <<
"Test " << testIdx <<
": " << duration.count() <<
" microseconds" << std::endl;
1269 std::map<Layer*, SubgraphView*> layerToSubgraph;
1270 for (
Layer* layer : graph)
1273 for (std::unique_ptr<SubgraphView>& subgraph : subgraphs)
1275 std::string name = std::to_string(i++);
1276 if (std::find(subgraph->begin(), subgraph->end(), layer) != subgraph->end())
1278 layerToSubgraph[layer] = subgraph.get();
1288 for (
Layer* layer : graph)
1290 std::string name =
"NotAssigned";
1291 auto subgraphIt = layerToSubgraph.find(layer);
1292 if (subgraphIt != layerToSubgraph.end())
1294 auto subgraphIdx = std::distance(subgraphs.begin(),
1295 std::find_if(subgraphs.begin(), subgraphs.end(),
1296 [&](
auto& s) {
return s.get() == subgraphIt->second; }));
1297 name = std::to_string(subgraphIdx);
1302 std::ofstream f(
"GRAPH_" + std::to_string(testIdx) +
".dot");
1303 graph.SerializeToDot(f);
1309 for (std::unique_ptr<SubgraphView>& subgraph : subgraphs)
1311 for (
InputSlot* inputSlot : subgraph->GetInputSlots())
1313 std::queue<Layer*> toProcess;
1315 while (toProcess.size() > 0)
1317 Layer* l = toProcess.front();
1320 CHECK(layerToSubgraph[l] != subgraph.get());
1324 toProcess.push(&is.GetConnectedOutputSlot()->GetOwningLayer());
A layer that the constant data can be bound to.
This layer represents a split operation.
A ViewsDescriptor for the SplitterLayer.
void Splitter(const SplitterQueueDescriptor &data, std::vector< ITensorHandle *> inputs, std::vector< ITensorHandle *> outputs)
LayerT * AddLayer(Args &&... args)
Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
A Convolution2dDescriptor for the Convolution2dLayer.
Layer & GetOwningLayer() const
int Connect(InputSlot &destination)
This layer represents an activation operation with the specified activation function.
const std::vector< InputSlot > & GetInputSlots() const
unsigned int GetNumOutputSlots() const override
Returns the number of connectable output slots.
int LayerBindingId
Type of identifiers for bindable layers (inputs, outputs).
The SubgraphView class represents a subgraph of a Graph.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
A layer user-provided data can be bound to (e.g. inputs, outputs).
SubgraphView::InputSlots CreateInputsFrom(const std::vector< Layer *> &layers)
An OriginsDescriptor for the ConcatLayer.
This layer represents a merge operation.
const std::string & GetNameStr() const
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
float Activation(float in, ActivationFunction function, float a, float b)
std::vector< SubgraphViewPtr > Subgraphs
An ActivationDescriptor for the ActivationLayer.
This layer represents an addition operation.
std::unique_ptr< SubgraphView > SubgraphViewPtr
SubgraphView::SubgraphViewPtr CreateSubgraphViewFrom(SubgraphView::InputSlots &&inputs, SubgraphView::OutputSlots &&outputs, SubgraphView::Layers &&layers)
SubgraphView::OutputSlots CreateOutputsFrom(const std::vector< Layer *> &layers)
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
const char * GetName() const override
Returns the name of the layer.
This layer represents a convolution 2d operation.
size_t GetNumLayers() const
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...