aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-23 14:47:49 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-28 12:03:02 +0100
commita0162e17c56538ee6d72ecce4c3e0836cbb34c56 (patch)
treec47230c4024d7e79cacb39dafe179cdcf4571ade /src/armnnDeserializer/Deserializer.hpp
parent996f0f59e5b8a9ac73503814f7aadff4ef74cd35 (diff)
downloadarmnn-a0162e17c56538ee6d72ecce4c3e0836cbb34c56.tar.gz
MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp4
1 files changed, 4 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 0b05e16849..b1362c44b6 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -28,6 +28,7 @@ using TensorRawPtrVector = std::vector<TensorRawPtr>;
using LayerRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
+using UnidirectionalSequenceLstmDescriptorPtr = const armnnSerializer::UnidirectionalSequenceLstmDescriptor *;
class IDeserializer::DeserializerImpl
{
@@ -67,6 +68,8 @@ public:
static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
LstmInputParamsPtr lstmInputParams);
static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
+ static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor(
+ UnidirectionalSequenceLstmDescriptorPtr descriptor);
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<uint32_t> & targetDimsIn);
@@ -138,6 +141,7 @@ private:
void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
void ParseTranspose(GraphPtr graph, unsigned int layerIndex);
void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex);
+ void ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex);
void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
armnn::IConnectableLayer* layer);