aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index d2291c07a7..8de492ed5f 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -18,7 +18,8 @@ namespace armnnDeserializer
using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
using GraphPtr = const armnnSerializer::SerializedGraph *;
using TensorRawPtr = const armnnSerializer::TensorInfo *;
-using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
+using Pooling2dDescriptor = const armnnSerializer::Pooling2dDescriptor *;
+using Pooling3dDescriptor = const armnnSerializer::Pooling3dDescriptor *;
using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
@@ -60,7 +61,9 @@ public:
static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
static std::string GetLayerName(const GraphPtr& graph, unsigned int index);
- static armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
+ static armnn::Pooling2dDescriptor GetPooling2dDescriptor(Pooling2dDescriptor pooling2dDescriptor,
+ unsigned int layerIndex);
+ static armnn::Pooling3dDescriptor GetPooling3dDescriptor(Pooling3dDescriptor pooling3dDescriptor,
unsigned int layerIndex);
static armnn::NormalizationDescriptor GetNormalizationDescriptor(
NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
@@ -121,6 +124,7 @@ private:
void ParsePad(GraphPtr graph, unsigned int layerIndex);
void ParsePermute(GraphPtr graph, unsigned int layerIndex);
void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
+ void ParsePooling3d(GraphPtr graph, unsigned int layerIndex);
void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
void ParseQLstm(GraphPtr graph, unsigned int layerIndex);
void ParseQuantize(GraphPtr graph, unsigned int layerIndex);