aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2020-07-06 19:24:15 +0100
committermike.kelly <mike.kelly@arm.com>2020-07-06 19:00:14 +0000
commitc5789ca2e432075e2c92e7e0d99139c5329280e6 (patch)
treef5a19ee78352f6e3b97566a24bd07378605b18f3
parentf78c767d229a9fc7a692fe4e7b71e36710966615 (diff)
downloadarmnn-c5789ca2e432075e2c92e7e0d99139c5329280e6.tar.gz
GitHub #418 AddBroadcastReshapeLayer can cause inputs to be connected incorrectly
* Fixed issue where AddBroadcastReshapeLayer would always connect the Reshaped input to the first input slot and the other input to the first input slot. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Ifd2745a819eb0f72ff9433690afc92a6a34f2ec3
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 1b93aadc5b..a690e5325e 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -563,6 +563,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr);
armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1);
+ uint32_t inputSlotId = 1;
+ uint32_t reshapeSlotId = 0;
+
if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions())
{
uint32_t id = reshapedInputId;
@@ -571,6 +574,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
reshapedTensorInfo = ToTensorInfo(tensorPtr1);
inputTensorInfo = ToTensorInfo(tensorPtr);
+
+ inputSlotId = 0;
+ reshapeSlotId = 1;
}
uint32_t numDimensions = inputTensorInfo.GetNumDimensions();
@@ -592,11 +598,11 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex,
armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
- reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+ reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(reshapeSlotId));
RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId});
- armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1));
+ armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(inputSlotId));
RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot);
}