diff options
Diffstat (limited to 'src/backends/reference/workloads/ReverseV2Impl.cpp')
-rw-r--r-- | src/backends/reference/workloads/ReverseV2Impl.cpp | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp index f6d5fd74d1..896f9050f5 100644 --- a/src/backends/reference/workloads/ReverseV2Impl.cpp +++ b/src/backends/reference/workloads/ReverseV2Impl.cpp @@ -75,13 +75,16 @@ unsigned int ReverseRelocateIdx(unsigned int idx, return outputIdx; } -void ReverseV2(const ReverseV2Descriptor& params, - const TensorInfo& inputInfo, +void ReverseV2(const TensorInfo& inputInfo, + const TensorInfo& axisInfo, Decoder<float>& inputDecoder, + Decoder<int>& axisDecoder, Encoder<float>& outputEncoder) { + unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements()); + // Empty axis and empty tensor case: copy input to output - if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0) + if ((axesRank == 0) || inputInfo.GetNumElements() == 0) { for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++) { @@ -95,11 +98,19 @@ void ReverseV2(const ReverseV2Descriptor& params, unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions()); - std::vector<bool>axisFlag(inputRank, false); - std::vector<unsigned int>dimSize(inputRank, 0); + std::vector<bool> axisFlag(inputRank, false); + std::vector<unsigned int> dimSize(inputRank, 0); + std::vector<int32_t> axis(axesRank, 0); + + // Decode the axis information + for (unsigned int i=0; i < axesRank; i++) + { + axis[i] = axisDecoder.Get(); + axisDecoder += 1; + } // Make sure the axes are positive - for (int32_t axisElement: params.m_Axis) + for (int32_t axisElement: axis) { axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement; axisFlag[static_cast<uint32_t>(axisElement)] = true; |