From 689c968239180eda4263e34c3d450093d4a0450d Mon Sep 17 00:00:00 2001 From: Luca Foschiani Date: Wed, 26 Feb 2020 14:30:14 +0000 Subject: COMPMID-2967 Add support for QASYMM8_SIGNED in CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel Signed-off-by: Luca Foschiani Change-Id: I4f7918630ea95fc28597b3d7b189f3d8fd35aef8 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2890 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp | 107 ++++++++++++++++----- 1 file changed, 82 insertions(+), 25 deletions(-) (limited to 'src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp') diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp index a1b7b23c62..e86f303ff4 100644 --- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp @@ -24,25 +24,36 @@ #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ScaleKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFloatKernel.h" -#include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" #include "support/MemorySupport.h" namespace arm_compute { void CLGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max) { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, bias, output, result_offset, result_mult_int, result_shift, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_offset = result_offset; + info.gemmlowp_multiplier = result_mult_int; + info.gemmlowp_shift = result_shift; + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, &info); _kernel = std::move(k); } Status CLGEMMLowpQuantizeDownInt32ToUint8Scale::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max) { - return CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + return CLGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); } void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, @@ -108,45 +119,91 @@ Status CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(const ITens void CLGEMMLowpOutputStage::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, const GEMMLowpOutputStageInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_ERROR_ON(info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT); - switch(info.output_data_type) + switch(info.type) { - case DataType::QASYMM8: + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); + switch(info.output_data_type) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } break; } - case DataType::QASYMM8_SIGNED: + case GEMMLowpOutputStageType::QUANTIZE_DOWN: { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); + switch(info.output_data_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, bias, output, &info); + _kernel = std::move(k); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported output data type."); + break; + } + } break; } default: - ARM_COMPUTE_ERROR("Unsupported output data type."); + ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type."); } - } Status CLGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); - ARM_COMPUTE_RETURN_ERROR_ON(info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT); - switch(output->data_type()) + switch(info.type) { - case DataType::QASYMM8: - return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - case DataType::QASYMM8_SIGNED: - return CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QASYMM8_SIGNED: + return CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } + case GEMMLowpOutputStageType::QUANTIZE_DOWN: + { + switch(output->data_type()) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + return CLGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); + } + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } default: - return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type."); } - } -} // namespace arm_compute +} // namespace arm_compute \ No newline at end of file -- cgit v1.2.1