aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src/pyarmnn/swig/armnn_tfliteparser.i
blob: fbe5fd77206aedc98883fa91ed5f41d0d50eb794 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
%module pyarmnn_tfliteparser
%{
#include "armnnTfLiteParser/ITfLiteParser.hpp"
#include "armnn/Types.hpp"
#include "armnn/INetwork.hpp"
%}

//typemap definitions and other common stuff
%include "standard_header.i"

namespace std {
   %template(BindingPointInfo)     pair<int, armnn::TensorInfo>;
   %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
   %template(StringVector)         vector<string>;
}

namespace armnnTfLiteParser
{
%feature("docstring",
"
Interface for creating a parser object using TfLite (https://www.tensorflow.org/lite) tflite files.

Parsers are used to automatically construct Arm NN graphs from model files.

") ITfLiteParser;
%nodefaultctor ITfLiteParser;
class ITfLiteParser
{
public:
    %feature("docstring",
        "
        Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name and subgraph id.
        Args:
            subgraphId (int): The subgraph id.
            name (str): Name of the input.

        Returns:
            tuple: (`int`, `TensorInfo`).
        ") GetNetworkInputBindingInfo;
    std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(size_t subgraphId, const std::string& name);

    %feature("docstring",
        "
        Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name and subgraph id.

        Args:
            subgraphId (int): The subgraphID.
            name (str): Name of the output.

        Returns:
            tuple: (`int`, `TensorInfo`).
        ") GetNetworkOutputBindingInfo;
    std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(size_t subgraphId, const std::string& name);

    %feature("docstring",
        "
        Return the number of subgraphs in the parsed model.
        Returns:
            int: The number of subgraphs.
        ") GetSubgraphCount;
    size_t GetSubgraphCount();

     %feature("docstring",
        "
        Return the input tensor names for a given subgraph.

        Args:
            subgraphId (int): The subgraph id.

        Returns:
            list: A list of the input tensor names for the given model.
        ") GetSubgraphInputTensorNames;
    std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId);

    %feature("docstring",
        "
        Return the output tensor names for a given subgraph.

        Args:
            subgraphId (int): The subgraph id

        Returns:
            list: A list of the output tensor names for the given model.
        ") GetSubgraphOutputTensorNames;
    std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId);
};

%extend ITfLiteParser {
// This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
// method for ITfLiteParser python object that will use static factory method to do the job.

    ITfLiteParser() {
        return armnnTfLiteParser::ITfLiteParser::CreateRaw();
    }

// The following does not replace a real destructor of the Armnn class.
// It creates a functions that will be called when swig object goes out of the scope to clean resources.
// so the user doesn't need to call ITfLiteParser::Destroy himself.
// $self` is a pointer to extracted ArmNN ITfLiteParser object.

    ~ITfLiteParser() {
        armnnTfLiteParser::ITfLiteParser::Destroy($self);
    }

    %feature("docstring",
    "
    Create the network from a flatbuffers binary file.

    Args:
        graphFile (str): Path to the tflite model to be parsed.

    Returns:
        INetwork: Parsed network.

    Raises:
        RuntimeError: If model file was not found.
    ") CreateNetworkFromBinaryFile;

    %newobject CreateNetworkFromBinaryFile;
    armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile) {
        return $self->CreateNetworkFromBinaryFile(graphFile).release();
    }

}

}
// Clear exception typemap.
%exception;