aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/SubgraphViewTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/SubgraphViewTests.cpp')
-rw-r--r--src/armnn/test/SubgraphViewTests.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp
index 212ae0ee01..048c4f51fd 100644
--- a/src/armnn/test/SubgraphViewTests.cpp
+++ b/src/armnn/test/SubgraphViewTests.cpp
@@ -1928,6 +1928,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph,
if (layer->GetType() == LayerType::Multiplication)
{
IInputSlot* patternSubgraphInput = &layer->GetInputSlot(0);
+ IInputSlot* patternSubgraphConstant = &layer->GetInputSlot(1);
const IConnectableLayer* inputLayer = &patternSubgraphInput->GetConnection()->GetOwningIConnectableLayer();
const IConnectableLayer* constantLayer = &layer->GetInputSlot(1).GetConnection()->GetOwningIConnectableLayer();
@@ -1935,7 +1936,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph,
// Figure out which of the two inputs is the constant
if (constantLayer->GetType() != LayerType::Constant)
{
- patternSubgraphInput = &layer->GetInputSlot(1);
+ std::swap(patternSubgraphInput, patternSubgraphConstant);
std::swap(inputLayer, constantLayer);
}
@@ -1965,7 +1966,7 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph,
ConstTensor weights(weightsInfo, weightData);
const auto depthwiseLayer = replacementGraph->AddDepthwiseConvolution2dLayer(
- desc, weights, armnn::EmptyOptional(), "Replacement for Constant-Multiplication");
+ desc, "Replacement for Constant-Multiplication");
auto& outslot = layer->GetOutputSlot(0);
SubgraphView::IOutputSlots outputs{ &outslot };
@@ -1973,7 +1974,9 @@ bool ReplaceConstantMultiplicationWithDepthwise(SubgraphView& subgraph,
layers.push_back(layer);
layers.push_back(const_cast<IConnectableLayer*>(constantLayer));
- SubgraphView patternSubgraph(std::move(layers), {patternSubgraphInput}, {&layer->GetOutputSlot(0)});
+ SubgraphView patternSubgraph(std::move(layers),
+ {patternSubgraphInput, patternSubgraphConstant},
+ {&layer->GetOutputSlot(0)});
subgraph.SubstituteSubgraph(patternSubgraph, depthwiseLayer );