aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefElementwiseWorkload.hpp
blob: 371904977af9ac668e54b216e9d8bf2613d49912 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Types.hpp>
#include <backendsCommon/Workload.hpp>
#include <backendsCommon/WorkloadData.hpp>
#include "Maximum.hpp"
#include "Minimum.hpp"
#include "StringMapping.hpp"

namespace armnn
{

template <typename Functor,
          typename armnn::DataType DataType,
          typename ParentDescriptor,
          typename armnn::StringMapping::Id DebugString>
class RefElementwiseWorkload
    : public TypedWorkload<ParentDescriptor, DataType>
{
public:

    using TypedWorkload<ParentDescriptor, DataType>::m_Data;
    using TypedWorkload<ParentDescriptor, DataType>::TypedWorkload;

    void Execute() const override;
};

using RefAdditionFloat32Workload =
    RefElementwiseWorkload<std::plus<float>,
                          DataType::Float32,
                          AdditionQueueDescriptor,
                          StringMapping::RefAdditionWorkload_Execute>;

using RefAdditionUint8Workload =
    RefElementwiseWorkload<std::plus<float>,
                          DataType::QuantisedAsymm8,
                          AdditionQueueDescriptor,
                          StringMapping::RefAdditionWorkload_Execute>;

using RefSubtractionFloat32Workload =
    RefElementwiseWorkload<std::minus<float>,
                          DataType::Float32,
                          SubtractionQueueDescriptor,
                          StringMapping::RefSubtractionWorkload_Execute>;

using RefSubtractionUint8Workload =
    RefElementwiseWorkload<std::minus<float>,
                          DataType::QuantisedAsymm8,
                          SubtractionQueueDescriptor,
                          StringMapping::RefSubtractionWorkload_Execute>;

using RefMultiplicationFloat32Workload =
    RefElementwiseWorkload<std::multiplies<float>,
                          DataType::Float32,
                          MultiplicationQueueDescriptor,
                          StringMapping::RefMultiplicationWorkload_Execute>;

using RefMultiplicationUint8Workload =
    RefElementwiseWorkload<std::multiplies<float>,
                          DataType::QuantisedAsymm8,
                          MultiplicationQueueDescriptor,
                          StringMapping::RefMultiplicationWorkload_Execute>;

using RefDivisionFloat32Workload =
    RefElementwiseWorkload<std::divides<float>,
                          DataType::Float32,
                          DivisionQueueDescriptor,
                          StringMapping::RefDivisionWorkload_Execute>;

using RefDivisionUint8Workload =
    RefElementwiseWorkload<std::divides<float>,
                          DataType::QuantisedAsymm8,
                          DivisionQueueDescriptor,
                          StringMapping::RefDivisionWorkload_Execute>;

using RefMaximumFloat32Workload =
    RefElementwiseWorkload<armnn::maximum<float>,
                          DataType::Float32,
                          MaximumQueueDescriptor,
                          StringMapping::RefMaximumWorkload_Execute>;

using RefMaximumUint8Workload =
    RefElementwiseWorkload<armnn::maximum<float>,
                          DataType::QuantisedAsymm8,
                          MaximumQueueDescriptor,
                          StringMapping::RefMaximumWorkload_Execute>;

using RefMinimumFloat32Workload =
    RefElementwiseWorkload<minimum<float>,
                          DataType::Float32,
                          MinimumQueueDescriptor,
                          StringMapping::RefMinimumWorkload_Execute>;

using RefMinimumUint8Workload =
    RefElementwiseWorkload<minimum<float>,
                          DataType::QuantisedAsymm8,
                          MinimumQueueDescriptor,
                          StringMapping::RefMinimumWorkload_Execute>;
} // armnn