From c5789ca2e432075e2c92e7e0d99139c5329280e6 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 6 Jul 2020 19:24:15 +0100 Subject: GitHub #418 AddBroadcastReshapeLayer can cause inputs to be connected incorrectly * Fixed issue where AddBroadcastReshapeLayer would always connect the Reshaped input to the first input slot and the other input to the first input slot. Signed-off-by: Mike Kelly Change-Id: Ifd2745a819eb0f72ff9433690afc92a6a34f2ec3 --- src/armnnTfLiteParser/TfLiteParser.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 1b93aadc5b..a690e5325e 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -563,6 +563,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(tensorPtr); armnn::TensorInfo inputTensorInfo = ToTensorInfo(tensorPtr1); + uint32_t inputSlotId = 1; + uint32_t reshapeSlotId = 0; + if (inputTensorInfo.GetNumDimensions() < reshapedTensorInfo.GetNumDimensions()) { uint32_t id = reshapedInputId; @@ -571,6 +574,9 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, reshapedTensorInfo = ToTensorInfo(tensorPtr1); inputTensorInfo = ToTensorInfo(tensorPtr); + + inputSlotId = 0; + reshapeSlotId = 1; } uint32_t numDimensions = inputTensorInfo.GetNumDimensions(); @@ -592,11 +598,11 @@ void TfLiteParser::AddBroadcastReshapeLayer(size_t subgraphIndex, armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str()); reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo); - reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); + reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(reshapeSlotId)); RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {reshapedInputId}); - armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(1)); + armnn::IInputSlot* input1Slot = &(layer->GetInputSlot(inputSlotId)); RegisterConsumerOfTensor(subgraphIndex, inputId, input1Slot); } -- cgit v1.2.1