diff options
Diffstat (limited to 'src/armnn/test/SubgraphViewTests.cpp')
-rw-r--r-- | src/armnn/test/SubgraphViewTests.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp index 212ae0ee01..048c4f51fd 100644 --- a/src/armnn/test/SubgraphViewTests.cpp +++ b/src/armnn/test/SubgraphViewTests.cpp @@ -1928,6 +1928,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph, if (layer->GetType() == LayerType::Multiplication) { IInputSlot* patternSubgraphInput = &layer->GetInputSlot(0); + IInputSlot* patternSubgraphConstant = &layer->GetInputSlot(1); const IConnectableLayer* inputLayer = &patternSubgraphInput->GetConnection()->GetOwningIConnectableLayer(); const IConnectableLayer* constantLayer = &layer->GetInputSlot(1).GetConnection()->GetOwningIConnectableLayer(); @@ -1935,7 +1936,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph, // Figure out which of the two inputs is the constant if (constantLayer->GetType() != LayerType::Constant) { - patternSubgraphInput = &layer->GetInputSlot(1); + std::swap(patternSubgraphInput, patternSubgraphConstant); std::swap(inputLayer, constantLayer); } @@ -1965,7 +1966,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph, ConstTensor weights(weightsInfo, weightData); const auto depthwiseLayer = replacementGraph->AddDepthwiseConvolution2dLayer( - desc, weights, armnn::EmptyOptional(), "Replacement for Constant-Multiplication"); + desc, "Replacement for Constant-Multiplication"); auto& outslot = layer->GetOutputSlot(0); SubgraphView::IOutputSlots outputs{ &outslot }; @@ -1973,7 +1974,9 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph, layers.push_back(layer); layers.push_back(const_cast<IConnectableLayer*>(constantLayer)); - SubgraphView patternSubgraph(std::move(layers), {patternSubgraphInput}, {&layer->GetOutputSlot(0)}); + SubgraphView patternSubgraph(std::move(layers), + {patternSubgraphInput, patternSubgraphConstant}, + {&layer->GetOutputSlot(0)}); subgraph.SubstituteSubgraph(patternSubgraph, depthwiseLayer ); |