aboutsummaryrefslogtreecommitdiff
path: root/delegate/src
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-01-19 18:29:40 +0000
committermike.kelly <mike.kelly@arm.com>2023-01-24 22:57:16 +0000
commit04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f (patch)
tree478dbaf477eaa59fac838e6e73b56843f80b52d0 /delegate/src
parent0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (diff)
downloadarmnn-04d8229bb3e78d1b1dd21eed41e47aabc25d8e2f.tar.gz
IVGCVSW-7277 Fixed issues with FullyConnected on certain TFLite models
* TFLite Parser: * Fixed issue in ParseReshape where the targetShape wasn't always calculated correctly * Fixed issue in ParseFullyConnected where the wrong name was used for the ReshapeLayer * Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly * TFLite Delegate: * Added an ExpandDims to the FullyConnected to ensure that we reshape the output correctly Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I129dfcb8543f8a3a297c0589c841be20ef3b6407
Diffstat (limited to 'delegate/src')
-rw-r--r--delegate/src/DelegateUtils.hpp48
-rw-r--r--delegate/src/FullyConnected.hpp42
-rw-r--r--delegate/src/test/FullyConnectedTest.cpp2
3 files changed, 87 insertions, 5 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index 850b279fea..91447576d0 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -367,6 +367,52 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext,
return kTfLiteOk;
}
+armnn::IConnectableLayer* AddReshapeLayer(TfLiteContext* tfLiteContext,
+ TfLiteNode* tfLiteNode,
+ armnn::IConnectableLayer* prevLayer,
+ armnn::TensorInfo reshapedOutputTensorInfo,
+ armnn::TensorInfo outputTensorInfo,
+ armnnDelegate::DelegateData& data)
+{
+ armnn::ReshapeDescriptor desc;
+ desc.m_TargetShape = outputTensorInfo.GetShape();
+
+ bool isSupported = false;
+ armnn::BackendId setBackend;
+ FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
+ tfLiteContext,
+ IsReshapeSupported,
+ data.m_Backends,
+ isSupported,
+ setBackend,
+ reshapedOutputTensorInfo,
+ outputTensorInfo,
+ desc);
+
+ if (!isSupported)
+ {
+ return nullptr;
+ }
+
+ armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc);
+ reshapeLayer->SetBackendId(setBackend);
+ ARMNN_ASSERT(reshapeLayer != nullptr);
+
+ prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // Connect and prepare output slots
+ for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex)
+ {
+ data.m_OutputSlotForNode[static_cast<unsigned long>(
+ tfLiteNode->outputs->data[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0));
+ armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex);
+ data.m_OutputSlotForNode[static_cast<unsigned long>(
+ tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
+ }
+ return reshapeLayer;
+}
+
armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor)
{
switch (tfLiteTensor.type)
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index a2960e299b..2243ad0e0c 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateUtils.hpp"
+#include "armnnUtils/TensorUtils.hpp"
#include <armnn/utility/IgnoreUnused.hpp>
#include <tensorflow/lite/builtin_ops.h>
@@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
}
+ armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ // Calculate reshape to flatten to 2D [batch_size, input_size]
+ std::vector<unsigned int> reshapedDimensions(2);
+ reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
+ reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+ if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
+ reshapedDimensions[1], operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+ reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+ }
armnn::FullyConnectedDescriptor descriptor;
descriptor.m_TransposeWeightMatrix = true;
@@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
armnn::BackendId setBackend;
auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
{
+
FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
tfLiteContext,
IsFullyConnectedSupported,
@@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (!delegateData.m_Network)
{
- validateFunc(outputTensorInfo, isSupported);
+ validateFunc(reshapedOutputTensorInfo, isSupported);
return isSupported ? kTfLiteOk : kTfLiteError;
}
@@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
}
auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
+ delegateData);
+ if (!layer)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
+ operatorCode,
+ nodeIndex);
+ return kTfLiteError;
+ }
+ }
+
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
-
// Check Activation
TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
diff --git a/delegate/src/test/FullyConnectedTest.cpp b/delegate/src/test/FullyConnectedTest.cpp
index c300bc72bf..3ef5cedbd7 100644
--- a/delegate/src/test/FullyConnectedTest.cpp
+++ b/delegate/src/test/FullyConnectedTest.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//