aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-08-28 13:52:12 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-08-28 16:28:34 +0000
commit026d04576d3058e34f8f7e23f7a11514a04952dc (patch)
treecc743de9348479a754b3eee0bb4f71444bb5ed1e /arm_compute/runtime
parent1c76c1ddcd1294ee8149bd74ecf6f62963408286 (diff)
downloadComputeLibrary-026d04576d3058e34f8f7e23f7a11514a04952dc.tar.gz
COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection
Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime')
-rw-r--r--arm_compute/runtime/CL/CLTypes.h1
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h2
-rw-r--r--arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h10
-rw-r--r--arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h6
-rw-r--r--arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h6
5 files changed, 13 insertions, 12 deletions
diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h
index cbc525308f..19095a5589 100644
--- a/arm_compute/runtime/CL/CLTypes.h
+++ b/arm_compute/runtime/CL/CLTypes.h
@@ -53,6 +53,7 @@ struct CLGEMMKernelSelectionParams
unsigned int m{ 0 }; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
unsigned int n{ 0 }; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
unsigned int k{ 0 }; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int b{ 0 }; /**< Batch size */
bool is_rhs_constant{ false }; /**< True if the content of the rhs matrix is constant */
DataType data_type{ DataType::UNKNOWN }; /**< Data type */
};
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 8e4d3906d1..6e9cf0e2ca 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -185,7 +185,7 @@ public:
void prepare() override;
private:
- static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run);
+ static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run);
void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
index 815c2c8cef..579bbe32ad 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
@@ -44,11 +44,11 @@ public:
CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
private:
- CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+ CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
};
} // namespace cl_gemm
} // namespace arm_compute
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
index 4689f0c041..5547731821 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
@@ -44,9 +44,9 @@ public:
CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
private:
- CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+ CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
};
} // namespace cl_gemm
} // namespace arm_compute
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
index 8712be7531..782ef7474d 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
@@ -44,9 +44,9 @@ public:
CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
private:
- CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
- CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+ CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
};
} // namespace cl_gemm
} // namespace arm_compute