aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp438
1 files changed, 371 insertions, 67 deletions
diff --git a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp
index 46e53cec12..3ada3a3c4f 100644
--- a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp
@@ -72,6 +72,58 @@ inline int32x4x4_t load(const int32_t *ptr, int32_t x)
};
}
+inline int32x4x4_t add_s32(int32x4x4_t a, int32x4_t b)
+{
+ return
+ {
+ {
+ vaddq_s32(a.val[0], b),
+ vaddq_s32(a.val[1], b),
+ vaddq_s32(a.val[2], b),
+ vaddq_s32(a.val[3], b)
+ }
+ };
+}
+
+inline int32x4x4_t add_s32(int32x4x4_t a, int32x4x4_t b)
+{
+ return
+ {
+ {
+ vaddq_s32(a.val[0], b.val[0]),
+ vaddq_s32(a.val[1], b.val[1]),
+ vaddq_s32(a.val[2], b.val[2]),
+ vaddq_s32(a.val[3], b.val[3])
+ }
+ };
+}
+
+inline int32x4x4_t mul_s32(int32x4x4_t &a, int32_t mul_scalar)
+{
+ return
+ {
+ {
+ vmulq_n_s32(a.val[0], mul_scalar),
+ vmulq_n_s32(a.val[1], mul_scalar),
+ vmulq_n_s32(a.val[2], mul_scalar),
+ vmulq_n_s32(a.val[3], mul_scalar)
+ }
+ };
+}
+
+inline int32x4x4_t mul_s32(int32x4x4_t &a, const int32_t *multilpier)
+{
+ return
+ {
+ {
+ vmulq_s32(a.val[0], vld1q_s32(multilpier)),
+ vmulq_s32(a.val[1], vld1q_s32(multilpier + 4)),
+ vmulq_s32(a.val[2], vld1q_s32(multilpier + 8)),
+ vmulq_s32(a.val[3], vld1q_s32(multilpier + 12))
+ }
+ };
+}
+
inline int32x4x4_t get_a_offset(const int32_t *vector_sum_col_ptr, int32_t a_offset, int32_t x)
{
int32x4x4_t a_offset_term_s32 = load(vector_sum_col_ptr, x);
@@ -141,6 +193,82 @@ inline uint8x16_t finalize_quantization_floating_point(int32x4x4_t &in_s32, int3
return out_u8;
}
+template <bool is_bounded_relu>
+inline int8x16_t finalize_quantization_floating_point(int32x4x4_t &in_s32, int32x4_t result_shift_s32, int8x16_t min_s8, int8x16_t max_s8)
+{
+ const static int32x4_t zero_s32 = vdupq_n_s32(0);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], result_shift_s32);
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], result_shift_s32);
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], result_shift_s32);
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], result_shift_s32);
+
+ // Saturate negative values
+ in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
+ in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
+ in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
+ in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
+
+ // Convert S32 to S16
+ const int16x8x2_t in_s16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
+ }
+ };
+
+ // Convert S16 to S8
+ int8x16_t out_s8 = vcombine_s8(vqmovn_s16(in_s16.val[0]), vqmovn_s16(in_s16.val[1]));
+
+ if(is_bounded_relu)
+ {
+ out_s8 = vmaxq_s8(out_s8, min_s8);
+ out_s8 = vminq_s8(out_s8, max_s8);
+ }
+
+ return out_s8;
+}
+
+template <bool is_bounded_relu>
+inline int8x16_t finalize_quantization_floating_point(int32x4x4_t &in_s32, int32x4x4_t result_shift_s32, int8x16_t min_s8, int8x16_t max_s8)
+{
+ const static int32x4_t zero_s32 = vdupq_n_s32(0);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], vnegq_s32(result_shift_s32.val[0]));
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], vnegq_s32(result_shift_s32.val[1]));
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], vnegq_s32(result_shift_s32.val[2]));
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], vnegq_s32(result_shift_s32.val[3]));
+
+ // Saturate negative values
+ in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
+ in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
+ in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
+ in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
+
+ // Convert S32 to S16
+ const int16x8x2_t in_s16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
+ }
+ };
+
+ // Convert S16 to S8
+ int8x16_t out_s8 = vcombine_s8(vqmovn_s16(in_s16.val[0]), vqmovn_s16(in_s16.val[1]));
+
+ if(is_bounded_relu)
+ {
+ out_s8 = vmaxq_s8(out_s8, min_s8);
+ out_s8 = vminq_s8(out_s8, max_s8);
+ }
+
+ return out_s8;
+}
+
inline Window get_win_vector_sum(const Window &window)
{
Window win_vector_sum(window);
@@ -172,50 +300,12 @@ inline Iterator get_bias_it(const Window &window, const ITensor *bias)
return bias_it;
}
-inline int32x4x4_t add_s32(int32x4x4_t a, int32x4_t b)
-{
- return
- {
- {
- vaddq_s32(a.val[0], b),
- vaddq_s32(a.val[1], b),
- vaddq_s32(a.val[2], b),
- vaddq_s32(a.val[3], b)
- }
- };
-}
-
-inline int32x4x4_t add_s32(int32x4x4_t a, int32x4x4_t b)
-{
- return
- {
- {
- vaddq_s32(a.val[0], b.val[0]),
- vaddq_s32(a.val[1], b.val[1]),
- vaddq_s32(a.val[2], b.val[2]),
- vaddq_s32(a.val[3], b.val[3])
- }
- };
-}
-
-inline int32x4x4_t mul_s32(int32x4x4_t &a, int32_t mul_scalar)
-{
- return
- {
- {
- vmulq_n_s32(a.val[0], mul_scalar),
- vmulq_n_s32(a.val[1], mul_scalar),
- vmulq_n_s32(a.val[2], mul_scalar),
- vmulq_n_s32(a.val[3], mul_scalar)
- }
- };
-}
-
template <bool has_a_offset, bool has_b_offset, bool has_bias, bool is_bounded_relu, bool is_fixed_point>
inline void run_offset_contribution_output_stage_window(const int32_t *vector_sum_col_ptr, const int32_t *vector_sum_row_ptr, const int32_t *bias_ptr, Iterator mm_result_it, Iterator out_it,
const int32x4_t result_offset_s32, const int32x4_t result_shift_s32, uint8x16_t min_u8, uint8x16_t max_u8,
int32_t a_offset, int32_t b_offset, int32_t k_offset,
- GEMMLowpOutputStageInfo output_stage, int window_step_x, int window_start_x, int window_end_x)
+ int32_t multiplier, int32_t shift, int32_t offset, int32_t min_bound, int32_t max_bound,
+ int window_step_x, int window_start_x, int window_end_x)
{
int32x4x4_t offset_term_s32 = { 0, 0, 0, 0 };
if(!is_fixed_point)
@@ -251,12 +341,12 @@ inline void run_offset_contribution_output_stage_window(const int32_t *vector_su
}
if(!is_fixed_point)
{
- in_s32 = mul_s32(in_s32, output_stage.gemmlowp_multiplier);
+ in_s32 = mul_s32(in_s32, multiplier);
}
if(is_fixed_point)
{
- vst1q_u8(out_it.ptr() + x, finalize_quantization<is_bounded_relu>(in_s32, output_stage.gemmlowp_multiplier, output_stage.gemmlowp_shift, result_offset_s32, min_u8, max_u8));
+ vst1q_u8(out_it.ptr() + x, finalize_quantization<is_bounded_relu>(in_s32, multiplier, shift, result_offset_s32, min_u8, max_u8));
}
else
{
@@ -280,24 +370,99 @@ inline void run_offset_contribution_output_stage_window(const int32_t *vector_su
if(is_fixed_point)
{
// Finalize and store the result
- *(out_it.ptr() + x) = finalize_quantization<is_bounded_relu>(in_value, output_stage.gemmlowp_multiplier, output_stage.gemmlowp_shift,
- output_stage.gemmlowp_offset, static_cast<uint8_t>(output_stage.gemmlowp_min_bound), static_cast<uint8_t>(output_stage.gemmlowp_max_bound));
+ *(out_it.ptr() + x) = finalize_quantization<is_bounded_relu>(in_value, multiplier, shift, offset, static_cast<uint8_t>(min_bound), static_cast<uint8_t>(max_bound));
}
else
{
// Finalize quantization
- in_value = (in_value * output_stage.gemmlowp_multiplier) >> output_stage.gemmlowp_shift;
+ in_value = (in_value * multiplier) >> shift;
// Bound and store the result
if(is_bounded_relu)
{
- in_value = static_cast<uint8_t>(std::max<int32_t>(output_stage.gemmlowp_min_bound, std::min<int32_t>(output_stage.gemmlowp_max_bound, in_value)));
+ in_value = static_cast<uint8_t>(std::max<int32_t>(min_bound, std::min<int32_t>(max_bound, in_value)));
}
*(out_it.ptr() + x) = static_cast<uint8_t>(std::max<int32_t>(0, std::min<int32_t>(255, in_value)));
}
}
}
+template <bool has_a_offset, bool has_bias, bool is_bounded_relu, bool is_fixed_point>
+inline void run_offset_contribution_output_stage_window_symm(const int32_t *vector_sum_col_ptr, const int32_t *bias_ptr, Iterator mm_result_it, Iterator out_it,
+ const int32_t *result_multipliers, const int32_t *result_shifts,
+ const int32x4_t result_offset, int8x16_t min_s8, int8x16_t max_s8,
+ int32_t a_offset, int32_t offset, int32_t min_bound, int32_t max_bound,
+ int window_step_x, int window_start_x, int window_end_x)
+{
+ int32x4x4_t offset_term_s32 = { 0, 0, 0, 0 };
+ if(!is_fixed_point)
+ {
+ // Combine quantization offset with other offsets.
+ offset_term_s32 = add_s32(offset_term_s32, result_offset);
+ }
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ int32x4x4_t in_s32 = load_results_input(mm_result_it, x);
+
+ if(has_a_offset)
+ {
+ in_s32 = add_s32(in_s32, get_a_offset(vector_sum_col_ptr, a_offset, x));
+ }
+ if(has_bias)
+ {
+ in_s32 = add_s32(in_s32, load(bias_ptr, x));
+ }
+ if(!is_fixed_point)
+ {
+ in_s32 = add_s32(in_s32, offset_term_s32);
+ in_s32 = mul_s32(in_s32, result_multipliers + x);
+ }
+
+ if(is_fixed_point)
+ {
+ vst1q_s8(reinterpret_cast<int8_t *>(out_it.ptr() + x), finalize_quantization_symm<is_bounded_relu>(in_s32, load(result_multipliers, x), load(result_shifts, x), result_offset, min_s8, max_s8));
+ }
+ else
+ {
+ vst1q_s8(reinterpret_cast<int8_t *>(out_it.ptr() + x), finalize_quantization_floating_point<is_bounded_relu>(in_s32, load(result_shifts, x), min_s8, max_s8));
+ }
+ }
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ int32_t in_value = *(reinterpret_cast<const int32_t *>(mm_result_it.ptr()) + x) + wrapper::vgetlane(offset_term_s32.val[0], 0);
+
+ if(has_a_offset)
+ {
+ in_value += (*(vector_sum_col_ptr + x) * a_offset);
+ }
+ if(has_bias)
+ {
+ in_value += *(bias_ptr + x);
+ }
+
+ if(is_fixed_point)
+ {
+ // Finalize and store the result
+ *(out_it.ptr() + x) = finalize_quantization<is_bounded_relu>(in_value, result_multipliers[x], result_shifts[x], offset, static_cast<int8_t>(min_bound), static_cast<int8_t>(max_bound));
+ }
+ else
+ {
+ // Finalize quantization
+ in_value = (in_value * result_multipliers[x]) >> (-result_shifts[x]);
+
+ // Bound and store the result
+ if(is_bounded_relu)
+ {
+ in_value = static_cast<int8_t>(std::max<int32_t>(min_bound, std::min<int32_t>(max_bound, in_value)));
+ }
+ *(out_it.ptr() + x) = static_cast<int8_t>(std::max<int32_t>(-128, std::min<int32_t>(127, in_value)));
+ }
+ }
+}
+
template <bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point>
void run_offset_contribution_output_stage(const Window &window,
const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output,
@@ -307,10 +472,16 @@ void run_offset_contribution_output_stage(const Window &window,
const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
- const int32x4_t result_offset_s32 = vdupq_n_s32(output_stage.gemmlowp_offset);
- const int32x4_t result_shift_s32 = vdupq_n_s32(is_fixed_point ? output_stage.gemmlowp_shift : -output_stage.gemmlowp_shift);
- const uint8x16_t min_u8 = vdupq_n_u8(static_cast<uint8_t>(output_stage.gemmlowp_min_bound));
- const uint8x16_t max_u8 = vdupq_n_u8(static_cast<uint8_t>(output_stage.gemmlowp_max_bound));
+ const int32_t multiplier = output_stage.gemmlowp_multiplier;
+ const int32_t shift = output_stage.gemmlowp_shift;
+ const int32_t offset = output_stage.gemmlowp_offset;
+ const int32_t min_bound = output_stage.gemmlowp_min_bound;
+ const int32_t max_bound = output_stage.gemmlowp_max_bound;
+
+ const int32x4_t result_offset_s32 = vdupq_n_s32(offset);
+ const int32x4_t result_shift_s32 = vdupq_n_s32(is_fixed_point ? shift : -shift);
+ const uint8x16_t min_u8 = vdupq_n_u8(static_cast<uint8_t>(min_bound));
+ const uint8x16_t max_u8 = vdupq_n_u8(static_cast<uint8_t>(max_bound));
const int window_step_x = 16;
const auto window_start_x = static_cast<int>(window.x().start());
@@ -349,7 +520,8 @@ void run_offset_contribution_output_stage(const Window &window,
run_offset_contribution_output_stage_window<true, true, true, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, vector_sum_row_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it,
out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_col_it, vector_sum_row_it, bias_it, mm_result_it, out_it);
}
@@ -363,7 +535,8 @@ void run_offset_contribution_output_stage(const Window &window,
+ id.y() + (id.z() % depth_input) * height_input;
run_offset_contribution_output_stage_window<true, true, false, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, vector_sum_row_ptr, nullptr, mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_col_it, vector_sum_row_it, mm_result_it, out_it);
}
@@ -386,7 +559,8 @@ void run_offset_contribution_output_stage(const Window &window,
+ id.y() + (id.z() % depth_input) * height_input;
run_offset_contribution_output_stage_window<false, true, true, is_bounded_relu, is_fixed_point>(nullptr, vector_sum_row_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_row_it, bias_it, mm_result_it, out_it);
}
@@ -399,7 +573,8 @@ void run_offset_contribution_output_stage(const Window &window,
+ id.y() + (id.z() % depth_input) * height_input;
run_offset_contribution_output_stage_window<false, true, false, is_bounded_relu, is_fixed_point>(nullptr, vector_sum_row_ptr, nullptr, mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_row_it, mm_result_it, out_it);
}
@@ -422,7 +597,8 @@ void run_offset_contribution_output_stage(const Window &window,
const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
run_offset_contribution_output_stage_window<true, false, true, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, nullptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_col_it, bias_it, mm_result_it, out_it);
}
@@ -434,7 +610,8 @@ void run_offset_contribution_output_stage(const Window &window,
const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
run_offset_contribution_output_stage_window<true, false, false, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, nullptr, nullptr, mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
vector_sum_col_it, mm_result_it, out_it);
}
@@ -448,7 +625,8 @@ void run_offset_contribution_output_stage(const Window &window,
{
run_offset_contribution_output_stage_window<false, false, true, is_bounded_relu, is_fixed_point>(nullptr, nullptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
bias_it, mm_result_it, out_it);
}
@@ -458,7 +636,110 @@ void run_offset_contribution_output_stage(const Window &window,
{
run_offset_contribution_output_stage_window<false, false, false, is_bounded_relu, is_fixed_point>(nullptr, nullptr, nullptr, mm_result_it, out_it,
result_offset_s32, result_shift_s32, min_u8, max_u8, a_offset, b_offset, k_offset,
- output_stage, window_step_x, window_start_x, window_end_x);
+ multiplier, shift, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
+ },
+ mm_result_it, out_it);
+ }
+ return;
+ }
+}
+
+template <bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point>
+void run_offset_contribution_output_stage_symm(const Window &window,
+ const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output,
+ int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col,
+ GEMMLowpOutputStageInfo output_stage)
+{
+ ARM_COMPUTE_UNUSED(vector_sum_row, b_offset, k_offset);
+
+ const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
+
+ const int32_t offset = output_stage.gemmlowp_offset;
+ const int32_t min_bound = output_stage.gemmlowp_min_bound;
+ const int32_t max_bound = output_stage.gemmlowp_max_bound;
+
+ const int32_t *result_multipliers = output_stage.gemmlowp_multipliers.data();
+ const int32_t *result_shifts = output_stage.gemmlowp_shifts.data();
+ const int32x4_t result_offset_s32 = vdupq_n_s32(offset);
+ const int8x16_t min_s8 = vdupq_n_s8(static_cast<int8_t>(min_bound));
+ const int8x16_t max_s8 = vdupq_n_s8(static_cast<int8_t>(max_bound));
+
+ const int window_step_x = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ Window win(window);
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Window collapsed_window = win.collapse_if_possible(win, Window::DimZ);
+
+ Iterator mm_result_it(mm_result, win);
+ Iterator out_it(output, win);
+
+ if(a_offset != 0)
+ {
+ ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_col);
+
+ Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ if(bias != nullptr)
+ {
+ Iterator bias_it = get_bias_it(collapsed_window, bias);
+ execute_window_loop(collapsed_window, [&](const Coordinates & id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+ run_offset_contribution_output_stage_window_symm<true, true, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
+ result_multipliers, result_shifts,
+ result_offset_s32, min_s8, max_s8,
+ a_offset, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
+ },
+ vector_sum_col_it, bias_it, mm_result_it, out_it);
+ }
+ else
+ {
+ execute_window_loop(collapsed_window, [&](const Coordinates & id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+ run_offset_contribution_output_stage_window_symm<true, false, is_bounded_relu, is_fixed_point>(vector_sum_col_ptr, nullptr, mm_result_it, out_it,
+ result_multipliers, result_shifts,
+ result_offset_s32, min_s8, max_s8,
+ a_offset, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
+ },
+ vector_sum_col_it, mm_result_it, out_it);
+ }
+ }
+ else
+ {
+ if(bias != nullptr)
+ {
+ Iterator bias_it = get_bias_it(collapsed_window, bias);
+ execute_window_loop(collapsed_window, [&](const Coordinates &)
+ {
+ run_offset_contribution_output_stage_window_symm<false, true, is_bounded_relu, is_fixed_point>(nullptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
+ result_multipliers, result_shifts,
+ result_offset_s32, min_s8, max_s8,
+ a_offset, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
+ },
+ bias_it, mm_result_it, out_it);
+ }
+ else
+ {
+ execute_window_loop(collapsed_window, [&](const Coordinates &)
+ {
+ run_offset_contribution_output_stage_window_symm<false, false, is_bounded_relu, is_fixed_point>(nullptr, nullptr, mm_result_it, out_it,
+ result_multipliers, result_shifts,
+ result_offset_s32, min_s8, max_s8,
+ a_offset, offset, min_bound, max_bound,
+ window_step_x, window_start_x, window_end_x);
},
mm_result_it, out_it);
}
@@ -470,8 +751,18 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto
int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_max_bound > 255);
- ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_min_bound < 0 || output_stage.gemmlowp_min_bound > output_stage.gemmlowp_max_bound);
+ if(output->data_type() == DataType::QASYMM8)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_max_bound > 255);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_min_bound < 0);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_max_bound > 127);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_min_bound < -128);
+ ARM_COMPUTE_RETURN_ERROR_ON(mm_result->dimension(0) > 1 && output_stage.gemmlowp_multipliers.size() > 1 && b_offset != 0);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_min_bound > output_stage.gemmlowp_max_bound);
ARM_COMPUTE_RETURN_ERROR_ON(output_stage.type != GEMMLowpOutputStageType::QUANTIZE_DOWN && output_stage.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT);
if(bias != nullptr)
@@ -525,7 +816,7 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mm_result, output);
}
@@ -551,7 +842,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *mm_result,
}
NEGEMMLowpOffsetContributionOutputStageKernel::NEGEMMLowpOffsetContributionOutputStageFunction
-get_configured_function(const ITensor *mm_result, const ITensor *vector_sum_row, GEMMLowpOutputStageInfo output_stage)
+get_configured_function(const ITensor *mm_result, const ITensor *vector_sum_row, const ITensor *output, GEMMLowpOutputStageInfo output_stage)
{
static std::map<uint8_t, NEGEMMLowpOffsetContributionOutputStageKernel::NEGEMMLowpOffsetContributionOutputStageFunction> map_function =
{
@@ -562,7 +853,15 @@ get_configured_function(const ITensor *mm_result, const ITensor *vector_sum_row,
{ 4, &run_offset_contribution_output_stage<false, false, true> },
{ 5, &run_offset_contribution_output_stage<true, false, true> },
{ 6, &run_offset_contribution_output_stage<false, true, true> },
- { 7, &run_offset_contribution_output_stage<true, true, true> }
+ { 7, &run_offset_contribution_output_stage_symm<true, true, true> },
+ { 8, &run_offset_contribution_output_stage_symm<false, false, false> },
+ { 9, &run_offset_contribution_output_stage_symm<true, false, false> },
+ { 10, &run_offset_contribution_output_stage_symm<false, true, false> },
+ { 11, &run_offset_contribution_output_stage_symm<true, true, false> },
+ { 12, &run_offset_contribution_output_stage_symm<false, false, true> },
+ { 13, &run_offset_contribution_output_stage_symm<true, false, true> },
+ { 14, &run_offset_contribution_output_stage_symm<false, true, true> },
+ { 15, &run_offset_contribution_output_stage_symm<true, true, true> }
};
// Check if input is a 3D reinterpretation
@@ -574,11 +873,15 @@ get_configured_function(const ITensor *mm_result, const ITensor *vector_sum_row,
const bool is_bounded_relu = ((output_stage.gemmlowp_min_bound != output_stage.gemmlowp_max_bound)
&& !(output_stage.gemmlowp_min_bound == 0 && output_stage.gemmlowp_max_bound == 255));
+ // Check if we need to perform fixed point requantization
const bool is_fixed_point = output_stage.type != GEMMLowpOutputStageType::QUANTIZE_DOWN;
+ // Check if symmetric per-channel execution
+ const bool is_symm = output->info()->data_type() == DataType::QASYMM8_SIGNED;
+
// key acts as a bitset, setting the first bit on reinterpret_as_3d,
// the second on is_bounded_relu, and the third on is_fixed_point.
- uint8_t key = (reinterpret_as_3d ? 1UL : 0UL) | ((is_bounded_relu ? 1UL : 0UL) << 1) | ((is_fixed_point ? 1UL : 0UL) << 2);
+ uint8_t key = (reinterpret_as_3d ? 1UL : 0UL) | ((is_bounded_relu ? 1UL : 0UL) << 1) | ((is_fixed_point ? 1UL : 0UL) << 2) | ((is_symm ? 1UL : 0UL) << 3);
return map_function.find(key)->second;
}
} // namespace
@@ -591,8 +894,9 @@ NEGEMMLowpOffsetContributionOutputStageKernel::NEGEMMLowpOffsetContributionOutpu
}
void NEGEMMLowpOffsetContributionOutputStageKernel::configure(const ITensor *mm_result, const ITensor *vector_sum_col,
- const ITensor *vector_sum_row, const ITensor *bias, ITensor *output, int32_t k,
- int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage)
+ const ITensor *vector_sum_row, const ITensor *bias, ITensor *output,
+ int32_t k, int32_t a_offset, int32_t b_offset,
+ GEMMLowpOutputStageInfo output_stage)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(mm_result, output);
@@ -627,7 +931,7 @@ void NEGEMMLowpOffsetContributionOutputStageKernel::configure(const ITensor *mm_
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
- _function = get_configured_function(mm_result, vector_sum_row, output_stage);
+ _function = get_configured_function(mm_result, vector_sum_row, output, output_stage);
}
Status NEGEMMLowpOffsetContributionOutputStageKernel::validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col,