aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp6
1 files changed, 6 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index effc7ae144..6454643f98 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -22,6 +22,8 @@ public:
using TensorRawPtr = const armnnSerializer::TensorInfo *;
using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
+ using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
+ using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
@@ -58,6 +60,9 @@ public:
unsigned int layerIndex);
static armnn::NormalizationDescriptor GetNormalizationDescriptor(
NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
+ static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
+ static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
+ LstmInputParamsPtr lstmInputParams);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -94,6 +99,7 @@ private:
void ParseMerger(GraphPtr graph, unsigned int layerIndex);
void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
+ void ParseLstm(GraphPtr graph, unsigned int layerIndex);
void ParsePad(GraphPtr graph, unsigned int layerIndex);
void ParsePermute(GraphPtr graph, unsigned int layerIndex);
void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);