aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp42
1 files changed, 39 insertions, 3 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index a2960e299b..2243ad0e0c 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateUtils.hpp"
+#include "armnnUtils/TensorUtils.hpp"
#include <armnn/utility/IgnoreUnused.hpp>
#include <tensorflow/lite/builtin_ops.h>
@@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
}
+ armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ // Calculate reshape to flatten to 2D [batch_size, input_size]
+ std::vector<unsigned int> reshapedDimensions(2);
+ reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
+ reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+ if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
+ reshapedDimensions[1], operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+ reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+ }
armnn::FullyConnectedDescriptor descriptor;
descriptor.m_TransposeWeightMatrix = true;
@@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
armnn::BackendId setBackend;
auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
{
+
FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
tfLiteContext,
IsFullyConnectedSupported,
@@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (!delegateData.m_Network)
{
- validateFunc(outputTensorInfo, isSupported);
+ validateFunc(reshapedOutputTensorInfo, isSupported);
return isSupported ? kTfLiteOk : kTfLiteError;
}
@@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
}
auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
+
+ if (outputTensorInfo.GetNumDimensions() > 2)
+ {
+ layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
+ delegateData);
+ if (!layer)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
+ operatorCode,
+ nodeIndex);
+ return kTfLiteError;
+ }
+ }
+
if (!tfLiteNodeParameters)
{
// No Activation
return kTfLiteOk;
}
-
// Check Activation
TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);