// // Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Gather.hpp" #include #include namespace armnn { void Gather(const TensorInfo& paramsInfo, const TensorInfo& indicesInfo, const TensorInfo& outputInfo, Decoder& params, const int32_t* indices, Encoder& output, const int32_t axis_int) { IgnoreUnused(outputInfo); const int paramsRank = static_cast(paramsInfo.GetNumDimensions()); if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int)) { throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range", axis_int, paramsRank, paramsRank))); } const unsigned int axis = (axis_int < 0) ? static_cast(paramsRank + axis_int) : static_cast(axis_int); const TensorShape& paramsShape = paramsInfo.GetShape(); // Product of all dimensions to the left side of the axis unsigned int paramsOuterProduct = 1; for (unsigned int i = 0; i < axis; ++i) { paramsOuterProduct *= paramsShape[i]; } // Product of all dimensions to the right side of the axis unsigned int paramsInnerProduct = 1; for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k) { paramsInnerProduct *= paramsShape[k]; } unsigned int offset = 0; unsigned int outIndex = 0; for (unsigned int i = 0; i < paramsOuterProduct; ++i) { for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j) { unsigned int index = (indices[j] < 0) ? static_cast(static_cast(paramsShape[axis]) + indices[j]) : static_cast(indices[j]); if (index >= paramsShape[axis]) { throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}", index, paramsShape[axis] ))); } unsigned int startOffset = (paramsInnerProduct * index) + offset; unsigned int endOffset = startOffset + paramsInnerProduct; for (unsigned int k = startOffset; k < endOffset; ++k) { params[k]; float outputValue = params.Get(); output[outIndex]; output.Set(outputValue); ++outIndex; } } offset += paramsShape[axis] * paramsInnerProduct; } if (outIndex != outputInfo.GetNumElements()) { throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex))); } } } //namespace armnn