diff options
Diffstat (limited to 'include/armnnTfParser/ITfParser.hpp')
-rw-r--r-- | include/armnnTfParser/ITfParser.hpp | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/include/armnnTfParser/ITfParser.hpp b/include/armnnTfParser/ITfParser.hpp new file mode 100644 index 0000000000..a6f56c8a19 --- /dev/null +++ b/include/armnnTfParser/ITfParser.hpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnn/Types.hpp" +#include "armnn/Tensor.hpp" +#include "armnn/INetwork.hpp" + +#include <map> +#include <memory> +#include <unordered_map> +#include <vector> + +namespace armnnTfParser +{ + +using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>; + +class ITfParser; +using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>; + +/// parses a directed acyclic graph from a tensorflow protobuf file +class ITfParser +{ +public: + static ITfParser* CreateRaw(); + static ITfParserPtr Create(); + static void Destroy(ITfParser* parser); + + /// Create the network from a protobuf text file on disk + virtual armnn::INetworkPtr CreateNetworkFromTextFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Create the network from a protobuf binary file on disk + virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Create the network directly from protobuf text in a string. Useful for debugging/testing + virtual armnn::INetworkPtr CreateNetworkFromString( + const char* protoText, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name + virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0; + + /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name + virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0; + +protected: + virtual ~ITfParser() {}; +}; + +} |