aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp35
1 files changed, 23 insertions, 12 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index 337f1153a1..1129951104 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -54,7 +54,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
}
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
- armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
+ const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
// Check that we support fused activation before we attempt to create a layer
@@ -82,8 +82,6 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
return kTfLiteError;
}
- bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor);
-
armnn::TensorInfo biasTensorInfo;
if (biasEnabled)
{
@@ -141,7 +139,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
armnn::FullyConnectedDescriptor descriptor;
descriptor.m_TransposeWeightMatrix = true;
descriptor.m_BiasEnabled = biasEnabled;
- descriptor.m_ConstantWeights = isConstantWeights;
+ descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant();
bool isSupported = false;
armnn::BackendId setBackend;
@@ -172,11 +170,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
ARMNN_ASSERT(layer != nullptr);
// Add a constant layer for weights and biases if inputs are constant.
- if (isConstantWeights)
+ if (weightsTensorInfo.IsConstant())
{
auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
- weightsTensorInfo,
- armnn::Optional<armnn::PermutationVector&>());
+ weightsTensorInfo);
armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
@@ -187,11 +184,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (biasEnabled)
{
const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
- if(tflite::IsConstantTensor(&tfLiteBiasTensor))
+ if(biasTensorInfo.IsConstant())
{
auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
- biasTensorInfo,
- armnn::Optional<armnn::PermutationVector&>());
+ biasTensorInfo);
armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
ARMNN_ASSERT(biasLayer != nullptr);
@@ -201,6 +197,18 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
}
}
+ // The data input can also be constant, so we must check that this is also allocated to an input slot
+ if(inputTensorInfo.IsConstant())
+ {
+ auto input =
+ CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
+ inputTensorInfo);
+
+ armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
+ inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+ }
+
armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
outputSlot.SetTensorInfo(outputTensorInfo);
@@ -224,7 +232,7 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
}
- if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]]))
+ if (biasEnabled && !biasTensorInfo.IsConstant())
{
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
}
@@ -233,7 +241,10 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
if (reshapeLayer == nullptr)
{
- Connect(layer, tfLiteNode, delegateData);
+ if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
+ {
+ return kTfLiteError;
+ }
}
if (outputTensorInfo.GetNumDimensions() > 2)