diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/BatchToSpaceLayer.cpp | 43 | ||||
-rw-r--r-- | tests/validation/reference/BatchToSpaceLayer.h | 5 |
2 files changed, 26 insertions, 22 deletions
diff --git a/tests/validation/reference/BatchToSpaceLayer.cpp b/tests/validation/reference/BatchToSpaceLayer.cpp index 404ee73cac..aeda733bb6 100644 --- a/tests/validation/reference/BatchToSpaceLayer.cpp +++ b/tests/validation/reference/BatchToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 Arm Limited. + * Copyright (c) 2018, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,32 +35,35 @@ namespace reference { // Batch to Space template <typename T> -SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape) +SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info) { ARM_COMPUTE_ERROR_ON(block_shape[0] <= 0); ARM_COMPUTE_ERROR_ON(block_shape[1] <= 0); SimpleTensor<T> result(dst_shape, src.data_type()); + int out_pos = 0; + const auto width_out = static_cast<int>(dst_shape[0]); + const auto height_out = static_cast<int>(dst_shape[1]); + const auto z_out = static_cast<int>(dst_shape[2]); + const auto batch_out = static_cast<int>(dst_shape[3]); + ARM_COMPUTE_ERROR_ON(width_out <= static_cast<int>(crop_info.left + crop_info.right)); + ARM_COMPUTE_ERROR_ON(height_out <= static_cast<int>(crop_info.top + crop_info.bottom)); - int in_pos = 0; - const auto width_in = static_cast<int>(src.shape()[0]); - const auto height_in = static_cast<int>(src.shape()[1]); - const auto z_in = static_cast<int>(src.shape()[2]); - const auto batch_in = static_cast<int>(src.shape()[3]); - - for(int batch = 0; batch < batch_in; ++batch) + for(int batch = 0; batch < batch_out; ++batch) { - for(int z = 0; z < z_in; ++z) + for(int z = 0; z < z_out; ++z) { - for(int y = 0; y < height_in; ++y) + for(int y = 0; y < height_out; ++y) { - for(int x = 0; x < width_in; ++x) + for(int x = 0; x < width_out; ++x) { - const int r = src.shape()[3] / (block_shape[0] * block_shape[1]); - const int out_x = (block_shape[0] * x + (batch / r) % block_shape[0]); - const int out_y = (block_shape[1] * y + (batch / r) / block_shape[0]); - const int out_pos = out_x + dst_shape[0] * out_y + z * dst_shape[0] * dst_shape[1] + (batch % r) * dst_shape[0] * dst_shape[1] * dst_shape[2]; - result[out_pos] = src[in_pos]; - ++in_pos; + const int x_c = x + crop_info.left; + const int y_c = y + crop_info.top; + const int in_batch = batch + ((x_c % block_shape[0]) + (y_c % block_shape[1]) * (block_shape[0])) * dst_shape[3]; + const int in_x = x_c / block_shape[0]; + const int in_y = y_c / block_shape[1]; + const int in_pos = in_x + src.shape()[0] * in_y + z * src.shape()[0] * src.shape()[1] + in_batch * src.shape()[0] * src.shape()[1] * src.shape()[2]; + result[out_pos] = src[in_pos]; + ++out_pos; } } } @@ -68,8 +71,8 @@ SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<in return result; } -template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); -template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); +template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); +template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/BatchToSpaceLayer.h b/tests/validation/reference/BatchToSpaceLayer.h index 52556cb53f..18010f1885 100644 --- a/tests/validation/reference/BatchToSpaceLayer.h +++ b/tests/validation/reference/BatchToSpaceLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H #define ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H +#include "arm_compute/core/Types.h" #include "tests/SimpleTensor.h" #include "tests/validation/Helpers.h" @@ -36,7 +37,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape); +SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test |