aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/CLGEMMHelpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/CLGEMMHelpers.cpp')
-rw-r--r--src/core/CL/gemm/CLGEMMHelpers.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/src/core/CL/gemm/CLGEMMHelpers.cpp b/src/core/CL/gemm/CLGEMMHelpers.cpp
index 877bf1e047..d60626b158 100644
--- a/src/core/CL/gemm/CLGEMMHelpers.cpp
+++ b/src/core/CL/gemm/CLGEMMHelpers.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <utility>
@@ -34,11 +35,13 @@ namespace arm_compute
{
namespace cl_gemm
{
+using namespace arm_compute::misc::shape_calculator;
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
{
- v0 = ((m / (m0 * v0)) == 0) ? 1 : v0;
- h0 = ((n / (n0 * h0)) == 0) ? 1 : h0;
+ v0 = std::max(std::min(static_cast<int>(m / m0), static_cast<int>(v0)), static_cast<int>(1));
+ h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave);
const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image);
@@ -46,6 +49,24 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned
return std::make_pair(lhs_info, rhs_info);
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> select_lhs_rhs_info(std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_img,
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf,
+ unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+{
+ const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, data_type);
+ const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, info_img.second);
+ const TensorInfo tensor_reshaped_info(shape, 1, data_type);
+
+ if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, info_img.second)))
+ {
+ return info_img;
+ }
+ else
+ {
+ return info_buf;
+ }
+}
+
void update_padding_for_cl_image(ITensorInfo *tensor)
{
constexpr unsigned int num_floats_per_pixel = 4;