From 9c700378f2227cb9d51455ed4a5086daaac5532a Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 8 Jan 2020 11:33:44 +0000 Subject: COMPMID-2769: Add support for QASYMM8_SIGNED in NEFullyConnectedLayer Change-Id: I4c35c522375ae5a5de78716e079ebb9ffad15956 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/2581 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../NEON/functions/NEFullyConnectedLayer.cpp | 21 +++-- .../NEON/functions/NEGEMMLowpOutputStage.cpp | 93 +++++++++++++++++++++- 2 files changed, 108 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 4c264e4832..92ccd5d1cc 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -141,9 +141,8 @@ void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, FullyConnectedLayerInfo fc_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - // Perform validate step + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_ERROR_THROW_ON(NEFullyConnectedLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, @@ -260,7 +259,13 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh int32_t output_multiplier; int32_t output_shift; quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); - _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, oq_info.offset); + + GEMMLowpOutputStageInfo gemmlowp_output_stage_info; + gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier; + gemmlowp_output_stage_info.gemmlowp_shift = output_shift; + gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset; + gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, gemmlowp_output_stage_info); _gemmlowp_output.allocator()->allocate(); } @@ -272,7 +277,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn { ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); @@ -361,7 +366,13 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn int32_t output_multiplier; int32_t output_shift; ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&gemmlowp_output, biases, output)); + + GEMMLowpOutputStageInfo gemmlowp_output_stage_info; + gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier; + gemmlowp_output_stage_info.gemmlowp_shift = output_shift; + gemmlowp_output_stage_info.gemmlowp_offset = oq_info.offset; + gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&gemmlowp_output, biases, output, gemmlowp_output_stage_info)); } return Status{}; diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp index 3ef9351b78..465dddaac2 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" +#include "arm_compute/core/Validate.h" #include "support/ToolchainSupport.h" namespace arm_compute @@ -81,4 +82,94 @@ Status NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(const ITens { return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, min, max); } + +void NEGEMMLowpOutputStage::configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info) +{ + // Perform validate step + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info)); + + if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + { + switch(output->info()->data_type()) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } + } + else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + switch(output->info()->data_type()) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QSYMM16: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } + } + else + { + ARM_COMPUTE_ERROR("Unsupported output stage quantization type."); + } +} + +Status NEGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::UNKNOWN, "NEGEMMLowpQuantizeDownScaleByFixedPoint cannot be used with UNKNOWN output data type."); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16); + + ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)); + + if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } + else + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QASYMM8_SIGNED: + return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QSYMM16: + return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } +} } // namespace arm_compute -- cgit v1.2.1