aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src/pyarmnn/swig/armnn_tfparser.i
blob: 3438492d264fa51ac21052bda33443743bae2779 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
%module pyarmnn_tfparser
%{
#define SWIG_FILE_WITH_INIT
#include "armnnTfParser/ITfParser.hpp"
#include "armnn/INetwork.hpp"
%}

//typemap definitions and other common stuff
%include "standard_header.i"

namespace std {
   %template(BindingPointInfo)     pair<int, armnn::TensorInfo>;
   %template(MapStringTensorShape) map<std::string, armnn::TensorShape>;
   %template(StringVector)         vector<string>;
}

namespace armnnTfParser
{
%feature("docstring",
"
Interface for creating a parser object using TensorFlow (https://www.tensorflow.org/) frozen pb files.

Parsers are used to automatically construct Arm NN graphs from model files.

") ITfParser;
%nodefaultctor ITfParser;
class ITfParser
{
public:
    %feature("docstring",
        "
        Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

        Args:
            name (str): Name of the input.

        Returns:
            tuple: (`int`, `TensorInfo`).
        ") GetNetworkInputBindingInfo;
    std::pair<int, armnn::TensorInfo> GetNetworkInputBindingInfo(const std::string& name);

    %feature("docstring",
        "
        Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

        Args:
            name (str): Name of the output.

        Returns:
            tuple: (`int`, `TensorInfo`).
        ") GetNetworkOutputBindingInfo;
    std::pair<int, armnn::TensorInfo> GetNetworkOutputBindingInfo(const std::string& name);
};

%extend ITfParser {
    // This is not a substitution of the default constructor of the Armnn class. It tells swig to create custom __init__
    // method for ITfParser python object that will use static factory method to do the job.

    ITfParser() {
        return armnnTfParser::ITfParser::CreateRaw();
    }

    // The following does not replace a real destructor of the Armnn class.
    // It creates a functions that will be called when swig object goes out of the scope to clean resources.
    // so the user doesn't need to call ITfParser::Destroy himself.
    // $self` is a pointer to extracted ArmNN ITfParser object.

    ~ITfParser() {
        armnnTfParser::ITfParser::Destroy($self);
    }

    %feature("docstring",
    "
    Create the network from a pb Protocol buffer file.

    Args:
        graphFile (str): Path to the tf model to be parsed.
        inputShapes (dict): A dict containing the input name as a key & TensorShape as a value.
        requestedOutputs (list of str): A list of the output tensor names.

    Returns:
        INetwork: Parsed network.

    Raises:
        RuntimeError: If model file was not found.
     ") CreateNetworkFromBinaryFile;
    %newobject CreateNetworkFromBinaryFile;
    armnn::INetwork* CreateNetworkFromBinaryFile(const char* graphFile,
                                                 const std::map<std::string, armnn::TensorShape>& inputShapes,
                                                 const std::vector<std::string>& requestedOutputs) {
        return $self->CreateNetworkFromBinaryFile(graphFile, inputShapes, requestedOutputs).release();
    }

}

}
// Clear exception typemap.
%exception;