aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp43
1 files changed, 22 insertions, 21 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 558ff41a5c..c969c9f4f6 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -164,8 +164,8 @@ public:
{
if(!_gemm_kernel_asm)
return false;
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
- return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
+ return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
private:
@@ -428,7 +428,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
if(_gemm_kernel_asm->B_pretranspose_required())
{
// Fixed format kernels need no pretranspose.
- ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -492,8 +492,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
// The 4D tensor of dimension O'HWI' created for the
@@ -507,7 +507,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_gemm::interleave_by(wf);
+ const int interleave_by = arm_compute::interleave_by(wf);
ldb = (interleave_by * H * W * Ip);
}
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -603,7 +603,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -623,7 +623,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -665,7 +665,7 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
@@ -675,13 +675,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
-
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -689,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -702,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -722,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -730,6 +730,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we
ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
break;
}
+ expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
return Status{};
}
@@ -762,9 +763,9 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- arm_gemm::WeightFormat expected_weight_format;
- const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ arm_compute::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format