aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Softmax.cpp
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2019-07-23 09:50:50 +0100
committerFrancis Murtagh <francis.murtagh@arm.com>2019-07-23 09:50:56 +0100
commit07f2121feeeeae36a7e67eeb8a6965df63b848f3 (patch)
tree836ebdec3ca1305f6533bde3f6410780a5173781 /src/backends/reference/workloads/Softmax.cpp
parent6f3785d4f612e06854ab63dffbd2cd3d484c2e14 (diff)
downloadarmnn-07f2121feeeeae36a7e67eeb8a6965df63b848f3.tar.gz
IVGCVSW-3536 Add Axis parameter to reference Softmax implementation
* Add Axis parameter to Softmax Descriptor * Add new reference implementation for Softmax using Axis parameter * Add unit tests to cover each Axis Change-Id: Iafac2275d2212337456f2b1b56b0f76f77fb9543 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/Softmax.cpp')
-rw-r--r--src/backends/reference/workloads/Softmax.cpp83
1 files changed, 56 insertions, 27 deletions
diff --git a/src/backends/reference/workloads/Softmax.cpp b/src/backends/reference/workloads/Softmax.cpp
index 6cb219a6cc..ec4fdb8839 100644
--- a/src/backends/reference/workloads/Softmax.cpp
+++ b/src/backends/reference/workloads/Softmax.cpp
@@ -11,42 +11,71 @@
namespace armnn
{
+unsigned int GetNumElementsBetween(const TensorShape& shape,
+ unsigned int firstAxisInclusive,
+ unsigned int lastAxisExclusive)
+{
+ BOOST_ASSERT(0 <= firstAxisInclusive);
+ BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
+ BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
+ unsigned int count = 1;
+ for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
/// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
-void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta)
+void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
{
- unsigned int numChannels = inputTensorInfo.GetShape()[1];
+ BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
+ "Required axis index greater than number of dimensions.");
+ BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
+ "Required axis index lower than negative of the number of dimensions");
+
+ unsigned int uAxis = axis < 0 ?
+ inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
+ : static_cast<unsigned int>(axis);
- for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
+ const TensorShape& inputShape = inputTensorInfo.GetShape();
+ const unsigned int outerSize = GetNumElementsBetween(inputShape, 0, uAxis);
+ const unsigned int axisSize = inputShape[uAxis];
+ const unsigned int innerSize = GetNumElementsBetween(inputShape, uAxis + 1, inputShape.GetNumDimensions());
+
+ for (unsigned int outer = 0; outer < outerSize; ++outer)
{
- // Find maximum channel.
- in[n * numChannels];
- float max = in.Get();
- for (unsigned int c = 1; c < numChannels; c++)
+ unsigned int inputBeginIdx = outer * axisSize * innerSize;
+ unsigned int inputEndIdx = inputBeginIdx + axisSize * innerSize;
+ unsigned int outputBeginIdx = outer * axisSize * innerSize;
+
+ for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
{
- in[n * numChannels + c];
- float val = in.Get();
- if (val > max)
+ // Find max
+ float maxValue = std::numeric_limits<float>::lowest();
+ for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
{
- max = val;
+ in[iter];
+ maxValue = std::max(maxValue, in.Get());
}
- }
- // Exponentiate all values and sum.
- std::vector<float> exponentials(numChannels);
- float sum = 0.0f;
- for (unsigned int c = 0; c < numChannels; c++)
- {
- in[n * numChannels + c];
- float val = in.Get();
- exponentials[c] = expf((val - max) * beta);
- sum += exponentials[c];
- }
+ // Compute sum
+ float sum = 0.0f;
+ for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
+ {
+ in[iter];
+ sum += std::exp((in.Get() - maxValue) * beta);
+ }
- // Divide exponentials by sum to give outputs.
- for (unsigned int c = 0; c < numChannels; c++)
- {
- out[n * numChannels + c];
- out.Set(exponentials[c] / sum);
+ // Compute result
+ unsigned int outputIter = outputBeginIdx;
+ out[outputIter];
+ for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
+ {
+ out[outputIter];
+ in[iter];
+ out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
+ }
}
}
}