aboutsummaryrefslogtreecommitdiff
path: root/1.2/HalPolicy.hpp
blob: 4d77dfe5b6f0063c8c44c1d0897c46f29c2755ec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
//
// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "../ConversionUtils.hpp"
#include "../ConversionUtils_1_2.hpp"

#include <HalInterfaces.h>

#include <armnn/Types.hpp>

namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;

namespace armnn_driver
{
class DriverOptions;
namespace hal_1_2
{

class HalPolicy
{
public:
    using Model                     = V1_2::Model;
    using Operand                   = V1_2::Operand;
    using OperandLifeTime           = V1_0::OperandLifeTime;
    using OperandType               = V1_2::OperandType;
    using Operation                 = V1_2::Operation;
    using OperationType             = V1_2::OperationType;
    using ExecutionCallback         = V1_2::IExecutionCallback;
    using getSupportedOperations_cb = V1_2::IDevice::getSupportedOperations_1_2_cb;
    using ErrorStatus               = V1_0::ErrorStatus;
    using DeviceType                = V1_2::DeviceType;

    static DeviceType GetDeviceTypeFromOptions(const DriverOptions& options);

    static bool ConvertOperation(const Operation& operation, const Model& model, ConversionData& data);

private:
    static bool ConvertArgMinMax(const Operation& operation,
                                 const Model& model,
                                 ConversionData& data,
                                 armnn::ArgMinMaxFunction argMinMaxFunction);

    static bool ConvertAveragePool2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertCast(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertChannelShuffle(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertComparison(const Operation& operation,
                                  const Model& model,
                                  ConversionData& data,
                                  armnn::ComparisonOperation comparisonOperation);

    static bool ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertDepthToSpace(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertDepthwiseConv2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertElementwiseBinary(const Operation& operation,
                                         const Model& model,
                                         ConversionData& data,
                                         armnn::BinaryOperation binaryOperation);

    static bool ConvertElementwiseUnary(const Operation& operation,
                                        const Model& model,
                                        ConversionData& data,
                                        armnn::UnaryOperation unaryOperation);

    static bool ConvertFloor(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertGather(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertInstanceNormalization(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertL2Normalization(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertL2Pool2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertLocalResponseNormalization(const Operation& operation,
                                                  const Model& model,
                                                  ConversionData& data);

    static bool ConvertLogistic(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertLogSoftmax(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertLstm(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertMaxPool2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertMean(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertPad(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertPadV2(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertQuantize(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertQuantized16BitLstm(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertReduce(const Operation& operation,
                              const Model& model,
                              ConversionData& data,
                              ReduceOperation reduce_operation);

    static bool ConvertReLu(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertReLu1(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertReLu6(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertReshape(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertResize(const Operation& operation,
                              const Model& model,
                              ConversionData& data,
                              armnn::ResizeMethod resizeMethod);

    static bool ConvertSoftmax(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertSpaceToBatchNd(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertSpaceToDepth(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertSplit(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertSqrt(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertSqueeze(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertStridedSlice(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertTanH(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertTranspose(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertTransposeConv2d(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertTile(const Operation& operation, const Model& model, ConversionData& data);

    static bool ConvertUnidirectionalSequenceLstm(const Operation& operation,
                                                  const Model& model,
                                                  ConversionData& data);
};

} // namespace hal_1_2
} // namespace armnn_driver