aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/FullyConnected.hpp
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2021-07-13 19:46:11 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2021-08-06 09:25:26 +0000
commit81beae3a870004795275e9266bc43d845b9f78db (patch)
tree70af86f3c36c8e330c72770e6f1419ca7b2a4bb8 /delegate/src/FullyConnected.hpp
parent95e9efc28ce70a8cda93e722f5ce90ebc96bdd95 (diff)
downloadarmnn-81beae3a870004795275e9266bc43d845b9f78db.tar.gz
IVGCVSW-6119 ConstTensorsAsInput: FullyConnected
* Constant weights and biases are now stored as Constant layers. * Updated Serializer, Deserializer and unit tests to reflect this. * Updated TfLiteDelegate, TfLiteParser and OnnxParser. * Updated Schema with IsConstant and ConstantTensorsAsInputs. * Updated Ref backend to handle constant weights and bias as inputs rather than reading from member variables. * Added dynamic or constant input EndToEnd tests. !android-nn-driver:5959 Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: Ibf3cf437df1100e4b322b0d303c575c6339f9696
Diffstat (limited to 'delegate/src/FullyConnected.hpp')
-rw-r--r--delegate/src/FullyConnected.hpp43
1 files changed, 27 insertions, 16 deletions
diff --git a/delegate/src/FullyConnected.hpp b/delegate/src/FullyConnected.hpp
index e94304fb21..49686d6eaf 100644
--- a/delegate/src/FullyConnected.hpp
+++ b/delegate/src/FullyConnected.hpp
@@ -130,30 +130,39 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
return isSupported ? kTfLiteOk : kTfLiteError;
}
- armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
- armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
- if(descriptor.m_ConstantWeights)
+ armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
+ ARMNN_ASSERT(layer != nullptr);
+
+ // Add a constant layer for weights and biases if inputs are constant.
+ if (isConstantWeights)
{
auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
weightsTensorInfo,
armnn::Optional<armnn::PermutationVector&>());
- optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensor);
- if (biasEnabled)
+ armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
+
+ weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+ weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
+ }
+
+ if (biasEnabled)
+ {
+ const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
+ if(tflite::IsConstantTensor(&tfLiteBiasTensor))
{
- const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
biasTensorInfo,
armnn::Optional<armnn::PermutationVector&>());
- optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensor);
+
+ armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
+ ARMNN_ASSERT(biasLayer != nullptr);
+
+ biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+ biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
}
}
- armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
- optionalWeights,
- optionalBiases);
- ARMNN_ASSERT(layer != nullptr);
-
armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
outputSlot.SetTensorInfo(outputTensorInfo);
@@ -171,13 +180,15 @@ TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
// Connect
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
if (!descriptor.m_ConstantWeights)
{
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
- if (biasEnabled)
- {
- delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
- }
+ }
+
+ if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]]))
+ {
+ delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
}
delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
}