diff options
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index e4ac4a5a..1975434a 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -183,11 +183,12 @@ inline bool IsOperandTypeSupportedForTensors(V1_0::OperandType type) inline bool IsOperandTypeSupportedForTensors(V1_2::OperandType type) { - return type == V1_2::OperandType::BOOL || - type == V1_2::OperandType::TENSOR_FLOAT16 || - type == V1_2::OperandType::TENSOR_FLOAT32 || - type == V1_2::OperandType::TENSOR_QUANT8_ASYMM || - type == V1_2::OperandType::TENSOR_QUANT16_SYMM || + return type == V1_2::OperandType::BOOL || + type == V1_2::OperandType::TENSOR_FLOAT16 || + type == V1_2::OperandType::TENSOR_FLOAT32 || + type == V1_2::OperandType::TENSOR_QUANT8_ASYMM || + type == V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || + type == V1_2::OperandType::TENSOR_QUANT16_SYMM || type == V1_2::OperandType::TENSOR_INT32; } @@ -384,16 +385,37 @@ Shape GetOperandShape(const V1_2::Operand& operand) // we accept some tolerance. We don't want ArmNN itself to accept these inconsistencies as it is up to the // user (us, in this case) to ensure they match. void SanitizeBiasQuantizationScale(armnn::TensorInfo& biasInfo, - const armnn::TensorInfo& weightInfo, const armnn::TensorInfo& inputInfo) + const armnn::TensorInfo& weightInfo, + const armnn::TensorInfo& inputInfo) { - const float expectedBiasScale = weightInfo.GetQuantizationScale() * inputInfo.GetQuantizationScale(); - if (biasInfo.GetQuantizationScale() != expectedBiasScale) + if (weightInfo.HasPerAxisQuantization()) { - boost::math::fpc::close_at_tolerance<float> comparer(boost::math::fpc::percent_tolerance(1.0f)); - if (comparer(biasInfo.GetQuantizationScale(), expectedBiasScale)) + // NOTE: Bias scale is always set to 0 for per-axis quantization and + // it needs to be calculated: scale[i] = input_scale * weight_scale[i] + auto UpdateBiasScaleValue = [&inputInfo](float biasScale) -> float { - ALOGW("Bias quantization scale has been modified to match input*weights"); - biasInfo.SetQuantizationScale(expectedBiasScale); + return biasScale * inputInfo.GetQuantizationScale(); + }; + + std::vector<float> biasScales(weightInfo.GetQuantizationScales()); + std::transform(biasScales.begin(), biasScales.end(), biasScales.begin(), UpdateBiasScaleValue); + + biasInfo.SetQuantizationScales(biasScales); + biasInfo.SetQuantizationDim(weightInfo.GetQuantizationDim()); + + ALOGV("Bias quantization params have been updated for per-axis quantization"); + } + else + { + const float expectedBiasScale = weightInfo.GetQuantizationScale() * inputInfo.GetQuantizationScale(); + if (biasInfo.GetQuantizationScale() != expectedBiasScale) + { + boost::math::fpc::close_at_tolerance<float> comparer(boost::math::fpc::percent_tolerance(1.0f)); + if (comparer(biasInfo.GetQuantizationScale(), expectedBiasScale)) + { + ALOGW("Bias quantization scale has been modified to match input * weights"); + biasInfo.SetQuantizationScale(expectedBiasScale); + } } } } |