aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-01-14 15:14:43 +0000
committerManuel Bottini <manuel.bottini@arm.com>2019-01-23 14:11:17 +0000
commit1d4f3853dfd16f55338d772ad757db0ee8710d78 (patch)
tree344a48b5d296fef0e05d345d342ba0733f8db7a2
parent734151d20bef56cbedce2ae67945f42cb4e265c8 (diff)
downloadComputeLibrary-1d4f3853dfd16f55338d772ad757db0ee8710d78.tar.gz
COMPMID-1760: NEON: Implement Prod
Change-Id: I8062f4ca5ef5cf1a8183ac0834f240bbaf8f695d Reviewed-on: https://review.mlplatform.org/541 Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/PixelValue.h10
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp154
-rw-r--r--src/runtime/NEON/functions/NEReductionOperation.cpp18
-rw-r--r--tests/datasets/ShapeDatasets.h2
-rw-r--r--tests/validation/NEON/ReductionOperation.cpp28
-rw-r--r--tests/validation/reference/ReductionOperation.cpp30
6 files changed, 200 insertions, 42 deletions
diff --git a/arm_compute/core/PixelValue.h b/arm_compute/core/PixelValue.h
index e86eeba121..0ead9db7b1 100644
--- a/arm_compute/core/PixelValue.h
+++ b/arm_compute/core/PixelValue.h
@@ -41,10 +41,11 @@ public:
}
/** Initialize the union with a pixel value of chosen datatype
*
- * @param[in] v int value.
- * @param[in] datatype DataType that @p v have to be stored
+ * @param[in] v int value.
+ * @param[in] datatype DataType that @p v have to be stored
+ * @param[in] quant_info QuantizationInfo to apply in case of QASYMM8 datatype to @p v
*/
- PixelValue(uint64_t v, DataType datatype)
+ PixelValue(uint64_t v, DataType datatype, QuantizationInfo quant_info = QuantizationInfo())
: PixelValue()
{
switch(datatype)
@@ -55,6 +56,9 @@ public:
case DataType::S8:
value.s8 = static_cast<int8_t>(v);
break;
+ case DataType::QASYMM8:
+ value.u8 = sqcvt_qasymm8_f32(v, quant_info.scale, quant_info.offset);
+ break;
case DataType::U16:
value.u16 = static_cast<uint16_t>(v);
break;
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index f0209a32da..506094f0c1 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -362,6 +362,10 @@ struct RedOpX
{
init_res_value = *reinterpret_cast<T *>(input.ptr());
}
+ else if(op == ReductionOperation::PROD)
+ {
+ init_res_value = static_cast<T>(1.f);
+ }
auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
uint32x4x4_t vec_res_idx{ 0 };
@@ -379,6 +383,9 @@ struct RedOpX
case ReductionOperation::SUM:
vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -420,6 +427,17 @@ struct RedOpX
*(reinterpret_cast<T *>(output.ptr())) = res;
break;
}
+ case ReductionOperation::PROD:
+ {
+ auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
+ T res = 1;
+ for(int i = 0; i < S / 2; ++i)
+ {
+ res *= wrapper::vgetlane(carry_res, i);
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::ARG_IDX_MAX:
{
@@ -443,7 +461,13 @@ struct RedOpX_qasymm8
auto vec_res_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
auto vec_res_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
+
uint8x16_t vec_res_value = { 0 };
+
if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
{
vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
@@ -472,6 +496,36 @@ struct RedOpX_qasymm8
vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
break;
}
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
+
+ const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
+ const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
+
+ const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
+ const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
+ const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
+ const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
+
+ auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
+ auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
+ auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
+ auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -497,6 +551,21 @@ struct RedOpX_qasymm8
auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
*(reinterpret_cast<uint32_t *>(output.ptr())) = res;
}
+ else if(op == ReductionOperation::PROD)
+ {
+ auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
+
+ float res = wrapper::vgetlane(carry_res, 0);
+ res *= wrapper::vgetlane(carry_res, 1);
+ res *= wrapper::vgetlane(carry_res, 2);
+ res *= wrapper::vgetlane(carry_res, 3);
+
+ //re-quantize result
+ res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ }
else
{
auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
@@ -535,6 +604,10 @@ struct RedOpYZW
{
vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
}
+ else if(op == ReductionOperation::PROD)
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
+ }
else
{
vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
@@ -569,6 +642,9 @@ struct RedOpYZW
case ReductionOperation::SUM_SQUARE:
vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
@@ -620,21 +696,27 @@ struct RedOpYZW_qasymm8
auto vec_res_value2 = vdupq_n_u32(0);
auto vec_res_value3 = vdupq_n_u32(0);
auto vec_res_value4 = vdupq_n_u32(0);
- auto vec_res_value = wrapper::vloadq(input.ptr());
- for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ auto vec_res_value1_f = vdupq_n_f32(1);
+ auto vec_res_value2_f = vdupq_n_f32(1);
+ auto vec_res_value3_f = vdupq_n_f32(1);
+ auto vec_res_value4_f = vdupq_n_f32(1);
+
+ auto vec_res_value = wrapper::vloadq(input.ptr());
+
+ for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
{
uint8_t *in_ptr;
switch(axis)
{
case 1:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
break;
case 2:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
break;
case 3:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
break;
default:
ARM_COMPUTE_ERROR("Not supported");
@@ -660,17 +742,47 @@ struct RedOpYZW_qasymm8
vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
break;
}
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale);
+
+ const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements));
+ const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements));
+
+ const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1));
+ const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1));
+ const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2));
+ const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2));
+
+ auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1);
+ auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2);
+ auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3);
+ auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
vec_res_value = temp_vec_res_value;
break;
}
case ReductionOperation::ARG_IDX_MAX:
{
auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
vec_res_value = temp_vec_res_value;
break;
}
@@ -681,17 +793,34 @@ struct RedOpYZW_qasymm8
if(op == ReductionOperation::MEAN_SUM)
{
- const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
- const auto vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
- const auto vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
- const auto vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
- const auto vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
+ const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
+ vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv);
+ vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv);
+ vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv);
+ vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv);
vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
}
+ else if(op == ReductionOperation::PROD)
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset);
+ const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale));
+
+ //re-quantize
+ vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
+
+ vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f);
+ vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f);
+ vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f);
+ vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f);
+ }
+
if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
{
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()), vec_res_idx.val[0]);
@@ -798,7 +927,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
}
-
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp
index bb27b5d47a..3ec8ef145e 100644
--- a/src/runtime/NEON/functions/NEReductionOperation.cpp
+++ b/src/runtime/NEON/functions/NEReductionOperation.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,16 +50,6 @@ size_t reduction_window_split_dimension(unsigned int axis)
ARM_COMPUTE_ERROR("Unsupported reduction axis");
}
}
-BorderMode reduction_operation_border_mode(ReductionOperation op)
-{
- switch(op)
- {
- case ReductionOperation::SUM_SQUARE:
- return BorderMode::CONSTANT;
- default:
- return BorderMode::CONSTANT;
- }
-}
} // namespace
NEReductionOperation::NEReductionOperation()
@@ -86,9 +76,9 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i
if(axis == 0)
{
// Configure fill border kernel
- BorderSize fill_border_size = _reduction_kernel.border_size();
- BorderMode fill_border_mode = reduction_operation_border_mode(op);
- _fill_border_kernel.configure(input, fill_border_size, fill_border_mode, PixelValue(static_cast<float>(0.f)));
+ const BorderSize fill_border_size = _reduction_kernel.border_size();
+ const PixelValue pixelValue = PixelValue((op == ReductionOperation::PROD) ? 1 : 0, input->info()->data_type(), input->info()->quantization_info());
+ _fill_border_kernel.configure(input, fill_border_size, BorderMode::CONSTANT, pixelValue);
}
}
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h
index bd29fe649a..9ee89f43e0 100644
--- a/tests/datasets/ShapeDatasets.h
+++ b/tests/datasets/ShapeDatasets.h
@@ -148,7 +148,7 @@ public:
Small4DShapes()
: ShapeDataset("Shape",
{
- TensorShape{ 1U, 7U, 1U, 3U },
+ TensorShape{ 2U, 7U, 1U, 3U },
TensorShape{ 7U, 7U, 5U, 3U },
TensorShape{ 27U, 13U, 37U, 2U },
TensorShape{ 128U, 64U, 21U, 3U }
diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp
index e322947993..b9b4983ae6 100644
--- a/tests/validation/NEON/ReductionOperation.cpp
+++ b/tests/validation/NEON/ReductionOperation.cpp
@@ -49,7 +49,21 @@ RelativeTolerance<float> rel_tolerance_f32(0.00001f);
RelativeTolerance<float> tolerance_qasymm8(1);
const auto ReductionOperations = framework::dataset::make("ReductionOperation",
-{ ReductionOperation::SUM });
+{
+ ReductionOperation::SUM,
+ ReductionOperation::PROD
+});
+
+const auto QuantizationInfos = framework::dataset::make("QuantizationInfo",
+{
+ QuantizationInfo(1.f / 128, -10),
+ QuantizationInfo(1.f / 64, -5),
+ QuantizationInfo(1.f / 32, -2)
+});
+
+const auto Axises = framework::dataset::make("Axis",
+{ 0, 1, 2, 3 });
+
} // namespace
TEST_SUITE(NEON)
@@ -88,13 +102,13 @@ using NEReductionOperationFixture = ReductionOperationFixture<Tensor, Accessor,
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+ combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+ combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32);
@@ -106,17 +120,17 @@ using NEReductionOperationQuantizedFixture = ReductionOperationQuantizedFixture<
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
+ combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
ReductionOperations),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })))
+ QuantizationInfos))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
+ combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
ReductionOperations),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })))
+ QuantizationInfos))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 8e79c3bfb0..fb7a6d6997 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -128,7 +128,7 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
} // namespace
template <typename T, typename OT>
-SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
{
// Create reference
const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
@@ -213,12 +213,34 @@ SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorSha
return dst;
}
+template <typename T, typename OT>
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+{
+ return compute_reduction_operation<T, OT>(src, dst_shape, axis, op);
+}
+
+template <>
+SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+{
+ if(src.data_type() == DataType::QASYMM8 && op != ReductionOperation::MEAN_SUM)
+ {
+ SimpleTensor<float> src_f = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op);
+ return convert_to_asymmetric(dst_f, src.quantization_info());
+ }
+ else
+ {
+ return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op);
+ }
+}
+
+template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
-template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
-template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
-template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+
} // namespace reference
} // namespace validation
} // namespace test