From 932491f44d51940d82514417a82e43cb11b06bd4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 21 Sep 2018 16:33:15 +0100 Subject: COMPMID-1519: Add support for 3D input/output in CLGEMMLowpOutputStage Change-Id: I637add70310d2da4d82b236a6352af9d33be17a1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/149706 Reviewed-by: Isabella Gottardi Reviewed-by: Michele DiGiorgio Tested-by: bsgcomp --- src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp') diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 1d6f343cb2..62e7ee7ce6 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -108,7 +108,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo // in order to know how the matrices have been reshaped bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); - const int m = a->info()->dimension(1); + const int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const int n = b->info()->dimension(0); const int k = a->info()->dimension(0); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); @@ -206,12 +206,12 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso int32_t a_offset = a->quantization_info().offset; int32_t b_offset = b->quantization_info().offset; - const int m = a->dimension(1); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); const int n = b->dimension(0); const int k = a->dimension(0); constexpr int mult_transpose1xW_width = 1; constexpr int mult_interleave4x4_height = 1; - bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target()); -- cgit v1.2.1