From 1c9efebf4344e8db97e6d9282b2bf48b52090b58 Mon Sep 17 00:00:00 2001 From: giuros01 Date: Fri, 11 Jan 2019 14:04:43 +0000 Subject: Issue COMPMID-1835: Remove CLGEMMInterleave4x4Kernel and replace with CLGEMMReshapeLHSMatrixKernel Change-Id: Id6a1bd78f9b1698b64a004e4adebc41002b15745 Reviewed-on: https://review.mlplatform.org/496 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice --- src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp') diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 69455cf419..89fe7a4650 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,7 +40,8 @@ #include #include -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; namespace @@ -67,6 +68,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i else { GEMMRHSMatrixInfo rhs_info; + GEMMLHSMatrixInfo lhs_info; const int m = reshape_info.m(); const int n = reshape_info.n(); const int k = reshape_info.k(); @@ -77,6 +79,11 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i rhs_info.h0 = mult_transpose1xW_width; rhs_info.interleave = false; rhs_info.transpose = false; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = mult_interleave4x4_height; + lhs_info.interleave = true; + lhs_info.transpose = true; TensorShape tensor_shape0{ input0->tensor_shape() }; tensor_shape0.set(0, k); @@ -89,7 +96,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); - const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); + const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_lhs_reshaped_shape(tensor_info0, lhs_info)); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); @@ -439,3 +446,4 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que } while(window.slide_window_slice_3D(slice)); } +} // namespace arm_compute -- cgit v1.2.1