diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-01-19 18:29:40 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2023-01-24 22:57:16 +0000 |
commit | 04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f (patch) | |
tree | 478dbaf477eaa59fac838e6e73b56843f80b52d0 /delegate/src/FullyConnected.hpp | |
parent | 0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (diff) | |
download | armnn-04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f.tar.gz |
IVGCVSW-7277 Fixed issues with FullyConnected on certain TFLite models
* TFLite Parser:
* Fixed issue in ParseReshape where the targetShape wasn't always calculated correctly
* Fixed issue in ParseFullyConnected where the wrong name was used for the ReshapeLayer
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
* TFLite Delegate:
* Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I129dfcb8543f8a3a297c0589c841be20ef3b6407
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index a2960e299b..2243ad0e0c 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -1,11 +1,12 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateUtils.hpp" +#include "armnnUtils/TensorUtils.hpp" #include <armnn/utility/IgnoreUnused.hpp> #include <tensorflow/lite/builtin_ops.h> @@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); } + armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + // Calculate reshape to flatten to 2D [batch_size, input_size] + std::vector<unsigned int> reshapedDimensions(2); + reshapedDimensions[1] = weightsTensorInfo.GetShape()[0]; + reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1]; + + if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ", + reshapedDimensions[1], operatorCode, nodeIndex); + return kTfLiteError; + } + reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); + } armnn::FullyConnectedDescriptor descriptor; descriptor.m_TransposeWeightMatrix = true; @@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::BackendId setBackend; auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) { + FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED", tfLiteContext, IsFullyConnectedSupported, @@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (!delegateData.m_Network) { - validateFunc(outputTensorInfo, isSupported); + validateFunc(reshapedOutputTensorInfo, isSupported); return isSupported ? kTfLiteOk : kTfLiteError; } @@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, } auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo, + delegateData); + if (!layer) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ", + operatorCode, + nodeIndex); + return kTfLiteError; + } + } + if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check Activation TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); |