aboutsummaryrefslogtreecommitdiff
path: root/src/core/GLES_COMPUTE/kernels
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-04-26 10:24:30 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:35 +0000
commitf6f08dac6d57770c191d1bc77123f0ddd2363d3f (patch)
tree3409794e82c069398fb6eaf74f5fbce645adc2c9 /src/core/GLES_COMPUTE/kernels
parenta4244190b6c7dc7d30d6adc621ca9a8b84b677ee (diff)
downloadComputeLibrary-f6f08dac6d57770c191d1bc77123f0ddd2363d3f.tar.gz
COMPMID-1044: Optimizing GCGEMM - Support for not reshaped GEMM on GLES
Change-Id: I22fe80393ec70e4501a4f9f9cad14014029d035d Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129134 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/GLES_COMPUTE/kernels')
-rw-r--r--src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
index 2bd769cac4..d576c30f80 100644
--- a/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
@@ -52,12 +52,13 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
{
ARM_COMPUTE_UNUSED(reshape_info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
if(!is_interleaved_transposed)
{
- ARM_COMPUTE_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
if(output->total_size() != 0)
{
@@ -141,18 +142,17 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
}
else // The input tensors have not been reshaped
{
- // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor
+ // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor.
+ num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
switch(data_type)
{
case DataType::F16:
num_elems_processed_per_iteration_x = 4;
- num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
break;
case DataType::F32:
num_elems_processed_per_iteration_x = max_gc_vector_width / data_size_from_type(data_type);
- num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
break;
default:
@@ -207,7 +207,6 @@ void GCGEMMMatrixMultiplyKernel::configure(const IGCTensor *input0, const IGCTen
// Create build options
std::set<std::string> build_opts;
std::string kernel_name;
- Window win;
build_opts.emplace("#define LOCAL_SIZE_X " + support::cpp11::to_string(1));
build_opts.emplace("#define LOCAL_SIZE_Y " + support::cpp11::to_string(1));
@@ -248,15 +247,26 @@ void GCGEMMMatrixMultiplyKernel::configure(const IGCTensor *input0, const IGCTen
{
// Special case for 1xN, 2xN, 3xN and 4xN input0 tensor
+ GPUTarget arch_target = get_arch_from_target(gpu_target);
switch(input0->info()->data_type())
{
case DataType::F16:
build_opts.emplace("#define DATA_TYPE_FP16");
build_opts.emplace("#define MM_PROCESS_4X_OPTIMIZED");
+ build_opts.emplace("#define GEMM_MM_FLOATING_POINT");
break;
case DataType::F32:
build_opts.emplace("#define DATA_TYPE_FP32");
+
+ if(arch_target == GPUTarget::BIFROST && input0->info()->num_dimensions() != 1)
+ {
+ build_opts.emplace("#define GEMM_MM_FLOATING_POINT_BIFROST");
+ }
+ else
+ {
+ build_opts.emplace("#define GEMM_MM_FLOATING_POINT");
+ }
break;
default:
@@ -264,7 +274,6 @@ void GCGEMMMatrixMultiplyKernel::configure(const IGCTensor *input0, const IGCTen
break;
}
- build_opts.emplace("#define GEMM_MM_FLOATING_POINT");
build_opts.emplace("#define NUM_ELEMS_PROCESSED_PER_THREAD_X " + support::cpp11::to_string(num_elements_processed.x()));
build_opts.emplace("#define NUM_ELEMS_PROCESSED_PER_THREAD_Y " + support::cpp11::to_string(num_elements_processed.y()));