aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/LogSoftmax.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/LogSoftmax.cpp')
-rw-r--r--src/backends/reference/workloads/LogSoftmax.cpp91
1 files changed, 91 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/LogSoftmax.cpp b/src/backends/reference/workloads/LogSoftmax.cpp
new file mode 100644
index 0000000000..3fa3dc0d8c
--- /dev/null
+++ b/src/backends/reference/workloads/LogSoftmax.cpp
@@ -0,0 +1,91 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "LogSoftmax.hpp"
+
+#include <TensorUtils.hpp>
+
+#include <cmath>
+
+#include <boost/assert.hpp>
+#include <boost/core/ignore_unused.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+namespace
+{
+
+inline bool ValidateAxis(int axis, unsigned int numDimensions)
+{
+ const int sNumDimensions = boost::numeric_cast<int>(numDimensions);
+ return axis < sNumDimensions && axis >= -sNumDimensions;
+}
+
+} // anonymous namespace
+
+namespace armnn
+{
+
+void LogSoftmax(Decoder<float>& input,
+ Encoder<float>& output,
+ const TensorInfo& inputInfo,
+ const LogSoftmaxDescriptor& descriptor)
+{
+ const unsigned int numDimensions = inputInfo.GetNumDimensions();
+
+ bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
+ BOOST_ASSERT_MSG(axisIsValid,
+ "Axis index is not in range [-numDimensions, numDimensions).");
+ boost::ignore_unused(axisIsValid);
+
+ unsigned int uAxis = descriptor.m_Axis < 0 ?
+ numDimensions - boost::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
+ boost::numeric_cast<unsigned int>(descriptor.m_Axis);
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
+ const unsigned int axisSize = inputShape[uAxis];
+ const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
+ uAxis + 1,
+ inputShape.GetNumDimensions());
+
+ for (unsigned int outer = 0; outer < outerSize; ++outer)
+ {
+ for (unsigned int inner = 0; inner < innerSize; ++inner)
+ {
+ // Find max
+ input[outer * axisSize * innerSize + inner];
+ float maxValue = input.Get();
+ for (unsigned int i = 1u; i < axisSize; ++i)
+ {
+ input[(outer * axisSize + i) * innerSize + inner];
+ maxValue = std::max(maxValue, input.Get());
+ }
+
+ // Compute sum
+ float sum = 0.0f;
+ for (unsigned int i = 0u; i < axisSize; ++i)
+ {
+ input[(outer * axisSize + i) * innerSize + inner];
+ sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
+ }
+
+ // Compute log sum
+ const float logSum = std::log(sum);
+
+ // Compute result
+ for (unsigned int i = 0u; i < axisSize; ++i)
+ {
+ const unsigned int index = (outer * axisSize + i) * innerSize + inner;
+
+ input [index];
+ output[index];
+
+ output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
+ }
+ }
+ }
+}
+
+} // namespace armnn