diff options
Diffstat (limited to 'src/backends/reference/workloads/LogSoftmax.cpp')
-rw-r--r-- | src/backends/reference/workloads/LogSoftmax.cpp | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/LogSoftmax.cpp b/src/backends/reference/workloads/LogSoftmax.cpp new file mode 100644 index 0000000000..3fa3dc0d8c --- /dev/null +++ b/src/backends/reference/workloads/LogSoftmax.cpp @@ -0,0 +1,91 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "LogSoftmax.hpp" + +#include <TensorUtils.hpp> + +#include <cmath> + +#include <boost/assert.hpp> +#include <boost/core/ignore_unused.hpp> +#include <boost/numeric/conversion/cast.hpp> + +namespace +{ + +inline bool ValidateAxis(int axis, unsigned int numDimensions) +{ + const int sNumDimensions = boost::numeric_cast<int>(numDimensions); + return axis < sNumDimensions && axis >= -sNumDimensions; +} + +} // anonymous namespace + +namespace armnn +{ + +void LogSoftmax(Decoder<float>& input, + Encoder<float>& output, + const TensorInfo& inputInfo, + const LogSoftmaxDescriptor& descriptor) +{ + const unsigned int numDimensions = inputInfo.GetNumDimensions(); + + bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions); + BOOST_ASSERT_MSG(axisIsValid, + "Axis index is not in range [-numDimensions, numDimensions)."); + boost::ignore_unused(axisIsValid); + + unsigned int uAxis = descriptor.m_Axis < 0 ? + numDimensions - boost::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) : + boost::numeric_cast<unsigned int>(descriptor.m_Axis); + + const TensorShape& inputShape = inputInfo.GetShape(); + const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis); + const unsigned int axisSize = inputShape[uAxis]; + const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape, + uAxis + 1, + inputShape.GetNumDimensions()); + + for (unsigned int outer = 0; outer < outerSize; ++outer) + { + for (unsigned int inner = 0; inner < innerSize; ++inner) + { + // Find max + input[outer * axisSize * innerSize + inner]; + float maxValue = input.Get(); + for (unsigned int i = 1u; i < axisSize; ++i) + { + input[(outer * axisSize + i) * innerSize + inner]; + maxValue = std::max(maxValue, input.Get()); + } + + // Compute sum + float sum = 0.0f; + for (unsigned int i = 0u; i < axisSize; ++i) + { + input[(outer * axisSize + i) * innerSize + inner]; + sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta); + } + + // Compute log sum + const float logSum = std::log(sum); + + // Compute result + for (unsigned int i = 0u; i < axisSize; ++i) + { + const unsigned int index = (outer * axisSize + i) * innerSize + inner; + + input [index]; + output[index]; + + output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum); + } + } + } +} + +} // namespace armnn |