aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-24 12:05:06 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-26 11:44:34 +0000
commit05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071 (patch)
tree8bd2f0f2b874ac7347777b40d563506348fff754 /src/runtime/CL/functions/CLGEMM.cpp
parent1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (diff)
downloadComputeLibrary-05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071.tar.gz
COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1
Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1983 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp56
1 files changed, 36 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e78395f1de..762b00177c 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign
{
GEMMType gemm_type = GEMMType::RESHAPED_V1;
- if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT,
+ GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72,
+ GPUTarget::G76, GPUTarget::G77))
{
- if((m > 1) && (n < 16))
+ if(data_type == DataType::F32)
{
- gemm_type = GEMMType::RESHAPED_V1;
- }
- else if((m == 1) && (data_type == DataType::F32))
- {
- gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ if((m > 1) && (n < 16))
+ {
+ gemm_type = GEMMType::RESHAPED_V1;
+ }
+ else if(m == 1)
+ {
+ gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ }
+ else
+ {
+ // COMPMID-852
+ if((k > 256) && (m > 4) && reshape_b_only_on_first_run)
+ {
+ constexpr float alpha = 3.2f;
+ constexpr float fact0 = 1.51f;
+ constexpr float fact1 = 1.66f;
+ constexpr float ops = 12.0f;
+ const float scale = k > 1024 ? 1.07f : 1.0f;
+ gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ }
+ else
+ {
+ gemm_type = GEMMType::NATIVE;
+ }
+ }
+
+ const auto workload = static_cast<float>((m * n) / 20.0f);
+
+ gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{
- // COMPMID-852
- if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+ if((m == 1) || (!reshape_b_only_on_first_run))
{
- constexpr float alpha = 3.2f;
- constexpr float fact0 = 1.51f;
- constexpr float fact1 = 1.66f;
- constexpr float ops = 12.0f;
- const float scale = k > 1024 ? 1.07f : 1.0f;
- gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ gemm_type = GEMMType::NATIVE;
}
else
{
- gemm_type = GEMMType::NATIVE;
+ gemm_type = GEMMType::RESHAPED_V2;
}
}
-
- const auto workload = static_cast<float>((m * n) / 20.0f);
-
- gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{