aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormorgolock <pablo.tello@arm.com>2020-08-10 16:44:18 +0100
committerPablo Marquez <pablo.tello@arm.com>2020-08-19 10:53:52 +0000
commit0bc80daf319ea3219ca6a6fa200118dc859ee460 (patch)
tree32d9294334247d62b20b347ffb01e37bd1d5edd1 /src
parentc58f0ad7ac6d91f2789a78049d3cec7355113f9a (diff)
downloadComputeLibrary-0bc80daf319ea3219ca6a6fa200118dc859ee460.tar.gz
MLCE-229: Support for negative shifts in asm kernels
Change-Id: I2c5e98aae7698963f106d7423df0e65cd00ee2a9 Signed-off-by: morgolock <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3710 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp126
-rw-r--r--src/core/NEON/kernels/assembly/arm_gemm.hpp41
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp31
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp29
4 files changed, 152 insertions, 75 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index e50dca7f1f..201bd9dc2c 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -55,15 +55,16 @@ namespace {
* column is set up in any case (and it is hoped that the compiler can elide
* the needless movs in the per-layer case).
*/
-template<bool do_shift_correction, bool per_channel>
+template<bool do_shift_correction, bool per_channel, bool do_left_shift>
void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
- const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
- const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift);
- const int32x4_t v_minval = vdupq_n_s32(qp.minval);
- const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
- const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
+ const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
+ const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
+ const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
+ const int32x4_t v_minval = vdupq_n_s32(qp.minval);
+ const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
+ const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
/* To make sure we have plenty of accumulators, compute two rows at a
* time. If the number of rows is odd, compute the bottom row twice to
@@ -77,8 +78,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
unsigned int odds=(width % 4);
const int32_t *colptr = col_bias;
- const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
- const int32_t *perch_shift_ptr = qp.per_channel_shifts + start_col;
+ const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
+ const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
+ const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
const int32_t *in_ptr = input + (row * in_stride);
int8_t *out_ptr = output + (row * out_stride);
@@ -112,6 +114,11 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
int32x4_t v_shf2;
int32x4_t v_shf3;
+ int32x4_t v_shf0l;
+ int32x4_t v_shf1l;
+ int32x4_t v_shf2l;
+ int32x4_t v_shf3l;
+
if (per_channel) {
v_mul0 = vld1q_s32(perch_mul_ptr);
v_mul1 = vld1q_s32(perch_mul_ptr + 4);
@@ -124,9 +131,17 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_shf2 = vld1q_s32(perch_shift_ptr + 8);
v_shf3 = vld1q_s32(perch_shift_ptr + 12);
perch_shift_ptr += 16;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
+ v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
+ v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
+ }
} else {
v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
- v_shf0=v_shf1=v_shf2=v_shf3=v_shift;
+ v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
+ v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
}
// Load column pointers
@@ -171,7 +186,22 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in12 = vaddq_s32(v_in12, v_col2);
v_in13 = vaddq_s32(v_in13, v_col3);
- // Quantize - start with multiply
+ // Quantize
+
+ // If a left shift is needed it needs to happen first.
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+ v_in01 = vrshlq_s32(v_in01, v_shf1l);
+ v_in02 = vrshlq_s32(v_in02, v_shf2l);
+ v_in03 = vrshlq_s32(v_in03, v_shf3l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ v_in11 = vrshlq_s32(v_in11, v_shf1l);
+ v_in12 = vrshlq_s32(v_in12, v_shf2l);
+ v_in13 = vrshlq_s32(v_in13, v_shf3l);
+ }
+
+ // Multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
@@ -273,6 +303,7 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
while (regs--) {
int32x4_t v_mul0;
int32x4_t v_shf0;
+ int32x4_t v_shf0l;
if (per_channel) {
v_mul0 = vld1q_s32(perch_mul_ptr);
@@ -280,9 +311,15 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_shf0 = vld1q_s32(perch_shift_ptr);
perch_shift_ptr += 4;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ perch_shiftl_ptr += 4;
+ }
} else {
v_mul0=v_mul;
- v_shf0=v_shift;
+ v_shf0=v_right_shift;
+ v_shf0l=v_left_shift;
}
// Load column pointers
int32x4_t v_col0 = vld1q_s32(colptr);
@@ -306,7 +343,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in10 = vaddq_s32(v_in10, v_col0);
- // Quantize - start with multiply
+ // Quantize - start with (optional) left shift
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ }
+
+ // Then multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -358,10 +402,12 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
int32x4_t v_in10 = vdupq_n_s32(0);
int32x4_t v_mul0 = vdupq_n_s32(0);
int32x4_t v_shf0 = vdupq_n_s32(0);
+ int32x4_t v_shf0l = vdupq_n_s32(0);
if (!per_channel) {
v_mul0 = v_mul;
- v_shf0 = v_shift;
+ v_shf0 = v_right_shift;
+ v_shf0l = v_left_shift;
}
do {
@@ -371,6 +417,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
+ }
}
if (odds == 1) { break; }
@@ -380,6 +429,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
+ }
}
if (odds == 2) { break; }
@@ -389,6 +441,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
+ }
}
} while (0);
@@ -402,7 +457,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in10 = vaddq_s32(v_in10, v_col0);
- // Quantize - start with multiply
+ // Quantize - start with (optional) left shift
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ }
+
+ // Then multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -464,19 +526,39 @@ void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned in
const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
if (qp.per_channel_requant) {
if (qp.minval >= qp.c_offset) {
- requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ if (qp.per_channel_left_shifts) {
+ requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
} else {
- requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ if (qp.per_channel_left_shifts) {
+ requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
}
} else {
if (qp.minval >= qp.c_offset) {
- requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ if (qp.per_layer_left_shift > 0) {
+ requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
} else {
- requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ if (qp.per_layer_left_shift > 0) {
+ requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
}
}
}
diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp
index 58db511547..f6421c12ab 100644
--- a/src/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -122,38 +122,41 @@ public:
struct Requantize32
{
public:
- const int32_t *bias = nullptr;
- size_t bias_multi_stride = 0;
- int32_t a_offset = 0;
- int32_t b_offset = 0;
- int32_t c_offset = 0;
- bool per_channel_requant = false;
- int32_t per_layer_shift = 0;
- int32_t per_layer_mul = 0;
- const int32_t *per_channel_shifts = nullptr;
- const int32_t *per_channel_muls = nullptr;
- int32_t minval = 0;
- int32_t maxval = 0;
+ const int32_t *bias = nullptr;
+ size_t bias_multi_stride = 0;
+ int32_t a_offset = 0;
+ int32_t b_offset = 0;
+ int32_t c_offset = 0;
+ bool per_channel_requant = false;
+ int32_t per_layer_left_shift = 0;
+ int32_t per_layer_right_shift = 0;
+ int32_t per_layer_mul = 0;
+ const int32_t *per_channel_left_shifts = nullptr;
+ const int32_t *per_channel_right_shifts = nullptr;
+ const int32_t *per_channel_muls = nullptr;
+ int32_t minval = 0;
+ int32_t maxval = 0;
Requantize32() = default;
// Constructor for per-tensor quantization
Requantize32(const int32_t *bias, size_t bias_multi_stride,
int32_t a_offset, int32_t b_offset, int32_t c_offset,
- int32_t requant_shift, int32_t requant_mul,
- int32_t minv, int32_t maxv)
- : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
- minval(minv), maxval(maxv)
+ int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
+ : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max(requant_shift, int32_t(0))),
+ per_layer_right_shift(std::min(requant_shift, int32_t(0))), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
{
}
// Constructor for per-channel quantization
Requantize32(const int32_t *bias, size_t bias_multi_stride,
int32_t a_offset, int32_t b_offset, int32_t c_offset,
- const int32_t *requant_shifts, const int32_t *requant_muls,
+ const int32_t *requant_left_shifts,
+ const int32_t *requant_right_shifts,
+ const int32_t *requant_muls,
int32_t minv, int32_t maxv)
- : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_shifts(requant_shifts),
- per_channel_muls(requant_muls), minval(minv), maxval(maxv)
+ : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
+ per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
{
}
};
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 3b9dde2bf7..eeea3a45ee 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -182,8 +182,8 @@ public:
*
* @return A tuple with the pointers to the shift and multiplier data respectively
*/
- std::tuple<const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers);
+ std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
void run() override;
@@ -235,18 +235,29 @@ private:
arm_gemm::KernelDescription _kernel_info{};
/** Per channel quantization shifts */
std::vector<int32_t> _shifts{};
+ std::vector<int32_t> right_shifts{};
+ std::vector<int32_t> left_shifts{};
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
-std::tuple<const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
- _multipliers = multipliers;
- _shifts = shifts;
- std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(), std::negate<int32_t>());
- return std::make_tuple(_shifts.data(), _multipliers.data());
+ _multipliers = multipliers;
+ _shifts = shifts;
+ bool need_left = false;
+ for(const auto s : _shifts)
+ {
+ left_shifts.push_back(std::max(-s, int32_t(0)));
+ right_shifts.push_back(std::min(-s, int32_t(0)));
+ if(s > 0 && !need_left)
+ {
+ need_left = true;
+ }
+ }
+ return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -498,7 +509,9 @@ void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &a
const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
a_offset, b_offset, os_info.gemmlowp_offset,
- std::get<0>(requantize_data), std::get<1>(requantize_data),
+ (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr,
+ std::get<2>(requantize_data),
+ std::get<3>(requantize_data),
os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
}
else
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index dada6d16da..83db146a8a 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -117,18 +117,8 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
{
if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- // Result shifts < 0 are not supported by asm kernels
- const std::vector<int32_t> &shifts = info.gemmlowp_output_stage().gemmlowp_shifts;
- const bool is_asm_supported = info.gemmlowp_output_stage().gemmlowp_shift >= 0
- && std::all_of(shifts.cbegin(), shifts.cend(), [](int32_t val)
- {
- return val >= 0;
- });
- if(is_asm_supported)
- {
- _asm_glue.configure(a_to_use, b, c, output, gemm_info);
- _fused_assembly_path = _asm_glue.is_configured();
- }
+ _asm_glue.configure(a_to_use, b, c, output, gemm_info);
+ _fused_assembly_path = _asm_glue.is_configured();
}
else
{
@@ -339,19 +329,8 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
bool run_optimised_requantized = false;
if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- // Result shifts < 0 are not supported by asm kernels
- const std::vector<int32_t> &shifts = info.gemmlowp_output_stage().gemmlowp_shifts;
- const bool is_asm_supported = info.gemmlowp_output_stage().gemmlowp_shift >= 0
- && std::all_of(shifts.cbegin(), shifts.cend(), [](int32_t val)
- {
- return val >= 0;
- });
-
- if(is_asm_supported)
- {
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info));
- run_optimised_requantized = run_optimised;
- }
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info));
+ run_optimised_requantized = run_optimised;
}
else
{