aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp31
-rw-r--r--src/backends/reference/workloads/Decoders.hpp21
-rw-r--r--src/backends/reference/workloads/Encoders.hpp21
3 files changed, 56 insertions, 17 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 491081dbac..ee6462dfa3 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -437,11 +437,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
const DataType inputType = input.GetDataType();
if (inputType == DataType::QAsymmU8)
{
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference convolution2d: weights type not supported for quantized input.");
@@ -554,14 +557,18 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference DepthwiseConvolution2d: input and output types mismatched.");
- const DataType inputType = input.GetDataType();
- if (inputType == DataType::QAsymmU8)
- {
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
+
+ const DataType inputType = input.GetDataType();
+ if (inputType == DataType::QAsymmU8)
+ {
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference convolution2d: weights type not supported for quantized input.");
@@ -607,6 +614,9 @@ bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
"Reference dequantize: input type not supported.");
+ supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
+ "Reference dequantize: per-axis quantized input not support .");
+
std::array<DataType,2> supportedOutputTypes = {
DataType::Float32,
DataType::Float16
@@ -1836,11 +1846,14 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
const DataType inputType = input.GetDataType();
if (inputType == DataType::QAsymmU8)
{
- std::array<DataType, 2> supportedWeightTypes =
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
{
DataType::QAsymmU8,
- DataType::QuantizedSymm8PerAxis
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis //Deprecated
};
+ ARMNN_NO_DEPRECATE_WARN_END
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
"Reference TransposeConvolution2d: weights type not supported for "
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index faabdcdb3f..6f309787bd 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -71,6 +71,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
{
switch(info.GetDataType())
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
{
std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -79,6 +80,7 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
params.second,
params.first);
}
+ ARMNN_NO_DEPRECATE_WARN_END
case DataType::QAsymmU8:
{
return std::make_unique<QASymm8Decoder>(
@@ -107,10 +109,21 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
}
case DataType::QSymmS8:
{
- return std::make_unique<QSymmS8Decoder>(
- static_cast<const int8_t*>(data),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset());
+ if (info.HasPerAxisQuantization())
+ {
+ std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+ return std::make_unique<QSymm8PerAxisDecoder>(
+ static_cast<const int8_t*>(data),
+ params.second,
+ params.first);
+ }
+ else
+ {
+ return std::make_unique<QSymmS8Decoder>(
+ static_cast<const int8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
}
default:
{
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index 4fe202f0bf..8ddd559448 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -22,6 +22,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
{
switch(info.GetDataType())
{
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
case armnn::DataType::QuantizedSymm8PerAxis:
{
std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -30,6 +31,7 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
params.second,
params.first);
}
+ ARMNN_NO_DEPRECATE_WARN_END
case armnn::DataType::QAsymmU8:
{
return std::make_unique<QASymm8Encoder>(
@@ -39,10 +41,21 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
}
case DataType::QSymmS8:
{
- return std::make_unique<QSymmS8Encoder>(
- static_cast<int8_t*>(data),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset());
+ if (info.HasPerAxisQuantization())
+ {
+ std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+ return std::make_unique<QSymm8PerAxisEncoder>(
+ static_cast<int8_t*>(data),
+ params.second,
+ params.first);
+ }
+ else
+ {
+ return std::make_unique<QSymmS8Encoder>(
+ static_cast<int8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
}
case armnn::DataType::QSymmS16:
{