diff options
Diffstat (limited to 'include/armnnTfParser/ITfParser.hpp')
-rw-r--r-- | include/armnnTfParser/ITfParser.hpp | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/include/armnnTfParser/ITfParser.hpp b/include/armnnTfParser/ITfParser.hpp index b0ffc0d379..91e4cb39bf 100644 --- a/include/armnnTfParser/ITfParser.hpp +++ b/include/armnnTfParser/ITfParser.hpp @@ -30,31 +30,48 @@ public: static void Destroy(ITfParser* parser); /// Create the network from a protobuf text file on the disk. - virtual armnn::INetworkPtr CreateNetworkFromTextFile( + armnn::INetworkPtr CreateNetworkFromTextFile( const char* graphFile, const std::map<std::string, armnn::TensorShape>& inputShapes, - const std::vector<std::string>& requestedOutputs) = 0; + const std::vector<std::string>& requestedOutputs); /// Create the network from a protobuf binary file on the disk. - virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( + armnn::INetworkPtr CreateNetworkFromBinaryFile( const char* graphFile, const std::map<std::string, armnn::TensorShape>& inputShapes, - const std::vector<std::string>& requestedOutputs) = 0; + const std::vector<std::string>& requestedOutputs); /// Create the network directly from protobuf text in a string. Useful for debugging/testing. - virtual armnn::INetworkPtr CreateNetworkFromString( + armnn::INetworkPtr CreateNetworkFromString( const char* protoText, const std::map<std::string, armnn::TensorShape>& inputShapes, - const std::vector<std::string>& requestedOutputs) = 0; + const std::vector<std::string>& requestedOutputs); /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. - virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0; + BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const; /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. - virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0; + BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const; -protected: - virtual ~ITfParser() {}; +private: + template <typename T> + friend class ParsedConstTfOperation; + friend class ParsedMatMulTfOperation; + friend class ParsedMulTfOperation; + friend class ParsedTfOperation; + friend class SingleLayerParsedTfOperation; + friend class DeferredSingleLayerParsedTfOperation; + friend class ParsedIdentityTfOperation; + + template <template<typename> class OperatorType, typename T> + friend struct MakeTfOperation; + + + ITfParser(); + ~ITfParser(); + + struct TfParserImpl; + std::unique_ptr<TfParserImpl> pTfParserImpl; }; } |