diff options
author | Matthew Sloyan <matthew.sloyan@arm.com> | 2021-07-13 19:46:11 +0100 |
---|---|---|
committer | Matthew Sloyan <matthew.sloyan@arm.com> | 2021-08-06 09:25:26 +0000 |
commit | 81beae3a870004795275e9266bc43d845b9f78db (patch) | |
tree | 70af86f3c36c8e330c72770e6f1419ca7b2a4bb8 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 95e9efc28ce70a8cda93e722f5ce90ebc96bdd95 (diff) | |
download | armnn-81beae3a870004795275e9266bc43d845b9f78db.tar.gz |
IVGCVSW-6119 ConstTensorsAsInput: FullyConnected
* Constant weights and biases are now stored as Constant layers.
* Updated Serializer, Deserializer and unit tests to reflect this.
* Updated TfLiteDelegate, TfLiteParser and OnnxParser.
* Updated Schema with IsConstant and ConstantTensorsAsInputs.
* Updated Ref backend to handle constant weights and
bias as inputs rather than reading from member variables.
* Added dynamic or constant input EndToEnd tests.
!android-nn-driver:5959
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ibf3cf437df1100e4b322b0d303c575c6339f9696
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 80 |
1 files changed, 42 insertions, 38 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index b669ae4efa..3e59244753 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -555,6 +555,9 @@ CreateConstTensorImpl(TfLiteParserImpl::BufferRawPtr bufferPtr, ::memcpy(data.get(), bufferPtr->data.data(), tensorInfo.GetNumBytes()); } + // Make sure isConstant flag is set. + tensorInfo.SetConstant(); + return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data)); } @@ -2571,42 +2574,26 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator armnn::IConnectableLayer* layer = nullptr; auto layerName = fmt::format("FullyConnected:{}:{}", subgraphIndex, operatorIndex); - Optional<ConstTensor> filterOptionalConstTensor; + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + // Add the first input tensor to the registration list + std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0]}; + std::vector<unsigned int> ignoreInputWhenRegister = {}; desc.m_ConstantWeights = IsConstTensor(inputs[1]); - auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); - std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0]}; - if (desc.m_ConstantWeights) - { - filterOptionalConstTensor = Optional<ConstTensor>(CreateConstTensorNonPermuted(inputs[1], filterTensorInfo)); - } - else - { - // Non const weights will need to be registered as inputs - tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]); - } + // Add the weights input to the registration list, constant layers will be added by SetupConstantLayers if constant. + tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]); - Optional<ConstTensor> biasOptionalConstTensor; if (inputs.size() == 3) { desc.m_BiasEnabled = true; - if (desc.m_ConstantWeights) - { - TensorInfo biasTensorInfo = ToTensorInfo(inputs[2]); - biasOptionalConstTensor = Optional<ConstTensor>(CreateConstTensorNonPermuted(inputs[2], biasTensorInfo)); - } - else - { - // Non const biases will need to be registered as inputs - tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]); - } + + // Add the biases input to the registration list, constant layer will be added by SetupConstantLayers. + tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]); } - layer = m_Network->AddFullyConnectedLayer(desc, - filterOptionalConstTensor, - biasOptionalConstTensor, - layerName.c_str()); + // Filters and biases are always passed to fully connected as inputs + layer = m_Network->AddFullyConnectedLayer(desc, layerName.c_str()); ARMNN_ASSERT(layer != nullptr); armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); @@ -3732,6 +3719,7 @@ void TfLiteParserImpl::RegisterInputSlots(size_t subgraphIndex, { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); ARMNN_ASSERT(layer != nullptr); + if (tensorIndexes.size() + startingSlotIndex != layer->GetNumInputSlots()) { throw ParseException( @@ -3831,19 +3819,27 @@ void TfLiteParserImpl::SetupConstantLayers(size_t subgraphIndex) m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size() > 0) { TensorRawPtr tensorPtr = subgraphPtr->tensors[tensorIndex].get(); - armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr); - auto tensorAndData = CreateConstTensorNonPermuted(tensorPtr, tensorInfo); - std::string layerName = fmt::format("Constant:{}", tensorPtr->name); - IConnectableLayer *layer = - m_Network->AddConstantLayer(tensorAndData, layerName.c_str()); + if(IsConstTensor(tensorPtr)) + { + armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr); + auto tensorAndData = CreateConstTensorNonPermuted(tensorPtr, tensorInfo); - layer->GetOutputSlot(0).SetTensorInfo(tensorInfo); - RegisterOutputSlots(subgraphIndex, - VIRTUAL_OPERATOR_ID, - layer, - { tensorIndex }); + std::string layerName = fmt::format("Constant:{}", tensorPtr->name); + IConnectableLayer *layer = m_Network->AddConstantLayer(tensorAndData, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + RegisterOutputSlots(subgraphIndex, + VIRTUAL_OPERATOR_ID, + layer, + { tensorIndex }); + } + else + { + throw ParseException( + fmt::format("Invalid Tensor: Tensor should be constant. {}", + CHECK_LOCATION().AsString())); + } } } } @@ -3863,6 +3859,9 @@ TfLiteParserImpl::CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr b armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector) { + // Make sure isConstant flag is set. + tensorInfo.SetConstant(); + auto constData = CreateConstTensorImpl<T>(bufferPtr, tensorPtr, tensorInfo, @@ -3885,7 +3884,6 @@ bool TfLiteParserImpl::IsConstTensor(TensorRawPtr tensorPtr) return isConst; } - std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> TfLiteParserImpl::CreateConstTensorPermuted(TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo, @@ -3895,6 +3893,9 @@ TfLiteParserImpl::CreateConstTensorPermuted(TensorRawPtr tensorPtr, auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer); CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer); + // Make sure isConstant flag is set. + tensorInfo.SetConstant(); + switch (tensorInfo.GetDataType()) { case armnn::DataType::Float32: @@ -3941,6 +3942,9 @@ armnn::ConstTensor TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr t auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer); CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer); + // Make sure isConstant flag is set. + tensorInfo.SetConstant(); + return ConstTensor(tensorInfo, bufferPtr->data.data()); } |