diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-01-28 16:18:54 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-04-22 15:46:50 +0000 |
commit | 5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch) | |
tree | b256346d6fc78e78735cc50ec822286f809dd37f /src/armnnTfLiteParser/TfLiteParser.hpp | |
parent | 4dae5794644b44be8c93bc6db553a205551bc077 (diff) | |
download | armnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz |
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser
* Added support for float operations with int8 weights to TFLite Parser
* Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
TransposeConv and UnidirectionalSequenceLstm
* Renamed subgraphIndex to subgraph to fix name-shadowing warning.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 474393cbe6..8c9674a5a6 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -183,6 +183,7 @@ private: void ParseTanH(size_t subgraphIndex, size_t operatorIndex); void ParseTranspose(size_t subgraphIndex, size_t operatorIndex); void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex); + void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex); void ParseUnpack(size_t subgraphIndex, size_t operatorIndex); void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); @@ -234,13 +235,19 @@ private: std::unique_ptr<int32_t[]> m_Int32Data; }; + bool ShouldConstantTensorBeCreated(unsigned int tensorIndex); bool IsConstTensor(TensorRawPtr tensorPtr); armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo); + std::pair<armnn::ConstTensor, SupportedDataStorage> CreateConstTensorPermuted(TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector); + std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> + CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, + armnn::TensorInfo& tensorInfo, + armnn::DataType inputDataType); template<typename T> std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> @@ -248,6 +255,9 @@ private: TfLiteParserImpl::TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector); + std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>> + CreateConstTensorPtr(TensorRawPtr tensorPtr, + armnn::TensorInfo& inputTensorInfo); // Settings for configuring the TfLiteParser armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; @@ -274,9 +284,12 @@ private: /// The first index is the subgraph ID, the second index is the tensor ID std::vector<TensorConnections> m_SubgraphConnections; - /// This is used in case that the model does not speciry the output. + /// This is used in case that the model does not specify the output. /// The shape can be calculated from the options. std::vector<std::vector<unsigned int>> m_OverridenOutputShapes; + + std::vector<unsigned int> m_ConstantsToDequantize; + std::vector<unsigned int> m_ConstantsToBeCreated; }; } |