aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-01-19 18:29:40 +0000
committermike.kelly <mike.kelly@arm.com>2023-01-24 22:57:16 +0000
commit04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f (patch)
tree478dbaf477eaa59fac838e6e73b56843f80b52d0 /delegate/src/FullyConnected.hpp
parent0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (diff)
downloadarmnn-04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f.tar.gz
IVGCVSW-7277 Fixed issues with FullyConnected on certain TFLite models
* TFLite Parser: * Fixed issue in ParseReshape where the targetShape wasn't always calculated correctly * Fixed issue in ParseFullyConnected where the wrong name was used for the ReshapeLayer * Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly * TFLite Delegate: * Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I129dfcb8543f8a3a297c0589c841be20ef3b6407
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp42
1 files changed, 39 insertions, 3 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index a2960e299b..2243ad0e0c 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateUtils.hpp"
+#include "armnnUtils/TensorUtils.hpp"
#include <armnn/utility/IgnoreUnused.hpp>
#include <tensorflow/lite/builtin_ops.h>
@@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
}
+ armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ // Calculate reshape to flatten to 2D [batch_size, input_size]
+ std::vector<unsigned int> reshapedDimensions(2);
+ reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
+ reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+ if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
+ reshapedDimensions[1], operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+ reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+ }
armnn::FullyConnectedDescriptor descriptor;
descriptor.m_TransposeWeightMatrix = true;
@@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
armnn::BackendId setBackend;
auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
{
+
FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
tfLiteContext,
IsFullyConnectedSupported,
@@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (!delegateData.m_Network)
{
- validateFunc(outputTensorInfo, isSupported);
+ validateFunc(reshapedOutputTensorInfo, isSupported);
return isSupported ? kTfLiteOk : kTfLiteError;
}
@@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
}
auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
+ delegateData);
+ if (!layer)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
+ operatorCode,
+ nodeIndex);
+ return kTfLiteError;
+ }
+ }
+
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
-
// Check Activation
TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);