// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // %module pyarmnn_tfparser %{ #define SWIG_FILE_WITH_INIT #include "armnnTfParser/ITfParser.hpp" #include "armnn/INetwork.hpp" %} //typemap definitions and other common stuff %include "standard_header.i" namespace std { %template(BindingPointInfo) pair; %template(MapStringTensorShape) map; %template(StringVector) vector; } namespace armnnTfParser { %feature("docstring", " Interface for creating a parser object using TensorFlow (https://www.tensorflow.org/) frozen pb files. Parsers are used to automatically construct Arm NN graphs from model files. ") ITfParser; %nodefaultctor ITfParser; class ITfParser { public: %feature("docstring", " Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. Args: name (str): Name of the input. Returns: tuple: (`int`, `TensorInfo`). ") GetNetworkInputBindingInfo; std::pair GetNetworkInputBindingInfo(const std::string& name); %feature("docstring", " Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. Args: name (str): Name of the output. Returns: tuple: (`int`, `TensorInfo`). ") GetNetworkOutputBindingInfo; std::pair GetNetworkOutputBindingInfo(const std::string& name); }; %extend ITfParser { // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__ // method for ITfParser python object that will use static factory method to do the job. ITfParser() { return armnnTfParser::ITfParser::CreateRaw(); } // The following does not replace a real destructor of the Armnn class. // It creates a functions that will be called when swig object goes out of the scope to clean resources. // so the user doesn't need to call ITfParser::Destroy himself. // $self` is a pointer to extracted ArmNN ITfParser object. ~ITfParser() { armnnTfParser::ITfParser::Destroy($self); } %feature("docstring", " Create the network from a pb Protocol buffer file. Args: graphFile (str): Path to the tf model to be parsed. inputShapes (dict): A dict containing the input name as a key & TensorShape as a value. requestedOutputs (list of str): A list of the output tensor names. Returns: INetwork: Parsed network. Raises: RuntimeError: If model file was not found. ") CreateNetworkFromBinaryFile; %newobject CreateNetworkFromBinaryFile; armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile, const std::map& inputShapes, const std::vector& requestedOutputs) { return $self->CreateNetworkFromBinaryFile(graphFile, inputShapes, requestedOutputs).release(); } } } // Clear exception typemap. %exception;