aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp')
-rw-r--r--src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp45
1 files changed, 36 insertions, 9 deletions
diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
index c4a9ccd703..c6b51c698a 100644
--- a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp
@@ -42,9 +42,6 @@ CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(G
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
- ARM_COMPUTE_UNUSED(data_type);
-
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
unsigned int b);
@@ -52,31 +49,61 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMNativeKernelConfigurationB
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G71 =
{
{ DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 },
- { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
+ { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
+ { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 },
+ { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }
};
// Configurations for Mali-G76
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
{ DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 },
- { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
+ { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
+ { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 },
+ { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }
};
// Default configurations
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_default =
{
{ DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 },
- { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }
+ { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
+ { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
+ { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 },
+ { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }
};
switch(_target)
{
case GPUTarget::G71:
- return (this->*gemm_configs_G71[data_type])(m, n, k, b);
+ if(gemm_configs_G71.find(data_type) != gemm_configs_G71.end())
+ {
+ return (this->*gemm_configs_G71[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
case GPUTarget::G76:
- return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
+ {
+ return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
default:
- return (this->*gemm_configs_default[data_type])(m, n, k, b);
+ if(gemm_configs_default.find(data_type) != gemm_configs_default.end())
+ {
+ return (this->*gemm_configs_default[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
}
}