aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/Broadcast.hpp
blob: b65b57f7a1a77890d03f1e6c233bcb1996acb7c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//

#include <armnn/Tensor.hpp>

#include <functional>

namespace armnn
{

struct BroadcastLoop
{
    BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);

    unsigned int GetNumDimensions()
    {
        return static_cast<unsigned int>(m_DimData.size());
    }

    template <typename T0, typename T1, typename U, typename Func>
    void Unroll(Func operationFunc,
                unsigned int dimension,
                const T0* inData0,
                const T1* inData1,
                U* outData)
    {
        if (dimension >= GetNumDimensions())
        {
            *outData = operationFunc(*inData0, *inData1);
            return;
        }

        for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
        {
            Unroll(operationFunc, dimension + 1, inData0, inData1, outData);

            inData0 += m_DimData[dimension].m_Stride1;
            inData1 += m_DimData[dimension].m_Stride2;
            outData += m_DimData[dimension].m_StrideOut;
        }
    }

private:
    // Struct to hold the dimension data
    struct BroadcastDimensionData
    {
        unsigned int m_DimSize;
        unsigned int m_StrideOut;
        unsigned int m_Stride1;
        unsigned int m_Stride2;
    };

    std::vector<BroadcastDimensionData> m_DimData;
};

} //namespace armnn