aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp35
1 files changed, 29 insertions, 6 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index a03ec108c6..617d66cf24 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -42,8 +42,9 @@ using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(),
- _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _original_b(nullptr), _a_offset(0), _b_offset(0),
- _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false)
+ _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _original_b(nullptr),
+ _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false),
+ _fuse_output_stage(false), _run_activation(false)
{
}
@@ -87,12 +88,12 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
{
if(a->info()->data_type() == DataType::QASYMM8 && gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- _asm_glue.configure(a, b, c, output, 1.f, 0.f, gemm_info);
+ _asm_glue.configure(a, b, c, output, gemm_info);
_fused_assembly_path = _asm_glue.is_configured();
}
else
{
- _asm_glue.configure(a, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info);
+ _asm_glue.configure(a, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, gemm_info);
}
_assembly_path = _asm_glue.is_configured();
break;
@@ -192,6 +193,14 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
}
}
+ // Configure activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ _run_activation = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation)));
+ if(_run_activation)
+ {
+ _activation_func.configure(output, nullptr, activation);
+ }
+
// Allocate tensors
if(!_assembly_path && !_run_vector_matrix_multiplication)
{
@@ -253,12 +262,12 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
bool run_optimised_requantized = false;
if(is_data_type_quantized_asymmetric(a->data_type()))
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, c, output, 1.f, 0.f, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, c, output, gemm_info));
run_optimised_requantized = run_optimised;
}
else
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, gemm_info));
}
if(run_optimised)
@@ -361,6 +370,14 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
a_offset, b_offset));
}
}
+
+ // Validate activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ if(activation.enabled())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, activation));
+ }
+
return Status{};
}
@@ -415,6 +432,12 @@ void NEGEMMLowpMatrixMultiplyCore::run()
NEScheduler::get().schedule(&_offset_contribution_kernel, Window::DimY);
}
}
+
+ // Run fused activation
+ if(_run_activation)
+ {
+ _activation_func.run();
+ }
}
void NEGEMMLowpMatrixMultiplyCore::prepare()