diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-04-19 12:05:08 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | bb36a8efc1092f66798e3b880c55ec488021bb02 (patch) | |
tree | 62e0265d84575bc10496c84f4908ed27529166ea /src/runtime/CL/functions/CLGEMM.cpp | |
parent | 4dcb583c052e14f08809cc9ee420e690264e7bbe (diff) | |
download | ComputeLibrary-bb36a8efc1092f66798e3b880c55ec488021bb02.tar.gz |
COMPMID-922 - CLGEMM FP16 optimizations - part2
This patch improves of ~30 % GEMM fp16 when the reshape is required
The results have been reported at the following confluence page:
https://confluence.arm.com/display/MLENG/GEMM+FP16+performance%3A+ACL+18.05
Change-Id: I8233095a7e9ab06f1f915782a25dd41653b49140
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128254
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index e735adba39..1ee51a0a48 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -32,6 +32,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/ITensorAllocator.h" @@ -47,7 +48,7 @@ inline bool is_interleaved_transposed(int m, int n, int k, DataType data_type, b if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) { // COMPMID-852 - if(k > 256 && m > 4 && data_type == DataType::F32 && reshape_b_only_on_first_run) + if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run) { const float scale = k < 1024 ? 2.0f : 2.5f; flag = (scale * n) > ((1.66f * n) + 38.4f); |