diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ReorgLayer.cpp | 31 |
1 files changed, 10 insertions, 21 deletions
diff --git a/tests/validation/reference/ReorgLayer.cpp b/tests/validation/reference/ReorgLayer.cpp index cb13a737e0..2eb5d01926 100644 --- a/tests/validation/reference/ReorgLayer.cpp +++ b/tests/validation/reference/ReorgLayer.cpp @@ -24,6 +24,7 @@ #include "ReorgLayer.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" namespace arm_compute { @@ -33,29 +34,17 @@ namespace validation { namespace reference { -namespace -{ -TensorShape compute_reorg_shape(const TensorShape &src_shape, int32_t stride) -{ - ARM_COMPUTE_ERROR_ON(stride <= 0); - - TensorShape dst_shape = src_shape; - dst_shape.set(0, src_shape.x() / stride); - dst_shape.set(1, src_shape.y() / stride); - dst_shape.set(2, src_shape.z() * stride * stride); - - return dst_shape; -} -} // namespace - template <typename T> SimpleTensor<T> reorg_layer(const SimpleTensor<T> &src, int32_t stride) { - // Calculate output shape - const TensorShape dst_shape = compute_reorg_shape(src.shape(), stride); + ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 4); + ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NCHW); + + TensorInfo input_info(src.shape(), 1, src.data_type()); + const TensorShape output_shape = misc::shape_calculator::compute_reorg_output_shape(input_info, stride); // Create destination tensor - SimpleTensor<T> dst{ dst_shape, src.data_type() }; + SimpleTensor<T> dst{ output_shape, src.data_type() }; const unsigned int W = dst.shape().x(); const unsigned int H = dst.shape().y(); @@ -88,9 +77,9 @@ SimpleTensor<T> reorg_layer(const SimpleTensor<T> &src, int32_t stride) return dst; } -template SimpleTensor<uint8_t> reorg_layer(const SimpleTensor<uint8_t> &src, int32_t stride); -template SimpleTensor<uint16_t> reorg_layer(const SimpleTensor<uint16_t> &src, int32_t stride); -template SimpleTensor<uint32_t> reorg_layer(const SimpleTensor<uint32_t> &src, int32_t stride); +template SimpleTensor<int32_t> reorg_layer(const SimpleTensor<int32_t> &src, int32_t stride); +template SimpleTensor<int16_t> reorg_layer(const SimpleTensor<int16_t> &src, int32_t stride); +template SimpleTensor<int8_t> reorg_layer(const SimpleTensor<int8_t> &src, int32_t stride); } // namespace reference } // namespace validation } // namespace test |