diff options
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/PoolingLayerFixture.h | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index 7f2d7ac225..eb40cea0c2 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -35,7 +35,6 @@ #include "tests/framework/Fixture.h" #include "tests/validation/reference/PoolingLayer.h" #include <random> - namespace arm_compute { namespace test @@ -59,7 +58,7 @@ public: _pool_info = pool_info; _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); - _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo, indices); + _reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); } protected: @@ -92,7 +91,7 @@ protected: TensorType src = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout); const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info); TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout); - _target_indices = create_tensor<TensorType>(dst_shape, DataType::U32, 1); + _target_indices = create_tensor<TensorType>(dst_shape, DataType::U32, 1, output_qinfo, data_layout); // Create and configure function FunctionType pool_layer; @@ -120,15 +119,14 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, + SimpleTensor<T> compute_reference(TensorShape shape, PoolingLayerInfo info, DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices) { // Create reference - SimpleTensor<T> src{ shape, data_type, 1, input_qinfo }; + SimpleTensor<T> src(shape, data_type, 1, input_qinfo); // Fill reference fill(src); - - return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr); + return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr, data_layout); } TensorType _target{}; |