aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h3
-rw-r--r--tests/validation/reference/PoolingLayer.cpp37
2 files changed, 15 insertions, 25 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 3c28b3b64d..a40baf415a 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -175,9 +175,8 @@ class SpecialPoolingLayerValidationFixture : public PoolingLayerValidationGeneri
{
public:
template <typename...>
- void setup(TensorShape src_shape, TensorShape dst_shape, PoolingLayerInfo pool_info, DataType data_type)
+ void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
{
- ARM_COMPUTE_UNUSED(dst_shape);
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW, 0, QuantizationInfo());
}
};
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp
index 071c20ed56..f3f456b26e 100644
--- a/tests/validation/reference/PoolingLayer.cpp
+++ b/tests/validation/reference/PoolingLayer.cpp
@@ -25,6 +25,7 @@
#include "Permute.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "tests/validation/FixedPoint.h"
#include "tests/validation/Helpers.h"
@@ -36,24 +37,7 @@ namespace validation
{
namespace reference
{
-namespace
-{
-TensorShape calculate_output_shape(TensorShape shape, const PoolingLayerInfo &info)
-{
- TensorShape dst_shape = shape;
- const int pool_size_x = info.is_global_pooling() ? shape.x() : info.pool_size().width;
- const int pool_size_y = info.is_global_pooling() ? shape.y() : info.pool_size().height;
- const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(),
- shape.y(),
- pool_size_x,
- pool_size_y,
- info.pad_stride_info());
- dst_shape.set(0, scaled_dims.first);
- dst_shape.set(1, scaled_dims.second);
-
- return dst_shape;
-}
-} // namespace
+using namespace arm_compute::misc::shape_calculator;
template <typename T>
SimpleTensor<T> pooling_layer_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const PoolingLayerInfo &info)
@@ -187,7 +171,10 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo
bool exclude_padding = info.exclude_padding();
// Create reference
- SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() };
+ TensorInfo src_info(src.shape(), 1, src.data_type(), src.fixed_point_position());
+ src_info.set_data_layout(src.data_layout());
+
+ SimpleTensor<T> dst{ compute_pool_shape(src_info, info), src.data_type(), 1, src.fixed_point_position() };
const auto w_dst = static_cast<int>(dst.shape()[0]);
const auto h_dst = static_cast<int>(dst.shape()[1]);
@@ -300,18 +287,22 @@ SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, c
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info)
{
+ TensorInfo src_info(src.shape(), 1, src.data_type(), src.fixed_point_position());
+ src_info.set_data_layout(src.data_layout());
+
+ SimpleTensor<T> dst{ compute_pool_shape(src_info, info), src.data_type(), 1, src.fixed_point_position() };
+
if(src.data_layout() == DataLayout::NHWC)
{
SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
- SimpleTensor<T> dst{ calculate_output_shape(src_nchw.shape(), info), src_nchw.data_type(), 1, src_nchw.fixed_point_position() };
+ SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
- pooling_layer_nchw<T>(src_nchw, dst, info);
+ pooling_layer_nchw<T>(src_nchw, dst_nchw, info);
- return reference::permute<T>(dst, PermutationVector(2U, 0U, 1U));
+ return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U));
}
else
{
- SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() };
return pooling_layer_nchw<T>(src, dst, info);
}
}