diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 474393cbe6..8c9674a5a6 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -183,6 +183,7 @@ private: void ParseTanH(size_t subgraphIndex, size_t operatorIndex); void ParseTranspose(size_t subgraphIndex, size_t operatorIndex); void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex); + void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex); void ParseUnpack(size_t subgraphIndex, size_t operatorIndex); void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); @@ -234,13 +235,19 @@ private: std::unique_ptr<int32_t[]> m_Int32Data; }; + bool ShouldConstantTensorBeCreated(unsigned int tensorIndex); bool IsConstTensor(TensorRawPtr tensorPtr); armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo); + std::pair<armnn::ConstTensor, SupportedDataStorage> CreateConstTensorPermuted(TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector); + std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> + CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, + armnn::TensorInfo& tensorInfo, + armnn::DataType inputDataType); template<typename T> std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> @@ -248,6 +255,9 @@ private: TfLiteParserImpl::TensorRawPtr tensorPtr, armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector); + std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>> + CreateConstTensorPtr(TensorRawPtr tensorPtr, + armnn::TensorInfo& inputTensorInfo); // Settings for configuring the TfLiteParser armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; @@ -274,9 +284,12 @@ private: /// The first index is the subgraph ID, the second index is the tensor ID std::vector<TensorConnections> m_SubgraphConnections; - /// This is used in case that the model does not speciry the output. + /// This is used in case that the model does not specify the output. /// The shape can be calculated from the options. std::vector<std::vector<unsigned int>> m_OverridenOutputShapes; + + std::vector<unsigned int> m_ConstantsToDequantize; + std::vector<unsigned int> m_ConstantsToBeCreated; }; } |