From 5bdde8509542e942e908e9d508dd39c73194abfb Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Wed, 26 Aug 2020 13:55:15 +0100 Subject: COMPMID-3670 Extend cl image support to f16 in CLGEMMReshapeRHSMatrixKernel Change-Id: Ic0569fe9ed99e61084b601ce84ddc7ef288d1899 Signed-off-by: SiCong Li Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3852 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h | 6 +++--- src/core/CL/gemm/CLGEMMHelpers.cpp | 2 +- src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp | 8 ++++---- tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 4 ++-- tests/validation/CL/GEMMReshapeRHSMatrix.cpp | 10 ++++------ 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h b/arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h index 5f953ddf8d..557f71b07d 100644 --- a/arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h @@ -53,7 +53,7 @@ public: * Since the OpenCL image object is created importing the OpenCL buffer, the following conditions are required: * -# rhs_info.n0 can only be 4, 8 and 16 * -# rhs_info.k0 can only be 4, 8 and 16 - * -# Data type can only be F32 + * -# Data type can only be F32, F16 * -# The platform should support the OpenCL cl_khr_image2d_from_buffer extension * -# output width should be less or equal to (CL_DEVICE_IMAGE2D_MAX_WIDTH * 4) * -# output (height * depth) should be less or equal to CL_DEVICE_IMAGE2D_MAX_HEIGHT @@ -77,7 +77,7 @@ public: * Since the OpenCL image object is created importing the OpenCL buffer, the following conditions are required: * -# rhs_info.n0 can only be 4, 8 and 16 * -# rhs_info.k0 can only be 4, 8 and 16 - * -# Data type can only be F32 + * -# Data type can only be F32, F16 * -# The platform should support the OpenCL cl_khr_image2d_from_buffer extension * -# output width should be less or equal to (CL_DEVICE_IMAGE2D_MAX_WIDTH * 4) * -# output (height * depth) should be less or equal to CL_DEVICE_IMAGE2D_MAX_HEIGHT @@ -102,7 +102,7 @@ public: * Since the OpenCL image object is created importing the OpenCL buffer, the following conditions are required: * -# rhs_info.n0 can only be 4, 8 and 16 * -# rhs_info.k0 can only be 4, 8 and 16 - * -# Data type can only be F32 + * -# Data type can only be F32, F16 * -# The platform should support the OpenCL cl_khr_image2d_from_buffer extension * -# output width should be less or equal to (CL_DEVICE_IMAGE2D_MAX_WIDTH * 4) * -# output (height * depth) should be less or equal to CL_DEVICE_IMAGE2D_MAX_HEIGHT diff --git a/src/core/CL/gemm/CLGEMMHelpers.cpp b/src/core/CL/gemm/CLGEMMHelpers.cpp index 5734c93021..0a4a4adc31 100644 --- a/src/core/CL/gemm/CLGEMMHelpers.cpp +++ b/src/core/CL/gemm/CLGEMMHelpers.cpp @@ -65,7 +65,7 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, { ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.data_type() != DataType::F32, "Export to cl_image only supported with F32 data type"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(&tensor_reshaped_info, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment"); diff --git a/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp b/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp index c1993b72b9..ce294646a0 100644 --- a/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp +++ b/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp @@ -56,15 +56,15 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16); ARM_COMPUTE_RETURN_ERROR_ON((rhs_info.k0 == 1) && (rhs_info.transpose)); + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); + if(rhs_info.export_to_cl_image) { - const TensorInfo tensor_reshaped_info(compute_rhs_reshaped_shape(*input, rhs_info), 1, DataType::F32); + const TensorInfo tensor_reshaped_info(compute_rhs_reshaped_shape(*input, rhs_info), 1, input->data_type()); ARM_COMPUTE_RETURN_ON_ERROR(cl_gemm::validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info)); } - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); - if(output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_rhs_reshaped_shape(*input, rhs_info)); diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index bd0cd03ca7..afb2807d01 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -295,7 +295,7 @@ TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRHS) * - Incorrect input0 dimension when input is reinterpreted as 3D: input0->dimension(1) * input0->dimension(2) != m * - Correct support for creating an OpenCL image object from buffer * - Incorrect support for creating an OpenCL image object from buffer. N0 is 2 but it can only be 4,8 and 16 - * - Incorrect support for creating an OpenCL image object from buffer. Data type is F16 but it can only be F32 + * - Correct F16 support for creating an OpenCL image object from buffer. */ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip( framework::dataset::make("batch_size", { 1, 1, 1, 1, 1, 1, 2, 1, 1, 1 }), @@ -311,7 +311,7 @@ framework::dataset::make("data_type_input1", { DataType::F32, DataType::F32, framework::dataset::make("data_type_input2", { DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F16})), framework::dataset::make("data_type_output", { DataType::F16, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F32, DataType::F16})), framework::dataset::make("Beta", { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f , 1.0f})), -framework::dataset::make("Expected", { false, false, false, false, false, false, false, true, false, false })), +framework::dataset::make("Expected", { false, false, false, false, false, false, false, true, false, true })), b_value, m0_value, n0_value, k0_value, broadcast_bias, input_as_3d, depth_output_gemm3d, export_to_cl_image, dt_input0, dt_intpu1, dt_input2, dt_output, beta, expected) { bool expected_value = expected; diff --git a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp index c7b0752cc8..579ed32afe 100644 --- a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp @@ -46,9 +46,6 @@ namespace { // *INDENT-OFF* // clang-format off -/** Data types */ -const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 }); - /** Batch size values to test */ const auto b_values = framework::dataset::make("batchsize", 1, 3); @@ -124,19 +121,20 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } -DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine( +DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine(combine( framework::dataset::make("InputShape", { TensorShape(32U, 16U, 1U), TensorShape(32U, 16U, 2U) }), framework::dataset::make("N0",{ 4 })), framework::dataset::make("K0",{ 4, 8, 16 })), framework::dataset::make("H0",{ 1, 2, 4 })), - input_shape, n0, k0, h0) + framework::dataset::make("DataType",{ DataType::F32, DataType::F16 })), + input_shape, n0, k0, h0, data_type) { CLTensor input; CLTensor output; - input.info()->init(input_shape, 1, DataType::F32); + input.info()->init(input_shape, 1, data_type); unsigned int padding = 0; -- cgit v1.2.1