diff options
Diffstat (limited to 'src/cpu/operators/CpuMatMul.cpp')
-rw-r--r-- | src/cpu/operators/CpuMatMul.cpp | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp index 87cb6c6b54..515b511044 100644 --- a/src/cpu/operators/CpuMatMul.cpp +++ b/src/cpu/operators/CpuMatMul.cpp @@ -25,9 +25,9 @@ #include "src/cpu/operators/CpuMatMul.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/core/experimental/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NEMatMul.h" #include "src/common/utils/Log.h" @@ -45,7 +45,6 @@ namespace cpu { namespace { - Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act, GEMMLowpOutputStageInfo &gemmlowp_output_stage_info) { @@ -74,15 +73,14 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo return Status{}; } - -} +} // namespace CpuMatMul::CpuMatMul() : _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape() { } -Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED); @@ -100,7 +98,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const TensorInfo rhs_transposed{}; auto gemm_info = AsmGemmInfo(); - gemm_info.activation_info = info.fused_activation(); + gemm_info.activation_info = act_info; gemm_info.fast_mode = settings.fast_math(); // Validate and then permute a/b @@ -139,7 +137,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const return Status{}; } -void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings); @@ -189,7 +187,7 @@ void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, // ----------------------------------------------------- // Use transposed tensors if the corresponding transpose flags are set // Fill AsmGemmInfo class object before configuration - _gemm_info.activation_info = info.fused_activation(); + _gemm_info.activation_info = act_info; _gemm_info.fast_mode = settings.fast_math(); _gemm_info.negated_offsets = false; |