aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
blob: f51b2109b48dd465381049d8496a3df809f850ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>

#include <tosa_generated.h>

using namespace armnn;
using namespace tosa;

// Function to return Tosa datatype from input ArmNN datatype.
inline DType ArmNNToDType(const DataType& type)
{
    switch (type)
    {
        case DataType::Float16:
        case DataType::BFloat16:
            return DType_FP16;
        case DataType::Float32:
            return DType_FP32;
        case DataType::QAsymmU8:
            return DType_UINT8;
        case DataType::QSymmS8:
        case DataType::QAsymmS8:
            return DType_INT8;
        case DataType::QSymmS16:
            return DType_INT16;
        case DataType::Signed32:
            return DType_INT32;
        case DataType::Signed64:
            // No signed 64, only DType_INT48.
            return DType_UNKNOWN;
        case DataType::Boolean:
            return DType_BOOL;
        default:
            return DType_UNKNOWN;
    }
}

// Function to return Tosa tensor shape from input ArmNN tensor shape.
inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
{
    std::vector<int32_t> returnShape;
    for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
    {
        returnShape.push_back(static_cast<int32_t>(shape[i]));
    }
    return returnShape;
}

// Function to return unique int as a string to ensure uniqueness between all input, output and block names.
static int uniqueTosaMappingID = 0;
inline std::string GetUniqueTosaMappingID()
{
    return std::to_string(++uniqueTosaMappingID);
}

// Function to return Tosa Op as string.
inline std::string TosaOpToString(Op tosaOp)
{
    switch (tosaOp)
    {
        case Op_ADD:
            return "Op_ADD";
        case Op_AVG_POOL2D:
            return "Op_AVG_POOL2D";
        case Op_MAX_POOL2D:
            return "Op_MAX_POOL2D";
        case Op_PAD:
            return "Op_PAD";
        case Op_UNKNOWN:
            return "Op_UNKNOWN";
        case Op_ARGMAX:
            return "Op_ARGMAX";
        case Op_CONV2D:
            return "Op_CONV2D";
        case Op_CONV3D:
            return "Op_CONV3D";
        case Op_DEPTHWISE_CONV2D:
            return "Op_DEPTHWISE_CONV2D";
        case Op_FULLY_CONNECTED:
            return "Op_FULLY_CONNECTED";
        case Op_MATMUL:
            return "Op_MATMUL";
        case Op_TRANSPOSE_CONV2D:
            return "Op_TRANSPOSE_CONV2D";
        case Op_CLAMP:
            return "Op_CLAMP";
        case Op_RESERVED:
            return "Op_RESERVED";
        case Op_SIGMOID:
            return "Op_SIGMOID";
        case Op_TANH:
            return "Op_TANH";
        case Op_ARITHMETIC_RIGHT_SHIFT:
            return "Op_ARITHMETIC_RIGHT_SHIFT";
        case Op_BITWISE_AND:
            return "Op_BITWISE_AND";
        case Op_BITWISE_OR:
            return "Op_BITWISE_OR";
        case Op_BITWISE_XOR:
            return "Op_BITWISE_XOR";
        case Op_INTDIV:
            return "Op_INTDIV";
        case Op_LOGICAL_AND:
            return "Op_LOGICAL_AND";
        case Op_LOGICAL_LEFT_SHIFT:
            return "Op_LOGICAL_LEFT_SHIFT";
        case Op_LOGICAL_RIGHT_SHIFT:
            return "Op_LOGICAL_RIGHT_SHIFT";
        case Op_LOGICAL_OR:
            return "Op_LOGICAL_OR";
        case Op_LOGICAL_XOR:
            return "Op_LOGICAL_XOR";
        case Op_MAXIMUM:
            return "Op_MAXIMUM";
        case Op_MINIMUM:
            return "Op_MINIMUM";
        case Op_MUL:
            return "Op_MUL";
        case Op_POW:
            return "Op_POW";
        case Op_SUB:
            return "Op_SUB";
        case Op_TABLE:
            return "Op_TABLE";
        case Op_ABS:
            return "Op_ABS";
        case Op_BITWISE_NOT:
            return "Op_BITWISE_NOT";
        case Op_CEIL:
            return "Op_CEIL";
        case Op_CLZ:
            return "Op_CLZ";
        case Op_EXP:
            return "Op_EXP";
        case Op_FLOOR:
            return "Op_FLOOR";
        case Op_LOG:
            return "Op_LOG";
        case Op_LOGICAL_NOT:
            return "Op_LOGICAL_NOT";
        case Op_NEGATE:
            return "Op_NEGATE";
        case Op_RECIPROCAL:
            return "Op_RECIPROCAL";
        case Op_RSQRT:
            return "Op_RSQRT";
        case Op_SELECT:
            return "Op_SELECT";
        case Op_EQUAL:
            return "Op_EQUAL";
        case Op_GREATER:
            return "Op_GREATER";
        case Op_GREATER_EQUAL:
            return "Op_GREATER_EQUAL";
        case Op_REDUCE_ANY:
            return "Op_REDUCE_ANY";
        case Op_REDUCE_ALL:
            return "Op_REDUCE_ALL";
        case Op_REDUCE_MAX:
            return "Op_REDUCE_MAX";
        case Op_REDUCE_MIN:
            return "Op_REDUCE_MIN";
        case Op_REDUCE_PRODUCT:
            return "Op_REDUCE_PRODUCT";
        case Op_REDUCE_SUM:
            return "Op_REDUCE_SUM";
        case Op_CONCAT:
            return "Op_CONCAT";
        case Op_RESHAPE:
            return "Op_RESHAPE";
        case Op_REVERSE:
            return "Op_REVERSE";
        case Op_SLICE:
            return "Op_SLICE";
        case Op_TILE:
            return "Op_TILE";
        case Op_TRANSPOSE:
            return "Op_TRANSPOSE";
        case Op_GATHER:
            return "Op_GATHER";
        case Op_SCATTER:
            return "Op_SCATTER";
        case Op_RESIZE:
            return "Op_RESIZE";
        case Op_CAST:
            return "Op_CAST";
        case Op_RESCALE:
            return "Op_RESCALE";
        case Op_CONST:
            return "Op_CONST";
        case Op_IDENTITY:
            return "Op_IDENTITY";
        case Op_CUSTOM:
            return "Op_CUSTOM";
        case Op_COND_IF:
            return "Op_COND_IF";
        case Op_WHILE_LOOP:
            return "Op_WHILE_LOOP";
    }
    return "";
}