aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ReverseV2Impl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ReverseV2Impl.cpp')
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.cpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp
index f6d5fd74d1..896f9050f5 100644
--- a/src/backends/reference/workloads/ReverseV2Impl.cpp
+++ b/src/backends/reference/workloads/ReverseV2Impl.cpp
@@ -75,13 +75,16 @@ unsigned int ReverseRelocateIdx(unsigned int idx,
return outputIdx;
}
-void ReverseV2(const ReverseV2Descriptor& params,
- const TensorInfo& inputInfo,
+void ReverseV2(const TensorInfo& inputInfo,
+ const TensorInfo& axisInfo,
Decoder<float>& inputDecoder,
+ Decoder<int>& axisDecoder,
Encoder<float>& outputEncoder)
{
+ unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements());
+
// Empty axis and empty tensor case: copy input to output
- if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0)
+ if ((axesRank == 0) || inputInfo.GetNumElements() == 0)
{
for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
{
@@ -95,11 +98,19 @@ void ReverseV2(const ReverseV2Descriptor& params,
unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
- std::vector<bool>axisFlag(inputRank, false);
- std::vector<unsigned int>dimSize(inputRank, 0);
+ std::vector<bool> axisFlag(inputRank, false);
+ std::vector<unsigned int> dimSize(inputRank, 0);
+ std::vector<int32_t> axis(axesRank, 0);
+
+ // Decode the axis information
+ for (unsigned int i=0; i < axesRank; i++)
+ {
+ axis[i] = axisDecoder.Get();
+ axisDecoder += 1;
+ }
// Make sure the axes are positive
- for (int32_t axisElement: params.m_Axis)
+ for (int32_t axisElement: axis)
{
axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
axisFlag[static_cast<uint32_t>(axisElement)] = true;