aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ReorgLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ReorgLayer.cpp')
-rw-r--r--tests/validation/reference/ReorgLayer.cpp31
1 files changed, 10 insertions, 21 deletions
diff --git a/tests/validation/reference/ReorgLayer.cpp b/tests/validation/reference/ReorgLayer.cpp
index cb13a737e0..2eb5d01926 100644
--- a/tests/validation/reference/ReorgLayer.cpp
+++ b/tests/validation/reference/ReorgLayer.cpp
@@ -24,6 +24,7 @@
#include "ReorgLayer.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
@@ -33,29 +34,17 @@ namespace validation
{
namespace reference
{
-namespace
-{
-TensorShape compute_reorg_shape(const TensorShape &src_shape, int32_t stride)
-{
- ARM_COMPUTE_ERROR_ON(stride <= 0);
-
- TensorShape dst_shape = src_shape;
- dst_shape.set(0, src_shape.x() / stride);
- dst_shape.set(1, src_shape.y() / stride);
- dst_shape.set(2, src_shape.z() * stride * stride);
-
- return dst_shape;
-}
-} // namespace
-
template <typename T>
SimpleTensor<T> reorg_layer(const SimpleTensor<T> &src, int32_t stride)
{
- // Calculate output shape
- const TensorShape dst_shape = compute_reorg_shape(src.shape(), stride);
+ ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 4);
+ ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NCHW);
+
+ TensorInfo input_info(src.shape(), 1, src.data_type());
+ const TensorShape output_shape = misc::shape_calculator::compute_reorg_output_shape(input_info, stride);
// Create destination tensor
- SimpleTensor<T> dst{ dst_shape, src.data_type() };
+ SimpleTensor<T> dst{ output_shape, src.data_type() };
const unsigned int W = dst.shape().x();
const unsigned int H = dst.shape().y();
@@ -88,9 +77,9 @@ SimpleTensor<T> reorg_layer(const SimpleTensor<T> &src, int32_t stride)
return dst;
}
-template SimpleTensor<uint8_t> reorg_layer(const SimpleTensor<uint8_t> &src, int32_t stride);
-template SimpleTensor<uint16_t> reorg_layer(const SimpleTensor<uint16_t> &src, int32_t stride);
-template SimpleTensor<uint32_t> reorg_layer(const SimpleTensor<uint32_t> &src, int32_t stride);
+template SimpleTensor<int32_t> reorg_layer(const SimpleTensor<int32_t> &src, int32_t stride);
+template SimpleTensor<int16_t> reorg_layer(const SimpleTensor<int16_t> &src, int32_t stride);
+template SimpleTensor<int8_t> reorg_layer(const SimpleTensor<int8_t> &src, int32_t stride);
} // namespace reference
} // namespace validation
} // namespace test