diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-01-08 11:33:44 +0000 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-01-14 10:55:58 +0000 |
commit | 9c700378f2227cb9d51455ed4a5086daaac5532a (patch) | |
tree | 53eb4acf5a9226e941e332b93db8ef260fb2d42b /src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp | |
parent | ab709a0ea0f7f5c8e02c315afffc300e09c783a8 (diff) | |
download | ComputeLibrary-9c700378f2227cb9d51455ed4a5086daaac5532a.tar.gz |
COMPMID-2769: Add support for QASYMM8_SIGNED in NEFullyConnectedLayer
Change-Id: I4c35c522375ae5a5de78716e079ebb9ffad15956
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2581
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp index 3ef9351b78..465dddaac2 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" +#include "arm_compute/core/Validate.h" #include "support/ToolchainSupport.h" namespace arm_compute @@ -81,4 +82,94 @@ Status NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(const ITens { return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, min, max); } + +void NEGEMMLowpOutputStage::configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info) +{ + // Perform validate step + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info)); + + if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + { + switch(output->info()->data_type()) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } + } + else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + switch(output->info()->data_type()) + { + case DataType::QASYMM8: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QASYMM8_SIGNED: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + case DataType::QSYMM16: + { + auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>(); + k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + _kernel = std::move(k); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported output data type."); + } + } + else + { + ARM_COMPUTE_ERROR("Unsupported output stage quantization type."); + } +} + +Status NEGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::UNKNOWN, "NEGEMMLowpQuantizeDownScaleByFixedPoint cannot be used with UNKNOWN output data type."); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16); + + ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)); + + if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } + else + { + switch(output->data_type()) + { + case DataType::QASYMM8: + return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QASYMM8_SIGNED: + return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + case DataType::QSYMM16: + return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound); + default: + return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type."); + } + } +} } // namespace arm_compute |