diff options
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r-- | delegate/src/FullyConnected.hpp | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp index 337f1153a1..1129951104 100644 --- a/delegate/src/FullyConnected.hpp +++ b/delegate/src/FullyConnected.hpp @@ -54,7 +54,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, } const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); - armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor); + const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); // Check that we support fused activation before we attempt to create a layer @@ -82,8 +82,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, return kTfLiteError; } - bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor); - armnn::TensorInfo biasTensorInfo; if (biasEnabled) { @@ -141,7 +139,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, armnn::FullyConnectedDescriptor descriptor; descriptor.m_TransposeWeightMatrix = true; descriptor.m_BiasEnabled = biasEnabled; - descriptor.m_ConstantWeights = isConstantWeights; + descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant(); bool isSupported = false; armnn::BackendId setBackend; @@ -172,11 +170,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, ARMNN_ASSERT(layer != nullptr); // Add a constant layer for weights and biases if inputs are constant. - if (isConstantWeights) + if (weightsTensorInfo.IsConstant()) { auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor, - weightsTensorInfo, - armnn::Optional<armnn::PermutationVector&>()); + weightsTensorInfo); armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor); @@ -187,11 +184,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (biasEnabled) { const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]]; - if(tflite::IsConstantTensor(&tfLiteBiasTensor)) + if(biasTensorInfo.IsConstant()) { auto biasTensor = CreateConstTensor(&tfLiteBiasTensor, - biasTensorInfo, - armnn::Optional<armnn::PermutationVector&>()); + biasTensorInfo); armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor); ARMNN_ASSERT(biasLayer != nullptr); @@ -201,6 +197,18 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, } } + // The data input can also be constant, so we must check that this is also allocated to an input slot + if(inputTensorInfo.IsConstant()) + { + auto input = + CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]], + inputTensorInfo); + + armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input); + inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + } + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); outputSlot.SetTensorInfo(outputTensorInfo); @@ -224,7 +232,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1)); } - if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]])) + if (biasEnabled && !biasTensorInfo.IsConstant()) { delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2)); } @@ -233,7 +241,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData, if (reshapeLayer == nullptr) { - Connect(layer, tfLiteNode, delegateData); + if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk) + { + return kTfLiteError; + } } if (outputTensorInfo.GetNumDimensions() > 2) |