aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Gather.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Gather.cpp')
-rw-r--r--src/backends/reference/workloads/Gather.cpp58
1 files changed, 35 insertions, 23 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index 1624052819..48039432a5 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,14 +1,11 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Gather.hpp"
-#include "RefWorkloadUtils.hpp"
-
#include <armnn/backends/WorkloadData.hpp>
-#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/utility/NumericCast.hpp>
namespace armnn
@@ -20,40 +17,55 @@ void Gather(const TensorInfo& paramsInfo,
Decoder<float>& params,
const int32_t* indices,
Encoder<float>& output,
- const int32_t axis)
+ const int32_t axis_int)
{
IgnoreUnused(outputInfo);
- IgnoreUnused(axis);
+
+ const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
+ ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
+ const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
+ : static_cast<unsigned int>(axis_int);
const TensorShape& paramsShape = paramsInfo.GetShape();
- unsigned int paramsProduct = 1;
- for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+ // Product of all dimensions to the left side of the axis
+ unsigned int paramsOuterProduct = 1;
+ for (unsigned int i = 0; i < axis; ++i)
+ {
+ paramsOuterProduct *= paramsShape[i];
+ }
+ // Product of all dimensions to the right side of the axis
+ unsigned int paramsInnerProduct = 1;
+ for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
{
- paramsProduct = paramsProduct * paramsShape[i];
+ paramsInnerProduct *= paramsShape[k];
}
+ unsigned int offset = 0;
unsigned int outIndex = 0;
- for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+ for (unsigned int i = 0; i < paramsOuterProduct; ++i)
{
- unsigned int indx = armnn::numeric_cast<unsigned int>(indices[i]);
-
- ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
+ for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
+ {
+ unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
+ ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
- unsigned int startOffset = indx * paramsProduct;
- unsigned int endOffset = startOffset + paramsProduct;
+ unsigned int startOffset = (paramsInnerProduct * index) + offset;
+ unsigned int endOffset = startOffset + paramsInnerProduct;
- for (unsigned int j = startOffset; j < endOffset; ++j)
- {
- params[j];
- float outputValue = params.Get();
- output[outIndex];
- output.Set(outputValue);
- ++outIndex;
+ for (unsigned int k = startOffset; k < endOffset; ++k)
+ {
+ params[k];
+ float outputValue = params.Get();
+ output[outIndex];
+ output.Set(outputValue);
+ ++outIndex;
+ }
}
+ offset += paramsShape[axis] * paramsInnerProduct;
}
ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
}
-} //namespace armnn
+} //namespace armnn \ No newline at end of file