aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp70
1 files changed, 61 insertions, 9 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 65d800cb0c..4e43d04446 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -164,6 +164,23 @@ public:
arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
+ /** Set requantization shifts to be used
+ *
+ * @param[in] shifts Requantization shifts
+ *
+ * @return Pointer to the shift data
+ */
+ /** Set requantization data to be used
+ *
+ *
+ * @param shifts Requantization shifts
+ * @param multipliers Requantization multipliers
+ *
+ * @return A tuple with the pointers to the shift and multiplier data respectively
+ */
+ std::tuple<const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers);
+
// Inherited methods overridden:
void run() override;
void prepare() override;
@@ -212,9 +229,24 @@ private:
FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
/** GEMM kernel description */
arm_gemm::KernelDescription _kernel_info{};
+ /** Per channel quantization shifts */
+ std::vector<int32_t> _shifts{};
+ /** Per channel quantization multipliers */
+ std::vector<int32_t> _multipliers{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
+std::tuple<const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
+{
+ _multipliers = multipliers;
+ _shifts = shifts;
+ std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(),
+ std::bind1st(std::multiplies<int32_t>(), -1));
+ return std::make_tuple(_shifts.data(), _multipliers.data());
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
@@ -435,18 +467,32 @@ void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &a
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
+
// Configure requantization info
const int32_t a_offset = -a->info()->quantization_info().uniform().offset;
const int32_t b_offset = -b->info()->quantization_info().uniform().offset;
const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage();
- const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0,
- a_offset, b_offset, os_info.gemmlowp_offset,
- -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
- os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ arm_gemm::Requantize32 gemm_requant_info{};
+ if(os_info.gemmlowp_shifts.size() > 1)
+ {
+ const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
+ gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
+ a_offset, b_offset, os_info.gemmlowp_offset,
+ std::get<0>(requantize_data), std::get<1>(requantize_data),
+ os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ }
+ else
+ {
+ gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
+ a_offset, b_offset, os_info.gemmlowp_offset,
+ -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
+ os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ }
- // Create arm_gemm fallback
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
+ // Configure fallback
fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
@@ -484,7 +530,6 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32, "Only S32 output supported for QASYMM8_SIGNED input");
return Status{};
}
@@ -524,7 +569,14 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ if(d->info()->data_type() == DataType::S32)
+ {
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ }
+ else
+ {
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ }
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC