// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // %module pyarmnn_tfliteparser %{ #include "armnnTfLiteParser/ITfLiteParser.hpp" #include "armnn/Types.hpp" #include "armnn/INetwork.hpp" %} //typemap definitions and other common stuff %include "standard_header.i" namespace std { %template(BindingPointInfo) pair; %template(MapStringTensorShape) map; %template(StringVector) vector; } namespace armnnTfLiteParser { %feature("docstring", " Interface for creating a parser object using TfLite (https://www.tensorflow.org/lite) tflite files. Parsers are used to automatically construct Arm NN graphs from model files. ") ITfLiteParser; %nodefaultctor ITfLiteParser; class ITfLiteParser { public: %feature("docstring", " Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name and subgraph id. Args: subgraphId (int): The subgraph id. name (str): Name of the input. Returns: tuple: (`int`, `TensorInfo`). ") GetNetworkInputBindingInfo; std::pair GetNetworkInputBindingInfo(size_t subgraphId, const std::string& name); %feature("docstring", " Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name and subgraph id. Args: subgraphId (int): The subgraphID. name (str): Name of the output. Returns: tuple: (`int`, `TensorInfo`). ") GetNetworkOutputBindingInfo; std::pair GetNetworkOutputBindingInfo(size_t subgraphId, const std::string& name); %feature("docstring", " Return the number of subgraphs in the parsed model. Returns: int: The number of subgraphs. ") GetSubgraphCount; size_t GetSubgraphCount(); %feature("docstring", " Return the input tensor names for a given subgraph. Args: subgraphId (int): The subgraph id. Returns: list: A list of the input tensor names for the given model. ") GetSubgraphInputTensorNames; std::vector GetSubgraphInputTensorNames(size_t subgraphId); %feature("docstring", " Return the output tensor names for a given subgraph. Args: subgraphId (int): The subgraph id Returns: list: A list of the output tensor names for the given model. ") GetSubgraphOutputTensorNames; std::vector GetSubgraphOutputTensorNames(size_t subgraphId); }; %extend ITfLiteParser { // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__ // method for ITfLiteParser python object that will use static factory method to do the job. ITfLiteParser() { return armnnTfLiteParser::ITfLiteParser::CreateRaw(); } // The following does not replace a real destructor of the Armnn class. // It creates a functions that will be called when swig object goes out of the scope to clean resources. // so the user doesn't need to call ITfLiteParser::Destroy himself. // $self` is a pointer to extracted ArmNN ITfLiteParser object. ~ITfLiteParser() { armnnTfLiteParser::ITfLiteParser::Destroy($self); } %feature("docstring", " Create the network from a flatbuffers binary file. Args: graphFile (str): Path to the tflite model to be parsed. Returns: INetwork: Parsed network. Raises: RuntimeError: If model file was not found. ") CreateNetworkFromBinaryFile; %newobject CreateNetworkFromBinaryFile; armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile) { return $self->CreateNetworkFromBinaryFile(graphFile).release(); } } } // Clear exception typemap. %exception;