aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp43
1 files changed, 27 insertions, 16 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index e94304fb21..49686d6eaf 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -130,30 +130,39 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
return isSupported ? kTfLiteOk : kTfLiteError;
}
- armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
- armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
- if(descriptor.m_ConstantWeights)
+ armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
+ ARMNN_ASSERT(layer != nullptr);
+
+ // Add a constant layer for weights and biases if inputs are constant.
+ if (isConstantWeights)
{
auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
weightsTensorInfo,
armnn::Optional<armnn::PermutationVector&>());
- optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensor);
- if (biasEnabled)
+ armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
+
+ weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+ weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
+ }
+
+ if (biasEnabled)
+ {
+ const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
+ if(tflite::IsConstantTensor(&tfLiteBiasTensor))
{
- const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
biasTensorInfo,
armnn::Optional<armnn::PermutationVector&>());
- optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensor);
+
+ armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
+ ARMNN_ASSERT(biasLayer != nullptr);
+
+ biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+ biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
}
}
- armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
- optionalWeights,
- optionalBiases);
- ARMNN_ASSERT(layer != nullptr);
-
armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
outputSlot.SetTensorInfo(outputTensorInfo);
@@ -171,13 +180,15 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
// Connect
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
if (!descriptor.m_ConstantWeights)
{
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
- if (biasEnabled)
- {
- delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
- }
+ }
+
+ if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]]))
+ {
+ delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
}
delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
}