diff options
Diffstat (limited to 'arm_compute')
5 files changed, 13 insertions, 12 deletions
diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h index cbc525308f..19095a5589 100644 --- a/arm_compute/runtime/CL/CLTypes.h +++ b/arm_compute/runtime/CL/CLTypes.h @@ -53,6 +53,7 @@ struct CLGEMMKernelSelectionParams unsigned int m{ 0 }; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ unsigned int n{ 0 }; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ unsigned int k{ 0 }; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int b{ 0 }; /**< Batch size */ bool is_rhs_constant{ false }; /**< True if the content of the rhs matrix is constant */ DataType data_type{ DataType::UNKNOWN }; /**< Data type */ }; diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index 8e4d3906d1..6e9cf0e2ca 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -185,7 +185,7 @@ public: void prepare() override; private: - static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run); + static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run); void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h index 815c2c8cef..579bbe32ad 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h @@ -44,11 +44,11 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h index 4689f0c041..5547731821 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h @@ -44,9 +44,9 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h index 8712be7531..782ef7474d 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h @@ -44,9 +44,9 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute |