From 2697fd8fa42425f7bfdd60dd486d4c2132b06523 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Tue, 15 Oct 2019 16:49:24 +0100 Subject: COMPMID-2707: add keep_dims parameter to Reduction Operation The added parameter is used to decide whether or not to keep the target dimension of reduction operation. ArgMinMax operations will always remove the reduced dimension. Following things are updated to support the parameter. - [CL/NEON] functions and reference kernel - [CL/NEON] ArgMinMax function to use ReductionOperation function - [CL/NEON] validation test suite for Reduction and ArgMinMax operations to validate the added parameter - ReductionOperationFixture is modified NOT to pre-populate output tensor and now relies on underlying kernel/function. - Adjust CL validation test suite for Reduction operation to remove excessive test cases with axis values beyond input tensor's dimension. Change-Id: I3e24d276ed469a4201f323001708f0c525f11c4f Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/2167 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Georgios Pinitas --- tests/validation/CL/ArgMinMax.cpp | 28 ++++++----- tests/validation/CL/ReductionOperation.cpp | 56 +++++++++++++++++----- tests/validation/NEON/ArgMinMax.cpp | 16 ++++--- tests/validation/NEON/ReductionOperation.cpp | 36 ++++++++------ tests/validation/fixtures/ArgMinMaxFixture.h | 4 +- .../fixtures/ReductionOperationFixture.h | 34 ++++++------- 6 files changed, 110 insertions(+), 64 deletions(-) (limited to 'tests') diff --git a/tests/validation/CL/ArgMinMax.cpp b/tests/validation/CL/ArgMinMax.cpp index 6de09bed25..845fdbf493 100644 --- a/tests/validation/CL/ArgMinMax.cpp +++ b/tests/validation/CL/ArgMinMax.cpp @@ -25,7 +25,9 @@ #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" #include "arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h" +#include "arm_compute/runtime/CL/functions/CLReductionOperation.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "tests/CL/CLAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/datasets/SplitDataset.h" @@ -49,16 +51,18 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 3U, 16U, 2U), 1, DataType::F32), // Invalid axis TensorInfo(TensorShape(27U, 3U, 16U, 2U), 1, DataType::F32), // Invalid output shape TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32) // Invalid operation + TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32), // Invalid operation + TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32) // Not allowed keeping the dimension }), - framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32), - TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::F32) + framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::U32), + TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32) })), - framework::dataset::make("Axis", { 4, 0, 2, 0 })), - framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::MEAN_SUM })), - framework::dataset::make("Expected", { false, false, true, false })), + framework::dataset::make("Axis", { 4, 0, 2, 0, 2 })), + framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::MEAN_SUM, ReductionOperation::ARG_IDX_MAX })), + framework::dataset::make("Expected", { false, false, true, false, false })), input_info, output_info, axis, operation, expected) { const Status status = CLArgMinMaxLayer::validate(&input_info.clone()->set_is_resizable(false), axis, &output_info.clone()->set_is_resizable(false), operation); @@ -76,13 +80,13 @@ DATA_TEST_CASE(Configuration, CLTensor ref_src = create_tensor(shape, data_type); CLTensor dst; + constexpr int axis = 1; + // Create and Configure function CLArgMinMaxLayer arg_min_max_layer; - arg_min_max_layer.configure(&ref_src, 1, &dst, ReductionOperation::ARG_IDX_MAX); + arg_min_max_layer.configure(&ref_src, axis, &dst, ReductionOperation::ARG_IDX_MAX); - // Validate valid region - TensorShape output_shape = shape; - output_shape.set(1, 1); + const auto output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, false); const ValidRegion valid_region = shape_to_valid_region(output_shape); validate(dst.info()->valid_region(), valid_region); } diff --git a/tests/validation/CL/ReductionOperation.cpp b/tests/validation/CL/ReductionOperation.cpp index 9a3cd996fa..1dec020d18 100644 --- a/tests/validation/CL/ReductionOperation.cpp +++ b/tests/validation/CL/ReductionOperation.cpp @@ -57,6 +57,7 @@ const auto ReductionOperations = framework::dataset::make("ReductionOperation", }); +const auto KeepDimensions = framework::dataset::make("KeepDims", { true, false }); } // namespace TEST_SUITE(CL) @@ -64,29 +65,34 @@ TEST_SUITE(ReductionOperation) // *INDENT-OFF* // clang-format off -DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Mismatching data type input/output TensorInfo(TensorShape(128U, 64U), 3, DataType::F32), // Number of Input channels != 1 TensorInfo(TensorShape(128U, 64U), 1, DataType::S16), // DataType != QASYMM8/F16/F32 TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Axis >= num_max_dimensions TensorInfo(TensorShape(128U, 64U), 1, DataType::QASYMM8), // Axis == 0 and SUM_SQUARE and QASYMM8 - TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) + TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), + TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) // Kept Dimension when keep_dims = false + }), framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(1U, 64U), 1, DataType::F16), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), TensorInfo(TensorShape(1U, 64U), 1, DataType::S16), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), TensorInfo(TensorShape(1U, 64U), 1, DataType::QASYMM8), + TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32) })), - framework::dataset::make("Axis", { 0U, 0U, 0U, static_cast(TensorShape::num_max_dimensions), 1U, 0U })), - framework::dataset::make("Expected", { false, false, false, false, false, true })), - input_info, output_info, axis, expected) + framework::dataset::make("Axis", { 0U, 0U, 0U, static_cast(TensorShape::num_max_dimensions), 1U, 0U, 0U })), + framework::dataset::make("KeepDims", { true, true, true, true, true, true, false })), + framework::dataset::make("Expected", { false, false, false, false, false, true , false })), + input_info, output_info, axis, keep_dims, expected) { bool is_valid = bool(CLReductionOperation::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(true), axis, - ReductionOperation::SUM_SQUARE)); + ReductionOperation::SUM_SQUARE, + keep_dims)); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -97,28 +103,54 @@ using CLReductionOperationFixture = ReductionOperationFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) +FIXTURE_DATA_TEST_CASE(RunSmall2D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1 })), ReductionOperations), KeepDimensions)) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f16); +} +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small3DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2 })), ReductionOperations), KeepDimensions)) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f16); +} +FIXTURE_DATA_TEST_CASE(RunSmall4D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), + KeepDimensions)) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16); } FIXTURE_DATA_TEST_CASE(RunLarge, CLReductionOperationFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) + combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), KeepDimensions)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0, tolerance_f16); } TEST_SUITE_END() // F16 TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) +FIXTURE_DATA_TEST_CASE(RunSmall2D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1 })), ReductionOperations), KeepDimensions)) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f32); +} +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small3DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2 })), ReductionOperations), KeepDimensions)) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f32); +} +FIXTURE_DATA_TEST_CASE(RunSmall4D, CLReductionOperationFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), + KeepDimensions)) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLarge, CLReductionOperationFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations)) + combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), KeepDimensions)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32); diff --git a/tests/validation/NEON/ArgMinMax.cpp b/tests/validation/NEON/ArgMinMax.cpp index 71fb39a30d..642a69ba5f 100644 --- a/tests/validation/NEON/ArgMinMax.cpp +++ b/tests/validation/NEON/ArgMinMax.cpp @@ -24,9 +24,11 @@ #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/Traits.h" #include "arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h" +#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "tests/NEON/Accessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/datasets/SplitDataset.h" @@ -54,7 +56,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( }), framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32), + TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::U32), TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::F32) })), framework::dataset::make("Axis", { 4, 0, 2, 0 })), @@ -74,17 +76,17 @@ DATA_TEST_CASE(Configuration, shape, data_type) { // Create tensors - Tensor ref_src = create_tensor(shape, data_type); - Tensor dst; + Tensor ref_src = create_tensor(shape, data_type); + Tensor dst; + const int axis = 1; // Create and Configure function NEArgMinMaxLayer arg_min_max_layer; - arg_min_max_layer.configure(&ref_src, 1, &dst, ReductionOperation::ARG_IDX_MAX); + arg_min_max_layer.configure(&ref_src, axis, &dst, ReductionOperation::ARG_IDX_MAX); // Validate valid region - TensorShape output_shape = shape; - output_shape.set(1, 1); - const ValidRegion valid_region = shape_to_valid_region(output_shape); + const auto expected_output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, false); + const ValidRegion valid_region = shape_to_valid_region(expected_output_shape); validate(dst.info()->valid_region(), valid_region); } diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp index 5b697a5efa..3a7f707d23 100644 --- a/tests/validation/NEON/ReductionOperation.cpp +++ b/tests/validation/NEON/ReductionOperation.cpp @@ -66,6 +66,8 @@ const auto QuantizationInfos = framework::dataset::make("QuantizationInfo", const auto Axises = framework::dataset::make("Axis", { 0, 1, 2, 3 }); +const auto KeepDims = framework::dataset::make("KeepDims", { true, false }); + } // namespace TEST_SUITE(NEON) @@ -73,27 +75,31 @@ TEST_SUITE(ReductionOperation) // *INDENT-OFF* // clang-format off -DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Mismatching data type input/output TensorInfo(TensorShape(128U, 64U), 2, DataType::F32), // Number of Input channels != 1 TensorInfo(TensorShape(128U, 64U), 1, DataType::S16), // DataType != F32 TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Axis >= num_max_dimensions - TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) + TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), + TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) // Kept dimension when keep_dims = false }), framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(1U, 64U), 1, DataType::F16), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), TensorInfo(TensorShape(1U, 64U), 1, DataType::S16), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), + TensorInfo(TensorShape(1U, 64U), 1, DataType::F32), TensorInfo(TensorShape(1U, 64U), 1, DataType::F32) })), - framework::dataset::make("Axis", { 0U, 0U, 0U, static_cast(TensorShape::num_max_dimensions), 0U })), - framework::dataset::make("Expected", { false, false, false, false, true })), - input_info, output_info, axis, expected) + framework::dataset::make("Axis", { 0U, 0U, 0U, static_cast(TensorShape::num_max_dimensions), 0U, 0U })), + framework::dataset::make("KeepDims", { true, true, true, true, true, false})), + framework::dataset::make("Expected", { false, false, false, false, true, false })), + input_info, output_info, axis, keep_dims, expected) { bool is_valid = bool(NEReductionOperation::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(true), axis, - ReductionOperation::SUM_SQUARE)); + ReductionOperation::SUM_SQUARE, + keep_dims)); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -104,13 +110,13 @@ using NEReductionOperationFixture = ReductionOperationFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations)) + combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations), KeepDims)) { // Validate output validate(Accessor(_target), _reference, tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations)) + combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations), KeepDims)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32); @@ -122,17 +128,19 @@ using NEReductionOperationQuantizedFixture = ReductionOperationQuantizedFixture< TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), - ReductionOperations), - QuantizationInfos)) + combine(combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), + ReductionOperations), + QuantizationInfos), + KeepDims)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationQuantizedFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), - ReductionOperations), - QuantizationInfos)) + combine(combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises), + ReductionOperations), + QuantizationInfos), + KeepDims)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index ed6b51abe5..f8fe4ff1ee 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -26,6 +26,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/Tensor.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -121,8 +122,7 @@ protected: // Fill reference fill(src); - TensorShape output_shape = src_shape; - output_shape.set(axis, 1); + TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); return reference::reduction_operation(src, output_shape, axis, op); } diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index d01f41abf0..867c08ec3a 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -26,6 +26,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/Tensor.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -45,11 +46,15 @@ class ReductionOperationValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info, bool keep_dims = false) { - const TensorShape output_shape = get_output_shape(shape, axis); - _target = compute_target(shape, output_shape, data_type, axis, op, quantization_info); - _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); + const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN); + _keep_dims = keep_dims && !is_arg_min_max; + + const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, _keep_dims); + + _target = compute_target(shape, data_type, axis, op, quantization_info); + _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); } protected: @@ -70,15 +75,15 @@ protected: } } - TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) + TensorType compute_target(const TensorShape &src_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create tensors TensorType src = create_tensor(src_shape, data_type, 1, quantization_info); - TensorType dst = create_tensor(dst_shape, data_type, 1, quantization_info); + TensorType dst; // Create and configure function FunctionType reduction_func; - reduction_func.configure(&src, &dst, axis, op); + reduction_func.configure(&src, &dst, axis, op, _keep_dims); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -114,12 +119,7 @@ protected: SimpleTensor _reference{}; private: - TensorShape get_output_shape(TensorShape shape, unsigned int axis) - { - TensorShape output_shape(shape); - output_shape.set(axis, 1); - return output_shape; - } + bool _keep_dims{ false }; }; template @@ -127,9 +127,9 @@ class ReductionOperationQuantizedFixture : public ReductionOperationValidationFi { public: template - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo()) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo(), bool keep_dims = false) { - ReductionOperationValidationFixture::setup(shape, data_type, axis, op, quantization_info); + ReductionOperationValidationFixture::setup(shape, data_type, axis, op, quantization_info, keep_dims); } }; @@ -138,9 +138,9 @@ class ReductionOperationFixture : public ReductionOperationValidationFixture - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, bool keep_dims = false) { - ReductionOperationValidationFixture::setup(shape, data_type, axis, op, QuantizationInfo()); + ReductionOperationValidationFixture::setup(shape, data_type, axis, op, QuantizationInfo(), keep_dims); } }; } // namespace validation -- cgit v1.2.1