diff options
author | James Conroy <james.conroy@arm.com> | 2020-05-13 10:27:58 +0100 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-05-13 23:06:38 +0000 |
commit | 8d33318a7ac33d90ed79701ff717de8d9940cc67 (patch) | |
tree | 2cf4140ec37b5b0a43b9618bab7f4f8076b5f4ab /src/armnnDeserializer/Deserializer.hpp | |
parent | 5061601fb6833dda20a6097af6a92e5e07310f25 (diff) | |
download | armnn-8d33318a7ac33d90ed79701ff717de8d9940cc67.tar.gz |
IVGCVSW-4777 Add QLstm serialization support
* Adds serialization/deserilization for QLstm.
* 3 unit tests: basic, layer norm and advanced.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index f7e47cc8c2..d6ceced7c6 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -24,6 +24,7 @@ public: using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; + using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *; using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; using TensorRawPtrVector = std::vector<TensorRawPtr>; using LayerRawPtr = const armnnSerializer::LayerBase *; @@ -62,6 +63,7 @@ public: static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor); static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor, LstmInputParamsPtr lstmInputParams); + static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr); static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, const std::vector<uint32_t> & targetDimsIn); @@ -113,6 +115,7 @@ private: void ParsePermute(GraphPtr graph, unsigned int layerIndex); void ParsePooling2d(GraphPtr graph, unsigned int layerIndex); void ParsePrelu(GraphPtr graph, unsigned int layerIndex); + void ParseQLstm(GraphPtr graph, unsigned int layerIndex); void ParseQuantize(GraphPtr graph, unsigned int layerIndex); void ParseReshape(GraphPtr graph, unsigned int layerIndex); void ParseResize(GraphPtr graph, unsigned int layerIndex); |