aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefLayerSupport.cpp
blob: a38c431e093506195ab1872c8689cdb0cfca90d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
//
// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "TosaRefLayerSupport.hpp"

#include <tosaCommon/TosaMappings.hpp>

#include <armnn/Types.hpp>
#include <armnn/utility/IgnoreUnused.hpp>

#include <graph_status.h>
#include <model_runner.h>

#include <vector>

namespace armnn
{

bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
                                           const std::vector<TensorInfo>& infos,
                                           const BaseDescriptor& descriptor,
                                           const Optional<LstmInputParamsInfo>& lstmParamsInfo,
                                           const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
                                           Optional<std::string&> reasonIfUnsupported) const
{
    IgnoreUnused(lstmParamsInfo);
    IgnoreUnused(quantizedLstmInputParamsInfo);
    IgnoreUnused(reasonIfUnsupported);

    std::vector<const TensorInfo*> inputInfos;
    std::vector<const TensorInfo*> outputInfos;

    switch (type)
    {
        case LayerType::Input:
        case LayerType::Output:
            return true;
        case LayerType::Addition:
        case LayerType::Multiplication:
        case LayerType::Subtraction:
        case LayerType::ElementwiseBinary:
            // Setup inputs and outputs
            inputInfos.push_back(&infos[0]);
            inputInfos.push_back(&infos[1]);
            outputInfos.push_back(&infos[2]);
            break;
        case LayerType::Concat:
            for (unsigned int i = 0; i < infos.size() - 1; ++i)
            {
                inputInfos.push_back(&infos[i]);
            }
            outputInfos.push_back(&infos.back());
            break;
        case LayerType::Constant:
            outputInfos.push_back(&infos[0]);
            break;
        case LayerType::Convolution2d:
        {
            inputInfos.push_back(&infos[0]); // input
            outputInfos.push_back(&infos[1]); // output
            inputInfos.push_back(&infos[2]); // weights

            auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
            if(conv2dDesc->m_BiasEnabled)
            {
                inputInfos.push_back(&infos[3]); // bias
            }
            break;
        }
        case LayerType::ElementwiseUnary:
        case LayerType::Pooling2d:
        case LayerType::Quantize:
        case LayerType::Reshape:
        case LayerType::Resize:
        case LayerType::Slice:
        case LayerType::Transpose:
        {
            inputInfos.push_back(&infos[0]);
            outputInfos.push_back(&infos[1]);
            break;
        }
        case LayerType::Splitter:
        {
            inputInfos.push_back(&infos[0]);
            for (unsigned int i = 1; i < infos.size(); ++i)
            {
                outputInfos.push_back(&infos[i]);
            }
            break;
        }
        case LayerType::TransposeConvolution2d:
        {
            inputInfos.push_back(&infos[0]); // input
            outputInfos.push_back(&infos[1]); // output
            inputInfos.push_back(&infos[2]); // weights

            auto conv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
            if(conv2dDesc->m_BiasEnabled)
            {
                inputInfos.push_back(&infos[3]); // bias
            }
            break;
        }
        default:
            // Default to false for all unsupported layers.
            return false;
    }

    auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor);
    if (mappings->GetName() == "")
    {
        // There currently isn't a TOSA mapping for this layer, as the default was returned.
        return false;
    }

    TosaSerializationHandler handler;

    // Add all mappings to main block.
    auto* block = new TosaSerializationBasicBlock("main",
                                                  "main",
                                                  mappings->GetOperators(),
                                                  mappings->GetTensors(),
                                                  mappings->GetInputs(),
                                                  mappings->GetOutputs());

    std::vector<TosaSerializationBasicBlock*> blocks;
    blocks.emplace_back(block);

    // Add blocks to the main region.
    auto* region = new TosaSerializationRegion("main", blocks);
    handler.GetRegions().emplace_back(region);

    GraphStatus status;
    TosaReference::IModelRunner runner;

#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
    // There currently isn't a way to disable the output from the TOSA Reference Model, but it does have a file pointer
    // to write debug output to, so set this to /dev/null (if it exists on the system) to hide the output.
    func_debug_t funcDebug;

    FILE* file = fopen("/dev/null", "w");
    funcDebug.func_debug_file = (file == nullptr) ? stderr : file;

    runner.setFuncDebug(funcDebug);
#endif

    // Initialise the model runner with the TosaSerializationHandler, which runs validation on the mapping.
    status = runner.initialize(handler);

#if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
    // Reset FuncDebug as they can persist across multiple IModelRunner instances.
    funcDebug.func_debug_file = stderr;
    runner.setFuncDebug(funcDebug);
#endif

    if(status == GraphStatus::TOSA_ERROR || status == GraphStatus::TOSA_UNPREDICTABLE)
    {
        return false;
    }
    else
    {
        return true;
    }
}

} // namespace armnn