From 71ac9037abce1c6c4af42c485d5395dd6fd79a5a Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 14 Nov 2019 14:31:44 +0000 Subject: COMPMID-2923 Integrate arm_gemm per channel quantization Signed-off-by: Michalis Spyrou Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d Reviewed-on: https://review.mlplatform.org/c/2548 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 70 +++++++++++++++++++--- 1 file changed, 61 insertions(+), 9 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 65d800cb0c..4e43d04446 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -164,6 +164,23 @@ public: arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); + /** Set requantization shifts to be used + * + * @param[in] shifts Requantization shifts + * + * @return Pointer to the shift data + */ + /** Set requantization data to be used + * + * + * @param shifts Requantization shifts + * @param multipliers Requantization multipliers + * + * @return A tuple with the pointers to the shift and multiplier data respectively + */ + std::tuple set_requantize_data(const std::vector &shifts, + const std::vector &multipliers); + // Inherited methods overridden: void run() override; void prepare() override; @@ -212,8 +229,23 @@ private: FallbackTransform _weights_transform{}; /** GEMM kernel description */ arm_gemm::KernelDescription _kernel_info{}; + /** Per channel quantization shifts */ + std::vector _shifts{}; + /** Per channel quantization multipliers */ + std::vector _multipliers{}; }; +template +std::tuple Fallback::set_requantize_data(const std::vector &shifts, + const std::vector &multipliers) +{ + _multipliers = multipliers; + _shifts = shifts; + std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(), + std::bind1st(std::multiplies(), -1)); + return std::make_tuple(_shifts.data(), _multipliers.data()); +} + template void Fallback::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, @@ -435,18 +467,32 @@ void create_arm_gemm_quant(std::unique_ptr &a arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique>(); + // Configure requantization info const int32_t a_offset = -a->info()->quantization_info().uniform().offset; const int32_t b_offset = -b->info()->quantization_info().uniform().offset; const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); - const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0, - a_offset, b_offset, os_info.gemmlowp_offset, - -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, - os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + arm_gemm::Requantize32 gemm_requant_info{}; + if(os_info.gemmlowp_shifts.size() > 1) + { + const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + std::get<0>(requantize_data), std::get<1>(requantize_data), + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } + else + { + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } - // Create arm_gemm fallback - auto fallback = support::cpp14::make_unique>(); + // Configure fallback fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } @@ -484,7 +530,6 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32, "Only S32 output supported for QASYMM8_SIGNED input"); return Status{}; } @@ -524,7 +569,14 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const break; case DataType::S8: case DataType::QASYMM8_SIGNED: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + if(d->info()->data_type() == DataType::S32) + { + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } + else + { + create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -- cgit v1.2.1