aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Softmax.cpp
blob: ec4fdb8839fec8803b2aa93cb7a9fd0d477caa3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "Softmax.hpp"

#include <cmath>
#include <vector>

namespace armnn
{

unsigned int GetNumElementsBetween(const TensorShape& shape,
                                   unsigned int firstAxisInclusive,
                                   unsigned int lastAxisExclusive)
{
    BOOST_ASSERT(0 <= firstAxisInclusive);
    BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
    BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
    unsigned int count = 1;
    for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
    {
        count *= shape[i];
    }
    return count;
}

/// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
{
    BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
                     "Required axis index greater than number of dimensions.");
    BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
                     "Required axis index lower than negative of the number of dimensions");

    unsigned int uAxis = axis < 0  ?
                         inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
                         : static_cast<unsigned int>(axis);

    const TensorShape& inputShape = inputTensorInfo.GetShape();
    const unsigned int outerSize  = GetNumElementsBetween(inputShape, 0, uAxis);
    const unsigned int axisSize   = inputShape[uAxis];
    const unsigned int innerSize  = GetNumElementsBetween(inputShape, uAxis + 1, inputShape.GetNumDimensions());

    for (unsigned int outer = 0; outer < outerSize; ++outer)
    {
        unsigned int inputBeginIdx  = outer * axisSize * innerSize;
        unsigned int inputEndIdx    = inputBeginIdx + axisSize * innerSize;
        unsigned int outputBeginIdx = outer * axisSize * innerSize;

        for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
        {
            // Find max
            float maxValue = std::numeric_limits<float>::lowest();
            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
            {
                in[iter];
                maxValue = std::max(maxValue, in.Get());
            }

            // Compute sum
            float sum = 0.0f;
            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
            {
                in[iter];
                sum += std::exp((in.Get() - maxValue) * beta);
            }

            // Compute result
            unsigned int outputIter = outputBeginIdx;
            out[outputIter];
            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
            {
                out[outputIter];
                in[iter];
                out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
            }
        }
    }
}

} //namespace armnn