diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 63 |
1 files changed, 31 insertions, 32 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index 2a8639f7fc..e232feed9b 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -5,50 +5,53 @@ #pragma once -#include "armnn/INetwork.hpp" -#include "armnnDeserializer/IDeserializer.hpp" +#include <armnn/INetwork.hpp> +#include <armnnDeserializer/IDeserializer.hpp> #include <ArmnnSchema_generated.h> #include <unordered_map> namespace armnnDeserializer { -class Deserializer : public IDeserializer -{ -public: - // Shorthands for deserializer types - using ConstTensorRawPtr = const armnnSerializer::ConstTensor *; - using GraphPtr = const armnnSerializer::SerializedGraph *; - using TensorRawPtr = const armnnSerializer::TensorInfo *; - using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *; - using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; - using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; - using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; - using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *; - using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; - using TensorRawPtrVector = std::vector<TensorRawPtr>; - using LayerRawPtr = const armnnSerializer::LayerBase *; - using LayerBaseRawPtr = const armnnSerializer::LayerBase *; - using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>; +// Shorthands for deserializer types +using ConstTensorRawPtr = const armnnSerializer::ConstTensor *; +using GraphPtr = const armnnSerializer::SerializedGraph *; +using TensorRawPtr = const armnnSerializer::TensorInfo *; +using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *; +using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; +using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; +using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; +using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *; +using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; +using TensorRawPtrVector = std::vector<TensorRawPtr>; +using LayerRawPtr = const armnnSerializer::LayerBase *; +using LayerBaseRawPtr = const armnnSerializer::LayerBase *; +using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>; + +class IDeserializer::DeserializerImpl +{ public: /// Create an input network from binary file contents - armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override; + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent); /// Create an input network from a binary input stream - armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override; + armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent); /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name - BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override; + BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const; /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name - BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override; + BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const; - Deserializer(); - ~Deserializer() {} + DeserializerImpl(); + ~DeserializerImpl() = default; + + // No copying allowed until it is wanted and properly implemented + DeserializerImpl(const DeserializerImpl&) = delete; + DeserializerImpl& operator=(const DeserializerImpl&) = delete; -public: // testable helpers static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len); static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex); @@ -68,15 +71,11 @@ public: const std::vector<uint32_t> & targetDimsIn); private: - // No copying allowed until it is wanted and properly implemented - Deserializer(const Deserializer&) = delete; - Deserializer& operator=(const Deserializer&) = delete; - /// Create the network from an already loaded flatbuffers graph armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph); // signature for the parser functions - using LayerParsingFunction = void(Deserializer::*)(GraphPtr graph, unsigned int layerIndex); + using LayerParsingFunction = void(DeserializerImpl::*)(GraphPtr graph, unsigned int layerIndex); void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex); void ParseAbs(GraphPtr graph, unsigned int layerIndex); @@ -188,4 +187,4 @@ private: std::unordered_map<unsigned int, Connections> m_GraphConnections; }; -} // namespace armnnDeserializer +} // namespace armnnDeserializer
\ No newline at end of file |