diff options
author | Francis Murtagh <francis.murtagh@arm.com> | 2019-07-23 09:50:50 +0100 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2019-07-23 09:50:56 +0100 |
commit | 07f2121feeeeae36a7e67eeb8a6965df63b848f3 (patch) | |
tree | 836ebdec3ca1305f6533bde3f6410780a5173781 /src/backends/reference/workloads/Softmax.cpp | |
parent | 6f3785d4f612e06854ab63dffbd2cd3d484c2e14 (diff) | |
download | armnn-07f2121feeeeae36a7e67eeb8a6965df63b848f3.tar.gz |
IVGCVSW-3536 Add Axis parameter to reference Softmax implementation
* Add Axis parameter to Softmax Descriptor
* Add new reference implementation for Softmax using Axis parameter
* Add unit tests to cover each Axis
Change-Id: Iafac2275d2212337456f2b1b56b0f76f77fb9543
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/Softmax.cpp')
-rw-r--r-- | src/backends/reference/workloads/Softmax.cpp | 83 |
1 files changed, 56 insertions, 27 deletions
diff --git a/src/backends/reference/workloads/Softmax.cpp b/src/backends/reference/workloads/Softmax.cpp index 6cb219a6cc..ec4fdb8839 100644 --- a/src/backends/reference/workloads/Softmax.cpp +++ b/src/backends/reference/workloads/Softmax.cpp @@ -11,42 +11,71 @@ namespace armnn { +unsigned int GetNumElementsBetween(const TensorShape& shape, + unsigned int firstAxisInclusive, + unsigned int lastAxisExclusive) +{ + BOOST_ASSERT(0 <= firstAxisInclusive); + BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive); + BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions()); + unsigned int count = 1; + for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++) + { + count *= shape[i]; + } + return count; +} + /// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo. -void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta) +void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis) { - unsigned int numChannels = inputTensorInfo.GetShape()[1]; + BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()), + "Required axis index greater than number of dimensions."); + BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()), + "Required axis index lower than negative of the number of dimensions"); + + unsigned int uAxis = axis < 0 ? + inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis)) + : static_cast<unsigned int>(axis); - for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++) + const TensorShape& inputShape = inputTensorInfo.GetShape(); + const unsigned int outerSize = GetNumElementsBetween(inputShape, 0, uAxis); + const unsigned int axisSize = inputShape[uAxis]; + const unsigned int innerSize = GetNumElementsBetween(inputShape, uAxis + 1, inputShape.GetNumDimensions()); + + for (unsigned int outer = 0; outer < outerSize; ++outer) { - // Find maximum channel. - in[n * numChannels]; - float max = in.Get(); - for (unsigned int c = 1; c < numChannels; c++) + unsigned int inputBeginIdx = outer * axisSize * innerSize; + unsigned int inputEndIdx = inputBeginIdx + axisSize * innerSize; + unsigned int outputBeginIdx = outer * axisSize * innerSize; + + for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx) { - in[n * numChannels + c]; - float val = in.Get(); - if (val > max) + // Find max + float maxValue = std::numeric_limits<float>::lowest(); + for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize) { - max = val; + in[iter]; + maxValue = std::max(maxValue, in.Get()); } - } - // Exponentiate all values and sum. - std::vector<float> exponentials(numChannels); - float sum = 0.0f; - for (unsigned int c = 0; c < numChannels; c++) - { - in[n * numChannels + c]; - float val = in.Get(); - exponentials[c] = expf((val - max) * beta); - sum += exponentials[c]; - } + // Compute sum + float sum = 0.0f; + for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize) + { + in[iter]; + sum += std::exp((in.Get() - maxValue) * beta); + } - // Divide exponentials by sum to give outputs. - for (unsigned int c = 0; c < numChannels; c++) - { - out[n * numChannels + c]; - out.Set(exponentials[c] / sum); + // Compute result + unsigned int outputIter = outputBeginIdx; + out[outputIter]; + for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize) + { + out[outputIter]; + in[iter]; + out.Set(std::exp((in.Get() - maxValue) * beta) / sum); + } } } } |