aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-07-09 08:41:10 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-07-15 08:37:31 +0000
commited5fe69b6612a5cf0dd52340f6781885d77afbc9 (patch)
treefc0ca94be92fe7c39e4c2047379b3bd301e9d67e
parent4667dddc0ed403c636348294cd7f70261e5540cf (diff)
downloadComputeLibrary-ed5fe69b6612a5cf0dd52340f6781885d77afbc9.tar.gz
COMPMID-3326: Update heuristic for GEMMReshaped and GEMMReshapedOnlyRHS
- Update the heuristic for Arm Mali-G76 (F32) in order to use the OpenCL image2d object on GEMM - Create utility function to validate the support for image2d Change-Id: I0913ac5f27fd07992b0ac188af753a2abeb034ca Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3559 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/gemm/CLGEMMHelpers.h37
-rw-r--r--src/core/CL/gemm/CLGEMMHelpers.cpp48
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp44
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp42
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp19
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp19
-rw-r--r--src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp17
7 files changed, 144 insertions, 82 deletions
diff --git a/arm_compute/core/CL/gemm/CLGEMMHelpers.h b/arm_compute/core/CL/gemm/CLGEMMHelpers.h
index a370f9171a..013c068cf7 100644
--- a/arm_compute/core/CL/gemm/CLGEMMHelpers.h
+++ b/arm_compute/core/CL/gemm/CLGEMMHelpers.h
@@ -29,32 +29,45 @@
namespace arm_compute
{
+class ITensorInfo;
+struct GEMMRHSMatrixInfo;
+
namespace cl_gemm
{
/** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
*
- * @param[in] m Number of rows (M) in the LHS matrix not reshaped
- * @param[in] n Number of columns (N) in the RHS matrix not reshaped
- * @param[in] m0 Number of rows processed by each thread/work-item
- * @param[in] n0 Number of columns processed by each thread/work-item
- * @param[in] k0 Number of inner accumulation performed by each thread/work-item
- * @param[in] v0 Number of vertical blocks of size (m0xk0) stored on the same output row
- * @param[in] h0 Number of horizontal blocks of size (k0xn0) stored on the same output row
- * @param[in] lhs_interleave True if the v0 (m0xk0) blocks have to be interleaved in the output row
- * @param[in] rhs_interleave True if the h0 (k0xn0) blocks have to be interleaved in the output row
- * @param[in] lhs_transpose True if the (m0xk0) block has to be transposed before been stored
- * @param[in] rhs_transpose True if the (k0xn0) block has to be transposed before been stored
+ * @param[in] m Number of rows (M) in the LHS matrix not reshaped
+ * @param[in] n Number of columns (N) in the RHS matrix not reshaped
+ * @param[in] m0 Number of rows processed by each thread/work-item
+ * @param[in] n0 Number of columns processed by each thread/work-item
+ * @param[in] k0 Number of inner accumulation performed by each thread/work-item
+ * @param[in] v0 Number of vertical blocks of size (m0xk0) stored on the same output row
+ * @param[in] h0 Number of horizontal blocks of size (k0xn0) stored on the same output row
+ * @param[in] lhs_interleave True if the v0 (m0xk0) blocks have to be interleaved in the output row
+ * @param[in] rhs_interleave True if the h0 (k0xn0) blocks have to be interleaved in the output row
+ * @param[in] lhs_transpose True if the (m0xk0) block has to be transposed before been stored
+ * @param[in] rhs_transpose True if the (k0xn0) block has to be transposed before been stored
+ * @param[in] export_to_cl_image (Optional) True if the RHS reshaped matrix has to be exported to cl_image
*
* @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
*/
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
- bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose);
+ bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image = false);
/** Update padding required to export the OpenCL buffer to OpenCL image2d
*
* @param[in,out] tensor ITensorInfo of the tensor required to be exported to OpenCL image2d
*/
void update_padding_for_cl_image(ITensorInfo *tensor);
+
+/** Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix
+ *
+ * @param[in] tensor_reshaped_info TensorInfo for the RHS reshaped matrix
+ * @param[in] rhs_info @ref GEMMRHSMatrixInfo
+ *
+ * @return Status reporting if we can use the image2d OpenCL object on the RHS reshaped matrix
+ */
+Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info);
} // namespace cl_gemm
} // namespace arm_compute
#endif /*ARM_COMPUTE_CLGEMMHELPERS_H */
diff --git a/src/core/CL/gemm/CLGEMMHelpers.cpp b/src/core/CL/gemm/CLGEMMHelpers.cpp
index 1165e70273..5734c93021 100644
--- a/src/core/CL/gemm/CLGEMMHelpers.cpp
+++ b/src/core/CL/gemm/CLGEMMHelpers.cpp
@@ -23,7 +23,10 @@
*/
#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
+#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/OpenCL.h"
+#include "arm_compute/core/ITensorInfo.h"
#include <utility>
@@ -32,24 +35,13 @@ namespace arm_compute
namespace cl_gemm
{
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
- bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose)
+ bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
{
- GEMMLHSMatrixInfo lhs_info;
- GEMMRHSMatrixInfo rhs_info;
+ v0 = ((m / (m0 * v0)) == 0) ? 1 : v0;
+ h0 = ((n / (n0 * h0)) == 0) ? 1 : h0;
- // Configure GEMMLHSMatrixInfo
- lhs_info.m0 = m0;
- lhs_info.k0 = k0;
- lhs_info.v0 = ((m / (lhs_info.m0 * v0)) == 0) ? 1 : v0;
- lhs_info.interleave = lhs_interleave;
- lhs_info.transpose = lhs_transpose;
-
- // Configure GEMMRHSMatrixInfo
- rhs_info.n0 = n0;
- rhs_info.k0 = lhs_info.k0;
- rhs_info.h0 = ((n / (rhs_info.n0 * h0)) == 0) ? 1 : h0;
- rhs_info.interleave = rhs_interleave;
- rhs_info.transpose = rhs_transpose;
+ const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave);
+ const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image);
return std::make_pair(lhs_info, rhs_info);
}
@@ -66,5 +58,27 @@ void update_padding_for_cl_image(ITensorInfo *tensor)
tensor->extend_padding(PaddingSize(0, padding, 0, 0));
}
+
+Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
+{
+ if(rhs_info.export_to_cl_image)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.data_type() != DataType::F32, "Export to cl_image only supported with F32 data type");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment");
+
+ // Check the width and height of the output tensor.
+ // Since we cannot create a 3d image from a buffer, the third dimension is collapsed on the second dimension
+ const size_t max_image_w = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
+ const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.tensor_shape()[0] > max_image_w * 4, "Not supported width for cl_image");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(tensor_reshaped_info.tensor_shape()[1] * tensor_reshaped_info.tensor_shape()[2] > max_image_h, "Not supported height for cl_image");
+ }
+
+ return Status{};
+}
} // namespace cl_gemm
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index f2954be7d2..a533f14d02 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -27,6 +27,9 @@
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
#include "arm_compute/core/GPUTarget.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <map>
#include <utility>
@@ -35,6 +38,8 @@ namespace arm_compute
{
namespace cl_gemm
{
+using namespace arm_compute::misc::shape_calculator;
+
CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifrost(GPUTarget gpu)
: ICLGEMMKernelConfiguration(gpu)
{
@@ -153,13 +158,48 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
+ // Get lhs_info/rhs_info in case of OpenCL buffer
if(n <= 4)
{
- return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
+ }
+ else
+ {
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
+ }
+
+ // Get lhs_info/rhs_info in case of OpenCL image
+ // Condition on the GPU workload
+ if((m / 4) * (n / 4) >= 2560)
+ {
+ // Big workload
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
+ }
+ else
+ {
+ // Small workload
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
+ }
+
+ const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
+ const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
+ const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
+
+ // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
+ const bool use_cl_image2d = (n <= 4) ? false : true;
+
+ if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
+ {
+ return std::make_pair(lhs_info_img, rhs_info_img);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
+ return std::make_pair(lhs_info_buf, rhs_info_buf);
}
}
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index f662089c77..581c2d2199 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -27,6 +27,9 @@
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
#include "arm_compute/core/GPUTarget.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <map>
#include <utility>
@@ -35,6 +38,8 @@ namespace arm_compute
{
namespace cl_gemm
{
+using namespace arm_compute::misc::shape_calculator;
+
CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::CLGEMMReshapedOnlyRHSKernelConfigurationBifrost(GPUTarget gpu)
: ICLGEMMKernelConfiguration(gpu)
{
@@ -139,14 +144,47 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
+ // Get lhs_info/rhs_info in case of OpenCL buffer
if(m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
- return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+ }
+
+ // Get lhs_info/rhs_info in case of OpenCL image
+ if(m == 1)
+ {
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 4, false, true, false, false, true);
+ }
+ else
+ {
+ const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
+ }
+
+ const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
+ const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
+ const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
+
+ // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
+ const bool use_cl_image2d = (m == 1 && n <= 4096) ? false : true;
+
+ if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
+ {
+ return std::make_pair(lhs_info_img, rhs_info_img);
+ }
+ else
+ {
+ return std::make_pair(lhs_info_buf, rhs_info_buf);
}
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index be73c85731..5a46a1e013 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
+#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
@@ -79,23 +80,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
&& (!gemm_info.broadcast_bias),
"Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision && (input0->data_type() == DataType::F32), "Mixed precision only supported for F16 data type");
-
- if(rhs_info.export_to_cl_image)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::F32, "Export to cl_image only supported with F32 data type");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment");
-
- // Check the width and height of the output tensor.
- // Since we cannot create a 3d image from a buffer, the third dimension is collapsed with the second dimension
- size_t max_image_w = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
- size_t max_image_h = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->tensor_shape()[0] > max_image_w * 4, "Not supported width for cl_image");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->tensor_shape()[1] * input1->tensor_shape()[2] > max_image_h, "Not supported height for cl_image");
- }
+ ARM_COMPUTE_RETURN_ON_ERROR(cl_gemm::validate_image2d_support_on_rhs(*input1, rhs_info));
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 8b4f930ea4..7d76ffd86c 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
@@ -65,23 +66,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
&& (!gemm_info.broadcast_bias),
"Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
-
- if(rhs_info.export_to_cl_image)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::F32, "Export to cl_image only supported with F32 data type");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment");
-
- // Check the width and height of the output tensor.
- // Since we cannot create a 3d image from a buffer, the third dimension is collapsed with the second dimension
- size_t max_image_w = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
- size_t max_image_h = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->tensor_shape()[0] > max_image_w * 4, "Not supported width for cl_image");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->tensor_shape()[1] * input1->tensor_shape()[2] > max_image_h, "Not supported height for cl_image");
- }
+ ARM_COMPUTE_RETURN_ON_ERROR(cl_gemm::validate_image2d_support_on_rhs(*input1, rhs_info));
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
diff --git a/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp b/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp
index c57066ae03..c1993b72b9 100644
--- a/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.cpp
@@ -58,21 +58,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
if(rhs_info.export_to_cl_image)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.n0 == 2) || (rhs_info.n0 == 3), "Export to cl_image only supported with n0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 == 2) || (rhs_info.k0 == 3), "Export to cl_image only supported with k0 = 4, 8 or 16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() != DataType::F32, "Export to cl_image only supported with F32 data type");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device()), "The extension cl_khr_image2d_from_buffer is not supported on the target platform");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0, "Impossible to retrieve the cl_image pitch alignment");
-
- TensorShape output_shape = compute_rhs_reshaped_shape(*input, rhs_info);
-
- // Check the width and height of the output tensor.
- // Since we cannot create a 3d image from a buffer, the third dimension is collapsed with the second dimension
- size_t max_image_w = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
- size_t max_image_h = CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(output_shape[0] > max_image_w * 4, "Not supported width for cl_image");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(output_shape[1] * output_shape[2] > max_image_h, "Not supported height for cl_image");
+ const TensorInfo tensor_reshaped_info(compute_rhs_reshaped_shape(*input, rhs_info), 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ON_ERROR(cl_gemm::validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info));
}
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);