diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/Col2ImFixture.h | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/validation/fixtures/Col2ImFixture.h b/tests/validation/fixtures/Col2ImFixture.h index ddc78a5032..5488f8a3ea 100644 --- a/tests/validation/fixtures/Col2ImFixture.h +++ b/tests/validation/fixtures/Col2ImFixture.h @@ -44,16 +44,16 @@ namespace validation { using namespace arm_compute::misc::shape_calculator; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> class Col2ImValidationFixture : public framework::Fixture { public: template <typename...> void setup(TensorShape input_shape, const unsigned int convolved_width, unsigned int convolved_height, unsigned int num_groups, DataType data_type) { - const std::pair<unsigned int, unsigned int> convolved_dims(convolved_width, convolved_height); + const Size2D convolved_dims(convolved_width, convolved_height); - const TensorShape output_shape = compute_col2im_shape(TensorInfo(input_shape, 1, data_type), convolved_dims, num_groups); + const TensorShape output_shape = compute_col2im_shape(TensorInfo(input_shape, 1, data_type), convolved_dims, batch_size_on_z, num_groups); _target = compute_target(input_shape, output_shape, convolved_dims, num_groups, data_type); _reference = compute_reference(input_shape, output_shape, num_groups, data_type); @@ -66,7 +66,7 @@ protected: library->fill_tensor_uniform(tensor, seed); } - TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, std::pair<unsigned int, unsigned int> convolved_dims, unsigned int num_groups, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const Size2D &convolved_dims, unsigned int num_groups, DataType data_type) { // Create tensors TensorType src = create_tensor<TensorType>(input_shape, data_type); |