diff options
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index a2960e299b..2243ad0e0c 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -1,11 +1,12 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateUtils.hpp" +#include "armnnUtils/TensorUtils.hpp" #include <armnn/utility/IgnoreUnused.hpp> #include <tensorflow/lite/builtin_ops.h> @@ -103,6 +104,25 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); } + armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + // Calculate reshape to flatten to 2D [batch_size, input_size] + std::vector<unsigned int> reshapedDimensions(2); + reshapedDimensions[1] = weightsTensorInfo.GetShape()[0]; + reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1]; + + if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ", + reshapedDimensions[1], operatorCode, nodeIndex); + return kTfLiteError; + } + reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); + } armnn::FullyConnectedDescriptor descriptor; descriptor.m_TransposeWeightMatrix = true; @@ -113,6 +133,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::BackendId setBackend; auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) { + FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED", tfLiteContext, IsFullyConnectedSupported, @@ -128,7 +149,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (!delegateData.m_Network) { - validateFunc(outputTensorInfo, isSupported); + validateFunc(reshapedOutputTensorInfo, isSupported); return isSupported ? kTfLiteOk : kTfLiteError; } @@ -202,12 +223,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, } auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data); + + if (outputTensorInfo.GetNumDimensions() > 2) + { + layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo, + delegateData); + if (!layer) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ", + operatorCode, + nodeIndex); + return kTfLiteError; + } + } + if (!tfLiteNodeParameters) { // No Activation return kTfLiteOk; } - // Check Activation TfLiteFusedActivation activationType = tfLiteNodeParameters->activation; return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData); |