diff options
Diffstat (limited to 'tests/validation/fixtures/BatchToSpaceLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/BatchToSpaceLayerFixture.h | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h index 6554c09de4..5a23261a6e 100644 --- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h +++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,14 +36,14 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class BatchToSpaceLayerValidationFixture : public framework::Fixture +class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{}) { - _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout); - _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type); + _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info); + _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info); } protected: @@ -57,7 +57,7 @@ protected: library->fill(tensor, distribution, i); } TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, - DataType data_type, DataLayout data_layout) + DataType data_type, DataLayout data_layout, const CropInfo &crop_info) { if(data_layout == DataLayout::NHWC) { @@ -72,7 +72,7 @@ protected: // Create and configure function FunctionType batch_to_space; - batch_to_space.configure(&input, &block_shape, &output); + batch_to_space.configure(&input, &block_shape, &output, crop_info); ARM_COMPUTE_ASSERT(input.info()->is_resizable()); ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable()); @@ -104,7 +104,7 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, - const TensorShape &output_shape, DataType data_type) + const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info) { // Create reference SimpleTensor<T> input{ input_shape, data_type }; @@ -118,12 +118,23 @@ protected: } // Compute reference - return reference::batch_to_space(input, block_shape, output_shape); + return reference::batch_to_space(input, block_shape, output_shape, crop_info); } TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + { + BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{}); + } +}; } // namespace validation } // namespace test } // namespace arm_compute |