18 const int32_t* indices,
20 const int32_t axis_int)
25 ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
26 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
27 :
static_cast<unsigned int>(axis_int);
32 unsigned int paramsOuterProduct = 1;
33 for (
unsigned int i = 0; i < axis; ++i)
35 paramsOuterProduct *= paramsShape[i];
38 unsigned int paramsInnerProduct = 1;
41 paramsInnerProduct *= paramsShape[k];
44 unsigned int offset = 0;
45 unsigned int outIndex = 0;
46 for (
unsigned int i = 0; i < paramsOuterProduct; ++i)
50 unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
51 ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
53 unsigned int startOffset = (paramsInnerProduct * index) + offset;
54 unsigned int endOffset = startOffset + paramsInnerProduct;
56 for (
unsigned int k = startOffset; k < endOffset; ++k)
59 float outputValue = params.
Get();
61 output.
Set(outputValue);
65 offset += paramsShape[axis] * paramsInnerProduct;