diff options
Diffstat (limited to 'tests/validation/CL')
-rw-r--r-- | tests/validation/CL/GEMMReshapeRHSMatrix.cpp | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp index 55688cf160..aa6667666c 100644 --- a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp @@ -123,6 +123,49 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( bool has_error = bool(CLGEMMReshapeRHSMatrixKernel::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), rhs_info)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } + +DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine( + framework::dataset::make("InputShape", { TensorShape(32U, 16U, 1U), + TensorShape(32U, 16U, 2U) + }), + framework::dataset::make("N0",{ 4 })), + framework::dataset::make("K0",{ 2, 4, 8 })), + framework::dataset::make("H0",{ 1, 2, 4 })), + input_shape, n0, k0, h0) +{ + CLTensor input; + CLTensor output; + + input.info()->init(input_shape, 1, DataType::F32); + + unsigned int padding = 0; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = n0; + rhs_info.k0 = k0; + rhs_info.h0 = h0; + rhs_info.transpose = true; + rhs_info.interleave = true; + rhs_info.export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device()) && (get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) != 0); + + if(rhs_info.export_to_cl_image) + { + TensorShape output_shape = compute_rhs_reshaped_shape(*input.info(), rhs_info); + constexpr unsigned int num_floats_per_pixel = 4; + + const unsigned int pixel_aligment = get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()); + const unsigned int row_pitch_alignment = pixel_aligment * num_floats_per_pixel; + const unsigned int round_up_width = ((output_shape[0] + row_pitch_alignment - 1) / row_pitch_alignment) * row_pitch_alignment; + + padding = round_up_width - output_shape[0]; + } + + CLGEMMReshapeRHSMatrixKernel kernel; + + kernel.configure(&input, &output, rhs_info); + + ARM_COMPUTE_EXPECT((output.info()->padding().right == padding), framework::LogLevel::ERRORS); +} // clang-format on // *INDENT-ON* |