From 1d4f3853dfd16f55338d772ad757db0ee8710d78 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Mon, 14 Jan 2019 15:14:43 +0000 Subject: COMPMID-1760: NEON: Implement Prod Change-Id: I8062f4ca5ef5cf1a8183ac0834f240bbaf8f695d Reviewed-on: https://review.mlplatform.org/541 Reviewed-by: Pablo Marquez Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- arm_compute/core/PixelValue.h | 10 +- .../NEON/kernels/NEReductionOperationKernel.cpp | 154 +++++++++++++++++++-- .../NEON/functions/NEReductionOperation.cpp | 18 +-- tests/datasets/ShapeDatasets.h | 2 +- tests/validation/NEON/ReductionOperation.cpp | 28 +++- tests/validation/reference/ReductionOperation.cpp | 30 +++- 6 files changed, 200 insertions(+), 42 deletions(-) diff --git a/arm_compute/core/PixelValue.h b/arm_compute/core/PixelValue.h index e86eeba121..0ead9db7b1 100644 --- a/arm_compute/core/PixelValue.h +++ b/arm_compute/core/PixelValue.h @@ -41,10 +41,11 @@ public: } /** Initialize the union with a pixel value of chosen datatype * - * @param[in] v int value. - * @param[in] datatype DataType that @p v have to be stored + * @param[in] v int value. + * @param[in] datatype DataType that @p v have to be stored + * @param[in] quant_info QuantizationInfo to apply in case of QASYMM8 datatype to @p v */ - PixelValue(uint64_t v, DataType datatype) + PixelValue(uint64_t v, DataType datatype, QuantizationInfo quant_info = QuantizationInfo()) : PixelValue() { switch(datatype) @@ -55,6 +56,9 @@ public: case DataType::S8: value.s8 = static_cast(v); break; + case DataType::QASYMM8: + value.u8 = sqcvt_qasymm8_f32(v, quant_info.scale, quant_info.offset); + break; case DataType::U16: value.u16 = static_cast(v); break; diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index f0209a32da..506094f0c1 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -362,6 +362,10 @@ struct RedOpX { init_res_value = *reinterpret_cast(input.ptr()); } + else if(op == ReductionOperation::PROD) + { + init_res_value = static_cast(1.f); + } auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); uint32x4x4_t vec_res_idx{ 0 }; @@ -379,6 +383,9 @@ struct RedOpX case ReductionOperation::SUM: vec_res_value = wrapper::vadd(vec_elements, vec_res_value); break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; case ReductionOperation::ARG_IDX_MIN: { auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); @@ -420,6 +427,17 @@ struct RedOpX *(reinterpret_cast(output.ptr())) = res; break; } + case ReductionOperation::PROD: + { + auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + T res = 1; + for(int i = 0; i < S / 2; ++i) + { + res *= wrapper::vgetlane(carry_res, i); + } + *(reinterpret_cast(output.ptr())) = res; + break; + } case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::ARG_IDX_MAX: { @@ -443,7 +461,13 @@ struct RedOpX_qasymm8 auto vec_res_value3 = vdupq_n_u32(static_cast(0.f)); auto vec_res_value4 = vdupq_n_u32(static_cast(0.f)); + auto vec_res_value1_f = vdupq_n_f32(static_cast(1.f)); + auto vec_res_value2_f = vdupq_n_f32(static_cast(1.f)); + auto vec_res_value3_f = vdupq_n_f32(static_cast(1.f)); + auto vec_res_value4_f = vdupq_n_f32(static_cast(1.f)); + uint8x16_t vec_res_value = { 0 }; + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); @@ -472,6 +496,36 @@ struct RedOpX_qasymm8 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); break; } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset); + const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale); + + const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements)); + const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements)); + + const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1)); + const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1)); + const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2)); + const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2)); + + auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1); + auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2); + auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3); + auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); + break; + } case ReductionOperation::ARG_IDX_MIN: { auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); @@ -497,6 +551,21 @@ struct RedOpX_qasymm8 auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); *(reinterpret_cast(output.ptr())) = res; } + else if(op == ReductionOperation::PROD) + { + auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); + carry_res = wrapper::vmul(carry_res, vec_res_value3_f); + carry_res = wrapper::vmul(carry_res, vec_res_value4_f); + + float res = wrapper::vgetlane(carry_res, 0); + res *= wrapper::vgetlane(carry_res, 1); + res *= wrapper::vgetlane(carry_res, 2); + res *= wrapper::vgetlane(carry_res, 3); + + //re-quantize result + res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset); + *(output.ptr()) = static_cast(res); + } else { auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); @@ -535,6 +604,10 @@ struct RedOpYZW { vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); } + else if(op == ReductionOperation::PROD) + { + vec_res_value = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); + } else { vec_res_value = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); @@ -569,6 +642,9 @@ struct RedOpYZW case ReductionOperation::SUM_SQUARE: vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; case ReductionOperation::ARG_IDX_MIN: { auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); @@ -620,21 +696,27 @@ struct RedOpYZW_qasymm8 auto vec_res_value2 = vdupq_n_u32(0); auto vec_res_value3 = vdupq_n_u32(0); auto vec_res_value4 = vdupq_n_u32(0); - auto vec_res_value = wrapper::vloadq(input.ptr()); - for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + auto vec_res_value1_f = vdupq_n_f32(1); + auto vec_res_value2_f = vdupq_n_f32(1); + auto vec_res_value3_f = vdupq_n_f32(1); + auto vec_res_value4_f = vdupq_n_f32(1); + + auto vec_res_value = wrapper::vloadq(input.ptr()); + + for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) { uint8_t *in_ptr; switch(axis) { case 1: - in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)); + in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim)); break; case 2: - in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)); + in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim)); break; case 3: - in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)); + in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim)); break; default: ARM_COMPUTE_ERROR("Not supported"); @@ -660,17 +742,47 @@ struct RedOpYZW_qasymm8 vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); break; } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset); + const auto scale32x4f_4 = vdupq_n_f32(in_info.quantization_info().scale); + + const auto temp16x8t_1 = vmovl_u8(vget_low_u8(vec_elements)); + const auto temp16x8t_2 = vmovl_u8(vget_high_u8(vec_elements)); + + const auto temp32x4t_1 = vmovl_u16(vget_low_u16(temp16x8t_1)); + const auto temp32x4t_2 = vmovl_u16(vget_high_u16(temp16x8t_1)); + const auto temp32x4t_3 = vmovl_u16(vget_low_u16(temp16x8t_2)); + const auto temp32x4t_4 = vmovl_u16(vget_high_u16(temp16x8t_2)); + + auto temp32x4f_1 = vcvtq_f32_u32(temp32x4t_1); + auto temp32x4f_2 = vcvtq_f32_u32(temp32x4t_2); + auto temp32x4f_3 = vcvtq_f32_u32(temp32x4t_3); + auto temp32x4f_4 = vcvtq_f32_u32(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); + break; + } case ReductionOperation::ARG_IDX_MIN: { auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); vec_res_value = temp_vec_res_value; break; } case ReductionOperation::ARG_IDX_MAX: { auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_idx = calculate_index(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); vec_res_value = temp_vec_res_value; break; } @@ -681,17 +793,34 @@ struct RedOpYZW_qasymm8 if(op == ReductionOperation::MEAN_SUM) { - const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis))); - const auto vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv); - const auto vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv); - const auto vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv); - const auto vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv); + const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis))); + vec_res_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value1), vec_width_inv); + vec_res_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value2), vec_width_inv); + vec_res_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value3), vec_width_inv); + vec_res_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_res_value4), vec_width_inv); vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f); vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f); vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f); vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f); } + else if(op == ReductionOperation::PROD) + { + const auto offset32x4f_4 = vdupq_n_f32(in_info.quantization_info().offset); + const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(in_info.quantization_info().scale)); + + //re-quantize + vec_res_value1_f = vaddq_f32(vmulq_f32(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); + vec_res_value2_f = vaddq_f32(vmulq_f32(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); + vec_res_value3_f = vaddq_f32(vmulq_f32(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); + vec_res_value4_f = vaddq_f32(vmulq_f32(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); + + vec_res_value1 = vcvtq_u32_f32(vec_res_value1_f); + vec_res_value2 = vcvtq_u32_f32(vec_res_value2_f); + vec_res_value3 = vcvtq_u32_f32(vec_res_value3_f); + vec_res_value4 = vcvtq_u32_f32(vec_res_value4_f); + } + if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) { wrapper::vstore(reinterpret_cast(output.ptr()), vec_res_idx.val[0]); @@ -798,7 +927,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32); } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output); const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis); diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp index bb27b5d47a..3ec8ef145e 100644 --- a/src/runtime/NEON/functions/NEReductionOperation.cpp +++ b/src/runtime/NEON/functions/NEReductionOperation.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -50,16 +50,6 @@ size_t reduction_window_split_dimension(unsigned int axis) ARM_COMPUTE_ERROR("Unsupported reduction axis"); } } -BorderMode reduction_operation_border_mode(ReductionOperation op) -{ - switch(op) - { - case ReductionOperation::SUM_SQUARE: - return BorderMode::CONSTANT; - default: - return BorderMode::CONSTANT; - } -} } // namespace NEReductionOperation::NEReductionOperation() @@ -86,9 +76,9 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i if(axis == 0) { // Configure fill border kernel - BorderSize fill_border_size = _reduction_kernel.border_size(); - BorderMode fill_border_mode = reduction_operation_border_mode(op); - _fill_border_kernel.configure(input, fill_border_size, fill_border_mode, PixelValue(static_cast(0.f))); + const BorderSize fill_border_size = _reduction_kernel.border_size(); + const PixelValue pixelValue = PixelValue((op == ReductionOperation::PROD) ? 1 : 0, input->info()->data_type(), input->info()->quantization_info()); + _fill_border_kernel.configure(input, fill_border_size, BorderMode::CONSTANT, pixelValue); } } diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h index bd29fe649a..9ee89f43e0 100644 --- a/tests/datasets/ShapeDatasets.h +++ b/tests/datasets/ShapeDatasets.h @@ -148,7 +148,7 @@ public: Small4DShapes() : ShapeDataset("Shape", { - TensorShape{ 1U, 7U, 1U, 3U }, + TensorShape{ 2U, 7U, 1U, 3U }, TensorShape{ 7U, 7U, 5U, 3U }, TensorShape{ 27U, 13U, 37U, 2U }, TensorShape{ 128U, 64U, 21U, 3U } diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp index e322947993..b9b4983ae6 100644 --- a/tests/validation/NEON/ReductionOperation.cpp +++ b/tests/validation/NEON/ReductionOperation.cpp @@ -49,7 +49,21 @@ RelativeTolerance rel_tolerance_f32(0.00001f); RelativeTolerance tolerance_qasymm8(1); const auto ReductionOperations = framework::dataset::make("ReductionOperation", -{ ReductionOperation::SUM }); +{ + ReductionOperation::SUM, + ReductionOperation::PROD +}); + +const auto QuantizationInfos = framework::dataset::make("QuantizationInfo", +{ + QuantizationInfo(1.f / 128, -10), + QuantizationInfo(1.f / 64, -5), + QuantizationInfo(1.f / 32, -2) +}); + +const auto Axises = framework::dataset::make("Axis", +{ 0, 1, 2, 3 }); + } // namespace TEST_SUITE(NEON) @@ -88,13 +102,13 @@ using NEReductionOperationFixture = ReductionOperationFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) + combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations)) { // Validate output validate(Accessor(_target), _reference, tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) + combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32); @@ -106,17 +120,17 @@ using NEReductionOperationQuantizedFixture = ReductionOperationQuantizedFixture< TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), + combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), ReductionOperations), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) }))) + QuantizationInfos)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationQuantizedFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), + combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), ReductionOperations), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) }))) + QuantizationInfos)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 8e79c3bfb0..fb7a6d6997 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -128,7 +128,7 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } // namespace template -SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor compute_reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) { // Create reference const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); @@ -213,12 +213,34 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorSha return dst; } +template +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +{ + return compute_reduction_operation(src, dst_shape, axis, op); +} + +template <> +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +{ + if(src.data_type() == DataType::QASYMM8 && op != ReductionOperation::MEAN_SUM) + { + SimpleTensor src_f = convert_from_asymmetric(src); + SimpleTensor dst_f = reference::reduction_operation(src_f, dst_shape, axis, op); + return convert_to_asymmetric(dst_f, src.quantization_info()); + } + else + { + return compute_reduction_operation(src, dst_shape, axis, op); + } +} + +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); + template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); + } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1