diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-01-19 18:29:40 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2023-01-24 22:57:16 +0000 |
commit | 04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f (patch) | |
tree | 478dbaf477eaa59fac838e6e73b56843f80b52d0 /delegate/src | |
parent | 0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (diff) | |
download | armnn-04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f.tar.gz |
IVGCVSW-7277 Fixed issues with FullyConnected on certain TFLite models
* TFLite Parser:
* Fixed issue in ParseReshape where the targetShape wasn't always calculated correctly
* Fixed issue in ParseFullyConnected where the wrong name was used for the ReshapeLayer
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
* TFLite Delegate:
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I129dfcb8543f8a3a297c0589c841be20ef3b6407
Diffstat (limited to 'delegate/src')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 48 | ||||
-rw-r--r-- | delegate/src/FullyConnected.hpp | 42 | ||||
-rw-r--r-- | delegate/src/test/FullyConnectedTest.cpp | 2 |
3 files changed, 87 insertions, 5 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index 850b279fea..91447576d0 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -367,6 +367,52 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext, return kTfLiteOk; } +armnn::IConnectableLayer* AddReshapeLayer(TfLiteContext* tfLiteContext, + TfLiteNode* tfLiteNode, + armnn::IConnectableLayer* prevLayer, + armnn::TensorInfo reshapedOutputTensorInfo, + armnn::TensorInfo outputTensorInfo, + armnnDelegate::DelegateData& data) +{ + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = outputTensorInfo.GetShape(); + + bool isSupported = false; + armnn::BackendId setBackend; + FORWARD_LAYER_SUPPORT_FUNC("RESHAPE", + tfLiteContext, + IsReshapeSupported, + data.m_Backends, + isSupported, + setBackend, + reshapedOutputTensorInfo, + outputTensorInfo, + desc); + + if (!isSupported) + { + return nullptr; + } + + armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc); + reshapeLayer->SetBackendId(setBackend); + ARMNN_ASSERT(reshapeLayer != nullptr); + + prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo); + reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // Connect and prepare output slots + for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex) + { + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0)); + armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex); + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])] = &outputSlot; + } + return reshapeLayer; +} + armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) { switch (tfLiteTensor.type) diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index a2960e299b..2243ad0e0c 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -1,11 +1,12 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateUtils.hpp" +#include "armnnUtils/TensorUtils.hpp" #include <armnn/utility/IgnoreUnused.hpp> #include <tensorflow/lite/builtin_ops.h> @@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); } + armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + // Calculate reshape to flatten to 2D [batch_size, input_size] + std::vector<unsigned int> reshapedDimensions(2); + reshapedDimensions[1] = weightsTensorInfo.GetShape()[0]; + reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1]; + + if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ", + reshapedDimensions[1], operatorCode, nodeIndex); + return kTfLiteError; + } + reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); + } armnn::FullyConnectedDescriptor descriptor; descriptor.m_TransposeWeightMatrix = true; @@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::BackendId setBackend; auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) { + FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED", tfLiteContext, IsFullyConnectedSupported, @@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (!delegateData.m_Network) { - validateFunc(outputTensorInfo, isSupported); + validateFunc(reshapedOutputTensorInfo, isSupported); return isSupported ? kTfLiteOk : kTfLiteError; } @@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, } auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo, + delegateData); + if (!layer) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ", + operatorCode, + nodeIndex); + return kTfLiteError; + } + } + if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check Activation TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); diff --git a/delegate/src/test/FullyConnectedTest.cpp b/delegate/src/test/FullyConnectedTest.cpp index c300bc72bf..3ef5cedbd7 100644 --- a/delegate/src/test/FullyConnectedTest.cpp +++ b/delegate/src/test/FullyConnectedTest.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // |