diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-01-19 18:29:40 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2023-01-24 22:57:16 +0000 |
commit | 04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f (patch) | |
tree | 478dbaf477eaa59fac838e6e73b56843f80b52d0 /delegate/src/DelegateUtils.hpp | |
parent | 0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (diff) | |
download | armnn-04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f.tar.gz |
IVGCVSW-7277 Fixed issues with FullyConnected on certain TFLite models
* TFLite Parser:
* Fixed issue in ParseReshape where the targetShape wasn't always calculated correctly
* Fixed issue in ParseFullyConnected where the wrong name was used for the ReshapeLayer
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
* TFLite Delegate:
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I129dfcb8543f8a3a297c0589c841be20ef3b6407
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index 850b279fea..91447576d0 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -367,6 +367,52 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext, return kTfLiteOk; } +armnn::IConnectableLayer* AddReshapeLayer(TfLiteContext* tfLiteContext, + TfLiteNode* tfLiteNode, + armnn::IConnectableLayer* prevLayer, + armnn::TensorInfo reshapedOutputTensorInfo, + armnn::TensorInfo outputTensorInfo, + armnnDelegate::DelegateData& data) +{ + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = outputTensorInfo.GetShape(); + + bool isSupported = false; + armnn::BackendId setBackend; + FORWARD_LAYER_SUPPORT_FUNC("RESHAPE", + tfLiteContext, + IsReshapeSupported, + data.m_Backends, + isSupported, + setBackend, + reshapedOutputTensorInfo, + outputTensorInfo, + desc); + + if (!isSupported) + { + return nullptr; + } + + armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc); + reshapeLayer->SetBackendId(setBackend); + ARMNN_ASSERT(reshapeLayer != nullptr); + + prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo); + reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // Connect and prepare output slots + for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex) + { + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0)); + armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex); + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])] = &outputSlot; + } + return reshapeLayer; +} + armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) { switch (tfLiteTensor.type) |