diff options
author | Luca Foschiani <luca.foschiani@arm.com> | 2020-02-26 14:30:14 +0000 |
---|---|---|
committer | Luca Foschiani <luca.foschiani@arm.com> | 2020-03-23 17:16:22 +0000 |
commit | 689c968239180eda4263e34c3d450093d4a0450d (patch) | |
tree | 9ecc01efac6f59f05c862bf32d6e1ee3ce5a69ed /src/runtime/CL | |
parent | 3bb75d60ced0cefa503e90f5d0d8cfe3db3f8637 (diff) | |
download | ComputeLibrary-689c968239180eda4263e34c3d450093d4a0450d.tar.gz |
COMPMID-2967 Add support for QASYMM8_SIGNED in CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel
Signed-off-by: Luca Foschiani <luca.foschiani@arm.com>
Change-Id: I4f7918630ea95fc28597b3d7b189f3d8fd35aef8
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2890
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp | 107 |
1 files changed, 82 insertions, 25 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp index a1b7b23c62..e86f303ff4 100644 --- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp @@ -24,25 +24,36 @@ #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ScaleKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFloatKernel.h" -#include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" #include "support/MemorySupport.h" namespace arm_compute { void CLGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max) { - auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>(); - k->configure(input, bias, output, result_offset, result_mult_int, result_shift, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_offset = result_offset; + info.gemmlowp_multiplier = result_mult_int; + info.gemmlowp_shift = result_shift; + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ScaleKernel>(); + k->configure(input, bias, output, &info); _kernel = std::move(k); } Status CLGEMMLowpQuantizeDownInt32ToUint8Scale::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max) { - return CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + return CLGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); } void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, @@ -108,45 +119,91 @@ Status CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(const ITens void CLGEMMLowpOutputStage::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, const GEMMLowpOutputStageInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_ERROR_ON(info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT); - switch(info.output_data_type) + switch(info.type) { - case DataType::QASYMM8: + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: { - auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); + switch(info.output_data_type) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } break; } - case DataType::QASYMM8_SIGNED: + case GEMMLowpOutputStageType::QUANTIZE_DOWN: { - auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); + switch(info.output_data_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ScaleKernel>(); + k->configure(input, bias, output, &info); + _kernel = std::move(k); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported output data type."); + break; + } + } break; } default: - ARM_COMPUTE_ERROR("Unsupported output data type."); + ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type."); } - } Status CLGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); - ARM_COMPUTE_RETURN_ERROR_ON(info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT); - switch(output->data_type()) + switch(info.type) { - case DataType::QASYMM8: - return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - case DataType::QASYMM8_SIGNED: - return CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QASYMM8_SIGNED: + return CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } + case GEMMLowpOutputStageType::QUANTIZE_DOWN: + { + switch(output->data_type()) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + return CLGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); + } + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } default: - return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type."); } - } -} // namespace arm_compute +} // namespace arm_compute
\ No newline at end of file |