diff options
Diffstat (limited to 'python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i')
-rw-r--r-- | python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i b/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i new file mode 100644 index 0000000000..3438492d26 --- /dev/null +++ b/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i @@ -0,0 +1,102 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +%module pyarmnn_tfparser +%{ +#define SWIG_FILE_WITH_INIT +#include "armnnTfParser/ITfParser.hpp" +#include "armnn/INetwork.hpp" +%} + +//typemap definitions and other common stuff +%include "standard_header.i" + +namespace std { + %template(BindingPointInfo) pair<int, armnn::TensorInfo>; + %template(MapStringTensorShape) map<std::string, armnn::TensorShape>; + %template(StringVector) vector<string>; +} + +namespace armnnTfParser +{ +%feature("docstring", +" +Interface for creating a parser object using TensorFlow (https://www.tensorflow.org/) frozen pb files. + +Parsers are used to automatically construct Arm NN graphs from model files. + +") ITfParser; +%nodefaultctor ITfParser; +class ITfParser +{ +public: + %feature("docstring", + " + Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. + + Args: + name (str): Name of the input. + + Returns: + tuple: (`int`, `TensorInfo`). + ") GetNetworkInputBindingInfo; + std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(const std::string& name); + + %feature("docstring", + " + Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. + + Args: + name (str): Name of the output. + + Returns: + tuple: (`int`, `TensorInfo`). + ") GetNetworkOutputBindingInfo; + std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(const std::string& name); +}; + +%extend ITfParser { + // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__ + // method for ITfParser python object that will use static factory method to do the job. + + ITfParser() { + return armnnTfParser::ITfParser::CreateRaw(); + } + + // The following does not replace a real destructor of the Armnn class. + // It creates a functions that will be called when swig object goes out of the scope to clean resources. + // so the user doesn't need to call ITfParser::Destroy himself. + // $self` is a pointer to extracted ArmNN ITfParser object. + + ~ITfParser() { + armnnTfParser::ITfParser::Destroy($self); + } + + %feature("docstring", + " + Create the network from a pb Protocol buffer file. + + Args: + graphFile (str): Path to the tf model to be parsed. + inputShapes (dict): A dict containing the input name as a key & TensorShape as a value. + requestedOutputs (list of str): A list of the output tensor names. + + Returns: + INetwork: Parsed network. + + Raises: + RuntimeError: If model file was not found. + ") CreateNetworkFromBinaryFile; + %newobject CreateNetworkFromBinaryFile; + armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) { + return $self->CreateNetworkFromBinaryFile(graphFile, inputShapes, requestedOutputs).release(); + } + +} + +} +// Clear exception typemap. +%exception; |