diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.hpp')
-rw-r--r-- | src/armnnTfParser/TfParser.hpp | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp index c5b4bce8ac..75cd3a5bd0 100644 --- a/src/armnnTfParser/TfParser.hpp +++ b/src/armnnTfParser/TfParser.hpp @@ -36,9 +36,9 @@ using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>; /// /// WithOutputTensorIndex wraps a value and an index. The purpose of -/// this template is to signify that in Tensorflow the input name of -/// a layer has the convention of 'inputTensorName:#index' where the -/// #index can be omitted and it implicitly means the 0. output of +/// this template is to signify that, in Tensorflow, the input name of +/// a layer has the convention of 'inputTensorName:#index', where the +/// #index can be omitted and it implicitly means the 0 output of /// the referenced layer. By supporting this notation we can handle /// layers with multiple outputs, such as Split. /// @@ -64,28 +64,28 @@ using OutputId = WithOutputTensorIndex<std::string>; class TfParser : public ITfParser { public: - /// Create the network from a protobuf text file on disk + /// Creates the network from a protobuf text file on the disk. virtual armnn::INetworkPtr CreateNetworkFromTextFile( const char* graphFile, const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs) override; - /// Create the network from a protobuf binary file on disk + /// Creates the network from a protobuf binary file on the disk. virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( const char* graphFile, const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs) override; - /// Create the network directly from protobuf text in a string. Useful for debugging/testing + /// Creates the network directly from protobuf text in a string. Useful for debugging/testing. virtual armnn::INetworkPtr CreateNetworkFromString( const char* protoText, const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs) override; - /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name + /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name. virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; - /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name + /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name. virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; public: @@ -95,19 +95,20 @@ private: template <typename T> friend class ParsedConstTfOperation; friend class ParsedMatMulTfOperation; + friend class ParsedMulTfOperation; - /// Parses a GraphDef loaded into memory from one of the other CreateNetwork* + /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*. armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef, const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs); - /// sets up variables and then performs BFS to parse all nodes + /// Sets up variables and then performs BFS to parse all nodes. void LoadGraphDef(const tensorflow::GraphDef& graphDef); - /// parses a given node, assuming nodes before it in graph have been done + /// Parses a given node, assuming nodes before it in the graph have been done. void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); - /// Handling identity layers as the input for Conv2D layer + /// Handling identity layers as the input for Conv2D layer. const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef); /// Finds the nodes connected as inputs of the given node in the graph. std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const; @@ -120,7 +121,7 @@ private: ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); - /// Checks if there is a pre-parsed const tensor is available with the given name and Type + /// Checks if there is a pre-parsed const tensor available with the given name and Type. template<typename Type> bool HasParsedConstTensor(const std::string & nodeName) const; @@ -149,11 +150,22 @@ private: ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef, armnn::PoolingAlgorithm pooltype); + ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc); ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false); + +private: + armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef); + armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef, const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName); + bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef, + size_t alphaLayerIndex, + const OutputOfParsedTfOperation& otherOp, + armnn::IOutputSlot** outputOfLeakyRelu, + armnn::ActivationDescriptor & desc); + static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName, const char* bindingPointDesc, const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo); @@ -173,27 +185,27 @@ private: void Cleanup(); - /// The network we're building. Gets cleared after it is passed to the user + /// The network we're building. Gets cleared after it is passed to the user. armnn::INetworkPtr m_Network; using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); - /// map of TensorFlow operation names to parsing member functions + /// Map of TensorFlow operation names to parsing member functions. static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions; std::map<std::string, armnn::TensorShape> m_InputShapes; std::vector<std::string> m_RequestedOutputs; - /// map of nodes extracted from the GraphDef to speed up parsing + /// Map of nodes extracted from the GraphDef to speed up parsing. std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName; std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations; - /// maps input layer names to their corresponding ids and tensor infos + /// Maps input layer names to their corresponding ids and tensor info. std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo; - /// maps output layer names to their corresponding ids and tensor infos + /// Maps output layer names to their corresponding ids and tensor info. std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo; }; } |