diff options
author | Mike Kelly <mike.kelly@arm.com> | 2020-07-06 19:24:15 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2020-07-06 19:00:14 +0000 |
commit | c5789ca2e432075e2c92e7e0d99139c5329280e6 (patch) | |
tree | f5a19ee78352f6e3b97566a24bd07378605b18f3 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | f78c767d229a9fc7a692fe4e7b71e36710966615 (diff) | |
download | armnn-c5789ca2e432075e2c92e7e0d99139c5329280e6.tar.gz |
GitHub #418 AddBroadcastReshapeLayer can cause inputs to be connected incorrectly
* Fixed issue where AddBroadcastReshapeLayer would always connect the Reshaped input to the first input slot and the other input to the first input slot.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ifd2745a819eb0f72ff9433690afc92a6a34f2ec3
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 1b93aadc5b..a690e5325e 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -563,6 +563,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr); armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1); + uint32_t inputSlotId = 1; + uint32_t reshapeSlotId = 0; + if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions()) { uint32_t id = reshapedInputId; @@ -571,6 +574,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, reshapedTensorInfo = ToTensorInfo(tensorPtr1); inputTensorInfo = ToTensorInfo(tensorPtr); + + inputSlotId = 0; + reshapeSlotId = 1; } uint32_t numDimensions = inputTensorInfo.GetNumDimensions(); @@ -592,11 +598,11 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str()); reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo); - reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); + reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(reshapeSlotId)); RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId}); - armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1)); + armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(inputSlotId)); RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot); } |