aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-09-15 14:17:41 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-09-18 09:53:42 +0000
commit6f9313477f6a158210479996523c210452d4f07a (patch)
treea64f8b97f3e95b6d084955cf675fa5e6d19205a7
parent82c1a1fc63d6a49c0b4be39529412c7f7bc8ea64 (diff)
downloadComputeLibrary-6f9313477f6a158210479996523c210452d4f07a.tar.gz
COMPMID-3671: Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedKernel
Resolves: COMPMID-3671, COMPMID-3672 - Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedKernel - Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedOnlyRHSKernel - Change the interface of create_image2d_from_buffer - Extend test Change-Id: I27363be71fa515fbf71aa4be5ed0d6c730f38f34 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3992 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/CLUtils.cpp18
-rw-r--r--src/core/CL/CLUtils.h4
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp2
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp273
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp112
6 files changed, 399 insertions, 12 deletions
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 5d0cdf7f46..67af240044 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -26,12 +26,26 @@
#include "src/core/CL/CLUtils.h"
-cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch)
+cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch)
{
+ cl_channel_type cl_data_type;
+
+ switch(data_type)
+ {
+ case DataType::F32:
+ cl_data_type = CL_FLOAT;
+ break;
+ case DataType::F16:
+ cl_data_type = CL_HALF_FLOAT;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not support with OpenCL image2d");
+ }
+
cl_mem cl_image;
cl_int err = CL_SUCCESS;
- const cl_image_format format = { CL_RGBA, data_type };
+ const cl_image_format format = { CL_RGBA, cl_data_type };
cl_image_desc desc;
memset(&desc, 0, sizeof(desc));
diff --git a/src/core/CL/CLUtils.h b/src/core/CL/CLUtils.h
index 8f1c58bcba..b65d547756 100644
--- a/src/core/CL/CLUtils.h
+++ b/src/core/CL/CLUtils.h
@@ -44,12 +44,12 @@ class TensorShape;
* @param[in] ctx cl::Context object
* @param[in] buffer cl::Buffer object from which the OpenCL image2d object is created
* @param[in] shape2d 2D tensor shape
- * @param[in] data_type cl_channel_type to use. Only supported CL_FLOAT
+ * @param[in] data_type DataType to use. Only supported: F32,F16
* @param[in] image_row_pitch Image row pitch (a.k.a. stride Y) to be used in the image2d object
*
* @return cl::Image2D object
*/
-cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch);
+cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch);
} // arm_compute
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 8f20de1ea1..b0f0e8a81f 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -376,7 +376,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::run(const Window &window, cl::CommandQu
const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2));
const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1];
- input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch);
+ input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch);
}
do
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index cf77c70bfa..0ae30ed30e 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -378,7 +378,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co
const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2));
const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1];
- input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch);
+ input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch);
}
do
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index d7853f3ea7..98149ce149 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -139,13 +139,13 @@ const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
/** M0 values to test - Nightly */
-const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 });
+const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
/** N0 values to test - Nightly */
-const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
+const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
/** K0 values to test - Nightly */
-const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8 });
+const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
/** N0 values to test with export to OpenCL image object - Nightly */
const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
@@ -154,10 +154,10 @@ const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0",
const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
/** V0 values to test - Nightly */
-const auto v0_values_nightly = framework::dataset::make("V0", 1, 4);
+const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
/** H0 values to test - Nightly */
-const auto h0_values_nightly = framework::dataset::make("H0", 1, 4);
+const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
/** Interleave values to test with LHS matrix */
const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
@@ -886,6 +886,269 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
+
+TEST_SUITE(ExportToCLImage)
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect k0
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect n0
+
+ }),
+ framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
+
+ })),
+ framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U), 1, DataType::F16),
+
+ })),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
+
+ })),
+ framework::dataset::make("LHSMInfo",{
+ GEMMLHSMatrixInfo(4, 4, 1, false, true),
+ GEMMLHSMatrixInfo(4, 8, 1, false, true),
+ GEMMLHSMatrixInfo(4, 4, 1, false, true),
+ GEMMLHSMatrixInfo(4, 2, 1, false, false),
+ GEMMLHSMatrixInfo(4, 4, 1, false, false),
+
+ })),
+ framework::dataset::make("RHSMInfo",{
+ GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
+ GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
+ GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
+ GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
+ GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
+ })),
+ framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
+ 64 /**<N Number of RHS columns*/,
+ 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ GEMMLHSMatrixInfo(),
+ GEMMRHSMatrixInfo(),
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */),
+ GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
+ 64 /**<N Number of RHS columns*/,
+ 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ GEMMLHSMatrixInfo(),
+ GEMMRHSMatrixInfo(),
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */),
+ GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
+ 64 /**<N Number of RHS columns*/,
+ 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ GEMMLHSMatrixInfo(),
+ GEMMRHSMatrixInfo(),
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */),
+
+ GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
+ 64 /**<N Number of RHS columns*/,
+ 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ GEMMLHSMatrixInfo(),
+ GEMMRHSMatrixInfo(),
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */),
+ GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
+ 64 /**<N Number of RHS columns*/,
+ 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ GEMMLHSMatrixInfo(),
+ GEMMRHSMatrixInfo(),
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */)
+ })),
+ framework::dataset::make("Expected", { true,
+ true,
+ true,
+ false,
+ false})),
+ input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
+ &input1_info.clone()->set_is_resizable(true),
+ &input2_info.clone()->set_is_resizable(true),
+ &output_info.clone()->set_is_resizable(true),1.f,1.f,
+ lhs_info,
+ rhs_info,
+ gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ v0_values_precommit),
+ h0_values_precommit),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_precommit),
+ beta_values_precommit),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_export_to_cl_image_values_nightly),
+ k0_export_to_cl_image_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_nightly),
+ beta_values_nightly),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ v0_values_precommit),
+ h0_values_precommit),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_precommit),
+ beta_values_precommit),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_export_to_cl_image_values_nightly),
+ k0_export_to_cl_image_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_nightly),
+ beta_values_nightly),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+}
+TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END() // FP16
TEST_SUITE(MixedPrecision)
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index afb2807d01..d792afac1d 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -74,7 +74,7 @@ constexpr float abs_tolerance_f16(0.01f);
const auto a_values = framework::dataset::make("alpha", {-0.75f} );
/** Beta values to test */
-const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
+const auto beta_values = framework::dataset::make("beta", {-0.35f} );
/** M values to test */
const auto m_values = framework::dataset::make("M", 37);
@@ -692,6 +692,116 @@ FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixtur
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
+TEST_SUITE(ExportToCLImage)
+FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ h0_values),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ broadcast_bias_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+}
+
+FIXTURE_DATA_TEST_CASE(RunNightly, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ h0_values),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ broadcast_bias_values),
+ act_values))
+{
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
+}
+
+FIXTURE_DATA_TEST_CASE(RunPrecommit3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<half>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ h0_values),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ h0_values),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+TEST_SUITE_END() // ExportToCLImage
+
TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float