From 6f92c8e9f8bb38dcf5dccf8deeff5112ecd8e37c Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Wed, 22 Nov 2023 11:41:15 +0000 Subject: Update Doxygen for 23.11 Signed-off-by: Nikhil Raj Change-Id: I47cd933f5002cb94a73aa97689d7b3d9c93cb849 --- 23.11/_ref_gather_nd_workload_8cpp_source.html | 238 +++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 23.11/_ref_gather_nd_workload_8cpp_source.html (limited to '23.11/_ref_gather_nd_workload_8cpp_source.html') diff --git a/23.11/_ref_gather_nd_workload_8cpp_source.html b/23.11/_ref_gather_nd_workload_8cpp_source.html new file mode 100644 index 0000000000..3d3097707f --- /dev/null +++ b/23.11/_ref_gather_nd_workload_8cpp_source.html @@ -0,0 +1,238 @@ + + + + + + + + +Arm NN: src/backends/reference/workloads/RefGatherNdWorkload.cpp Source File + + + + + + + + + + + + + + + + +
+
+ + + + ArmNN + + + +
+
+  23.11 +
+
+
+ + + + + + + +
+
+ +
+
+
+ +
+ +
+
+ + +
+ +
+ +
+
+
RefGatherNdWorkload.cpp
+
+
+Go to the documentation of this file.
1 //
+
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
+
3 // SPDX-License-Identifier: MIT
+
4 //
+
5 
+
6 #include <fmt/format.h>
+ +
8 
+
9 #include "Gather.hpp"
+
10 #include "Profiling.hpp"
+
11 #include "RefWorkloadUtils.hpp"
+ +
13 
+
14 namespace armnn
+
15 {
+
16 
+ +
18 {
+ +
20 }
+
21 
+ +
23 {
+
24  WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+
25  Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+
26 }
+
27 
+
28 void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+
29 {
+
30  ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefGatherNdWorkload_Execute");
+
31 
+
32  const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
+
33  const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
+
34  const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
35 
+
36  std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->Map());
+
37 
+
38  const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
+
39  std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
+
40  // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+
41  for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+
42  {
+
43  if (indices[i] < 0)
+
44  {
+
45  throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
+
46  }
+
47  }
+
48 
+
49  std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
+
50 
+
51  std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(inputInfo0, inputInfo1);
+
52 
+
53  /// Calculate flattened indices: flattenedIndices = indices * flattenedCoefficients
+
54  // Calculate the flattened coefficients to use in the multiplication
+
55  // to calculate the flattened indices needed by gather
+
56  TensorShape paramsShape = inputInfo0.GetShape();
+
57  std::vector<unsigned int> flattenedCoeff(keyIndices["ND"], 1);
+
58  for (unsigned int i = 1; i < keyIndices["ND"]; ++i)
+
59  {
+
60  flattenedCoeff[i-1] = paramsShape[i];
+
61  }
+
62  for (unsigned int i = keyIndices["ND"]-1; i > 0; --i)
+
63  {
+
64  flattenedCoeff[i-1] *= flattenedCoeff[i];
+
65  }
+
66 
+
67  // Prepare the vector to store the output of the matrix multiplication,
+
68  // which will represent the flattened indices needed by gather
+
69  armnn::TensorInfo flattenedIndices_Info = inputInfo1;
+
70  flattenedIndices_Info.SetShape({ keyIndices["W"] });
+
71  std::vector<int32_t> flattenedIndices(flattenedIndices_Info.GetNumElements(), 0);
+
72 
+
73  // Multiplication to calculate the flattened indices, which are the indices needed by gather.
+
74  for (unsigned int i = 0; i < keyIndices["W"]; ++i)
+
75  {
+
76  for (unsigned int j = 0; j < keyIndices["ND"]; ++j)
+
77  {
+
78  flattenedIndices[i] += indices[i * keyIndices["ND"] + j] * static_cast<int32_t>(flattenedCoeff[j]);
+
79  }
+
80  }
+
81 
+
82  /// Call Gather with adequate shapes
+
83  // Reshape params into {K, C}
+
84  armnn::TensorInfo params_K_C_Info = inputInfo0;
+
85  params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] });
+
86 
+
87  // Reshape indices into {N, W}
+
88  armnn::TensorInfo indices_N_W_Info = inputInfo1;
+
89  indices_N_W_Info.SetShape({ keyIndices["N"], keyIndices["W"] });
+
90 
+
91  // Reshape output to have the shape given by gather {N, W, C}
+
92  // (the original outputInfo has the shape given by gatherNd)
+
93  armnn::TensorInfo outputGather_Info = outputInfo;
+
94  outputGather_Info.SetShape({ keyIndices["N"], keyIndices["W"], keyIndices["C"] });
+
95 
+
96  // output_gather = gather(params_K_C, indices_N_W)
+
97  Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info,
+
98  *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0);
+
99 }
+
100 
+
101 } //namespace armnn
+
+
+
unsigned int GetNumElements() const
Definition: Tensor.hpp:196
+
void ExecuteAsync(ExecutionData &executionData) override
+ + + + +
#define ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID(label)
Creates a profiling event that uses GetGuid() and GetName() from the calling class.
+
std::map< std::string, unsigned int > CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
+ +
void Execute() const override
+
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
+ +
std::vector< ITensorHandle * > m_Outputs
+ + +
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
+ + +
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:193
+
Copyright (c) 2021 ARM Limited and Contributors.
+ + + +
std::vector< ITensorHandle * > m_Inputs
+
void Gather(const TensorInfo &paramsInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > &params, const int32_t *indices, Encoder< float > &output, const int32_t axis_int)
Definition: Gather.cpp:15
+ + + + + -- cgit v1.2.1