diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 491081dbac..ee6462dfa3 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -437,11 +437,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, const DataType inputType = input.GetDataType(); if (inputType == DataType::QAsymmU8) { - std::array<DataType, 2> supportedWeightTypes = + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array<DataType, 3> supportedWeightTypes = { DataType::QAsymmU8, - DataType::QuantizedSymm8PerAxis + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis // deprecated }; + ARMNN_NO_DEPRECATE_WARN_END supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, "Reference convolution2d: weights type not supported for quantized input."); @@ -554,14 +557,18 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference DepthwiseConvolution2d: input and output types mismatched."); - const DataType inputType = input.GetDataType(); - if (inputType == DataType::QAsymmU8) - { - std::array<DataType, 2> supportedWeightTypes = + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array<DataType, 3> supportedWeightTypes = { DataType::QAsymmU8, - DataType::QuantizedSymm8PerAxis + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis // deprecated }; + ARMNN_NO_DEPRECATE_WARN_END + + const DataType inputType = input.GetDataType(); + if (inputType == DataType::QAsymmU8) + { supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, "Reference convolution2d: weights type not supported for quantized input."); @@ -607,6 +614,9 @@ bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input, supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported, "Reference dequantize: input type not supported."); + supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported, + "Reference dequantize: per-axis quantized input not support ."); + std::array<DataType,2> supportedOutputTypes = { DataType::Float32, DataType::Float16 @@ -1836,11 +1846,14 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, const DataType inputType = input.GetDataType(); if (inputType == DataType::QAsymmU8) { - std::array<DataType, 2> supportedWeightTypes = + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array<DataType, 3> supportedWeightTypes = { DataType::QAsymmU8, - DataType::QuantizedSymm8PerAxis + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis //Deprecated }; + ARMNN_NO_DEPRECATE_WARN_END supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, "Reference TransposeConvolution2d: weights type not supported for " |