diff options
author | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
---|---|---|
committer | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
commit | bceff2fb3fc68bb0aa88b886900c34b77340c826 (patch) | |
tree | d867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnnTfParser/TfParser.hpp | |
parent | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff) | |
download | armnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz |
Release 18.03
Diffstat (limited to 'src/armnnTfParser/TfParser.hpp')
-rw-r--r-- | src/armnnTfParser/TfParser.hpp | 199 |
1 files changed, 199 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp new file mode 100644 index 0000000000..c5b4bce8ac --- /dev/null +++ b/src/armnnTfParser/TfParser.hpp @@ -0,0 +1,199 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnnTfParser/ITfParser.hpp" + +#include "armnn/Types.hpp" +#include "armnn/Tensor.hpp" +#include "armnn/INetwork.hpp" + +#include <map> +#include <memory> +#include <unordered_map> +#include <vector> + +namespace armnn +{ +class TensorInfo; +} + +namespace tensorflow +{ +class GraphDef; +class NodeDef; +} + +namespace armnnTfParser +{ + +using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>; + +class ParsedTfOperation; +using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>; + +/// +/// WithOutputTensorIndex wraps a value and an index. The purpose of +/// this template is to signify that in Tensorflow the input name of +/// a layer has the convention of 'inputTensorName:#index' where the +/// #index can be omitted and it implicitly means the 0. output of +/// the referenced layer. By supporting this notation we can handle +/// layers with multiple outputs, such as Split. +/// +template <typename T> +struct WithOutputTensorIndex +{ + T m_IndexedValue; + unsigned int m_Index; + + WithOutputTensorIndex(const T & value, unsigned int index) + : m_IndexedValue{value} + , m_Index{index} {} + + WithOutputTensorIndex(T && value, unsigned int index) + : m_IndexedValue{value} + , m_Index{index} {} +}; + +using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>; +using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>; +using OutputId = WithOutputTensorIndex<std::string>; + +class TfParser : public ITfParser +{ +public: + /// Create the network from a protobuf text file on disk + virtual armnn::INetworkPtr CreateNetworkFromTextFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) override; + + /// Create the network from a protobuf binary file on disk + virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) override; + + /// Create the network directly from protobuf text in a string. Useful for debugging/testing + virtual armnn::INetworkPtr CreateNetworkFromString( + const char* protoText, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) override; + + /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name + virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; + + /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name + virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; + +public: + TfParser(); + +private: + template <typename T> + friend class ParsedConstTfOperation; + friend class ParsedMatMulTfOperation; + + /// Parses a GraphDef loaded into memory from one of the other CreateNetwork* + armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs); + + /// sets up variables and then performs BFS to parse all nodes + void LoadGraphDef(const tensorflow::GraphDef& graphDef); + + /// parses a given node, assuming nodes before it in graph have been done + void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + + /// Handling identity layers as the input for Conv2D layer + const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef); + /// Finds the nodes connected as inputs of the given node in the graph. + std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const; + /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph, + /// and throws an exception if the number of inputs does not match the expected one. + /// This will automatically resolve any identity nodes. The result vector contains the parsed operation + /// together with the output tensor index to make the connection unambiguous. + std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef, + std::size_t expectedNumInputs); + + ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + + /// Checks if there is a pre-parsed const tensor is available with the given name and Type + template<typename Type> + bool HasParsedConstTensor(const std::string & nodeName) const; + + ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef, + armnn::PoolingAlgorithm pooltype); + ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc); + ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false); + armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef, + const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName); + + static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName, + const char* bindingPointDesc, + const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo); + + void TrackInputBinding(armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo); + + void TrackOutputBinding(armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo); + + static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo, + const char* bindingPointDesc, + std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo); + + void Cleanup(); + + /// The network we're building. Gets cleared after it is passed to the user + armnn::INetworkPtr m_Network; + + using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef); + + /// map of TensorFlow operation names to parsing member functions + static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions; + + std::map<std::string, armnn::TensorShape> m_InputShapes; + std::vector<std::string> m_RequestedOutputs; + + /// map of nodes extracted from the GraphDef to speed up parsing + std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName; + + std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations; + + /// maps input layer names to their corresponding ids and tensor infos + std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo; + + /// maps output layer names to their corresponding ids and tensor infos + std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo; +}; +} |