diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 7 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 4 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 17 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 16 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp | 2 |
7 files changed, 54 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 137e77eebe..04f822cea9 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -115,6 +115,13 @@ bool LayerSupportBase::IsDepthwiseConvolutionSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsDequantizeSupported(const TensorInfo& input, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0, const armnn::TensorInfo& input1, const armnn::DetectionPostProcessDescriptor& descriptor, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index ceb3b2768e..7d64095667 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -67,6 +67,10 @@ public: const Optional<TensorInfo>& biases, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsDequantizeSupported(const TensorInfo& input, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsDetectionPostProcessSupported(const TensorInfo& input0, const TensorInfo& input1, const DetectionPostProcessDescriptor& descriptor, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e30a3f36b7..91b1c5790b 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1153,6 +1153,23 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI } } +void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor"); + ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor"); + + if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 && + workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16) + { + throw InvalidArgumentException("Input to dequantize layer must be quantized type."); + } + + if (workloadInfo.m_OutputTensorInfos[0].GetDataType() != DataType::Float32) + { + throw InvalidArgumentException("Output of dequantize layer must be Float32 type."); + } +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 9250ceac43..5640701d82 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -416,4 +416,9 @@ struct PreCompiledQueueDescriptor : QueueDescriptorWithParameters<PreCompiledDes void Validate(const WorkloadInfo& workloadInfo) const; }; +struct DequantizeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 833f3b894e..6534a00343 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -229,6 +229,16 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::Dequantize: + { + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType), + OverrideDataType(output, DataType::Float32), + reason); + break; + } case LayerType::DetectionPostProcess: { const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); @@ -821,6 +831,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d( return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize( + const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess( const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 2aa3854c4a..ed7303cf33 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -79,6 +79,9 @@ public: virtual std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d( const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateDetectionPostProcess( const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 8f86132274..26fb03f55d 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -336,6 +336,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Debug) DECLARE_LAYER_POLICY_2_PARAM(DepthwiseConvolution2d) +DECLARE_LAYER_POLICY_1_PARAM(Dequantize) + DECLARE_LAYER_POLICY_2_PARAM(DetectionPostProcess) DECLARE_LAYER_POLICY_1_PARAM(Equal) |