diff options
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index e94304fb21..49686d6eaf 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -130,30 +130,39 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, return isSupported ? kTfLiteOk : kTfLiteError; } - armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional(); - armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional(); - if(descriptor.m_ConstantWeights) + armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor); + ARMNN_ASSERT(layer != nullptr); + + // Add a constant layer for weights and biases if inputs are constant. + if (isConstantWeights) { auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor, weightsTensorInfo, armnn::Optional<armnn::PermutationVector&>()); - optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensor); - if (biasEnabled) + armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor); + + weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u)); + weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo); + } + + if (biasEnabled) + { + const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]]; + if(tflite::IsConstantTensor(&tfLiteBiasTensor)) { - const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]]; auto biasTensor = CreateConstTensor(&tfLiteBiasTensor, biasTensorInfo, armnn::Optional<armnn::PermutationVector&>()); - optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensor); + + armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor); + ARMNN_ASSERT(biasLayer != nullptr); + + biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u)); + biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo); } } - armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor, - optionalWeights, - optionalBiases); - ARMNN_ASSERT(layer != nullptr); - armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); outputSlot.SetTensorInfo(outputTensorInfo); @@ -171,13 +180,15 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, // Connect delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0)); reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0)); + if (!descriptor.m_ConstantWeights) { delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1)); - if (biasEnabled) - { - delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2)); - } + } + + if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]])) + { + delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2)); } delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot; } |