diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeUtils.hpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeUtils.hpp | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp index 0ee13b3e7f..eae152dc20 100644 --- a/src/backends/aclCommon/ArmComputeUtils.hpp +++ b/src/backends/aclCommon/ArmComputeUtils.hpp @@ -153,18 +153,21 @@ inline arm_compute::InterpolationPolicy ConvertResizeMethodToAclInterpolationPol template<typename T> inline T ComputeSoftmaxAclAxis(const SoftmaxDescriptor& softmaxDesc, const armnn::TensorInfo& tensor) { - // Detect the Android default value of -1 and return the ACL default value of 1. + // Detect the Android default value of -1 and return the ACL default value of 0. if (softmaxDesc.m_Axis == -1) { - return 1; + return 0; } - unsigned int dim = tensor.GetNumDimensions(); + unsigned int dim = tensor.GetNumDimensions(); ARMNN_ASSERT(dim != 0); // Currently ArmNN support axis 1. - return static_cast<T>(dim) - 1; + auto aclAxis = (static_cast<T>(dim) - 1); + aclAxis = aclAxis > 0 ? aclAxis -1 : aclAxis; + + return aclAxis; } inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input) |