ArmNN
 20.08
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Gather.hpp"
7 
8 #include "RefWorkloadUtils.hpp"
9 
12 
13 #include <boost/numeric/conversion/cast.hpp>
14 
15 namespace armnn
16 {
17 
18 void Gather(const TensorInfo& paramsInfo,
19  const TensorInfo& indicesInfo,
20  const TensorInfo& outputInfo,
21  Decoder<float>& params,
22  const int32_t* indices,
23  Encoder<float>& output,
24  const int32_t axis)
25 {
26  IgnoreUnused(outputInfo);
27  IgnoreUnused(axis);
28 
29  const TensorShape& paramsShape = paramsInfo.GetShape();
30 
31  unsigned int paramsProduct = 1;
32  for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
33  {
34  paramsProduct = paramsProduct * paramsShape[i];
35  }
36 
37  unsigned int outIndex = 0;
38  for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
39  {
40  unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
41 
42  ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
43 
44  unsigned int startOffset = indx * paramsProduct;
45  unsigned int endOffset = startOffset + paramsProduct;
46 
47  for (unsigned int j = startOffset; j < endOffset; ++j)
48  {
49  params[j];
50  float outputValue = params.Get();
51  output[outIndex];
52  output.Set(outputValue);
53  ++outIndex;
54  }
55  }
56 
57  ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
58 }
59 
60 } //namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
virtual void Set(IType right)=0
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
virtual IType Get() const =0
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > &params, const int32_t *indices, Encoder< float > &output, const int32_t axis)
Definition: Gather.cpp:18
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:191
unsigned int GetNumElements() const
Definition: Tensor.hpp:192