aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefChannelShuffleWorkload.cpp
blob: 6571715c635e65905595b30b1cee959bfe8772a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <backendsCommon/test/DataTypeUtils.hpp>
#include <armnn/backends/ITensorHandleFactory.hpp>
#include <armnnUtils/Transpose.hpp>
#include "RefChannelShuffleWorkload.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
#include "Decoders.hpp"
#include "Encoders.hpp"

namespace armnn
{
void RefChannelShuffleWorkload::Execute() const
{
    Execute(m_Data.m_Inputs, m_Data.m_Outputs);
}

void RefChannelShuffleWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
{
    Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
}

// Reference implementation for channel shuffle taken from
// https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/master/nn/common/operations/ChannelShuffle.cpp
void RefChannelShuffleWorkload::Execute(std::vector<ITensorHandle*> inputs,
                                        std::vector<ITensorHandle*> outputs) const
{
    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefChannelShuffleWorkload_Execute");

    const TensorInfo& inputInfo  = GetTensorInfo(inputs[0]);
    const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
    std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map());
    Decoder<float>& decoder = *decoderPtr;

    std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
    Encoder<float>& encoder = *encoderPtr;

    auto getNumberOfElements = [](const TensorShape& tensorShape,uint32_t startAxis, uint32_t lastAxis)
    {
        uint32_t count = 1;
        for (uint32_t i = startAxis; i < lastAxis; i++)
        {
            count *= tensorShape[i];
        }
        return count;
    };
    const TensorShape tensorShape = GetTensorInfo(inputs[0]).GetShape();
    uint32_t channelsAxis = m_Data.m_Parameters.m_Axis; // channelsAxis to perform channel shuffle on

    const uint32_t numGroups = m_Data.m_Parameters.m_NumGroups;
    const uint32_t groupSize = tensorShape[channelsAxis] / numGroups;

    uint32_t outerSize = getNumberOfElements(tensorShape, 0, channelsAxis);
    uint32_t innerSize = getNumberOfElements(tensorShape, channelsAxis + 1, tensorShape.GetNumDimensions());

    for (uint32_t outer = 0; outer < outerSize; ++outer)
    {
        for (uint32_t inner = 0; inner < innerSize; ++inner)
        {
            uint32_t decoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
            decoder += decoderStep1;
            uint32_t encoderStep1 = outer * tensorShape[channelsAxis] * innerSize + inner;
            encoder += encoderStep1;
            for (uint32_t i = 0; i < groupSize; i++)
            {
                for (uint32_t j = 0; j < numGroups; j++, encoder += innerSize, encoderStep1 += innerSize)
                {
                    decoder += innerSize * (i + j * groupSize);
                    float decoded = decoder.Get();
                    encoder.Set(decoded);
                    decoder -= innerSize * (i + j * groupSize);
                }
            }
            decoder -= decoderStep1;
            encoder -= encoderStep1;
        }
    }
}
}