ArmNN
 23.02
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Gather.hpp"
7 
10 
11 namespace armnn
12 {
13 
14 void Gather(const TensorInfo& paramsInfo,
15  const TensorInfo& indicesInfo,
16  const TensorInfo& outputInfo,
17  Decoder<float>& params,
18  const int32_t* indices,
19  Encoder<float>& output,
20  const int32_t axis_int)
21 {
22  IgnoreUnused(outputInfo);
23 
24  const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
25  ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
26  const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
27  : static_cast<unsigned int>(axis_int);
28 
29  const TensorShape& paramsShape = paramsInfo.GetShape();
30 
31  // Product of all dimensions to the left side of the axis
32  unsigned int paramsOuterProduct = 1;
33  for (unsigned int i = 0; i < axis; ++i)
34  {
35  paramsOuterProduct *= paramsShape[i];
36  }
37  // Product of all dimensions to the right side of the axis
38  unsigned int paramsInnerProduct = 1;
39  for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
40  {
41  paramsInnerProduct *= paramsShape[k];
42  }
43 
44  unsigned int offset = 0;
45  unsigned int outIndex = 0;
46  for (unsigned int i = 0; i < paramsOuterProduct; ++i)
47  {
48  for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
49  {
50  unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
51  ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
52 
53  unsigned int startOffset = (paramsInnerProduct * index) + offset;
54  unsigned int endOffset = startOffset + paramsInnerProduct;
55 
56  for (unsigned int k = startOffset; k < endOffset; ++k)
57  {
58  params[k];
59  float outputValue = params.Get();
60  output[outIndex];
61  output.Set(outputValue);
62  ++outIndex;
63  }
64  }
65  offset += paramsShape[axis] * paramsInnerProduct;
66  }
67 
68  ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
69 }
70 
71 } //namespace armnn
armnn::Encoder< float >
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:196
armnn::Decoder< float >
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::Gather
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > &params, const int32_t *indices, Encoder< float > &output, const int32_t axis_int)
Definition: Gather.cpp:14
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
armnn::Encoder::Set
virtual void Set(IType right)=0
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
Gather.hpp
armnn::Decoder::Get
virtual IType Get() const =0
NumericCast.hpp
WorkloadData.hpp