From 66da7510362d00c6d5b6e8c1fe7f10145efe764b Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 20 Nov 2020 14:50:54 +0000 Subject: IVGCVSW-5544 Fix FullyConnected Delegate tests * Correct input shape Signed-off-by: Narumol Prangnawarat Change-Id: I9d1fe4c8ef32a9dfba7f7fdd6af314e9a522fce8 --- delegate/src/FullyConnected.hpp | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) (limited to 'delegate/src/FullyConnected.hpp') diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index b79f6a2bb2..53251f7c55 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -129,6 +129,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor)); } + armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); + + if (inputTensorInfo.GetNumDimensions() > 2) + { + // Calculate reshape to flatten to 2D [batch_size, input_size] + std::vector reshapedDimensions(2); + reshapedDimensions[1] = weightsTensorInfo.GetShape()[1]; + reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1]; + + if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ", + reshapedDimensions[1], operatorCode, nodeIndex); + return kTfLiteError; + } + + reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); + } + armnn::FullyConnectedDescriptor descriptor; descriptor.m_TransposeWeightMatrix = true; descriptor.m_BiasEnabled = biasEnabled; @@ -141,7 +162,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, IsFullyConnectedSupported, delegateData.m_Backends, isSupported, - inputTensorInfo, + reshapedTensorInfo, outputTensorInfo, weightsTensorInfo, biasTensorInfo, @@ -184,22 +205,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (inputTensorInfo.GetNumDimensions() > 2) { // Add reshape to flatten to 2D [batch_size, input_size] - std::vector reshapedDimensions(2); - reshapedDimensions[1] = weightsTensorInfo.GetShape()[1]; - reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1]; - - if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0) - { - TF_LITE_MAYBE_KERNEL_LOG( - tfLiteContext, - "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ", - reshapedDimensions[1], operatorCode, nodeIndex); - return kTfLiteError; - } - - armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); - reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() }); - armnn::ReshapeDescriptor reshapeDescriptor; reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape(); reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor); @@ -210,7 +215,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, // Connect delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0)); reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); - armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot; } -- cgit v1.2.1