aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-03-19 17:22:29 +0000
committerJim Flynn <jim.flynn@arm.com>2019-03-21 16:09:19 +0000
commit11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch)
treef4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnnDeserializer/Deserializer.hpp
parentdb059fd50f9afb398b8b12cd4592323fc8f60d7f (diff)
downloadarmnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd <nina.drozd@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp6
1 files changed, 6 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index effc7ae144..6454643f98 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -22,6 +22,8 @@ public:
using TensorRawPtr = const armnnSerializer::TensorInfo *;
using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
+ using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
+ using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
@@ -58,6 +60,9 @@ public:
unsigned int layerIndex);
static armnn::NormalizationDescriptor GetNormalizationDescriptor(
NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
+ static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
+ static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
+ LstmInputParamsPtr lstmInputParams);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -94,6 +99,7 @@ private:
void ParseMerger(GraphPtr graph, unsigned int layerIndex);
void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
+ void ParseLstm(GraphPtr graph, unsigned int layerIndex);
void ParsePad(GraphPtr graph, unsigned int layerIndex);
void ParsePermute(GraphPtr graph, unsigned int layerIndex);
void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);