From 369d8fcd93213e37781a58fb9805a59bf14691db Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Thu, 24 Nov 2022 13:12:36 +0000 Subject: IVGCVSW-4926 Add support in CpuRef implementation for Gather for axis different to 0 !android-nn-driver:8727 Signed-off-by: Nikhil Raj Signed-off-by: Matthew Sloyan Signed-off-by: Teresa Charlin Change-Id: I4336007ad5a8552f7893ce6253f93cf9d1f5474f --- src/backends/reference/RefLayerSupport.cpp | 8 +--- src/backends/reference/test/RefLayerTests.cpp | 5 ++- src/backends/reference/workloads/Gather.cpp | 58 ++++++++++++++++----------- 3 files changed, 41 insertions(+), 30 deletions(-) (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 669c91d628..a5015a7376 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1573,11 +1573,7 @@ bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0, DataType::Signed32 }; - if (descriptor.m_Axis != 0) - { - reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n"); - supported &= false; - } + IgnoreUnused(descriptor); supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, "Reference Gather: input type not supported"); diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 750da8fba2..0e228dbea9 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -2215,6 +2215,9 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesFloat16, ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesUint8, GatherMultiDimParamsMultiDimIndicesUint8Test) ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt16, GatherMultiDimParamsMultiDimIndicesInt16Test) ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt32, GatherMultiDimParamsMultiDimIndicesInt32Test) +ARMNN_AUTO_TEST_CASE_WITH_THF(Gather1dParamsAxis, Gather1dParamsAxisTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesAxis1, GatherMultiDimParamsMultiDimIndicesAxis1Test) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesAxis2, GatherMultiDimParamsMultiDimIndicesAxis2Test) // GatherNd diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp index 1624052819..48039432a5 100644 --- a/src/backends/reference/workloads/Gather.cpp +++ b/src/backends/reference/workloads/Gather.cpp @@ -1,14 +1,11 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Gather.hpp" -#include "RefWorkloadUtils.hpp" - #include -#include #include namespace armnn @@ -20,40 +17,55 @@ void Gather(const TensorInfo& paramsInfo, Decoder& params, const int32_t* indices, Encoder& output, - const int32_t axis) + const int32_t axis_int) { IgnoreUnused(outputInfo); - IgnoreUnused(axis); + + const int paramsRank = static_cast(paramsInfo.GetNumDimensions()); + ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank); + const unsigned int axis = (axis_int < 0) ? static_cast(paramsRank + axis_int) + : static_cast(axis_int); const TensorShape& paramsShape = paramsInfo.GetShape(); - unsigned int paramsProduct = 1; - for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i) + // Product of all dimensions to the left side of the axis + unsigned int paramsOuterProduct = 1; + for (unsigned int i = 0; i < axis; ++i) + { + paramsOuterProduct *= paramsShape[i]; + } + // Product of all dimensions to the right side of the axis + unsigned int paramsInnerProduct = 1; + for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k) { - paramsProduct = paramsProduct * paramsShape[i]; + paramsInnerProduct *= paramsShape[k]; } + unsigned int offset = 0; unsigned int outIndex = 0; - for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i) + for (unsigned int i = 0; i < paramsOuterProduct; ++i) { - unsigned int indx = armnn::numeric_cast(indices[i]); - - ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]); + for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j) + { + unsigned int index = armnn::numeric_cast(indices[j]); + ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]); - unsigned int startOffset = indx * paramsProduct; - unsigned int endOffset = startOffset + paramsProduct; + unsigned int startOffset = (paramsInnerProduct * index) + offset; + unsigned int endOffset = startOffset + paramsInnerProduct; - for (unsigned int j = startOffset; j < endOffset; ++j) - { - params[j]; - float outputValue = params.Get(); - output[outIndex]; - output.Set(outputValue); - ++outIndex; + for (unsigned int k = startOffset; k < endOffset; ++k) + { + params[k]; + float outputValue = params.Get(); + output[outIndex]; + output.Set(outputValue); + ++outIndex; + } } + offset += paramsShape[axis] * paramsInnerProduct; } ARMNN_ASSERT(outIndex == outputInfo.GetNumElements()); } -} //namespace armnn +} //namespace armnn \ No newline at end of file -- cgit v1.2.1