diff options
Diffstat (limited to 'tests/validation/fixtures/SpaceToDepthFixture.h')
-rw-r--r-- | tests/validation/fixtures/SpaceToDepthFixture.h | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tests/validation/fixtures/SpaceToDepthFixture.h b/tests/validation/fixtures/SpaceToDepthFixture.h index 7448ec546b..0776e495eb 100644 --- a/tests/validation/fixtures/SpaceToDepthFixture.h +++ b/tests/validation/fixtures/SpaceToDepthFixture.h @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_TEST_SPACE_TO_DEPTH_LAYER_FIXTURE #define ARM_COMPUTE_TEST_SPACE_TO_DEPTH_LAYER_FIXTURE +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "tests/Globals.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" @@ -69,6 +70,12 @@ protected: TensorType input = create_tensor<TensorType>(input_shape, data_type, 1, QuantizationInfo(), data_layout); TensorType output = create_tensor<TensorType>(output_shape, data_type, 1, QuantizationInfo(), data_layout); + auto calc_out_shape = misc::shape_calculator::compute_space_to_depth_shape(input.info(), block_shape); + ARM_COMPUTE_ASSERT(output_shape[0] == calc_out_shape[0]); + ARM_COMPUTE_ASSERT(output_shape[1] == calc_out_shape[1]); + ARM_COMPUTE_ASSERT(output_shape[2] == calc_out_shape[2]); + ARM_COMPUTE_ASSERT(output_shape[3] == calc_out_shape[3]); + // Create and configure function FunctionType space_to_depth; space_to_depth.configure(&input, &output, block_shape); |