aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-24 17:47:51 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-07-18 13:17:27 +0000
commit06be6f8d2a316a307fa623150f8adf8f9c3416c5 (patch)
tree0db15a25f2c306fab3d843236f878ec9479b7f57
parentff2719299ea76a95f20a35a7900875a8152e293a (diff)
downloadComputeLibrary-06be6f8d2a316a307fa623150f8adf8f9c3416c5.tar.gz
COMPMID-2096: Refactor the CLGEMMLowp function selection (heuristic)
Change-Id: I15a8b39e0354d3b6686ed4cc8c361782c0512037 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1410 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: VidhyaSudhan Loganathan <vidhyasudhan.loganathan@arm.com>
-rw-r--r--SConscript1
-rw-r--r--arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h59
-rw-r--r--arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h64
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h31
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h5
-rw-r--r--src/core/CL/CLKernelLibrary.cpp3
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl1040
-rw-r--r--src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp245
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp201
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp2
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp82
11 files changed, 517 insertions, 1216 deletions
diff --git a/SConscript b/SConscript
index a170a4a7c1..ed22f6eefe 100644
--- a/SConscript
+++ b/SConscript
@@ -188,6 +188,7 @@ if env['opencl']:
core_files += Glob('src/core/CL/*.cpp')
core_files += Glob('src/core/CL/kernels/*.cpp')
core_files += Glob('src/core/CL/gemm/*.cpp')
+ core_files += Glob('src/core/CL/gemm/native/*.cpp')
core_files += Glob('src/core/CL/gemm/reshaped/*.cpp')
core_files += Glob('src/core/CL/gemm/reshaped_only_rhs/*.cpp')
diff --git a/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h b/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h
new file mode 100644
index 0000000000..7d0e7c97d4
--- /dev/null
+++ b/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATION_H__
+#define __ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATION_H__
+
+#include "arm_compute/core/CL/ICLGEMMKernelConfiguration.h"
+#include "arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+namespace cl_gemm
+{
+/** CLGEMMNative factory class */
+class CLGEMMNativeKernelConfigurationFactory final
+{
+public:
+ /** Static method to construct CLGEMMNative kernel object accordingly with the GPU architecture
+ *
+ * @param[in] arch GPU target
+ *
+ * @return CLGEMMNative kernel configuration class
+ */
+ static std::unique_ptr<ICLGEMMKernelConfiguration> create(GPUTarget arch)
+ {
+ switch(get_arch_from_target(arch))
+ {
+ case GPUTarget::BIFROST:
+ return support::cpp14::make_unique<CLGEMMNativeKernelConfigurationBifrost>(arch);
+ default:
+ return nullptr;
+ }
+ }
+};
+} // namespace cl_gemm
+} // namespace arm_compute
+#endif /*__ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATION_H__ */
diff --git a/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h b/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h
new file mode 100644
index 0000000000..ea46818750
--- /dev/null
+++ b/arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATIONBIFROST_H__
+#define __ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATIONBIFROST_H__
+
+#include "arm_compute/core/CL/ICLGEMMKernelConfiguration.h"
+
+namespace arm_compute
+{
+namespace cl_gemm
+{
+/** Bifrost based OpenCL GEMMNative configuration */
+class CLGEMMNativeKernelConfigurationBifrost final : public ICLGEMMKernelConfiguration
+{
+public:
+ /** Constructor
+ *
+ * @param[in] arch GPU target
+ */
+ CLGEMMNativeKernelConfigurationBifrost(GPUTarget arch);
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLGEMMNativeKernelConfigurationBifrost(const CLGEMMNativeKernelConfigurationBifrost &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLGEMMNativeKernelConfigurationBifrost &operator=(const CLGEMMNativeKernelConfigurationBifrost &) = delete;
+ /** Default Move Constructor. */
+ CLGEMMNativeKernelConfigurationBifrost(CLGEMMNativeKernelConfigurationBifrost &&) = default;
+ /** Default move assignment operator */
+ CLGEMMNativeKernelConfigurationBifrost &operator=(CLGEMMNativeKernelConfigurationBifrost &&) = default;
+
+ // Inherited overridden method
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override;
+
+private:
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+};
+} // namespace cl_gemm
+} // namespace arm_compute
+#endif /*__ARM_COMPUTE_CLGEMMNATIVEKERNELCONFIGURATIONBIFROST_H__ */
diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h
index e576271780..409ed1bec2 100644
--- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,8 +32,9 @@ class ICLTensor;
/** OpenCL kernel to multiply matrices
*
- * @note @ref CLGEMMLowpMatrixMultiplyKernel low precision matrix product kernel
- * This kernel performs the following computation:
+ * @note This kernel should be used ONLY for Midgard architectures
+ *
+ * This kernel performs the following computation:
*
* -# Convert a values from int8 to int32
* -# Convert b values from int8 to int32
@@ -55,24 +56,24 @@ public:
CLGEMMLowpMatrixMultiplyKernel &operator=(CLGEMMLowpMatrixMultiplyKernel &&) = default;
/** Initialise the kernel's input and output.
*
- * @param[in] input0 Input tensor containing the interleaved Matrix A. Data type supported: QASYMM8
- * @param[in] input1 Input tensor containing the transposed1xW Matrix B. Data type supported: same as @p input0
- * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: S32
- * @param[in] is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel
- * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
+ * @note This kernel should be used ONLY for Midgard architectures
+ *
+ * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8
+ * @param[in] input1 Input tensor containing the RHS matrix. Data type supported: same as @p input0
+ * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: S32
+ * @param[in] gemm_info (Optional) GEMM information used to retrieve the original dimensions of the input matrices
*/
- void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
+ void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMReshapeInfo &gemm_info = GEMMReshapeInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyKernel
*
- * @param[in] input0 Input tensor info containing the interleaved Matrix A. Data type supported: QASYMM8
- * @param[in] input1 Input tensor info containing the transposed Matrix B. Data type supported: same as @p input0
- * @param[in] output Output tensor info to store the result of matrix multiplication. Data type supported: S32
- * @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel
- * @param[in] reshape_info GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
+ * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8
+ * @param[in] input1 Input tensor containing the RHS matrix. Data type supported: same as @p input0
+ * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: S32
+ * @param[in] gemm_info (Optional) GEMM information used to retrieve the original dimensions of the input matrices
*
* @return a status
*/
- static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info);
+ static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMReshapeInfo &gemm_info = GEMMReshapeInfo());
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h
index a07101c020..541985b50c 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h
@@ -25,6 +25,7 @@
#define __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H__
#include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h"
@@ -100,7 +101,8 @@ public:
private:
CLMemoryGroup _memory_group;
- CLGEMMLowpMatrixMultiplyKernel _mm_kernel;
+ CLGEMMLowpMatrixMultiplyKernel _mm_midgard_kernel;
+ CLGEMMLowpMatrixMultiplyNativeKernel _mm_native_kernel;
CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
CLGEMMReshapeRHSMatrixKernel _mtx_b_reshape_kernel;
CLGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel;
@@ -115,6 +117,7 @@ private:
int32_t _a_offset;
int32_t _b_offset;
bool _is_gemm_reshaped;
+ bool _is_midgard;
bool _reshape_b_only_on_first_run;
bool _is_prepared;
bool _fuse_output_stage;
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 8b64b1f20e..16bcd50d06 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -330,10 +330,7 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "gemmlowp_matrix_a_reduction", "gemmlowp.cl" },
{ "gemmlowp_matrix_a_reduction_dot8", "gemmlowp.cl" },
{ "gemmlowp_matrix_b_reduction", "gemmlowp.cl" },
- { "gemmlowp_mm_bifrost", "gemmlowp.cl" },
- { "gemmlowp_mm_bifrost_dot8", "gemmlowp.cl" },
{ "gemmlowp_mm_midgard", "gemmlowp.cl" },
- { "gemmlowp_mm_interleaved_transposed_midgard", "gemmlowp.cl" },
{ "gemmlowp_mm_native", "gemmlowp.cl" },
{ "gemmlowp_mm_reshaped_lhs_nt_rhs_t", "gemmlowp.cl" },
{ "gemmlowp_mm_reshaped_only_rhs_t", "gemmlowp.cl" },
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index d6494fe380..fc90dbd16c 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -193,168 +193,6 @@
(n0, k0, a, b, c); \
})
-#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel before running the matrix multiplication
- *
- * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024
- * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D
- )
-{
- const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
- const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
- const int z = get_global_id(2);
-
- // Offset
- const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
- const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
-
- // src_addr_a = address of matrix A
- // src_addr_b = address of matrix B
- __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr_b += z * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- // Compute end row address for matrix B
- __global uchar *src_end_addr_b = src_addr_b + COLS_B;
-
- src_addr_a += offset_row_a;
- src_addr_b += offset_row_b;
-
- // Reset accumulators
- int4 c00 = 0;
- int4 c10 = 0;
- int4 c20 = 0;
- int4 c30 = 0;
-
- for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP)
- {
- // Load values from matrix A (interleaved) and matrix B (transposed)
- int4 a0 = convert_int4(vload4(0, src_addr_a));
- int4 b0 = convert_int4(vload4(0, src_addr_b));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
-
- a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
- b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
- }
-
- for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
- {
- // Load values from matrix A (interleaved) and matrix B (transposed)
- int4 a0 = convert_int4(vload4(0, src_addr_a));
- int4 b0 = convert_int4(vload4(0, src_addr_b));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
- uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store 4x4 block
- vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
- vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
- vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
- vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst.ptr += z * dst_stride_z;
-
- // Store 4x4 block
- vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
- vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
- vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
- vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
-
#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
@@ -631,884 +469,6 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
-
-/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
- *
- * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
- * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_INPUT_AS_3D)
- ,
- uint src_cross_plane_pad
-#endif // REINTERPRET_INPUT_AS_3D
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint dst_cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D
- )
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx;
-
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zin = min(DEPTH_GEMM3D - 1, zin);
-
- // Add offset due to the cross plane paddings
- zin *= (src_cross_plane_pad * src0_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply src0_stride_z by DEPTH_GEMM3D
- src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- int end_row_vec_a = src_addr.s0 + COLS_A;
-
- uint acc00 = 0;
- uint acc01 = 0;
- uint acc02 = 0;
- uint acc03 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uint acc10 = 0;
- uint acc11 = 0;
- uint acc12 = 0;
- uint acc13 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uint acc20 = 0;
- uint acc21 = 0;
- uint acc22 = 0;
- uint acc23 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uint acc30 = 0;
- uint acc31 = 0;
- uint acc32 = 0;
- uint acc33 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uint acc40 = 0;
- uint acc41 = 0;
- uint acc42 = 0;
- uint acc43 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-
- for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
- {
- // Load values from matrix A
- uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- // Load values from matrix B
- uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
-
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
-
- acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
-
- acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
-
- acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
-
- acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
-
- acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- }
-
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
- {
- // Load values from matrix A
- uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- // Load values from matrix B
- uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
-
- // Accumulate
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
-
- acc00 += ((uint)tmp0);
- acc01 += ((uint)tmp1);
- acc02 += ((uint)tmp2);
- acc03 += ((uint)tmp3);
- }
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
-
- acc10 += ((uint)tmp0);
- acc11 += ((uint)tmp1);
- acc12 += ((uint)tmp2);
- acc13 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
-
- acc20 += ((uint)tmp0);
- acc21 += ((uint)tmp1);
- acc22 += ((uint)tmp2);
- acc23 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
-
- acc30 += ((uint)tmp0);
- acc31 += ((uint)tmp1);
- acc32 += ((uint)tmp2);
- acc33 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
-
- acc40 += ((uint)tmp0);
- acc41 += ((uint)tmp1);
- acc42 += ((uint)tmp2);
- acc43 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- }
-
- const int z = get_global_id(2);
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (dst_cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst.ptr += z * dst_stride_z;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-
-#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
- *
- * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
- * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_INPUT_AS_3D)
- ,
- uint src_cross_plane_pad
-#endif // REINTERPRET_INPUT_AS_3D
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint dst_cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D)
- )
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx;
-
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zin = min(DEPTH_GEMM3D - 1, zin);
-
- // Add offset due to the cross plane paddings
- zin *= (src_cross_plane_pad * src0_stride_y);
-
- zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply src0_stride_z by DEPTH_GEMM3D
- src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- uint acc00 = 0;
- uint acc01 = 0;
- uint acc02 = 0;
- uint acc03 = 0;
- uint acc04 = 0;
- uint acc05 = 0;
- uint acc06 = 0;
- uint acc07 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uint acc10 = 0;
- uint acc11 = 0;
- uint acc12 = 0;
- uint acc13 = 0;
- uint acc14 = 0;
- uint acc15 = 0;
- uint acc16 = 0;
- uint acc17 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uint acc20 = 0;
- uint acc21 = 0;
- uint acc22 = 0;
- uint acc23 = 0;
- uint acc24 = 0;
- uint acc25 = 0;
- uint acc26 = 0;
- uint acc27 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uint acc30 = 0;
- uint acc31 = 0;
- uint acc32 = 0;
- uint acc33 = 0;
- uint acc34 = 0;
- uint acc35 = 0;
- uint acc36 = 0;
- uint acc37 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- // A and B src indices get incremented at the same time.
- int i = 0;
- for(; i <= ((int)COLS_A - 8); i += 8)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 8;
- }
-
- for(; i < (int)COLS_A; ++i)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- // Load values from matrix B
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
- src_addr.s1 += src1_stride_y;
-
- acc00 += (uint)a0 * b0.s0;
- acc01 += (uint)a0 * b0.s1;
- acc02 += (uint)a0 * b0.s2;
- acc03 += (uint)a0 * b0.s3;
- acc04 += (uint)a0 * b0.s4;
- acc05 += (uint)a0 * b0.s5;
- acc06 += (uint)a0 * b0.s6;
- acc07 += (uint)a0 * b0.s7;
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc10 += (uint)a1 * b0.s0;
- acc11 += (uint)a1 * b0.s1;
- acc12 += (uint)a1 * b0.s2;
- acc13 += (uint)a1 * b0.s3;
- acc14 += (uint)a1 * b0.s4;
- acc15 += (uint)a1 * b0.s5;
- acc16 += (uint)a1 * b0.s6;
- acc17 += (uint)a1 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc20 += (uint)a2 * b0.s0;
- acc21 += (uint)a2 * b0.s1;
- acc22 += (uint)a2 * b0.s2;
- acc23 += (uint)a2 * b0.s3;
- acc24 += (uint)a2 * b0.s4;
- acc25 += (uint)a2 * b0.s5;
- acc26 += (uint)a2 * b0.s6;
- acc27 += (uint)a2 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 += (uint)a3 * b0.s0;
- acc31 += (uint)a3 * b0.s1;
- acc32 += (uint)a3 * b0.s2;
- acc33 += (uint)a3 * b0.s3;
- acc34 += (uint)a3 * b0.s4;
- acc35 += (uint)a3 * b0.s5;
- acc36 += (uint)a3 * b0.s6;
- acc37 += (uint)a3 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 1;
- }
-
- int z = get_global_id(2);
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = dst.ptr;
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (dst_cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst_addr += z * dst_stride_z;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N)
diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
new file mode 100644
index 0000000000..e6423175a5
--- /dev/null
+++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
@@ -0,0 +1,245 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h"
+#include "arm_compute/core/GPUTarget.h"
+
+#include <map>
+#include <utility>
+
+namespace arm_compute
+{
+namespace cl_gemm
+{
+CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(GPUTarget arch)
+ : ICLGEMMKernelConfiguration(arch)
+{
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+{
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
+ ARM_COMPUTE_UNUSED(data_type);
+
+ using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
+ unsigned int b);
+
+ // Configurations for Mali-G71
+ static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G71 =
+ {
+ { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 },
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }
+ };
+
+ // Configurations for Mali-G76
+ static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
+ {
+ { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 },
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }
+ };
+
+ // Default configurations
+ static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_default =
+ {
+ { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 },
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }
+ };
+
+ switch(_target)
+ {
+ case GPUTarget::G71:
+ return (this->*gemm_configs_G71[data_type])(m, n, k, b);
+ case GPUTarget::G76:
+ return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ default:
+ return (this->*gemm_configs_default[data_type])(m, n, k, b);
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ if(n < 2048)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
+ }
+ else if(n >= 2048 && n < 8192)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 4, 2, 1, 1, false, false, false, false);
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(dot8_supported(CLKernelLibrary::get().get_device()))
+ {
+ if(m == 1)
+ {
+ if(n < 2048)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
+ }
+ else if(n >= 2048 && n < 16384)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
+ }
+ }
+ else
+ {
+ if(m < 64)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
+ }
+ }
+ }
+ else
+ {
+ if(m == 1)
+ {
+ if(n < 8192)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false);
+ }
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ if(n > 4196)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ if(k < 2048)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1, false, false, false, false);
+ }
+ else if(k >= 2048 && k < 16384)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1, false, false, false, false);
+ }
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1, false, false, false, false);
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ if(n < 2048)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
+ }
+ else if(n >= 2048 && n < 16384)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
+ }
+ }
+ else
+ {
+ if(m < 64)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
+ }
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1, false, false, false, false);
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
+}
+} // namespace cl_gemm
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index 1a1a4b7c3d..cda7a83de7 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -55,63 +55,38 @@ namespace
{
using ElementsProcessed = Steps;
-Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMReshapeInfo &gemm_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && gemm_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
- if(!is_interleaved_transposed)
+ const int m = gemm_info.m();
+ const int n = gemm_info.n();
+ const int k = gemm_info.k();
+
+ ARM_COMPUTE_UNUSED(m);
+ ARM_COMPUTE_UNUSED(n);
+ ARM_COMPUTE_UNUSED(k);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != static_cast<unsigned int>(k));
+ ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != static_cast<unsigned int>(n));
+ ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(1) != static_cast<unsigned int>(k));
+ if(gemm_info.reinterpret_input_as_3d())
{
- ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) * input0->dimension(2) != static_cast<unsigned int>(m));
}
else
{
- GEMMRHSMatrixInfo rhs_info;
- GEMMLHSMatrixInfo lhs_info;
- const int m = reshape_info.m();
- const int n = reshape_info.n();
- const int k = reshape_info.k();
- const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
- const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
- const bool unroll_block = dot8_supported(CLKernelLibrary::get().get_device());
-
- rhs_info.n0 = 16 / input1->element_size();
- rhs_info.k0 = 1;
- rhs_info.h0 = mult_transpose1xW_width;
- rhs_info.interleave = false;
- rhs_info.transpose = false;
- lhs_info.m0 = 4;
- lhs_info.k0 = 4;
- lhs_info.v0 = mult_interleave4x4_height;
- lhs_info.interleave = true;
- lhs_info.transpose = !unroll_block;
-
- TensorShape tensor_shape0{ input0->tensor_shape() };
- tensor_shape0.set(0, k);
- tensor_shape0.set(1, m);
-
- TensorShape tensor_shape1{ input1->tensor_shape() };
- tensor_shape1.set(0, n);
- tensor_shape1.set(1, k);
-
- const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
- const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
-
- const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_lhs_reshaped_shape(tensor_info0, lhs_info));
- const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info));
-
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != static_cast<unsigned int>(m));
}
if(output->total_size() != 0)
{
- const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info));
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, false, gemm_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
}
@@ -119,14 +94,12 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, bool is_interleaved_transposed,
- const GEMMReshapeInfo &reshape_info, ElementsProcessed &num_elements_processed)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, const GEMMReshapeInfo &gemm_info, ElementsProcessed &num_elements_processed)
{
- const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
- bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
- bool reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0);
Window win{};
Window win_out{};
@@ -141,7 +114,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
}
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)).set_data_type(DataType::S32));
+ auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, false, gemm_info)).set_data_type(DataType::S32));
TensorInfo tmp_info(*output);
@@ -154,66 +127,32 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
tmp_info.set_tensor_shape(tmp_shape);
}
- // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
- if(is_interleaved_transposed)
- {
- // reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
- ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+ // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x
+ // Note: if the dot product instruction is available, the 8x2 tile has to be used
+ num_elems_processed_per_iteration_x = 4;
+ num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
- // Configure kernel window
- num_elems_processed_per_iteration_x = 4;
- num_elems_processed_per_iteration_y = 4;
+ // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reinterpret_input_as_3d ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
- // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
- // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const int m = reshape_info.m();
- const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+ // Configure window
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
+ AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
+ AccessWindowStatic output_access(output, 0, 0,
+ ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+ output->dimension(1) + bottom_pad);
- AccessWindowRectangle input0_access(input0, 0, 0, num_elems_processed_per_iteration_y, 1, 1.f, 0.25f);
- AccessWindowStatic input1_access(input1, 0, 0,
- ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
- ceil_to_multiple(input1->dimension(1), num_elems_processed_per_iteration_y));
- AccessWindowStatic output_access(output, 0, 0,
- ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
- output->dimension(1) + bottom_pad);
+ window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
- window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
- update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
-
- output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
- }
- else
- {
- // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x
- // Note: if the dot product instruction is available, the 8x2 tile has to be used
- num_elems_processed_per_iteration_x = is_dot8_supported ? 8 : 4;
- num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), is_dot8_supported ? 2 : 4);
-
- // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
- // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const int m = reinterpret_input_as_3d ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
- const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
-
- // Configure window
- win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-
- AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
- AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
- AccessWindowStatic output_access(output, 0, 0,
- ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
- output->dimension(1) + bottom_pad);
-
- window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
- update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
-
- Coordinates coord;
- coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
- }
+ Coordinates coord;
+ coord.set_num_dimensions(output->num_dimensions());
+ output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
@@ -231,17 +170,17 @@ CLGEMMLowpMatrixMultiplyKernel::CLGEMMLowpMatrixMultiplyKernel()
{
}
-void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMReshapeInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), gemm_info));
_input0 = input0;
_input1 = input1;
_output = output;
- _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
- _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
+ _reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ _reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0);
// In case both input and output have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -257,16 +196,11 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
ElementsProcessed num_elements_processed{};
- // Get target architecture
- GPUTarget arch_target = get_arch_from_target(get_target());
-
// Configure kernel window
- auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, num_elements_processed);
+ auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), gemm_info, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
- const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
-
// Create build options
std::string kernel_name(" ");
CLBuildOptions build_opts;
@@ -275,38 +209,18 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
+ build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
+ build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
+ build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
- if(is_interleaved_transposed)
- {
- const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
- const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
-
- // Note: The computation tile has the x dimension equal to 4 which is less than the transpose_width (16)
- // In order to access correctly the elements from the transposed matrix B, we need to pass
- // the correct step which is calculated as (16 * mult_transpose1xW_width) / 4)
-
- build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0)));
- build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width));
- build_opts.add_option("-DTRANSPOSE1XW_WIDTH_STEP=" + support::cpp11::to_string(4 * mult_transpose1xW_width));
- build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
-
- kernel_name = "gemmlowp_mm_interleaved_transposed_" + string_from_target(arch_target) + (is_dot8_supported ? "_dot8" : "");
- }
- else
- {
- build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
- build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
- build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
-
- kernel_name = "gemmlowp_mm_" + string_from_target(arch_target) + (is_dot8_supported ? "_dot8" : "");
- }
+ kernel_name = "gemmlowp_mm_midgard";
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Set config_id for enabling LWS tuning
- _config_id = "gemmlowp_";
- _config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id = kernel_name;
+ _config_id += "_";
_config_id += (_reinterpret_input_as_3d ? "3di_" : "");
_config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
@@ -314,19 +228,16 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
_config_id += support::cpp11::to_string(output->info()->dimension(1));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
- _config_id += "_";
- _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
}
-Status CLGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+Status CLGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMReshapeInfo &gemm_info)
{
ElementsProcessed num_elements_processed{};
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, gemm_info));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
output->clone().get(),
- is_interleaved_transposed,
- reshape_info,
+ gemm_info,
num_elements_processed)
.first);
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
index fa2c544899..4bcfa82ca7 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
@@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 != rhs_info.k0);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(((lhs_info.k0 & (lhs_info.k0 - 1)) && lhs_info.k0 != 3), "Only 2,3,4,8,16 are supported for k0");
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 > 16);
- ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 2 || lhs_info.m0 > 8);
+ ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0");
const int m = gemm_info.m();
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 875e3a2a00..0286cb3d6d 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h"
#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
@@ -48,7 +49,8 @@ inline bool is_gemm_reshaped(bool reshape_b_only_on_first_run, GPUTarget gpu_tar
CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)),
- _mm_kernel(),
+ _mm_midgard_kernel(),
+ _mm_native_kernel(),
_mm_reshaped_only_rhs_kernel(),
_mtx_b_reshape_kernel(),
_mtx_a_reduction_kernel(),
@@ -63,6 +65,7 @@ CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemo
_a_offset(0),
_b_offset(0),
_is_gemm_reshaped(true),
+ _is_midgard(false),
_reshape_b_only_on_first_run(false),
_is_prepared(false),
_fuse_output_stage(false)
@@ -84,7 +87,9 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
const GPUTarget gpu_target = CLScheduler::get().target();
// Set the target for the kernels
- _mm_kernel.set_target(gpu_target);
+ _mm_midgard_kernel.set_target(gpu_target);
+ _mm_native_kernel.set_target(gpu_target);
+ _mm_reshaped_only_rhs_kernel.set_target(gpu_target);
const ICLTensor *matrix_a = a;
const ICLTensor *matrix_b = b;
@@ -103,6 +108,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
// Check if we need to reshape the matrix A and matrix B
_is_gemm_reshaped = is_gemm_reshaped(_reshape_b_only_on_first_run, gpu_target);
+ _is_midgard = gpu_target == GPUTarget::MIDGARD;
if(_is_gemm_reshaped)
{
@@ -159,8 +165,19 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
}
else
{
- // Configure matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ if(_is_midgard)
+ {
+ // Configure matrix multiply kernel
+ _mm_midgard_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ }
+ else
+ {
+ // Pick up the GEMM configuration
+ std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+ // Configure matrix multiply kernel
+ _mm_native_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ }
}
// Configure offset contribution kernel
@@ -178,8 +195,19 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
}
else
{
- // Configure matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, output, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ if(_is_midgard)
+ {
+ // Configure matrix multiply kernel
+ _mm_midgard_kernel.configure(matrix_a, matrix_b, output, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ }
+ else
+ {
+ // Pick up the GEMM configuration
+ std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+ // Configure matrix multiply kernel
+ _mm_native_kernel.configure(matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+ }
}
// Configure offset contribution kernel
@@ -232,6 +260,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
const unsigned int k = a->dimension(0);
const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const bool is_midgard = gpu_target == GPUTarget::MIDGARD;
bool reshape_matrix_b = is_gemm_reshaped(gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
@@ -287,9 +316,21 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
// Output tensor auto inizialitation if not yet initialized
auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
- // Validate matrix multiply
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, false, reshape_info));
+ if(is_midgard)
+ {
+ // Validate matrix multiply
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_info));
+ }
+ else
+ {
+ // Pick up the GEMM configuration
+ std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+ // Validate matrix multiply
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
+ }
}
+
// Validate offset contribution kernel
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
a_offset == 0 ? nullptr : &info_vector_sum_col,
@@ -308,9 +349,21 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
}
else
{
- // Validate matrix multiply
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, false, reshape_info));
+ if(is_midgard)
+ {
+ // Validate matrix multiply
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, reshape_info));
+ }
+ else
+ {
+ // Pick up the GEMM configuration
+ std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+ // Validate matrix multiply
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));
+ }
}
+
if(output->total_size() != 0)
{
// Validate offset contribution kernel
@@ -353,7 +406,14 @@ void CLGEMMLowpMatrixMultiplyCore::run()
}
else
{
- CLScheduler::get().enqueue(_mm_kernel, false);
+ if(_is_midgard)
+ {
+ CLScheduler::get().enqueue(_mm_midgard_kernel, false);
+ }
+ else
+ {
+ CLScheduler::get().enqueue(_mm_native_kernel, false);
+ }
}
// Run matrix A reduction kernel only if _b_offset is not equal to 0