aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/LogSoftmax.cpp
blob: 3fa3dc0d8c0f8e218115e4c363cb0b0e27498b0d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
//
// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "LogSoftmax.hpp"

#include <TensorUtils.hpp>

#include <cmath>

#include <boost/assert.hpp>
#include <boost/core/ignore_unused.hpp>
#include <boost/numeric/conversion/cast.hpp>

namespace
{

inline bool ValidateAxis(int axis, unsigned int numDimensions)
{
    const int sNumDimensions = boost::numeric_cast<int>(numDimensions);
    return axis < sNumDimensions && axis >= -sNumDimensions;
}

} // anonymous namespace

namespace armnn
{

void LogSoftmax(Decoder<float>& input,
                Encoder<float>& output,
                const TensorInfo& inputInfo,
                const LogSoftmaxDescriptor& descriptor)
{
    const unsigned int numDimensions = inputInfo.GetNumDimensions();

    bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
    BOOST_ASSERT_MSG(axisIsValid,
        "Axis index is not in range [-numDimensions, numDimensions).");
    boost::ignore_unused(axisIsValid);

    unsigned int uAxis = descriptor.m_Axis < 0  ?
        numDimensions - boost::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
        boost::numeric_cast<unsigned int>(descriptor.m_Axis);

    const TensorShape& inputShape = inputInfo.GetShape();
    const unsigned int outerSize  = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
    const unsigned int axisSize   = inputShape[uAxis];
    const unsigned int innerSize  = armnnUtils::GetNumElementsBetween(inputShape,
                                                                      uAxis + 1,
                                                                      inputShape.GetNumDimensions());

    for (unsigned int outer = 0; outer < outerSize; ++outer)
    {
        for (unsigned int inner = 0; inner < innerSize; ++inner)
        {
            // Find max
            input[outer * axisSize * innerSize + inner];
            float maxValue = input.Get();
            for (unsigned int i = 1u; i < axisSize; ++i)
            {
                input[(outer * axisSize + i) * innerSize + inner];
                maxValue = std::max(maxValue, input.Get());
            }

            // Compute sum
            float sum = 0.0f;
            for (unsigned int i = 0u; i < axisSize; ++i)
            {
                input[(outer * axisSize + i) * innerSize + inner];
                sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
            }

            // Compute log sum
            const float logSum = std::log(sum);

            // Compute result
            for (unsigned int i = 0u; i < axisSize; ++i)
            {
                const unsigned int index = (outer * axisSize + i) * innerSize + inner;

                input [index];
                output[index];

                output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
            }
        }
    }
}

} // namespace armnn