aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp40
1 files changed, 22 insertions, 18 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index b79f6a2bb2..53251f7c55 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -129,6 +129,27 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
}
+ armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
+
+ if (inputTensorInfo.GetNumDimensions() > 2)
+ {
+ // Calculate reshape to flatten to 2D [batch_size, input_size]
+ std::vector<unsigned int> reshapedDimensions(2);
+ reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
+ reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+ if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
+ reshapedDimensions[1], operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
+ reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+ }
+
armnn::FullyConnectedDescriptor descriptor;
descriptor.m_TransposeWeightMatrix = true;
descriptor.m_BiasEnabled = biasEnabled;
@@ -141,7 +162,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
IsFullyConnectedSupported,
delegateData.m_Backends,
isSupported,
- inputTensorInfo,
+ reshapedTensorInfo,
outputTensorInfo,
weightsTensorInfo,
biasTensorInfo,
@@ -184,22 +205,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (inputTensorInfo.GetNumDimensions() > 2)
{
// Add reshape to flatten to 2D [batch_size, input_size]
- std::vector<unsigned int> reshapedDimensions(2);
- reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
- reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
-
- if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
- {
- TF_LITE_MAYBE_KERNEL_LOG(
- tfLiteContext,
- "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
- reshapedDimensions[1], operatorCode, nodeIndex);
- return kTfLiteError;
- }
-
- armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
- reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
-
armnn::ReshapeDescriptor reshapeDescriptor;
reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
@@ -210,7 +215,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
// Connect
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
- armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
}