aboutsummaryrefslogtreecommitdiff
path: root/ModelToINetworkConverter.cpp
blob: 24fb49038f51ee1cdd2d425df7a63a567cc30f7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#define LOG_TAG "ArmnnDriver"

#include "ModelToINetworkConverter.hpp"
#include "Utils.hpp"

#include <log/log.h>
#include <type_traits>

namespace armnn_driver
{

template<typename HalPolicy>
ModelToINetworkConverter<HalPolicy>::ModelToINetworkConverter(const std::vector<armnn::BackendId>& backends,
    const HalModel& model,
    const std::set<unsigned int>& forcedUnsupportedOperations)
    : m_Data(backends)
    , m_Model(model)
    , m_ForcedUnsupportedOperations(forcedUnsupportedOperations)
    , m_ConversionResult(ConversionResult::Success)
{
    try
    {
        Convert();
    }
    catch (std::exception& e)
    {
        m_ConversionResult = ConversionResult::UnsupportedFeature;
        ALOGE("%s: Unexpected exception: %s", __func__, e.what());
        assert(false);
    }
}

template<typename HalPolicy>
void ModelToINetworkConverter<HalPolicy>::Convert()
{
    using HalModel       = typename HalPolicy::Model;
    using HalOperand     = typename HalPolicy::Operand;
    using HalOperandType = typename HalPolicy::OperandType;

    ALOGV("ModelToINetworkConverter::Convert(): %s", GetModelSummary<HalModel>(m_Model).c_str());

    // map the memory pool into shared pointers
    m_Data.m_MemPools.clear();
    if (!setRunTimePoolInfosFromHidlMemories(&m_Data.m_MemPools, m_Model.pools))
    {
        Fail("%s: Setting of run time pool infos from Hidl Memories has failed.", __func__);
        m_ConversionResult = ConversionResult::ErrorMappingPools;
        return;
    }

    uint32_t totalPoolSize = 0;
    for (auto&& pool : m_Model.pools)
    {
        totalPoolSize += pool.size();
    }

    // Create armnn::INetwork
    m_Data.m_Network = armnn::INetwork::Create();

    // add operations to it
    // track which layer outputs each operand
    ALOGV("ModelToINetworkConverter::Convert(): m_OutputSlotForOperand");
    m_Data.m_OutputSlotForOperand = std::vector<armnn::IOutputSlot*>(getMainModel(m_Model).operands.size(), nullptr);
    try
    {
        ALOGV("ModelToINetworkConverter::Convert(): for getMainModel(m_Model).inputIndexes.size()");
        for (uint32_t i = 0; i < getMainModel(m_Model).inputIndexes.size(); i++)
        {
            ALOGV("ModelToINetworkConverter::Convert(): getMainModel(m_Model).inputIndexes[i]");
            // inputs in android nn are represented by operands
            uint32_t inputIndex = getMainModel(m_Model).inputIndexes[i];
            ALOGV("ModelToINetworkConverter::Convert(): getMainModel(m_Model).operands[inputIndex];");
            const HalOperand& operand = getMainModel(m_Model).operands[inputIndex];
            ALOGV("ModelToINetworkConverter::Convert(): GetTensorInfoForOperand(operand)");
            const armnn::TensorInfo& tensor = GetTensorInfoForOperand(operand);
            ALOGV("ModelToINetworkConverter::Convert(): m_Data.m_Network->AddInputLayer(i)");
            armnn::IConnectableLayer* layer = m_Data.m_Network->AddInputLayer(i);

            ALOGV("ModelToINetworkConverter::Convert(): layer->GetOutputSlot(0)");
            armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
            ALOGV("ModelToINetworkConverter::Convert(): outputSlot.SetTensorInfo(GetTensorInfoForOperand(operand))");
            outputSlot.SetTensorInfo(GetTensorInfoForOperand(operand));

            ALOGV("ModelToINetworkConverter::Convert(): m_Data.m_OutputSlotForOperand[inputIndex] = &outputSlot");
            // store for later layers
            m_Data.m_OutputSlotForOperand[inputIndex] = &outputSlot;
        }
    }
    catch (UnsupportedOperand<HalOperandType>& e)
    {
        Fail("%s: Operand type %s not supported in ArmnnDriver", __func__, toString(e.m_type).c_str());
        m_ConversionResult = ConversionResult::UnsupportedFeature;
    }
    catch (const armnn::InvalidArgumentException& e)
    {
        Fail("%s: Failed to convert input operand to TensorShape: %s", __func__, e.what());
        m_ConversionResult = ConversionResult::UnsupportedFeature;
    }

    for (uint32_t operationIdx = 0; operationIdx < getMainModel(m_Model).operations.size(); operationIdx++)
    {
        const auto& operation = getMainModel(m_Model).operations[operationIdx];

        bool ok = true;
        if (m_ForcedUnsupportedOperations.find(operationIdx) != m_ForcedUnsupportedOperations.end())
        {
            Fail("%s: Operation at index %i has been forced to be unsupported.", __func__, operationIdx);
            ok = false;
        }

        if (ok)
        {
            try
            {
                ok = HalPolicy::ConvertOperation(operation, m_Model, m_Data);
            }
            catch (UnsupportedOperand<HalOperandType>& e)
            {
                Fail("%s: Operand type %s not supported in ArmnnDriver", __func__, toString(e.m_type).c_str());
                ok = false;
            }
            catch (const armnn::InvalidArgumentException& e)
            {
                Fail("%s: Failed to convert operation in %s", __func__, e.what());
                ok = false;
            }
        }

        // Store whether this operation was successfully converted.
        m_OperationSupported.emplace(operationIdx, ok);

        // Any single operation failing will fail the entire conversion.
        // We still need to continue and check the other ones.
        if (!ok)
        {
            m_ConversionResult = ConversionResult::UnsupportedFeature;
        }
    }
    try
    {
        if (m_ConversionResult == ConversionResult::Success)
        {
            for (uint32_t i = 0; i < getMainModel(m_Model).outputIndexes.size(); i++)
            {
                // outputs in android nn are represented by operands
                uint32_t outputIndex = getMainModel(m_Model).outputIndexes[i];
                const HalOperand& operand = getMainModel(m_Model).operands[outputIndex];
                const armnn::TensorInfo& tensor = GetTensorInfoForOperand(operand);
                armnn::IConnectableLayer* layer = m_Data.m_Network->AddOutputLayer(i);

                assert(m_Data.m_OutputSlotForOperand[outputIndex]);
                m_Data.m_OutputSlotForOperand[outputIndex]->Connect(layer->GetInputSlot(0));
            }
        }
    }
    catch (const armnn::InvalidArgumentException& e)
    {
        Fail("%s: Failed to convert output operand to TensorShape: %s", __func__, e.what());
        m_ConversionResult = ConversionResult::UnsupportedFeature;
    }
}

template<typename HalPolicy>
bool ModelToINetworkConverter<HalPolicy>::IsOperationSupported(uint32_t operationIndex) const
{
    std::map<uint32_t, bool>::const_iterator it = m_OperationSupported.find(operationIndex);
    assert(it != m_OperationSupported.end());
    return it->second;
}

///
/// Class template specializations
///

template class ModelToINetworkConverter<hal_1_0::HalPolicy>;

#ifdef ARMNN_ANDROID_NN_V1_1
template class ModelToINetworkConverter<hal_1_1::HalPolicy>;
#endif

#ifdef ARMNN_ANDROID_NN_V1_2
template class ModelToINetworkConverter<hal_1_1::HalPolicy>;
template class ModelToINetworkConverter<hal_1_2::HalPolicy>;
#endif

#ifdef ARMNN_ANDROID_NN_V1_3
template class ModelToINetworkConverter<hal_1_1::HalPolicy>;
template class ModelToINetworkConverter<hal_1_2::HalPolicy>;
template class ModelToINetworkConverter<hal_1_3::HalPolicy>;
#endif

} // armnn_driver