aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ReverseV2Impl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ReverseV2Impl.cpp')
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.cpp133
1 files changed, 133 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp
new file mode 100644
index 0000000000..f6d5fd74d1
--- /dev/null
+++ b/src/backends/reference/workloads/ReverseV2Impl.cpp
@@ -0,0 +1,133 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ReverseV2Impl.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/Logging.hpp>
+#include <armnnUtils/Permute.hpp>
+
+namespace armnn
+{
+
+// Get multi-dimensional index for input tensor
+std::vector<unsigned int> ReverseGetMultIdx(const unsigned int idx,
+ unsigned int inputRank,
+ std::vector<unsigned int>& elementNumInner)
+{
+ std::vector<unsigned int> indexList(inputRank);
+
+ unsigned int mIdx = idx;
+
+ for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
+ {
+ indexList[iDim] = static_cast<unsigned int>(mIdx / elementNumInner[iDim]);
+ mIdx %= elementNumInner[iDim];
+ }
+
+ return indexList;
+}
+
+// Get flattened index for output encoder
+unsigned int ReverseGetFlatIdx(const std::vector<unsigned int>& idxList,
+ unsigned int inputRank,
+ std::vector<unsigned int>& elementNumInner)
+{
+ unsigned int idx = 0;
+
+ for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
+ {
+ idx += idxList[iDim] * elementNumInner[iDim];
+ }
+
+ return idx;
+}
+
+// Relocate the coordinate to the reversed tensor
+unsigned int ReverseRelocateIdx(unsigned int idx,
+ unsigned int inputRank,
+ std::vector<bool>& axisFlag,
+ std::vector<unsigned int>& dimSize,
+ std::vector<unsigned int>& elementNumInner)
+{
+ // Get the multidimensional index list for input
+ auto inputIdxList = ReverseGetMultIdx(idx, inputRank, elementNumInner);
+
+ std::vector<unsigned int> outputIdxList(inputRank);
+
+ // Relocate the input index to the output one
+ for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
+ {
+ if (axisFlag[iDim])
+ {
+ outputIdxList[iDim] = dimSize[iDim] - inputIdxList[iDim] - 1;
+ }
+ else
+ {
+ outputIdxList[iDim] = inputIdxList[iDim];
+ }
+ }
+
+ // Get the 1-dimensional flattened index for output
+ unsigned int outputIdx = ReverseGetFlatIdx(outputIdxList, inputRank, elementNumInner);
+ return outputIdx;
+}
+
+void ReverseV2(const ReverseV2Descriptor& params,
+ const TensorInfo& inputInfo,
+ Decoder<float>& inputDecoder,
+ Encoder<float>& outputEncoder)
+{
+ // Empty axis and empty tensor case: copy input to output
+ if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0)
+ {
+ for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
+ {
+ float inputValue = inputDecoder.Get();
+ inputDecoder += 1;
+ outputEncoder.Set(inputValue);
+ outputEncoder += 1;
+ }
+ return;
+ }
+
+ unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
+
+ std::vector<bool>axisFlag(inputRank, false);
+ std::vector<unsigned int>dimSize(inputRank, 0);
+
+ // Make sure the axes are positive
+ for (int32_t axisElement: params.m_Axis)
+ {
+ axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
+ axisFlag[static_cast<uint32_t>(axisElement)] = true;
+ }
+
+ const TensorShape &inputShape = inputInfo.GetShape();
+
+ unsigned int elementNum = inputInfo.GetNumElements();
+ unsigned int baseDimSize = 1;
+
+ std::vector<unsigned int> elementNumInner;
+
+ // Get the number of element within the specific dimension
+ for (unsigned int iDim = 0; iDim < inputRank; ++iDim) {
+ dimSize[iDim] = inputShape[iDim];
+ baseDimSize *= dimSize[iDim];
+ elementNumInner.push_back(static_cast<unsigned int>(elementNum / baseDimSize));
+ }
+
+ // Iterate through all elements
+ for (unsigned int idx = 0; idx < elementNum; ++idx)
+ {
+ float inputValue = inputDecoder.Get();
+ inputDecoder += 1;
+ auto outputIdx = ReverseRelocateIdx(idx, inputRank, axisFlag, dimSize, elementNumInner);
+ outputEncoder[outputIdx];
+ outputEncoder.Set(inputValue);
+ }
+}
+
+} // namespace armnn \ No newline at end of file