diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 23 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefReverseV2Workload.cpp | 9 | ||||
-rw-r--r-- | src/backends/reference/workloads/ReverseV2Impl.cpp | 23 | ||||
-rw-r--r-- | src/backends/reference/workloads/ReverseV2Impl.hpp | 7 |
5 files changed, 45 insertions, 21 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1d5fab1adc..e94478f088 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -344,7 +344,7 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, case LayerType::ReverseV2: return IsReverseV2Supported(infos[0], infos[1], - *(PolymorphicDowncast<const ReverseV2Descriptor*>(&descriptor)), + infos[2], reasonIfUnsupported); case LayerType::Reduce: return IsReduceSupported(infos[0], @@ -2361,12 +2361,11 @@ bool RefLayerSupport::IsResizeSupported(const TensorInfo& input, return supported; } -bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input, +bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0, + const TensorInfo& input1, const TensorInfo& output, - const ReverseV2Descriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - IgnoreUnused(descriptor); bool supported = true; // ReverseV2 is data type agnostic so it can support all the types in the Reference backend std::array<DataType,6> supportedTypes = @@ -2379,14 +2378,22 @@ bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input, DataType::QSymmS16 }; - supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, - "Reference ReverseV2: input type not supported"); + supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, + "Reference ReverseV2: input0 type not supported"); supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, "Reference ReverseV2: output type not supported"); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, - "Reference ReverseV2: input and output types not matching"); + supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported, + "Reference ReverseV2: input0 and output types not matching"); + + std::array<DataType,6> input2SupportedTypes = + { + DataType::Signed32 + }; + + supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported, + "Reference ReverseV2: input1 type not supported"); return supported; } diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 0afb9c2c94..21d59e27fc 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -299,9 +299,9 @@ public: const ResizeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; - bool IsReverseV2Supported(const TensorInfo& input, + bool IsReverseV2Supported(const TensorInfo& input0, + const TensorInfo& input1, const TensorInfo& output, - const ReverseV2Descriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; bool IsShapeSupported(const TensorInfo& input, diff --git a/src/backends/reference/workloads/RefReverseV2Workload.cpp b/src/backends/reference/workloads/RefReverseV2Workload.cpp index cd2d9f930b..22d5449466 100644 --- a/src/backends/reference/workloads/RefReverseV2Workload.cpp +++ b/src/backends/reference/workloads/RefReverseV2Workload.cpp @@ -32,16 +32,21 @@ namespace armnn ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReverseV2Workload_Execute"); const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& axisInfo = GetTensorInfo(inputs[1]); std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map()); + std::unique_ptr<Decoder<int>> axisDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map()); - ReverseV2(m_Data.m_Parameters, - inputInfo, + ReverseV2(inputInfo, + axisInfo, *inputDecoder, + *axisDecoder, *outputEncoder); } diff --git a/src/backends/reference/workloads/ReverseV2Impl.cpp b/src/backends/reference/workloads/ReverseV2Impl.cpp index f6d5fd74d1..896f9050f5 100644 --- a/src/backends/reference/workloads/ReverseV2Impl.cpp +++ b/src/backends/reference/workloads/ReverseV2Impl.cpp @@ -75,13 +75,16 @@ unsigned int ReverseRelocateIdx(unsigned int idx, return outputIdx; } -void ReverseV2(const ReverseV2Descriptor& params, - const TensorInfo& inputInfo, +void ReverseV2(const TensorInfo& inputInfo, + const TensorInfo& axisInfo, Decoder<float>& inputDecoder, + Decoder<int>& axisDecoder, Encoder<float>& outputEncoder) { + unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements()); + // Empty axis and empty tensor case: copy input to output - if (params.m_Axis.empty() || inputInfo.GetNumElements() == 0) + if ((axesRank == 0) || inputInfo.GetNumElements() == 0) { for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++) { @@ -95,11 +98,19 @@ void ReverseV2(const ReverseV2Descriptor& params, unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions()); - std::vector<bool>axisFlag(inputRank, false); - std::vector<unsigned int>dimSize(inputRank, 0); + std::vector<bool> axisFlag(inputRank, false); + std::vector<unsigned int> dimSize(inputRank, 0); + std::vector<int32_t> axis(axesRank, 0); + + // Decode the axis information + for (unsigned int i=0; i < axesRank; i++) + { + axis[i] = axisDecoder.Get(); + axisDecoder += 1; + } // Make sure the axes are positive - for (int32_t axisElement: params.m_Axis) + for (int32_t axisElement: axis) { axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement; axisFlag[static_cast<uint32_t>(axisElement)] = true; diff --git a/src/backends/reference/workloads/ReverseV2Impl.hpp b/src/backends/reference/workloads/ReverseV2Impl.hpp index bc1fe1d432..59407d4a4e 100644 --- a/src/backends/reference/workloads/ReverseV2Impl.hpp +++ b/src/backends/reference/workloads/ReverseV2Impl.hpp @@ -13,9 +13,10 @@ namespace armnn { -void ReverseV2(const ReverseV2Descriptor& params, - const TensorInfo& inputInfo, +void ReverseV2(const TensorInfo& inputInfo, + const TensorInfo& axisInfo, Decoder<float>& inputDecoder, + Decoder<int>& axisDecoder, Encoder<float>& outputEncoder); -} // namespace armnn
\ No newline at end of file +} // namespace armnn |