diff options
Diffstat (limited to 'tests/validation/CL/GEMMReshapeRHSMatrix.cpp')
-rw-r--r-- | tests/validation/CL/GEMMReshapeRHSMatrix.cpp | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp index aa6667666c..f8462058a6 100644 --- a/tests/validation/CL/GEMMReshapeRHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeRHSMatrix.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,11 +21,11 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" +#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" #include "tests/CL/CLAccessor.h" #include "tests/CL/Helper.h" #include "tests/PaddingCalculator.h" @@ -46,9 +46,6 @@ namespace { // *INDENT-OFF* // clang-format off -/** Data types */ -const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 }); - /** Batch size values to test */ const auto b_values = framework::dataset::make("batchsize", 1, 3); @@ -76,9 +73,10 @@ const auto i_values = framework::dataset::make("interleave", { true, false }); } // namespace using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::opencl::kernels; // Initialize the output tensor with zero and fill the border with zero -using CLGEMMReshapeRHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeRHSMatrixKernel, 16>; +using CLGEMMReshapeRHSMatrix = CLSynthetizeOperatorInitOutputWithZeroAndWithZeroConstantBorder<ClGemmReshapeRhsMatrixKernel, 16>; template <typename T> using CLGEMMReshapeRHSMatrixFixture = GEMMReshapeRHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, T>; @@ -120,23 +118,24 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( rhs_info.transpose = true; rhs_info.interleave = true; - bool has_error = bool(CLGEMMReshapeRHSMatrixKernel::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), rhs_info)); + bool has_error = bool(ClGemmReshapeRhsMatrixKernel::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), rhs_info)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } -DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine( +DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine(combine( framework::dataset::make("InputShape", { TensorShape(32U, 16U, 1U), TensorShape(32U, 16U, 2U) }), framework::dataset::make("N0",{ 4 })), - framework::dataset::make("K0",{ 2, 4, 8 })), + framework::dataset::make("K0",{ 4, 8, 16 })), framework::dataset::make("H0",{ 1, 2, 4 })), - input_shape, n0, k0, h0) + framework::dataset::make("DataType",{ DataType::F32, DataType::F16 })), + input_shape, n0, k0, h0, data_type) { CLTensor input; CLTensor output; - input.info()->init(input_shape, 1, DataType::F32); + input.info()->init(input_shape, 1, data_type); unsigned int padding = 0; @@ -160,9 +159,9 @@ DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(com padding = round_up_width - output_shape[0]; } - CLGEMMReshapeRHSMatrixKernel kernel; + ClGemmReshapeRhsMatrixKernel kernel; - kernel.configure(&input, &output, rhs_info); + kernel.configure(CLKernelLibrary::get().get_compile_context(), input.info(), output.info(), rhs_info); ARM_COMPUTE_EXPECT((output.info()->padding().right == padding), framework::LogLevel::ERRORS); } |