aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp63
1 files changed, 31 insertions, 32 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 2a8639f7fc..e232feed9b 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -5,50 +5,53 @@
#pragma once
-#include "armnn/INetwork.hpp"
-#include "armnnDeserializer/IDeserializer.hpp"
+#include <armnn/INetwork.hpp>
+#include <armnnDeserializer/IDeserializer.hpp>
#include <ArmnnSchema_generated.h>
#include <unordered_map>
namespace armnnDeserializer
{
-class Deserializer : public IDeserializer
-{
-public:
- // Shorthands for deserializer types
- using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
- using GraphPtr = const armnnSerializer::SerializedGraph *;
- using TensorRawPtr = const armnnSerializer::TensorInfo *;
- using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
- using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
- using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
- using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
- using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
- using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
- using TensorRawPtrVector = std::vector<TensorRawPtr>;
- using LayerRawPtr = const armnnSerializer::LayerBase *;
- using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
- using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
+// Shorthands for deserializer types
+using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
+using GraphPtr = const armnnSerializer::SerializedGraph *;
+using TensorRawPtr = const armnnSerializer::TensorInfo *;
+using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
+using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
+using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
+using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
+using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
+using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
+using TensorRawPtrVector = std::vector<TensorRawPtr>;
+using LayerRawPtr = const armnnSerializer::LayerBase *;
+using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
+using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
+
+class IDeserializer::DeserializerImpl
+{
public:
/// Create an input network from binary file contents
- armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
/// Create an input network from a binary input stream
- armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override;
+ armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent);
/// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
- BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override;
+ BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const;
/// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
- BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override;
+ BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const;
- Deserializer();
- ~Deserializer() {}
+ DeserializerImpl();
+ ~DeserializerImpl() = default;
+
+ // No copying allowed until it is wanted and properly implemented
+ DeserializerImpl(const DeserializerImpl&) = delete;
+ DeserializerImpl& operator=(const DeserializerImpl&) = delete;
-public:
// testable helpers
static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
@@ -68,15 +71,11 @@ public:
const std::vector<uint32_t> & targetDimsIn);
private:
- // No copying allowed until it is wanted and properly implemented
- Deserializer(const Deserializer&) = delete;
- Deserializer& operator=(const Deserializer&) = delete;
-
/// Create the network from an already loaded flatbuffers graph
armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph);
// signature for the parser functions
- using LayerParsingFunction = void(Deserializer::*)(GraphPtr graph, unsigned int layerIndex);
+ using LayerParsingFunction = void(DeserializerImpl::*)(GraphPtr graph, unsigned int layerIndex);
void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex);
void ParseAbs(GraphPtr graph, unsigned int layerIndex);
@@ -188,4 +187,4 @@ private:
std::unordered_map<unsigned int, Connections> m_GraphConnections;
};
-} // namespace armnnDeserializer
+} // namespace armnnDeserializer \ No newline at end of file