aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorTracy Narine <tracy.narine@arm.com>2023-07-13 16:50:54 +0100
committerTracy Narine <tracy.narine@arm.com>2023-07-17 14:19:36 +0100
commitbb8d7591a35bd95480b39001f8b7e41a6671f3a6 (patch)
treeabf2871aa1bb86378f423df405164b0d4521db3f /src/backends/reference/workloads
parent688268328c69e7d4181cdd31fe4717c80a6d1685 (diff)
downloadarmnn-bb8d7591a35bd95480b39001f8b7e41a6671f3a6.tar.gz
IVGCVSW-7879 Change REVERSE_V2 from LayerWithParameters with 1 input, to Layer with 2 inputs
* Changing ReverseV2 to use two inputs * This is required by the backends * The ReverseV2Descriptor was removed * Tests updated * Added a Run<> templatefor inputs with different data types Signed-off-by: Tracy Narine <tracy.narine@arm.com> Change-Id: I22f947de829b4b3da6bda3a74f4ffdef4052cc25
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/RefReverseV2Workload.cpp9
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.cpp23
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.hpp7
3 files changed, 28 insertions, 11 deletions
diff --git a/src/backends/reference/workloads/RefReverseV2Workload.cpp b/src/backends/reference/workloads/RefReverseV2Workload.cpp
index cd2d9f930b..22d5449466 100644
--- a/src/backends/reference/workloads/RefReverseV2Workload.cpp
+++ b/src/backends/reference/workloads/RefReverseV2Workload.cpp
@@ -32,16 +32,21 @@ namespace armnn
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReverseV2Workload_Execute");
const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& axisInfo = GetTensorInfo(inputs[1]);
std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
inputs[0]->Map());
+ std::unique_ptr<Decoder<int>> axisDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
outputs[0]->Map());
- ReverseV2(m_Data.m_Parameters,
- inputInfo,
+ ReverseV2(inputInfo,
+ axisInfo,
*inputDecoder,
+ *axisDecoder,
*outputEncoder);
}
diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp
index f6d5fd74d1..896f9050f5 100644
--- a/src/backends/reference/workloads/ReverseV2Impl.cpp
+++ b/src/backends/reference/workloads/ReverseV2Impl.cpp
@@ -75,13 +75,16 @@ unsigned int ReverseRelocateIdx(unsigned int idx,
return outputIdx;
}
-void ReverseV2(const ReverseV2Descriptor& params,
- const TensorInfo& inputInfo,
+void ReverseV2(const TensorInfo& inputInfo,
+ const TensorInfo& axisInfo,
Decoder<float>& inputDecoder,
+ Decoder<int>& axisDecoder,
Encoder<float>& outputEncoder)
{
+ unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements());
+
// Empty axis and empty tensor case: copy input to output
- if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0)
+ if ((axesRank == 0) || inputInfo.GetNumElements() == 0)
{
for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
{
@@ -95,11 +98,19 @@ void ReverseV2(const ReverseV2Descriptor& params,
unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
- std::vector<bool>axisFlag(inputRank, false);
- std::vector<unsigned int>dimSize(inputRank, 0);
+ std::vector<bool> axisFlag(inputRank, false);
+ std::vector<unsigned int> dimSize(inputRank, 0);
+ std::vector<int32_t> axis(axesRank, 0);
+
+ // Decode the axis information
+ for (unsigned int i=0; i < axesRank; i++)
+ {
+ axis[i] = axisDecoder.Get();
+ axisDecoder += 1;
+ }
// Make sure the axes are positive
- for (int32_t axisElement: params.m_Axis)
+ for (int32_t axisElement: axis)
{
axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
axisFlag[static_cast<uint32_t>(axisElement)] = true;
diff --git a/src/backends/reference/workloads/ReverseV2Impl.hpp b/src/backends/reference/workloads/ReverseV2Impl.hpp
index bc1fe1d432..59407d4a4e 100644
--- a/src/backends/reference/workloads/ReverseV2Impl.hpp
+++ b/src/backends/reference/workloads/ReverseV2Impl.hpp
@@ -13,9 +13,10 @@
namespace armnn
{
-void ReverseV2(const ReverseV2Descriptor& params,
- const TensorInfo& inputInfo,
+void ReverseV2(const TensorInfo& inputInfo,
+ const TensorInfo& axisInfo,
Decoder<float>& inputDecoder,
+ Decoder<int>& axisDecoder,
Encoder<float>& outputEncoder);
-} // namespace armnn \ No newline at end of file
+} // namespace armnn