aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp23
-rw-r--r--src/backends/reference/RefLayerSupport.hpp4
-rw-r--r--src/backends/reference/workloads/RefReverseV2Workload.cpp9
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.cpp23
-rw-r--r--src/backends/reference/workloads/ReverseV2Impl.hpp7
5 files changed, 45 insertions, 21 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 1d5fab1adc..e94478f088 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -344,7 +344,7 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
case LayerType::ReverseV2:
return IsReverseV2Supported(infos[0],
infos[1],
- *(PolymorphicDowncast<const ReverseV2Descriptor*>(&descriptor)),
+ infos[2],
reasonIfUnsupported);
case LayerType::Reduce:
return IsReduceSupported(infos[0],
@@ -2361,12 +2361,11 @@ bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
return supported;
}
-bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input,
+bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
+ const TensorInfo& input1,
const TensorInfo& output,
- const ReverseV2Descriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- IgnoreUnused(descriptor);
bool supported = true;
// ReverseV2 is data type agnostic so it can support all the types in the Reference backend
std::array<DataType,6> supportedTypes =
@@ -2379,14 +2378,22 @@ bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input,
DataType::QSymmS16
};
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference ReverseV2: input type not supported");
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference ReverseV2: input0 type not supported");
supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
"Reference ReverseV2: output type not supported");
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference ReverseV2: input and output types not matching");
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference ReverseV2: input0 and output types not matching");
+
+ std::array<DataType,6> input2SupportedTypes =
+ {
+ DataType::Signed32
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
+ "Reference ReverseV2: input1 type not supported");
return supported;
}
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 0afb9c2c94..21d59e27fc 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -299,9 +299,9 @@ public:
const ResizeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
- bool IsReverseV2Supported(const TensorInfo& input,
+ bool IsReverseV2Supported(const TensorInfo& input0,
+ const TensorInfo& input1,
const TensorInfo& output,
- const ReverseV2Descriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
bool IsShapeSupported(const TensorInfo& input,
diff --git a/src/backends/reference/workloads/RefReverseV2Workload.cpp b/src/backends/reference/workloads/RefReverseV2Workload.cpp
index cd2d9f930b..22d5449466 100644
--- a/src/backends/reference/workloads/RefReverseV2Workload.cpp
+++ b/src/backends/reference/workloads/RefReverseV2Workload.cpp
@@ -32,16 +32,21 @@ namespace armnn
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReverseV2Workload_Execute");
const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& axisInfo = GetTensorInfo(inputs[1]);
std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
inputs[0]->Map());
+ std::unique_ptr<Decoder<int>> axisDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
outputs[0]->Map());
- ReverseV2(m_Data.m_Parameters,
- inputInfo,
+ ReverseV2(inputInfo,
+ axisInfo,
*inputDecoder,
+ *axisDecoder,
*outputEncoder);
}
diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp
index f6d5fd74d1..896f9050f5 100644
--- a/src/backends/reference/workloads/ReverseV2Impl.cpp
+++ b/src/backends/reference/workloads/ReverseV2Impl.cpp
@@ -75,13 +75,16 @@ unsigned int ReverseRelocateIdx(unsigned int idx,
return outputIdx;
}
-void ReverseV2(const ReverseV2Descriptor& params,
- const TensorInfo& inputInfo,
+void ReverseV2(const TensorInfo& inputInfo,
+ const TensorInfo& axisInfo,
Decoder<float>& inputDecoder,
+ Decoder<int>& axisDecoder,
Encoder<float>& outputEncoder)
{
+ unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements());
+
// Empty axis and empty tensor case: copy input to output
- if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0)
+ if ((axesRank == 0) || inputInfo.GetNumElements() == 0)
{
for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
{
@@ -95,11 +98,19 @@ void ReverseV2(const ReverseV2Descriptor& params,
unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
- std::vector<bool>axisFlag(inputRank, false);
- std::vector<unsigned int>dimSize(inputRank, 0);
+ std::vector<bool> axisFlag(inputRank, false);
+ std::vector<unsigned int> dimSize(inputRank, 0);
+ std::vector<int32_t> axis(axesRank, 0);
+
+ // Decode the axis information
+ for (unsigned int i=0; i < axesRank; i++)
+ {
+ axis[i] = axisDecoder.Get();
+ axisDecoder += 1;
+ }
// Make sure the axes are positive
- for (int32_t axisElement: params.m_Axis)
+ for (int32_t axisElement: axis)
{
axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
axisFlag[static_cast<uint32_t>(axisElement)] = true;
diff --git a/src/backends/reference/workloads/ReverseV2Impl.hpp b/src/backends/reference/workloads/ReverseV2Impl.hpp
index bc1fe1d432..59407d4a4e 100644
--- a/src/backends/reference/workloads/ReverseV2Impl.hpp
+++ b/src/backends/reference/workloads/ReverseV2Impl.hpp
@@ -13,9 +13,10 @@
namespace armnn
{
-void ReverseV2(const ReverseV2Descriptor& params,
- const TensorInfo& inputInfo,
+void ReverseV2(const TensorInfo& inputInfo,
+ const TensorInfo& axisInfo,
Decoder<float>& inputDecoder,
+ Decoder<int>& axisDecoder,
Encoder<float>& outputEncoder);
-} // namespace armnn \ No newline at end of file
+} // namespace armnn