aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp28
1 files changed, 24 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 9050427b34..df8eaacf47 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -47,6 +47,21 @@
namespace arm_compute
{
+namespace
+{
+AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
+{
+ AsmGemmInfo asm_info;
+ asm_info.method = AsmConvMethod::Im2Col;
+ asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
+ asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
+ asm_info.activation_info = info.activation_info();
+ asm_info.output_stage = info.gemmlowp_output_stage();
+
+ return asm_info;
+}
+} // namespace
+
using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default;
@@ -120,6 +135,8 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
_mm_result_s32.allocator()->init(info_mm_result_s32);
}
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
switch(a->info()->data_type())
{
@@ -130,12 +147,12 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
{
if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- _asm_glue.configure(a_to_use, b, c, output, gemm_info);
+ _asm_glue.configure(a_to_use, b, c, output, asm_info);
_fused_assembly_path = _asm_glue.is_configured();
}
else
{
- _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, gemm_info);
+ _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, asm_info);
}
_assembly_path = _asm_glue.is_configured();
break;
@@ -346,17 +363,20 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
matrix_a_info = &signed_a;
}
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(info);
+
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
bool run_optimised_requantized = false;
if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
run_optimised_requantized = run_optimised;
}
else
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
}
if(run_optimised)