diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-11-14 14:31:44 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-01-23 14:57:14 +0000 |
commit | 71ac9037abce1c6c4af42c485d5395dd6fd79a5a (patch) | |
tree | 7097d94d7760bf8e172fc4c3725a2eff90bea9a1 /src/runtime/NEON | |
parent | 19bd412fd044197726dbd8c756dbd74a9e33fd2b (diff) | |
download | ComputeLibrary-71ac9037abce1c6c4af42c485d5395dd6fd79a5a.tar.gz |
COMPMID-2923 Integrate arm_gemm per channel quantization
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d
Reviewed-on: https://review.mlplatform.org/c/2548
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 70 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp | 4 |
2 files changed, 63 insertions, 11 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 65d800cb0c..4e43d04446 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -164,6 +164,23 @@ public: arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); + /** Set requantization shifts to be used + * + * @param[in] shifts Requantization shifts + * + * @return Pointer to the shift data + */ + /** Set requantization data to be used + * + * + * @param shifts Requantization shifts + * @param multipliers Requantization multipliers + * + * @return A tuple with the pointers to the shift and multiplier data respectively + */ + std::tuple<const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts, + const std::vector<int32_t> &multipliers); + // Inherited methods overridden: void run() override; void prepare() override; @@ -212,9 +229,24 @@ private: FallbackTransform<TypeInput, TypeOutput> _weights_transform{}; /** GEMM kernel description */ arm_gemm::KernelDescription _kernel_info{}; + /** Per channel quantization shifts */ + std::vector<int32_t> _shifts{}; + /** Per channel quantization multipliers */ + std::vector<int32_t> _multipliers{}; }; template <typename TypeInput, typename TypeOutput, class OutputStage> +std::tuple<const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, + const std::vector<int32_t> &multipliers) +{ + _multipliers = multipliers; + _shifts = shifts; + std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(), + std::bind1st(std::multiplies<int32_t>(), -1)); + return std::make_tuple(_shifts.data(), _multipliers.data()); +} + +template <typename TypeInput, typename TypeOutput, class OutputStage> void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) @@ -435,18 +467,32 @@ void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &a arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); + // Configure requantization info const int32_t a_offset = -a->info()->quantization_info().uniform().offset; const int32_t b_offset = -b->info()->quantization_info().uniform().offset; const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); - const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0, - a_offset, b_offset, os_info.gemmlowp_offset, - -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, - os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + arm_gemm::Requantize32 gemm_requant_info{}; + if(os_info.gemmlowp_shifts.size() > 1) + { + const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + std::get<0>(requantize_data), std::get<1>(requantize_data), + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } + else + { + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } - // Create arm_gemm fallback - auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>(); + // Configure fallback fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } @@ -484,7 +530,6 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32, "Only S32 output supported for QASYMM8_SIGNED input"); return Status{}; } @@ -524,7 +569,14 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const break; case DataType::S8: case DataType::QASYMM8_SIGNED: - create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + if(d->info()->data_type() == DataType::S32) + { + create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } + else + { + create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 440f043527..38481afe88 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -119,7 +119,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, case DataType::U8: case DataType::S8: { - if(a_to_use->info()->data_type() == DataType::QASYMM8 && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { _asm_glue.configure(a_to_use, b, c, output, gemm_info); _fused_assembly_path = _asm_glue.is_configured(); |