aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.hpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-01-28 16:18:54 +0000
committermike.kelly <mike.kelly@arm.com>2022-04-22 15:46:50 +0000
commit5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch)
treeb256346d6fc78e78735cc50ec822286f809dd37f /src/armnnTfLiteParser/TfLiteParser.hpp
parent4dae5794644b44be8c93bc6db553a205551bc077 (diff)
downloadarmnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser * Added support for float operations with int8 weights to TFLite Parser * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected, TransposeConv and UnidirectionalSequenceLstm * Renamed subgraphIndex to subgraph to fix name-shadowing warning. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 474393cbe6..8c9674a5a6 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -183,6 +183,7 @@ private:
void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
+ void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
@@ -234,13 +235,19 @@ private:
std::unique_ptr<int32_t[]> m_Int32Data;
};
+ bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
bool IsConstTensor(TensorRawPtr tensorPtr);
armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo);
+
std::pair<armnn::ConstTensor, SupportedDataStorage>
CreateConstTensorPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+ std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
+ CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
+ armnn::TensorInfo& tensorInfo,
+ armnn::DataType inputDataType);
template<typename T>
std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
@@ -248,6 +255,9 @@ private:
TfLiteParserImpl::TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+ std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
+ CreateConstTensorPtr(TensorRawPtr tensorPtr,
+ armnn::TensorInfo& inputTensorInfo);
// Settings for configuring the TfLiteParser
armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
@@ -274,9 +284,12 @@ private:
/// The first index is the subgraph ID, the second index is the tensor ID
std::vector<TensorConnections> m_SubgraphConnections;
- /// This is used in case that the model does not speciry the output.
+ /// This is used in case that the model does not specify the output.
/// The shape can be calculated from the options.
std::vector<std::vector<unsigned int>> m_OverridenOutputShapes;
+
+ std::vector<unsigned int> m_ConstantsToDequantize;
+ std::vector<unsigned int> m_ConstantsToBeCreated;
};
}