From 4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Mon, 13 Mar 2023 15:02:23 +0000 Subject: Add CropInfo to BatchToSpace reference and fixture Partially resolves COMPMID-5918, COMPMID-5865 Signed-off-by: SiCong Li Change-Id: Ib3b01e7dc1c944184a4c038045bf0469fbb9ff45 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9321 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins --- arm_compute/core/Types.h | 20 ++++++---- arm_compute/core/utils/misc/ShapeCalculator.h | 27 ++++++++++---- .../runtime/CL/functions/CLBatchToSpaceLayer.h | 20 ++++++---- .../runtime/NEON/functions/NEBatchToSpaceLayer.h | 14 ++++--- src/runtime/CL/functions/CLBatchToSpaceLayer.cpp | 22 ++++++----- src/runtime/NEON/functions/NEBatchToSpaceLayer.cpp | 14 ++++--- .../validation/fixtures/BatchToSpaceLayerFixture.h | 29 ++++++++++----- tests/validation/reference/BatchToSpaceLayer.cpp | 43 ++++++++++++---------- tests/validation/reference/BatchToSpaceLayer.h | 5 ++- 9 files changed, 121 insertions(+), 73 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 946b8a6cb6..8434611f7a 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -151,9 +151,9 @@ enum class DepthwiseConvolutionFunction /** Available DeconvolutionMethod*/ enum class DeconvolutionMethod { - GEMM, /**< Deconvolution using GEMM */ - DIRECT, /**< Direct deconvolution */ - UPSCALE_CONV2D /**< Deconvolution with Upscaling */ + GEMM, /**< Deconvolution using GEMM */ + DIRECT, /**< Direct deconvolution */ + UPSCALE_CONV2D /**< Deconvolution with Upscaling */ }; /** Available FuseBatchNormalizationType*/ @@ -2734,27 +2734,31 @@ public: return _fused_act; } /* Set Adjoint LHS flag */ - MatMulInfo& adj_lhs(bool adj_lhs) + MatMulInfo &adj_lhs(bool adj_lhs) { _adj_lhs = adj_lhs; return *this; } /* Set Adjoint RHS flag */ - MatMulInfo& adj_rhs(bool adj_rhs) + MatMulInfo &adj_rhs(bool adj_rhs) { _adj_rhs = adj_rhs; return *this; } /* Set Fused Activation Layer Info */ - MatMulInfo& fused_activation(const ActivationLayerInfo& act_info) + MatMulInfo &fused_activation(const ActivationLayerInfo &act_info) { _fused_act = act_info; return *this; } + private: - bool _adj_lhs{false}; - bool _adj_rhs{false}; + bool _adj_lhs{ false }; + bool _adj_rhs{ false }; ActivationLayerInfo _fused_act{}; // disabled by default }; + +/** Class for holding information related to cropping */ +using CropInfo = Padding2D; } // namespace arm_compute #endif /* ARM_COMPUTE_TYPES_H */ diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 94bd3aca03..6655cc1439 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1072,13 +1072,14 @@ inline TensorShape compute_slice_shape(const TensorShape &input_shape, const Coo /** Calculate the batch to space output shape of a tensor * - * @param[in] input Input tensor info - * @param[in] block_x Block shape x value - * @param[in] block_y Block shape y value + * @param[in] input Input tensor info + * @param[in] block_x Block shape x value + * @param[in] block_y Block shape y value + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed * * @return the calculated shape */ -inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y) +inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y, const CropInfo &crop_info = CropInfo{}) { ARM_COMPUTE_ERROR_ON(block_x <= 0 || block_y <= 0); @@ -1088,8 +1089,18 @@ inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); TensorShape output_shape{ input->tensor_shape() }; - output_shape.set(idx_width, input->tensor_shape()[idx_width] * block_x); - output_shape.set(idx_height, input->tensor_shape()[idx_height] * block_y); + + auto new_width = input->tensor_shape()[idx_width] * block_x; + auto new_height = input->tensor_shape()[idx_height] * block_y; + const auto width_crop = crop_info.left + crop_info.right; + const auto height_crop = crop_info.top + crop_info.bottom; + ARM_COMPUTE_ERROR_ON(new_width <= width_crop); + ARM_COMPUTE_ERROR_ON(new_height <= height_crop); + new_width -= width_crop; + new_height -= height_crop; + + output_shape.set(idx_width, new_width); + output_shape.set(idx_height, new_height); output_shape.set(idx_batch, input->tensor_shape()[idx_batch] / (block_x * block_y)); return output_shape; @@ -1537,14 +1548,14 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn */ inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) { - const auto input_num_dims = input_shape.num_dimensions(); + const auto input_num_dims = input_shape.num_dimensions(); const auto indices_num_dims = indices_shape.num_dimensions(); ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims); ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions); TensorShape output_shape; - size_t dim_no = 0; + size_t dim_no = 0; for(; dim_no < actual_axis; ++dim_no) { diff --git a/arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h b/arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h index f6ba2b0b02..4b7cb60bc1 100644 --- a/arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h +++ b/arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -67,24 +67,27 @@ public: * @param[in] input Tensor input. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape 1-D tensor with shape [M]. Data types supported: S32 * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output); + void configure(const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output, const CropInfo &crop_info = CropInfo{}); /** Set the input and output tensors. * * @param[in] compile_context The compile context to be used. * @param[in] input Tensor input. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape 1-D tensor with shape [M]. Data types supported: S32 * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output); + void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output, const CropInfo &crop_info = CropInfo{}); /** Set the input and output tensors. (Static block shape). * * @param[in] input Tensor input. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape_x Block shape x value. * @param[in] block_shape_y Block shape y value. * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output); + void configure(const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output, const CropInfo &crop_info = CropInfo{}); /** Set the input and output tensors. (Static block shape). * * @param[in] compile_context The compile context to be used. @@ -92,27 +95,30 @@ public: * @param[in] block_shape_x Block shape x value. * @param[in] block_shape_y Block shape y value. * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const CLCompileContext &compile_context, const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output); + void configure(const CLCompileContext &compile_context, const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output, const CropInfo &crop_info = CropInfo{}); /** Static function to check if given info will lead to a valid configuration of @ref CLBatchToSpaceLayer * * @param[in] input Tensor input info. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape block shape tensor info with shape [M]. Data types supported: S32 * @param[out] output Tensor output info. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output); + static Status validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output, const CropInfo &crop_info = CropInfo{}); /** Static function to check if given info will lead to a valid configuration of @ref CLBatchToSpaceLayer (Static block shape). * * @param[in] input Tensor input info. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape_x Block shape x value. * @param[in] block_shape_y Block shape y value. * @param[out] output Tensor output info. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed * * @return a status */ - static Status validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output); + static Status validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output, const CropInfo &crop_info = CropInfo{}); // Inherited methods overridden: void run() override; diff --git a/arm_compute/runtime/NEON/functions/NEBatchToSpaceLayer.h b/arm_compute/runtime/NEON/functions/NEBatchToSpaceLayer.h index 810bf81a22..92df8913ab 100644 --- a/arm_compute/runtime/NEON/functions/NEBatchToSpaceLayer.h +++ b/arm_compute/runtime/NEON/functions/NEBatchToSpaceLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -64,35 +64,39 @@ public: * @param[in] input Tensor input. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape 1-D tensor with shape [M]. Data types supported: S32 * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const ITensor *input, const ITensor *block_shape, ITensor *output); + void configure(const ITensor *input, const ITensor *block_shape, ITensor *output, const CropInfo &crop_info = CropInfo{}); /** Set the input and output tensors. (Static block shape). * * @param[in] input Tensor input. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape_x Block shape x value. * @param[in] block_shape_y Block shape y value. * @param[out] output Tensor output. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed */ - void configure(const ITensor *input, int32_t block_shape_x, int32_t block_shape_y, ITensor *output); + void configure(const ITensor *input, int32_t block_shape_x, int32_t block_shape_y, ITensor *output, const CropInfo &crop_info = CropInfo{}); /** Static function to check if given info will lead to a valid configuration of @ref CLBatchToSpaceLayer * * @param[in] input Tensor input info. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape block shape tensor info with shape [M]. Data types supported: S32 * @param[out] output Tensor output info. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output); + static Status validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output, const CropInfo &crop_info = CropInfo{}); /** Static function to check if given info will lead to a valid configuration of @ref CLBatchToSpaceLayer (Static block shape). * * @param[in] input Tensor input info. Supported tensor rank: 4. Data types supported: All. * @param[in] block_shape_x Block shape x value. * @param[in] block_shape_y Block shape y value. * @param[out] output Tensor output info. Data types supported: same as @p input + * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed * * @return a status */ - static Status validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output); + static Status validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output, const CropInfo &crop_info = CropInfo{}); }; } // namespace arm_compute #endif /* ARM_COMPUTE_NEBATCHTOSPACELAYER_H */ diff --git a/src/runtime/CL/functions/CLBatchToSpaceLayer.cpp b/src/runtime/CL/functions/CLBatchToSpaceLayer.cpp index a7691aa66b..d4342b456f 100644 --- a/src/runtime/CL/functions/CLBatchToSpaceLayer.cpp +++ b/src/runtime/CL/functions/CLBatchToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -43,35 +43,39 @@ CLBatchToSpaceLayer::CLBatchToSpaceLayer() CLBatchToSpaceLayer::~CLBatchToSpaceLayer() = default; -void CLBatchToSpaceLayer::configure(const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output) +void CLBatchToSpaceLayer::configure(const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output, const CropInfo &crop_info) { - configure(CLKernelLibrary::get().get_compile_context(), input, block_shape, output); + configure(CLKernelLibrary::get().get_compile_context(), input, block_shape, output, crop_info); } -void CLBatchToSpaceLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output) +void CLBatchToSpaceLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *block_shape, ICLTensor *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); ARM_COMPUTE_LOG_PARAMS(input, block_shape, output); _batch_to_space_kernel->configure(compile_context, input, block_shape, output); } -void CLBatchToSpaceLayer::configure(const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output) +void CLBatchToSpaceLayer::configure(const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output, const CropInfo &crop_info) { - configure(CLKernelLibrary::get().get_compile_context(), input, block_shape_x, block_shape_y, output); + configure(CLKernelLibrary::get().get_compile_context(), input, block_shape_x, block_shape_y, output, crop_info); } -void CLBatchToSpaceLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output) +void CLBatchToSpaceLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, int32_t block_shape_x, int32_t block_shape_y, ICLTensor *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); ARM_COMPUTE_LOG_PARAMS(input, block_shape_x, block_shape_y, output); _batch_to_space_kernel->configure(compile_context, input, block_shape_x, block_shape_y, output); } -Status CLBatchToSpaceLayer::validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output) +Status CLBatchToSpaceLayer::validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); return CLBatchToSpaceLayerKernel::validate(input, block_shape, output); } -Status CLBatchToSpaceLayer::validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output) +Status CLBatchToSpaceLayer::validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); return CLBatchToSpaceLayerKernel::validate(input, block_shape_x, block_shape_y, output); } diff --git a/src/runtime/NEON/functions/NEBatchToSpaceLayer.cpp b/src/runtime/NEON/functions/NEBatchToSpaceLayer.cpp index 5a2e37a517..b62fdad7a1 100644 --- a/src/runtime/NEON/functions/NEBatchToSpaceLayer.cpp +++ b/src/runtime/NEON/functions/NEBatchToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -33,28 +33,32 @@ namespace arm_compute { -void NEBatchToSpaceLayer::configure(const ITensor *input, const ITensor *block_shape, ITensor *output) +void NEBatchToSpaceLayer::configure(const ITensor *input, const ITensor *block_shape, ITensor *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); ARM_COMPUTE_LOG_PARAMS(input, block_shape, output); auto k = std::make_unique(); k->configure(input, block_shape, output); _kernel = std::move(k); } -void NEBatchToSpaceLayer::configure(const ITensor *input, int32_t block_shape_x, int32_t block_shape_y, ITensor *output) +void NEBatchToSpaceLayer::configure(const ITensor *input, int32_t block_shape_x, int32_t block_shape_y, ITensor *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); auto k = std::make_unique(); k->configure(input, block_shape_x, block_shape_y, output); _kernel = std::move(k); } -Status NEBatchToSpaceLayer::validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output) +Status NEBatchToSpaceLayer::validate(const ITensorInfo *input, const ITensorInfo *block_shape, const ITensorInfo *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); return NEBatchToSpaceLayerKernel::validate(input, block_shape, output); } -Status NEBatchToSpaceLayer::validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output) +Status NEBatchToSpaceLayer::validate(const ITensorInfo *input, int32_t block_shape_x, int32_t block_shape_y, const ITensorInfo *output, const CropInfo &crop_info) { + ARM_COMPUTE_UNUSED(crop_info); return NEBatchToSpaceLayerKernel::validate(input, block_shape_x, block_shape_y, output); } } // namespace arm_compute diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h index 6554c09de4..5a23261a6e 100644 --- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h +++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,14 +36,14 @@ namespace test namespace validation { template -class BatchToSpaceLayerValidationFixture : public framework::Fixture +class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture { public: template - void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{}) { - _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout); - _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type); + _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info); + _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info); } protected: @@ -57,7 +57,7 @@ protected: library->fill(tensor, distribution, i); } TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, - DataType data_type, DataLayout data_layout) + DataType data_type, DataLayout data_layout, const CropInfo &crop_info) { if(data_layout == DataLayout::NHWC) { @@ -72,7 +72,7 @@ protected: // Create and configure function FunctionType batch_to_space; - batch_to_space.configure(&input, &block_shape, &output); + batch_to_space.configure(&input, &block_shape, &output, crop_info); ARM_COMPUTE_ASSERT(input.info()->is_resizable()); ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable()); @@ -104,7 +104,7 @@ protected: } SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape, - const TensorShape &output_shape, DataType data_type) + const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info) { // Create reference SimpleTensor input{ input_shape, data_type }; @@ -118,12 +118,23 @@ protected: } // Compute reference - return reference::batch_to_space(input, block_shape, output_shape); + return reference::batch_to_space(input, block_shape, output_shape, crop_info); } TensorType _target{}; SimpleTensor _reference{}; }; + +template +class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout) + { + BatchToSpaceLayerValidationGenericFixture::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{}); + } +}; } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/reference/BatchToSpaceLayer.cpp b/tests/validation/reference/BatchToSpaceLayer.cpp index 404ee73cac..aeda733bb6 100644 --- a/tests/validation/reference/BatchToSpaceLayer.cpp +++ b/tests/validation/reference/BatchToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 Arm Limited. + * Copyright (c) 2018, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,32 +35,35 @@ namespace reference { // Batch to Space template -SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape) +SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info) { ARM_COMPUTE_ERROR_ON(block_shape[0] <= 0); ARM_COMPUTE_ERROR_ON(block_shape[1] <= 0); SimpleTensor result(dst_shape, src.data_type()); + int out_pos = 0; + const auto width_out = static_cast(dst_shape[0]); + const auto height_out = static_cast(dst_shape[1]); + const auto z_out = static_cast(dst_shape[2]); + const auto batch_out = static_cast(dst_shape[3]); + ARM_COMPUTE_ERROR_ON(width_out <= static_cast(crop_info.left + crop_info.right)); + ARM_COMPUTE_ERROR_ON(height_out <= static_cast(crop_info.top + crop_info.bottom)); - int in_pos = 0; - const auto width_in = static_cast(src.shape()[0]); - const auto height_in = static_cast(src.shape()[1]); - const auto z_in = static_cast(src.shape()[2]); - const auto batch_in = static_cast(src.shape()[3]); - - for(int batch = 0; batch < batch_in; ++batch) + for(int batch = 0; batch < batch_out; ++batch) { - for(int z = 0; z < z_in; ++z) + for(int z = 0; z < z_out; ++z) { - for(int y = 0; y < height_in; ++y) + for(int y = 0; y < height_out; ++y) { - for(int x = 0; x < width_in; ++x) + for(int x = 0; x < width_out; ++x) { - const int r = src.shape()[3] / (block_shape[0] * block_shape[1]); - const int out_x = (block_shape[0] * x + (batch / r) % block_shape[0]); - const int out_y = (block_shape[1] * y + (batch / r) / block_shape[0]); - const int out_pos = out_x + dst_shape[0] * out_y + z * dst_shape[0] * dst_shape[1] + (batch % r) * dst_shape[0] * dst_shape[1] * dst_shape[2]; - result[out_pos] = src[in_pos]; - ++in_pos; + const int x_c = x + crop_info.left; + const int y_c = y + crop_info.top; + const int in_batch = batch + ((x_c % block_shape[0]) + (y_c % block_shape[1]) * (block_shape[0])) * dst_shape[3]; + const int in_x = x_c / block_shape[0]; + const int in_y = y_c / block_shape[1]; + const int in_pos = in_x + src.shape()[0] * in_y + z * src.shape()[0] * src.shape()[1] + in_batch * src.shape()[0] * src.shape()[1] * src.shape()[2]; + result[out_pos] = src[in_pos]; + ++out_pos; } } } @@ -68,8 +71,8 @@ SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape); -template SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape); +template SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); +template SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/BatchToSpaceLayer.h b/tests/validation/reference/BatchToSpaceLayer.h index 52556cb53f..18010f1885 100644 --- a/tests/validation/reference/BatchToSpaceLayer.h +++ b/tests/validation/reference/BatchToSpaceLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H #define ARM_COMPUTE_TEST_BATCH_TO_SPACE_LAYER_H +#include "arm_compute/core/Types.h" #include "tests/SimpleTensor.h" #include "tests/validation/Helpers.h" @@ -36,7 +37,7 @@ namespace validation namespace reference { template -SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape); +SimpleTensor batch_to_space(const SimpleTensor &src, const SimpleTensor &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{}); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1