diff options
author | SiCong Li <sicong.li@arm.com> | 2023-03-13 15:02:23 +0000 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2023-03-14 15:38:29 +0000 |
commit | 4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4 (patch) | |
tree | 13d56b417d5c2b186bde627f4f5d0f05b7228a53 /tests/validation | |
parent | aaa9da1efa83911c7a67d50811ad669a92a7d12f (diff) | |
download | ComputeLibrary-4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4.tar.gz |
Add CropInfo to BatchToSpace reference and fixture
Partially resolves COMPMID-5918, COMPMID-5865
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Ib3b01e7dc1c944184a4c038045bf0469fbb9ff45
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9321
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/fixtures/BatchToSpaceLayerFixture.h | 29 | ||||
-rw-r--r-- | tests/validation/reference/BatchToSpaceLayer.cpp | 43 | ||||
-rw-r--r-- | tests/validation/reference/BatchToSpaceLayer.h | 5 |
3 files changed, 46 insertions, 31 deletions
diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h index 6554c09de4..5a23261a6e 100644 --- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h +++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,14 +36,14 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class BatchToSpaceLayerValidationFixture : public framework::Fixture +class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{}) { - _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout); - _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type); + _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info); + _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info); } protected: @@ -57,7 +57,7 @@ protected: library->fill(tensor, distribution, i); } TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, - DataType data_type, DataLayout data_layout) + DataType data_type, DataLayout data_layout, const CropInfo &crop_info) { if(data_layout == DataLayout::NHWC) { @@ -72,7 +72,7 @@ protected: // Create and configure function FunctionType batch_to_space; - batch_to_space.configure(&input, &block_shape, &output); + batch_to_space.configure(&input, &block_shape, &output, crop_info); ARM_COMPUTE_ASSERT(input.info()->is_resizable()); ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable()); @@ -104,7 +104,7 @@ protected: } SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, - const TensorShape &output_shape, DataType data_type) + const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info) { // Create reference SimpleTensor<T> input{ input_shape, data_type }; @@ -118,12 +118,23 @@ protected: } // Compute reference - return reference::batch_to_space(input, block_shape, output_shape); + return reference::batch_to_space(input, block_shape, output_shape, crop_info); } TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + { + BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{}); + } +}; } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/reference/BatchToSpaceLayer.cpp b/tests/validation/reference/BatchToSpaceLayer.cpp index 404ee73cac..aeda733bb6 100644 --- a/tests/validation/reference/BatchToSpaceLayer.cpp +++ b/tests/validation/reference/BatchToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 Arm Limited. + * Copyright (c) 2018, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,32 +35,35 @@ namespace reference { // Batch to Space template <typename T> -SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape) +SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info) { ARM_COMPUTE_ERROR_ON(block_shape[0] <= 0); ARM_COMPUTE_ERROR_ON(block_shape[1] <= 0); SimpleTensor<T> result(dst_shape, src.data_type()); + int out_pos = 0; + const auto width_out = static_cast<int>(dst_shape[0]); + const auto height_out = static_cast<int>(dst_shape[1]); + const auto z_out = static_cast<int>(dst_shape[2]); + const auto batch_out = static_cast<int>(dst_shape[3]); + ARM_COMPUTE_ERROR_ON(width_out <= static_cast<int>(crop_info.left + crop_info.right)); + ARM_COMPUTE_ERROR_ON(height_out <= static_cast<int>(crop_info.top + crop_info.bottom)); - int in_pos = 0; - const auto width_in = static_cast<int>(src.shape()[0]); - const auto height_in = static_cast<int>(src.shape()[1]); - const auto z_in = static_cast<int>(src.shape()[2]); - const auto batch_in = static_cast<int>(src.shape()[3]); - - for(int batch = 0; batch < batch_in; ++batch) + for(int batch = 0; batch < batch_out; ++batch) { - for(int z = 0; z < z_in; ++z) + for(int z = 0; z < z_out; ++z) { - for(int y = 0; y < height_in; ++y) + for(int y = 0; y < height_out; ++y) { - for(int x = 0; x < width_in; ++x) + for(int x = 0; x < width_out; ++x) { - const int r = src.shape()[3] / (block_shape[0] * block_shape[1]); - const int out_x = (block_shape[0] * x + (batch / r) % block_shape[0]); - const int out_y = (block_shape[1] * y + (batch / r) / block_shape[0]); - const int out_pos = out_x + dst_shape[0] * out_y + z * dst_shape[0] * dst_shape[1] + (batch % r) * dst_shape[0] * dst_shape[1] * dst_shape[2]; - result[out_pos] = src[in_pos]; - ++in_pos; + const int x_c = x + crop_info.left; + const int y_c = y + crop_info.top; + const int in_batch = batch + ((x_c % block_shape[0]) + (y_c % block_shape[1]) * (block_shape[0])) * dst_shape[3]; + const int in_x = x_c / block_shape[0]; + const int in_y = y_c / block_shape[1]; + const int in_pos = in_x + src.shape()[0] * in_y + z * src.shape()[0] * src.shape()[1] + in_batch * src.shape()[0] * src.shape()[1] * src.shape()[2]; + result[out_pos] = src[in_pos]; + ++out_pos; } } } @@ -68,8 +71,8 @@ SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<in return result; } -template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); -template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); +template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); +template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/BatchToSpaceLayer.h b/tests/validation/reference/BatchToSpaceLayer.h index 52556cb53f..18010f1885 100644 --- a/tests/validation/reference/BatchToSpaceLayer.h +++ b/tests/validation/reference/BatchToSpaceLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H #define ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H +#include "arm_compute/core/Types.h" #include "tests/SimpleTensor.h" #include "tests/validation/Helpers.h" @@ -36,7 +37,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); +SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test |