aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefElementwiseWorkload.hpp
blob: 4b108e436375f636ffce17d957c92f4b1625375f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Types.hpp>
#include <armnn/backends/Workload.hpp>
#include <armnn/backends/WorkloadData.hpp>
#include "BaseIterator.hpp"
#include "ElementwiseFunction.hpp"
#include "Maximum.hpp"
#include "Minimum.hpp"
#include "StringMapping.hpp"

namespace armnn
{

template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
{
public:
    using InType = typename ElementwiseBinaryFunction<Functor>::InType;
    using OutType = typename ElementwiseBinaryFunction<Functor>::OutType;
    using BaseWorkload<ParentDescriptor>::m_Data;

    RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info);
    void Execute() const override;
    void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor)  override;

private:
    void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
};

template <typename DataType = float>
using RefAdditionWorkload =
    RefElementwiseWorkload<std::plus<DataType>,
                          AdditionQueueDescriptor,
                          StringMapping::RefAdditionWorkload_Execute>;

template <typename DataType = float>
using RefSubtractionWorkload =
    RefElementwiseWorkload<std::minus<DataType>,
                          SubtractionQueueDescriptor,
                          StringMapping::RefSubtractionWorkload_Execute>;

template <typename DataType = float>
using RefMultiplicationWorkload =
    RefElementwiseWorkload<std::multiplies<DataType>,
                          MultiplicationQueueDescriptor,
                          StringMapping::RefMultiplicationWorkload_Execute>;

template <typename DataType = float>
using RefDivisionWorkload =
    RefElementwiseWorkload<std::divides<DataType>,
                          DivisionQueueDescriptor,
                          StringMapping::RefDivisionWorkload_Execute>;

template <typename DataType = float>
using RefMaximumWorkload =
    RefElementwiseWorkload<armnn::maximum<DataType>,
                          MaximumQueueDescriptor,
                          StringMapping::RefMaximumWorkload_Execute>;

template <typename DataType = float>
using RefMinimumWorkload =
    RefElementwiseWorkload<armnn::minimum<DataType>,
                          MinimumQueueDescriptor,
                          StringMapping::RefMinimumWorkload_Execute>;

} // armnn