diff options
author | SiCong Li <sicong.li@arm.com> | 2023-03-13 15:02:23 +0000 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2023-03-14 15:38:29 +0000 |
commit | 4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4 (patch) | |
tree | 13d56b417d5c2b186bde627f4f5d0f05b7228a53 /tests/validation/fixtures | |
parent | aaa9da1efa83911c7a67d50811ad669a92a7d12f (diff) | |
download | ComputeLibrary-4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4.tar.gz |
Add CropInfo to BatchToSpace reference and fixture
Partially resolves COMPMID-5918, COMPMID-5865
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Ib3b01e7dc1c944184a4c038045bf0469fbb9ff45
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9321
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/BatchToSpaceLayerFixture.h | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h index 6554c09de4..5a23261a6e 100644 --- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h +++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,14 +36,14 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class BatchToSpaceLayerValidationFixture : public framework::Fixture +class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{}) { - _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout); - _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type); + _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info); + _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info); } protected: @@ -57,7 +57,7 @@ protected: library->fill(tensor, distribution, i); } TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, - DataType data_type, DataLayout data_layout) + DataType data_type, DataLayout data_layout, const CropInfo &crop_info) { if(data_layout == DataLayout::NHWC) { @@ -72,7 +72,7 @@ protected: // Create and configure function FunctionType batch_to_space; - batch_to_space.configure(&input, &block_shape, &output); + batch_to_space.configure(&input, &block_shape, &output, crop_info); ARM_COMPUTE_ASSERT(input.info()->is_resizable()); ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable()); @@ -104,7 +104,7 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, - const TensorShape &output_shape, DataType data_type) + const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info) { // Create reference SimpleTensor<T> input{ input_shape, data_type }; @@ -118,12 +118,23 @@ protected: } // Compute reference - return reference::batch_to_space(input, block_shape, output_shape); + return reference::batch_to_space(input, block_shape, output_shape, crop_info); } TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + { + BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{}); + } +}; } // namespace validation } // namespace test } // namespace arm_compute |