aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/BatchToSpaceLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/BatchToSpaceLayer.cpp')
-rw-r--r--tests/validation/reference/BatchToSpaceLayer.cpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/tests/validation/reference/BatchToSpaceLayer.cpp b/tests/validation/reference/BatchToSpaceLayer.cpp
index aeda733bb6..63d121f59b 100644
--- a/tests/validation/reference/BatchToSpaceLayer.cpp
+++ b/tests/validation/reference/BatchToSpaceLayer.cpp
@@ -23,8 +23,10 @@
*/
#include "BatchToSpaceLayer.h"
+#include "arm_compute/core/Validate.h"
#include "tests/validation/Helpers.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
namespace test
@@ -35,18 +37,20 @@ namespace reference
{
// Batch to Space
template <typename T>
-SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info)
+SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const std::vector<int32_t> &block_shape, const CropInfo &crop_info, const TensorShape &dst_shape)
{
- ARM_COMPUTE_ERROR_ON(block_shape[0] <= 0);
- ARM_COMPUTE_ERROR_ON(block_shape[1] <= 0);
+ ARM_COMPUTE_ERROR_ON(block_shape[0] < 1);
+ ARM_COMPUTE_ERROR_ON(block_shape[1] < 1);
+ const auto expected_dst_shape = misc::shape_calculator::compute_batch_to_space_shape(DataLayout::NCHW, src.shape(), block_shape[0], block_shape[1], crop_info);
+ ARM_COMPUTE_ERROR_ON(arm_compute::detail::have_different_dimensions(expected_dst_shape, dst_shape, 0));
+ ARM_COMPUTE_UNUSED(expected_dst_shape);
+
SimpleTensor<T> result(dst_shape, src.data_type());
int out_pos = 0;
const auto width_out = static_cast<int>(dst_shape[0]);
const auto height_out = static_cast<int>(dst_shape[1]);
const auto z_out = static_cast<int>(dst_shape[2]);
const auto batch_out = static_cast<int>(dst_shape[3]);
- ARM_COMPUTE_ERROR_ON(width_out <= static_cast<int>(crop_info.left + crop_info.right));
- ARM_COMPUTE_ERROR_ON(height_out <= static_cast<int>(crop_info.top + crop_info.bottom));
for(int batch = 0; batch < batch_out; ++batch)
{
@@ -71,8 +75,8 @@ SimpleTensor<T> batch_to_space(const SimpleTensor<T> &src, const SimpleTensor<in
return result;
}
-template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{});
-template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const SimpleTensor<int32_t> &block_shape, const TensorShape &dst_shape, const CropInfo &crop_info = CropInfo{});
+template SimpleTensor<float> batch_to_space(const SimpleTensor<float> &src, const std::vector<int32_t> &block_shape, const CropInfo &crop_info, const TensorShape &dst_shape);
+template SimpleTensor<half> batch_to_space(const SimpleTensor<half> &src, const std::vector<int32_t> &block_shape, const CropInfo &crop_info, const TensorShape &dst_shape);
} // namespace reference
} // namespace validation
} // namespace test