aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-01-14 15:14:43 +0000
committerManuel Bottini <manuel.bottini@arm.com>2019-01-23 14:11:17 +0000
commit1d4f3853dfd16f55338d772ad757db0ee8710d78 (patch)
tree344a48b5d296fef0e05d345d342ba0733f8db7a2 /src/core/NEON/kernels/NEReductionOperationKernel.cpp
parent734151d20bef56cbedce2ae67945f42cb4e265c8 (diff)
downloadComputeLibrary-1d4f3853dfd16f55338d772ad757db0ee8710d78.tar.gz
COMPMID-1760: NEON: Implement Prod
Change-Id: I8062f4ca5ef5cf1a8183ac0834f240bbaf8f695d Reviewed-on: https://review.mlplatform.org/541 Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp154
1 files changed, 141 insertions, 13 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index f0209a32da..506094f0c1 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -362,6 +362,10 @@ struct RedOpX
{
init_res_value = *reinterpret_cast<T *>(input.ptr());
}
+ else if(op == ReductionOperation::PROD)
+ {
+ init_res_value = static_cast<T>(1.f);
+ }
auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
uint32x4x4_t vec_res_idx{ 0 };
@@ -379,6 +383,9 @@ struct RedOpX
case ReductionOperation::SUM:
vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -420,6 +427,17 @@ struct RedOpX
*(reinterpret_cast<T *>(output.ptr())) = res;
break;
}
+ case ReductionOperation::PROD:
+ {
+ auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
+ T res = 1;
+ for(int i = 0; i < S / 2; ++i)
+ {
+ res *= wrapper::vgetlane(carry_res, i);
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::ARG_IDX_MAX:
{
@@ -443,7 +461,13 @@ struct RedOpX_qasymm8
auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
+
uint8x16_t vec_res_value = { 0 };
+
if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
{
vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
@@ -472,6 +496,36 @@ struct RedOpX_qasymm8
vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
break;
}
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
+
+ const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
+ const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
+
+ const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
+ const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
+ const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
+ const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
+
+ auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
+ auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
+ auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
+ auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -497,6 +551,21 @@ struct RedOpX_qasymm8
auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
*(reinterpret_cast<uint32_t *>(output.ptr())) = res;
}
+ else if(op == ReductionOperation::PROD)
+ {
+ auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
+
+ float res = wrapper::vgetlane(carry_res, 0);
+ res *= wrapper::vgetlane(carry_res, 1);
+ res *= wrapper::vgetlane(carry_res, 2);
+ res *= wrapper::vgetlane(carry_res, 3);
+
+ //re-quantize result
+ res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ }
else
{
auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
@@ -535,6 +604,10 @@ struct RedOpYZW
{
vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
}
+ else if(op == ReductionOperation::PROD)
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
+ }
else
{
vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
@@ -569,6 +642,9 @@ struct RedOpYZW
case ReductionOperation::SUM_SQUARE:
vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -620,21 +696,27 @@ struct RedOpYZW_qasymm8
auto vec_res_value2 = vdupq_n_u32(0);
auto vec_res_value3 = vdupq_n_u32(0);
auto vec_res_value4 = vdupq_n_u32(0);
- auto vec_res_value = wrapper::vloadq(input.ptr());
- for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ auto vec_res_value1_f = vdupq_n_f32(1);
+ auto vec_res_value2_f = vdupq_n_f32(1);
+ auto vec_res_value3_f = vdupq_n_f32(1);
+ auto vec_res_value4_f = vdupq_n_f32(1);
+
+ auto vec_res_value = wrapper::vloadq(input.ptr());
+
+ for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
{
uint8_t *in_ptr;
switch(axis)
{
case 1:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
break;
case 2:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
break;
case 3:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
break;
default:
ARM_COMPUTE_ERROR("Not supported");
@@ -660,17 +742,47 @@ struct RedOpYZW_qasymm8
vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
break;
}
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
+
+ const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
+ const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
+
+ const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
+ const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
+ const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
+ const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
+
+ auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
+ auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
+ auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
+ auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
vec_res_value = temp_vec_res_value;
break;
}
case ReductionOperation::ARG_IDX_MAX:
{
auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
vec_res_value = temp_vec_res_value;
break;
}
@@ -681,17 +793,34 @@ struct RedOpYZW_qasymm8
if(op == ReductionOperation::MEAN_SUM)
{
- const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
- const auto vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
- const auto vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
- const auto vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
- const auto vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
+ const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
+ vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
+ vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
+ vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
+ vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
}
+ else if(op == ReductionOperation::PROD)
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
+
+ //re-quantize
+ vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
+
+ vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
+ vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
+ vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
+ vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
+ }
+
if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
{
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
@@ -798,7 +927,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
}
-
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);