From 4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Mon, 13 Mar 2023 15:02:23 +0000 Subject: Add CropInfo to BatchToSpace reference and fixture Partially resolves COMPMID-5918, COMPMID-5865 Signed-off-by: SiCong Li Change-Id: Ib3b01e7dc1c944184a4c038045bf0469fbb9ff45 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9321 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins --- .../validation/fixtures/BatchToSpaceLayerFixture.h | 29 +++++++++++++++------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'tests/validation/fixtures/BatchToSpaceLayerFixture.h') diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h index 6554c09de4..5a23261a6e 100644 --- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h +++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,14 +36,14 @@ namespace test namespace validation { template -class BatchToSpaceLayerValidationFixture : public framework::Fixture +class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture { public: template - void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{}) { - _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout); - _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type); + _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info); + _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info); } protected: @@ -57,7 +57,7 @@ protected: library->fill(tensor, distribution, i); } TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, - DataType data_type, DataLayout data_layout) + DataType data_type, DataLayout data_layout, const CropInfo &crop_info) { if(data_layout == DataLayout::NHWC) { @@ -72,7 +72,7 @@ protected: // Create and configure function FunctionType batch_to_space; - batch_to_space.configure(&input, &block_shape, &output); + batch_to_space.configure(&input, &block_shape, &output, crop_info); ARM_COMPUTE_ASSERT(input.info()->is_resizable()); ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable()); @@ -104,7 +104,7 @@ protected: } SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, - const TensorShape &output_shape, DataType data_type) + const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info) { // Create reference SimpleTensor input{ input_shape, data_type }; @@ -118,12 +118,23 @@ protected: } // Compute reference - return reference::batch_to_space(input, block_shape, output_shape); + return reference::batch_to_space(input, block_shape, output_shape, crop_info); } TensorType _target{}; SimpleTensor _reference{}; }; + +template +class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + { + BatchToSpaceLayerValidationGenericFixture::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{}); + } +}; } // namespace validation } // namespace test } // namespace arm_compute -- cgit v1.2.1