diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnnTfLiteParser/TfLiteParser.hpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp new file mode 100644 index 0000000000..91585af5d0 --- /dev/null +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -0,0 +1,156 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnn/INetwork.hpp" +#include "armnnTfLiteParser/ITfLiteParser.hpp" + +#include <schema_generated.h> +#include <functional> +#include <vector> + +namespace armnnTfLiteParser +{ + +class TfLiteParser : public ITfLiteParser +{ +public: + // Shorthands for TfLite types + using ModelPtr = std::unique_ptr<tflite::ModelT>; + using SubGraphPtr = std::unique_ptr<tflite::SubGraphT>; + using OperatorPtr = std::unique_ptr<tflite::OperatorT>; + using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>; + using TensorPtr = std::unique_ptr<tflite::TensorT>; + using TensorRawPtr = const tflite::TensorT *; + using TensorRawPtrVector = std::vector<TensorRawPtr>; + using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>; + using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>; + using BufferPtr = std::unique_ptr<tflite::BufferT>; + using BufferRawPtr = const tflite::BufferT *; + +public: + /// Create the network from a flatbuffers binary file on disk + virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override; + + /// Create the network from a flatbuffers binary + virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) override; + + + /// Retrieve binding info (layer id and tensor info) for the network input identified by + /// the given layer name and subgraph id + virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, + const std::string& name) const override; + + /// Retrieve binding info (layer id and tensor info) for the network output identified by + /// the given layer name and subgraph id + virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, + const std::string& name) const override; + + /// Return the number of subgraphs in the parsed model + virtual size_t GetSubgraphCount() const override; + + /// Return the input tensor names for a given subgraph + virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const override; + + /// Return the output tensor names for a given subgraph + virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override; + + TfLiteParser(); + virtual ~TfLiteParser() {} + +public: + // testable helpers + static ModelPtr LoadModelFromFile(const char * fileName); + static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len); + static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex); + static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex); + static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex); + static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex); + static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); + static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); + + static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex); + static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims, + const armnn::TensorInfo & inputTensorInfo); + + +private: + // No copying allowed until it is wanted and properly implemented + TfLiteParser(const TfLiteParser &) = delete; + TfLiteParser & operator=(const TfLiteParser &) = delete; + + /// Create the network from an already loaded flatbuffers model + armnn::INetworkPtr CreateNetworkFromModel(); + + // signature for the parser functions + using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex); + + void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex); + void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex); + void ParseConv2D(size_t subgraphIndex, size_t operatorIndex); + void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex); + void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); + void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex); + + void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); + void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot); + void RegisterInputSlots(size_t subgraphIndex, + size_t operatorIndex, + armnn::IConnectableLayer* layer, + const std::vector<unsigned int>& tensorIndexes); + void RegisterOutputSlots(size_t subgraphIndex, + size_t operatorIndex, + armnn::IConnectableLayer* layer, + const std::vector<unsigned int>& tensorIndexes); + + void SetupInputLayers(size_t subgraphIndex); + void SetupOutputLayers(size_t subgraphIndex); + + void ResetParser(); + + /// Attach an activation layer to the one passed as a parameter + armnn::IConnectableLayer* AddActivationLayer(armnn::IConnectableLayer* layer, + unsigned int outputSlot, + tflite::ActivationFunctionType activationType); + + // SupportedDataStorage's purpose is to hold data till we pass over to the network. + // We don't care about the content, and we want a single datatype to simplify the code. + struct SupportedDataStorage + { + std::unique_ptr<float[]> m_FloatData; + std::unique_ptr<uint8_t[]> m_Uint8Data; + std::unique_ptr<int32_t[]> m_Int32Data; + + SupportedDataStorage(std::unique_ptr<float[]> && data); + SupportedDataStorage(std::unique_ptr<uint8_t[]> && data); + SupportedDataStorage(std::unique_ptr<int32_t[]> && data); + }; + + std::pair<armnn::ConstTensor, SupportedDataStorage> CreateConstTensor(TensorRawPtr tensorPtr, + armnn::TensorInfo & tensorInfo, + bool convertFromTfToArmnnFormat); + + /// The network we're building. Gets cleared after it is passed to the user + armnn::INetworkPtr m_Network; + std::vector<OperatorParsingFunction> m_ParserFunctions; + ModelPtr m_Model; + + /// A mapping of an output slot to each of the input slots it should be connected to + /// The outputSlot is from the layer that creates this tensor as one of its ouputs + /// The inputSlots are from the layers that use this tensor as one of their inputs + struct TensorSlots + { + armnn::IOutputSlot* outputSlot; + std::vector<armnn::IInputSlot*> inputSlots; + + TensorSlots() : outputSlot(nullptr) { } + }; + typedef std::vector<TensorSlots> TensorConnections; + /// Connections for tensors in each subgraph + /// The first index is the subgraph ID, the second index is the tensor ID + std::vector<TensorConnections> m_SubgraphConnections; +}; + +} |