aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Broadcast.hpp
blob: 5bf6be8939c27dd422a0bfb242a1e61c2d055952 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>

#include <functional>

namespace armnn
{

struct BroadcastLoop
{
    BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);

    unsigned int GetNumDimensions()
    {
        return static_cast<unsigned int>(m_DimData.size());
    }

    template <typename Func, typename DecoderOp, typename EncoderOp>
    void Unroll(Func operationFunc,
                unsigned int dimension,
                DecoderOp& inData0,
                DecoderOp& inData1,
                EncoderOp& outData)
    {
        if (dimension >= GetNumDimensions())
        {
            outData.Set(operationFunc(inData0.Get(), inData1.Get()));
            return;
        }

        unsigned int inData0Movement = 0;
        unsigned int inData1Movement = 0;
        unsigned int outDataMovement = 0;

        for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
        {
            Unroll(operationFunc, dimension + 1, inData0, inData1, outData);

            inData0 += m_DimData[dimension].m_Stride1;
            inData1 += m_DimData[dimension].m_Stride2;
            outData += m_DimData[dimension].m_StrideOut;

            inData0Movement += m_DimData[dimension].m_Stride1;
            inData1Movement += m_DimData[dimension].m_Stride2;
            outDataMovement += m_DimData[dimension].m_StrideOut;
        }

        // move iterator back to the start
        inData0 -= inData0Movement;
        inData1 -= inData1Movement;
        outData -= outDataMovement;
    }

private:
    // Struct to hold the dimension data.
    struct BroadcastDimensionData
    {
        unsigned int m_DimSize;
        unsigned int m_StrideOut;
        unsigned int m_Stride1;
        unsigned int m_Stride2;
    };

    std::vector<BroadcastDimensionData> m_DimData;
};

} //namespace armnn