diff options
author | Luca Foschiani <luca.foschiani@arm.com> | 2020-02-13 15:07:36 +0000 |
---|---|---|
committer | Luca Foschiani <luca.foschiani@arm.com> | 2020-03-26 12:31:14 +0000 |
commit | 4b869532f8b2aa7f02aa55c4f4813e994ea2df68 (patch) | |
tree | 318506b8c5933165b1fe6d054fc7beec79c6a0f5 /src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp | |
parent | 1b14c75c0d591c4abe4d2d41b7e4e165fbf58382 (diff) | |
download | ComputeLibrary-4b869532f8b2aa7f02aa55c4f4813e994ea2df68.tar.gz |
COMPMID-2966 Add support for QASYMM8_SIGNED in NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel
Signed-off-by: Luca Foschiani <luca.foschiani@arm.com>
Change-Id: Ia8692f8fda16fa3b73f343e4b5b1b55e14403225
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2750
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp | 147 |
1 files changed, 87 insertions, 60 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp index 42d2ffce58..43ca7b3fbb 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp @@ -24,10 +24,10 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ScaleKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" -#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" #include "arm_compute/core/Validate.h" #include "support/MemorySupport.h" @@ -35,14 +35,25 @@ namespace arm_compute { void NEGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max) { - auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>(); - k->configure(input, bias, output, result_offset, result_mult_int, result_shift, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_offset = result_offset; + info.gemmlowp_multiplier = result_mult_int; + info.gemmlowp_shift = result_shift; + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>(); + k->configure(input, bias, output, &info); _kernel = std::move(k); } Status NEGEMMLowpQuantizeDownInt32ToUint8Scale::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max) { - return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max); + GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo(); + info.gemmlowp_min_bound = min; + info.gemmlowp_max_bound = max; + + return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); } void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift, @@ -89,53 +100,63 @@ void NEGEMMLowpOutputStage::configure(const ITensor *input, const ITensor *bias, ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info)); - if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + switch(info.type) { - switch(output->info()->data_type()) + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: { - case DataType::QASYMM8: + switch(info.output_data_type) { - auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); - break; + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QSYMM16: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported output data type."); + break; + } } - default: - ARM_COMPUTE_ERROR("Unsupported output data type."); + break; } - } - else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) - { - switch(output->info()->data_type()) + case GEMMLowpOutputStageType::QUANTIZE_DOWN: { - case DataType::QASYMM8: - { - auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); - break; - } - case DataType::QASYMM8_SIGNED: - { - auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); - break; - } - case DataType::QSYMM16: + switch(info.output_data_type) { - auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>(); - k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - _kernel = std::move(k); - break; + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>(); + k->configure(input, bias, output, &info); + _kernel = std::move(k); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported output data type."); + break; + } } - default: - ARM_COMPUTE_ERROR("Unsupported output data type."); + break; } - } - else - { - ARM_COMPUTE_ERROR("Unsupported output stage quantization type."); + default: + ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type."); } } @@ -147,29 +168,35 @@ Status NEGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorIn ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)); - if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + switch(info.type) { - switch(output->data_type()) + case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: { - case DataType::QASYMM8: - return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - default: - return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + switch(output->data_type()) + { + case DataType::QASYMM8: + return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QASYMM8_SIGNED: + return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QSYMM16: + return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } } - } - else - { - switch(output->data_type()) + case GEMMLowpOutputStageType::QUANTIZE_DOWN: { - case DataType::QASYMM8: - return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - case DataType::QASYMM8_SIGNED: - return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - case DataType::QSYMM16: - return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); - default: - return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + switch(output->data_type()) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } } + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type."); } } } // namespace arm_compute |